1use futures::{Stream, TryFuture, TryStream, stream::FusedStream};
30use pin_project_lite::pin_project;
31use std::{
32 future::Future,
33 pin::Pin,
34 task::{Context, Poll, ready},
35};
36use tonic::Status;
37
38use crate::streams::{FakeEmptyStream, IntoInner};
39
40pin_project! {
41 struct StreamEstablish<F> {
43 #[pin] inner: F,
44 }
45}
46
47impl<F> StreamEstablish<F> {
48 fn new(inner: F) -> Self {
49 Self { inner }
50 }
51}
52
53impl<F> Future for StreamEstablish<F>
54where
55 F: TryFuture<Error = Status>,
56{
57 type Output = Result<F::Ok, Status>;
58
59 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
60 use Poll::*;
61 let this = self.as_mut().project();
62 let response = ready!(this.inner.try_poll(cx));
63 let response = response.inspect_err(|e| {
64 tracing::error!("Error during grpc-web subscription establishment {e}");
65 })?;
66 Ready(Ok(response))
67 }
68}
69
70pin_project! {
71 #[project = ProjectStream]
73 enum StreamState< F, S> {
74 NotStarted {
75 #[pin] future: StreamEstablish<F>,
76 },
77 Started {
78 #[pin] stream: S,
79 },
80 Terminated
81 }
82}
83
84pin_project! {
85 pub struct NonBlockingWebStream<F, S> {
86 #[pin] state: StreamState<F, S>,
87 }
88}
89
90impl<F, S> NonBlockingWebStream<F, S>
91where
92 F: TryFuture<Error = Status>,
93{
94 pub fn new(request: F) -> Self {
95 Self {
96 state: StreamState::NotStarted {
97 future: StreamEstablish::new(request),
98 },
99 }
100 }
101
102 pub fn started(stream: S) -> Self {
104 Self {
105 state: StreamState::Started { stream },
106 }
107 }
108}
109
110impl<F> NonBlockingWebStream<F, FakeEmptyStream<Status>> {
111 pub fn empty() -> Self {
112 Self {
113 state: StreamState::Started {
114 stream: FakeEmptyStream::<Status>::new(),
115 },
116 }
117 }
118}
119
120impl<F, S> Stream for NonBlockingWebStream<F, S>
121where
122 S: TryStream<Error = Status>,
123 F: TryFuture<Error = Status>,
124 F::Ok: IntoInner<Out = S>,
125{
126 type Item = Result<S::Ok, Status>;
127
128 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129 use ProjectStream::*;
130 let mut this = self.as_mut().project();
131 match this.state.as_mut().project() {
132 NotStarted { future } => {
133 match ready!(future.poll(cx)) {
134 Ok(stream) => {
135 this.state.set(StreamState::Started {
136 stream: stream.into_inner(),
137 });
138 }
139 Err(e) => {
140 this.state.set(StreamState::Terminated);
141 return Poll::Ready(Some(Err(e)));
142 }
143 }
144 tracing::trace!("stream ready, polling for the first time...");
145 cx.waker().wake_by_ref();
146 Poll::Pending
147 }
148 Started { mut stream } => {
149 let next = stream.as_mut().try_poll_next(cx);
150 if let Poll::Ready(None) = next {
151 this.state.set(StreamState::Terminated);
152 }
153 next
154 }
155 Terminated => Poll::Ready(None),
156 }
157 }
158}
159
160impl<F, S> FusedStream for NonBlockingWebStream<F, S>
161where
162 F: TryFuture<Error = Status>,
163 S: TryStream<Error = Status> + FusedStream,
164 F::Ok: IntoInner<Out = S>,
165{
166 fn is_terminated(&self) -> bool {
167 match &self.state {
168 StreamState::Started { stream } => stream.is_terminated(),
169 StreamState::Terminated => true,
170 _ => false,
171 }
172 }
173}
174
175impl<F, S> std::fmt::Debug for NonBlockingWebStream<F, S> {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
177 match self.state {
178 StreamState::NotStarted { .. } => write!(f, "not started"),
179 StreamState::Started { .. } => write!(f, "started"),
180 StreamState::Terminated => write!(f, "terminated"),
181 }
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use futures::{Stream, stream};
188 use futures_test::future::FutureTestExt;
189 use prost::bytes::Bytes;
190 use tonic::{Response, Streaming};
191
192 use super::*;
193
194 struct TestStream;
195 impl Stream for TestStream {
196 type Item = Result<Response<Bytes>, Status>;
197
198 fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
199 unreachable!()
200 }
201 }
202
203 impl FusedStream for TestStream {
204 fn is_terminated(&self) -> bool {
205 unreachable!()
206 }
207 }
208
209 impl<T> From<Streaming<T>> for TestStream {
210 fn from(_: Streaming<T>) -> Self {
211 unreachable!()
212 }
213 }
214
215 struct MockFut;
216 impl Future for MockFut {
217 type Output = Result<TestStream, Status>;
218
219 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
220 unimplemented!()
221 }
222 }
223
224 impl IntoInner for MockFut {
225 type Out = TestStream;
226
227 fn into_inner(self) -> Self::Out {
228 todo!()
229 }
230 }
231
232 #[xmtp_common::test]
233 fn handles_err_on_establish() {
234 let stream: NonBlockingWebStream<_, TestStream> =
235 NonBlockingWebStream::new(futures::future::ready({
236 Err::<MockFut, _>(Status::internal("test error"))
239 }));
240 futures::pin_mut!(stream);
241
242 assert!(matches!(stream.state, StreamState::NotStarted { .. }));
243 let cx = futures::task::noop_waker();
244 let mut cx = std::task::Context::from_waker(&cx);
245 assert!(matches!(
246 stream.as_mut().poll_next(&mut cx),
247 Poll::Ready(Some(Err(_)))
248 ));
249
250 assert!(FusedStream::is_terminated(&stream));
251 assert!(matches!(
252 stream.as_mut().poll_next(&mut cx),
253 Poll::Ready(None)
254 ));
255 }
256
257 #[xmtp_common::test]
258 fn happy_path_future() {
259 let fut = futures::future::ready(Ok(()));
260 let fut = fut.pending_once();
261 let fut = StreamEstablish::new(fut);
262 futures::pin_mut!(fut);
263 let mut context = futures_test::task::noop_context();
264 assert_eq!(
265 Poll::Pending,
266 fut.as_mut().poll(&mut context).map(Result::unwrap)
267 );
268 assert_eq!(Poll::Ready(()), fut.poll(&mut context).map(Result::unwrap));
269 }
270
271 struct FakeFuture<T>(T);
272
273 impl<T> FakeFuture<T> {
274 fn inner(self: Pin<&mut Self>) -> Pin<&mut T> {
275 unsafe { self.map_unchecked_mut(|s| &mut s.0) }
277 }
278 }
279
280 impl<T> Future for FakeFuture<T>
281 where
282 T: TryFuture<Error = Status>,
283 {
284 type Output = Result<T::Ok, Status>;
285
286 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
287 self.inner().try_poll(cx)
288 }
289 }
290
291 struct FakeStream<T>(T);
292
293 impl<T> FakeStream<T> {
294 fn inner(self: Pin<&mut Self>) -> Pin<&mut T> {
295 unsafe { self.map_unchecked_mut(|s| &mut s.0) }
297 }
298 }
299
300 impl<T> Stream for FakeStream<T>
301 where
302 T: TryStream<Error = Status>,
303 {
304 type Item = Result<T::Ok, Status>;
305
306 fn poll_next(
307 self: Pin<&mut Self>,
308 cx: &mut Context<'_>,
309 ) -> Poll<Option<Result<T::Ok, Status>>> {
310 self.inner().try_poll_next(cx)
311 }
312 }
313
314 impl<T> IntoInner for FakeStream<T> {
315 type Out = FakeStream<T>;
316
317 fn into_inner(self) -> Self::Out {
318 self
319 }
320 }
321
322 impl<T: TryStream<Error = Status>> FusedStream for FakeStream<T> {
323 fn is_terminated(&self) -> bool {
324 unreachable!()
325 }
326 }
327
328 fn item<T>(i: T) -> Result<T, Status> {
329 Ok(i)
330 }
331
332 #[xmtp_common::test]
333 fn establish_changes_state_to_started() {
334 let s = FakeStream(stream::iter(vec![item(0usize), item(1), item(2)]));
335 let fut = futures::future::ready(Ok(s));
336 let fut = FakeFuture(fut);
337 let fut = fut.pending_once();
338 let s =
339 NonBlockingWebStream::<_, FakeStream<stream::Iter<std::vec::IntoIter<_>>>>::new(fut);
340
341 futures::pin_mut!(s);
342 let mut context = futures_test::task::noop_context();
343 assert_eq!(
344 Poll::Pending,
345 s.as_mut()
346 .poll_next(&mut context)
347 .map(Option::unwrap)
348 .map(Result::unwrap)
349 );
350 assert!(matches!(s.state, StreamState::NotStarted { .. }));
351 assert_eq!(
352 Poll::Pending,
353 s.as_mut()
354 .poll_next(&mut context)
355 .map(Option::unwrap)
356 .map(Result::unwrap)
357 );
358 assert!(matches!(s.state, StreamState::Started { .. }));
359 for i in 0..3 {
360 assert_eq!(
361 Poll::Ready(i),
362 s.as_mut()
363 .poll_next(&mut context)
364 .map(Option::unwrap)
365 .map(Result::unwrap)
366 );
367 }
368 assert_eq!(
370 Poll::Ready(None),
371 s.as_mut()
372 .poll_next(&mut context)
373 .map(|o| o.map(Result::unwrap))
374 );
375 assert_eq!(
376 Poll::Ready(None),
377 s.as_mut()
378 .poll_next(&mut context)
379 .map(|o| o.map(Result::unwrap))
380 );
381 assert!(matches!(s.state, StreamState::Terminated));
383 }
384}