xmtp_api_grpc/streams/
non_blocking_stream.rs

1//! Compatibility layer for JS-Fetch POST streams & gRPC Tonic Web
2//!
3//! a web ['fetch' request](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API/Using_Fetch)
4//! may complete successfully, but the fetch promise does not resolve until the first bytes of the
5//! body are received by the browser.[issue](https://github.com/devashishdxt/tonic-web-wasm-client/issues/22).
6//!
7//! This poses a behavior inconsistency between gRPC native - HTTP/2 and gRPC-web HTTP/1.1. On
8//! native, gRPC streams do not block on the first body received, while web streams do.
9//! This is particularly obvious in tests, where:
10//! 1. stream is started
11//! 2. data is sent (for instance, group messages)
12//! 3. inspect sent data
13//!
14//! on web, we never get past step 1.)
15//!
16//! This solution models web stream request as part of the stream.
17//! Once the initial promise request resolves, the stream continues polling the
18//! resulting response object.
19//!
20//! This problem is not unique to grpc-web, and must be solved for grpc-gateway streams as well
21//! [code example for  grpc-gateway](https://github.com/xmtp/libxmtp/blob/87338b819730ade4c292937e3243b16e3cdee248/xmtp_api_http/src/http_stream.rs#L165)
22//!
23//! In context of gRPC, this should not break anything that already works with native -- grpc requests, even
24//! unary requests, are all modeled as streams (a unary request is a stream with a single message),
25//! and none block on receipt of the body. Ideally, we could check the header status and ensure the
26//! initial response is 200 (OK), although the browser environment constraints do not allow for
27//! this.
28
29use futures::{Stream, TryFuture, TryStream, stream::FusedStream};
30use pin_project_lite::pin_project;
31use std::{
32    future::Future,
33    pin::Pin,
34    task::{Context, Poll, ready},
35};
36use tonic::Status;
37
38use crate::streams::{FakeEmptyStream, IntoInner};
39
40pin_project! {
41    /// The establish future for the http post stream
42    struct StreamEstablish<F> {
43        #[pin] inner: F,
44    }
45}
46
47impl<F> StreamEstablish<F> {
48    fn new(inner: F) -> Self {
49        Self { inner }
50    }
51}
52
53impl<F> Future for StreamEstablish<F>
54where
55    F: TryFuture<Error = Status>,
56{
57    type Output = Result<F::Ok, Status>;
58
59    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
60        use Poll::*;
61        let this = self.as_mut().project();
62        let response = ready!(this.inner.try_poll(cx));
63        let response = response.inspect_err(|e| {
64            tracing::error!("Error during grpc-web subscription establishment {e}");
65        })?;
66        Ready(Ok(response))
67    }
68}
69
70pin_project! {
71    /// The establish future for the http post stream
72    #[project = ProjectStream]
73    enum StreamState< F, S> {
74        NotStarted {
75            #[pin] future: StreamEstablish<F>,
76        },
77        Started {
78            #[pin] stream: S,
79        },
80        Terminated
81    }
82}
83
84pin_project! {
85    pub struct NonBlockingWebStream<F, S> {
86        #[pin] state: StreamState<F, S>,
87    }
88}
89
90impl<F, S> NonBlockingWebStream<F, S>
91where
92    F: TryFuture<Error = Status>,
93{
94    pub fn new(request: F) -> Self {
95        Self {
96            state: StreamState::NotStarted {
97                future: StreamEstablish::new(request),
98            },
99        }
100    }
101
102    /// Internal API to construct a started variant
103    pub fn started(stream: S) -> Self {
104        Self {
105            state: StreamState::Started { stream },
106        }
107    }
108}
109
110impl<F> NonBlockingWebStream<F, FakeEmptyStream<Status>> {
111    pub fn empty() -> Self {
112        Self {
113            state: StreamState::Started {
114                stream: FakeEmptyStream::<Status>::new(),
115            },
116        }
117    }
118}
119
120impl<F, S> Stream for NonBlockingWebStream<F, S>
121where
122    S: TryStream<Error = Status>,
123    F: TryFuture<Error = Status>,
124    F::Ok: IntoInner<Out = S>,
125{
126    type Item = Result<S::Ok, Status>;
127
128    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129        use ProjectStream::*;
130        let mut this = self.as_mut().project();
131        match this.state.as_mut().project() {
132            NotStarted { future } => {
133                match ready!(future.poll(cx)) {
134                    Ok(stream) => {
135                        this.state.set(StreamState::Started {
136                            stream: stream.into_inner(),
137                        });
138                    }
139                    Err(e) => {
140                        this.state.set(StreamState::Terminated);
141                        return Poll::Ready(Some(Err(e)));
142                    }
143                }
144                tracing::trace!("stream ready, polling for the first time...");
145                cx.waker().wake_by_ref();
146                Poll::Pending
147            }
148            Started { mut stream } => {
149                let next = stream.as_mut().try_poll_next(cx);
150                if let Poll::Ready(None) = next {
151                    this.state.set(StreamState::Terminated);
152                }
153                next
154            }
155            Terminated => Poll::Ready(None),
156        }
157    }
158}
159
160impl<F, S> FusedStream for NonBlockingWebStream<F, S>
161where
162    F: TryFuture<Error = Status>,
163    S: TryStream<Error = Status> + FusedStream,
164    F::Ok: IntoInner<Out = S>,
165{
166    fn is_terminated(&self) -> bool {
167        match &self.state {
168            StreamState::Started { stream } => stream.is_terminated(),
169            StreamState::Terminated => true,
170            _ => false,
171        }
172    }
173}
174
175impl<F, S> std::fmt::Debug for NonBlockingWebStream<F, S> {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
177        match self.state {
178            StreamState::NotStarted { .. } => write!(f, "not started"),
179            StreamState::Started { .. } => write!(f, "started"),
180            StreamState::Terminated => write!(f, "terminated"),
181        }
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use futures::{Stream, stream};
188    use futures_test::future::FutureTestExt;
189    use prost::bytes::Bytes;
190    use tonic::{Response, Streaming};
191
192    use super::*;
193
194    struct TestStream;
195    impl Stream for TestStream {
196        type Item = Result<Response<Bytes>, Status>;
197
198        fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
199            unreachable!()
200        }
201    }
202
203    impl FusedStream for TestStream {
204        fn is_terminated(&self) -> bool {
205            unreachable!()
206        }
207    }
208
209    impl<T> From<Streaming<T>> for TestStream {
210        fn from(_: Streaming<T>) -> Self {
211            unreachable!()
212        }
213    }
214
215    struct MockFut;
216    impl Future for MockFut {
217        type Output = Result<TestStream, Status>;
218
219        fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
220            unimplemented!()
221        }
222    }
223
224    impl IntoInner for MockFut {
225        type Out = TestStream;
226
227        fn into_inner(self) -> Self::Out {
228            todo!()
229        }
230    }
231
232    #[xmtp_common::test]
233    fn handles_err_on_establish() {
234        let stream: NonBlockingWebStream<_, TestStream> =
235            NonBlockingWebStream::new(futures::future::ready({
236                // we just need something that creates a reqwest error
237                // we also use now_or_never to guarantee this will trigger an error on the first poll
238                Err::<MockFut, _>(Status::internal("test error"))
239            }));
240        futures::pin_mut!(stream);
241
242        assert!(matches!(stream.state, StreamState::NotStarted { .. }));
243        let cx = futures::task::noop_waker();
244        let mut cx = std::task::Context::from_waker(&cx);
245        assert!(matches!(
246            stream.as_mut().poll_next(&mut cx),
247            Poll::Ready(Some(Err(_)))
248        ));
249
250        assert!(FusedStream::is_terminated(&stream));
251        assert!(matches!(
252            stream.as_mut().poll_next(&mut cx),
253            Poll::Ready(None)
254        ));
255    }
256
257    #[xmtp_common::test]
258    fn happy_path_future() {
259        let fut = futures::future::ready(Ok(()));
260        let fut = fut.pending_once();
261        let fut = StreamEstablish::new(fut);
262        futures::pin_mut!(fut);
263        let mut context = futures_test::task::noop_context();
264        assert_eq!(
265            Poll::Pending,
266            fut.as_mut().poll(&mut context).map(Result::unwrap)
267        );
268        assert_eq!(Poll::Ready(()), fut.poll(&mut context).map(Result::unwrap));
269    }
270
271    struct FakeFuture<T>(T);
272
273    impl<T> FakeFuture<T> {
274        fn inner(self: Pin<&mut Self>) -> Pin<&mut T> {
275            // This is okay because `field` is pinned when `self` is.
276            unsafe { self.map_unchecked_mut(|s| &mut s.0) }
277        }
278    }
279
280    impl<T> Future for FakeFuture<T>
281    where
282        T: TryFuture<Error = Status>,
283    {
284        type Output = Result<T::Ok, Status>;
285
286        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
287            self.inner().try_poll(cx)
288        }
289    }
290
291    struct FakeStream<T>(T);
292
293    impl<T> FakeStream<T> {
294        fn inner(self: Pin<&mut Self>) -> Pin<&mut T> {
295            // This is okay because `field` is pinned when `self` is.
296            unsafe { self.map_unchecked_mut(|s| &mut s.0) }
297        }
298    }
299
300    impl<T> Stream for FakeStream<T>
301    where
302        T: TryStream<Error = Status>,
303    {
304        type Item = Result<T::Ok, Status>;
305
306        fn poll_next(
307            self: Pin<&mut Self>,
308            cx: &mut Context<'_>,
309        ) -> Poll<Option<Result<T::Ok, Status>>> {
310            self.inner().try_poll_next(cx)
311        }
312    }
313
314    impl<T> IntoInner for FakeStream<T> {
315        type Out = FakeStream<T>;
316
317        fn into_inner(self) -> Self::Out {
318            self
319        }
320    }
321
322    impl<T: TryStream<Error = Status>> FusedStream for FakeStream<T> {
323        fn is_terminated(&self) -> bool {
324            unreachable!()
325        }
326    }
327
328    fn item<T>(i: T) -> Result<T, Status> {
329        Ok(i)
330    }
331
332    #[xmtp_common::test]
333    fn establish_changes_state_to_started() {
334        let s = FakeStream(stream::iter(vec![item(0usize), item(1), item(2)]));
335        let fut = futures::future::ready(Ok(s));
336        let fut = FakeFuture(fut);
337        let fut = fut.pending_once();
338        let s =
339            NonBlockingWebStream::<_, FakeStream<stream::Iter<std::vec::IntoIter<_>>>>::new(fut);
340
341        futures::pin_mut!(s);
342        let mut context = futures_test::task::noop_context();
343        assert_eq!(
344            Poll::Pending,
345            s.as_mut()
346                .poll_next(&mut context)
347                .map(Option::unwrap)
348                .map(Result::unwrap)
349        );
350        assert!(matches!(s.state, StreamState::NotStarted { .. }));
351        assert_eq!(
352            Poll::Pending,
353            s.as_mut()
354                .poll_next(&mut context)
355                .map(Option::unwrap)
356                .map(Result::unwrap)
357        );
358        assert!(matches!(s.state, StreamState::Started { .. }));
359        for i in 0..3 {
360            assert_eq!(
361                Poll::Ready(i),
362                s.as_mut()
363                    .poll_next(&mut context)
364                    .map(Option::unwrap)
365                    .map(Result::unwrap)
366            );
367        }
368        // stream ended after going through all items
369        assert_eq!(
370            Poll::Ready(None),
371            s.as_mut()
372                .poll_next(&mut context)
373                .map(|o| o.map(Result::unwrap))
374        );
375        assert_eq!(
376            Poll::Ready(None),
377            s.as_mut()
378                .poll_next(&mut context)
379                .map(|o| o.map(Result::unwrap))
380        );
381        // state should be terminated
382        assert!(matches!(s.state, StreamState::Terminated));
383    }
384}