xmtp_api_grpc/streams/
default.rs

1//! Default XMTP Streams
2
3use prost::bytes::Bytes;
4use std::{
5    marker::PhantomData,
6    pin::Pin,
7    task::{Context, Poll, ready},
8};
9
10use crate::error::GrpcError;
11use futures::{Stream, TryStream};
12use pin_project_lite::pin_project;
13use xmtp_proto::{
14    ApiEndpoint,
15    api::{ApiClientError, Client},
16};
17
18pin_project! {
19    /// A stream which maps the tonic error to ApiClientError, and attaches endpoint metadata
20    pub struct XmtpTonicStream<S, T> {
21        #[pin] inner: S,
22        endpoint: ApiEndpoint,
23        _marker: PhantomData<T>,
24    }
25}
26
27impl<S, T> XmtpTonicStream<S, T> {
28    pub fn new(inner: S, endpoint: ApiEndpoint) -> Self {
29        Self {
30            inner,
31            endpoint,
32            _marker: PhantomData,
33        }
34    }
35}
36
37impl<T> XmtpTonicStream<crate::GrpcStream, T> {
38    /// create a stream from the body of a request
39    /// makes the request and starts the stream
40    pub async fn from_body<B: prost::Name>(
41        body: B,
42        client: crate::GrpcClient,
43        endpoint: ApiEndpoint,
44    ) -> Result<Self, ApiClientError<GrpcError>> {
45        let pnq = xmtp_proto::path_and_query::<B>();
46        let request = http::Request::builder();
47        let path = http::uri::PathAndQuery::try_from(pnq.as_ref())?;
48        let s = client
49            .stream(request, path, body.encode_to_vec().into())
50            .await?;
51        Ok(Self::new(s.into_body(), endpoint))
52    }
53}
54
55impl<S, T> Stream for XmtpTonicStream<S, T>
56where
57    S: TryStream<Ok = Bytes, Error = GrpcError>,
58    GrpcError: From<<S as TryStream>::Error>,
59    T: prost::Message + Default,
60{
61    type Item = Result<T, ApiClientError<GrpcError>>;
62
63    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
64        let this = self.as_mut().project();
65        if let Some(item) = ready!(this.inner.try_poll_next(cx)) {
66            let res = item
67                .map_err(|e| ApiClientError::new(self.endpoint.clone(), e))
68                .and_then(|i| T::decode(i).map_err(GrpcError::from).map_err(Into::into));
69            Poll::Ready(Some(res))
70        } else {
71            Poll::Ready(None)
72        }
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use futures::{StreamExt, stream};
80    use prost::Message;
81    use rstest::rstest;
82
83    #[derive(Clone, PartialEq, Message)]
84    struct TestMessage {
85        #[prost(string, tag = "1")]
86        pub content: String,
87    }
88
89    impl prost::Name for TestMessage {
90        const NAME: &'static str = "TestMessage";
91        const PACKAGE: &'static str = "test";
92        fn full_name() -> String {
93            format!("{}.{}", Self::PACKAGE, Self::NAME)
94        }
95    }
96
97    fn create_test_message_bytes(content: &str) -> Bytes {
98        let msg = TestMessage {
99            content: content.to_string(),
100        };
101        Bytes::from(msg.encode_to_vec())
102    }
103
104    #[rstest]
105    #[case::empty_stream(vec![], vec![])]
106    #[case::single_message(
107        vec![Ok(create_test_message_bytes("test1"))],
108        vec![TestMessage { content: "test1".to_string() }],
109    )]
110    #[case::multiple_messages(
111        vec![
112            Ok(create_test_message_bytes("msg1")),
113            Ok(create_test_message_bytes("msg2")),
114            Ok(create_test_message_bytes("msg3"))
115        ],
116        vec![
117            TestMessage { content: "msg1".to_string() },
118            TestMessage { content: "msg2".to_string() },
119            TestMessage { content: "msg3".to_string() }
120        ],
121    )]
122    #[xmtp_common::test]
123    async fn test_successful_message_decoding(
124        #[case] input: Vec<Result<Bytes, GrpcError>>,
125        #[case] expected: Vec<TestMessage>,
126    ) {
127        let stream = stream::iter(input);
128        let endpoint = ApiEndpoint::SubscribeGroupMessages;
129        let stream = XmtpTonicStream::<_, TestMessage>::new(stream, endpoint);
130
131        let results: Vec<_> = stream.map(Result::unwrap).collect().await;
132        assert_eq!(results, expected);
133    }
134
135    #[xmtp_common::test]
136    async fn test_error_propagation() {
137        let grpc_error = GrpcError::Status(tonic::Status::unavailable("Connection failed"));
138        let input = vec![
139            Ok(create_test_message_bytes("msg1")),
140            Err(grpc_error),
141            Ok(create_test_message_bytes("msg3")),
142        ];
143
144        let stream = stream::iter(input);
145        let endpoint = ApiEndpoint::QueryGroupMessages;
146        let stream = XmtpTonicStream::<_, TestMessage>::new(stream, endpoint.clone());
147
148        let results: Vec<_> = stream.collect().await;
149        assert_eq!(results.len(), 3);
150
151        assert_eq!(
152            results[0].as_ref().unwrap(),
153            &TestMessage {
154                content: "msg1".to_string()
155            }
156        );
157
158        let api_error = results[1].as_ref().unwrap_err();
159        if let xmtp_proto::api::ApiClientError::ClientWithEndpoint {
160            endpoint: err_endpoint,
161            ..
162        } = api_error
163        {
164            assert_eq!(*err_endpoint, endpoint.to_string());
165        } else {
166            panic!("Expected ClientWithEndpoint error variant");
167        }
168
169        assert_eq!(
170            results[2].as_ref().unwrap(),
171            &TestMessage {
172                content: "msg3".to_string()
173            }
174        );
175    }
176
177    #[xmtp_common::test]
178    fn stream_ends() {
179        let input = vec![Ok(create_test_message_bytes("test"))];
180        let stream = stream::iter(input);
181        let endpoint = ApiEndpoint::SendGroupMessages;
182        let stream = XmtpTonicStream::<_, TestMessage>::new(stream, endpoint);
183
184        futures::pin_mut!(stream);
185        let mut cx = futures_test::task::noop_context();
186
187        let first_poll = stream.as_mut().poll_next(&mut cx);
188        assert!(matches!(first_poll, Poll::Ready(Some(Ok(_)))));
189        if let Poll::Ready(Some(Ok(msg))) = first_poll {
190            assert_eq!(msg.content, "test");
191        }
192
193        let end_poll = stream.poll_next(&mut cx);
194        assert!(matches!(end_poll, Poll::Ready(None)));
195    }
196}