xmtp_proto/traits/
stream.rs

1//! Default XMTP Stream
2
3use prost::bytes::Bytes;
4use std::{
5    marker::PhantomData,
6    pin::Pin,
7    task::{Context, Poll, ready},
8};
9
10use crate::{ApiEndpoint, api::ApiClientError};
11use futures::{Stream, TryStream};
12use pin_project_lite::pin_project;
13
14pin_project! {
15    /// A stream which maps the tonic error to ApiClientError, and attaches endpoint metadata
16    pub struct XmtpStream<S, T> {
17        #[pin] inner: S,
18        endpoint: ApiEndpoint,
19        _marker: PhantomData<T>,
20    }
21}
22
23impl<S, T> XmtpStream<S, T> {
24    pub fn new(inner: S, endpoint: ApiEndpoint) -> Self {
25        Self {
26            inner,
27            endpoint,
28            _marker: PhantomData,
29        }
30    }
31}
32
33impl<S, T> Stream for XmtpStream<S, T>
34where
35    S: TryStream<Ok = Bytes>,
36    T: prost::Message + Default,
37    S::Error: std::error::Error + 'static,
38{
39    type Item = Result<T, ApiClientError<S::Error>>;
40
41    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
42        let this = self.as_mut().project();
43        if let Some(item) = ready!(this.inner.try_poll_next(cx)) {
44            let res = item
45                .map_err(|e| ApiClientError::new(self.endpoint.clone(), e))
46                .and_then(|i| T::decode(i).map_err(ApiClientError::<S::Error>::DecodeError));
47            Poll::Ready(Some(res))
48        } else {
49            Poll::Ready(None)
50        }
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use futures::{StreamExt, pin_mut, stream};
58    use prost::Message;
59
60    #[derive(prost::Message)]
61    struct TestMessage {
62        #[prost(string, tag = "1")]
63        content: String,
64    }
65
66    #[derive(thiserror::Error, Debug)]
67    enum TestError {
68        #[error("mock stream error")]
69        StreamError,
70    }
71
72    #[xmtp_common::test]
73    async fn test_poll_next_successful_decode() {
74        let test_message = TestMessage {
75            content: "test content".to_string(),
76        };
77        let encoded_bytes = test_message.encode_to_vec();
78
79        let inner_stream =
80            stream::once(async move { Ok::<Bytes, TestError>(Bytes::from(encoded_bytes)) });
81        let xmtp_stream =
82            XmtpStream::<_, TestMessage>::new(inner_stream, ApiEndpoint::SubscribeGroupMessages);
83        pin_mut!(xmtp_stream);
84
85        let result = xmtp_stream.next().await.unwrap();
86        assert!(result.is_ok());
87        let decoded_message = result.unwrap();
88        assert_eq!(decoded_message.content, "test content");
89        // stream ends
90        let n = xmtp_stream.next().await;
91        assert!(n.is_none());
92    }
93
94    #[xmtp_common::test]
95    async fn test_poll_next_error_mapping() {
96        let inner_stream = stream::once(async { Err::<Bytes, TestError>(TestError::StreamError) });
97        let xmtp_stream =
98            XmtpStream::<_, TestMessage>::new(inner_stream, ApiEndpoint::SubscribeGroupMessages);
99        pin_mut!(xmtp_stream);
100
101        let result = xmtp_stream.next().await.unwrap();
102        assert!(result.is_err());
103
104        match result {
105            Err(ApiClientError::ClientWithEndpoint { endpoint, .. }) => {
106                assert_eq!(endpoint, "subscribe_group_messages");
107            }
108            _ => panic!("Expected ClientWithEndpoint error"),
109        }
110        // stream ends
111        let n = xmtp_stream.next().await;
112        assert!(n.is_none());
113    }
114}