xmtp_proto/traits/
stream.rs1use prost::bytes::Bytes;
4use std::{
5 marker::PhantomData,
6 pin::Pin,
7 task::{Context, Poll, ready},
8};
9
10use crate::{ApiEndpoint, api::ApiClientError};
11use futures::{Stream, TryStream};
12use pin_project_lite::pin_project;
13
14pin_project! {
15 pub struct XmtpStream<S, T> {
17 #[pin] inner: S,
18 endpoint: ApiEndpoint,
19 _marker: PhantomData<T>,
20 }
21}
22
23impl<S, T> XmtpStream<S, T> {
24 pub fn new(inner: S, endpoint: ApiEndpoint) -> Self {
25 Self {
26 inner,
27 endpoint,
28 _marker: PhantomData,
29 }
30 }
31}
32
33impl<S, T> Stream for XmtpStream<S, T>
34where
35 S: TryStream<Ok = Bytes>,
36 T: prost::Message + Default,
37 S::Error: std::error::Error + 'static,
38{
39 type Item = Result<T, ApiClientError<S::Error>>;
40
41 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
42 let this = self.as_mut().project();
43 if let Some(item) = ready!(this.inner.try_poll_next(cx)) {
44 let res = item
45 .map_err(|e| ApiClientError::new(self.endpoint.clone(), e))
46 .and_then(|i| T::decode(i).map_err(ApiClientError::<S::Error>::DecodeError));
47 Poll::Ready(Some(res))
48 } else {
49 Poll::Ready(None)
50 }
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57 use futures::{StreamExt, pin_mut, stream};
58 use prost::Message;
59
60 #[derive(prost::Message)]
61 struct TestMessage {
62 #[prost(string, tag = "1")]
63 content: String,
64 }
65
66 #[derive(thiserror::Error, Debug)]
67 enum TestError {
68 #[error("mock stream error")]
69 StreamError,
70 }
71
72 #[xmtp_common::test]
73 async fn test_poll_next_successful_decode() {
74 let test_message = TestMessage {
75 content: "test content".to_string(),
76 };
77 let encoded_bytes = test_message.encode_to_vec();
78
79 let inner_stream =
80 stream::once(async move { Ok::<Bytes, TestError>(Bytes::from(encoded_bytes)) });
81 let xmtp_stream =
82 XmtpStream::<_, TestMessage>::new(inner_stream, ApiEndpoint::SubscribeGroupMessages);
83 pin_mut!(xmtp_stream);
84
85 let result = xmtp_stream.next().await.unwrap();
86 assert!(result.is_ok());
87 let decoded_message = result.unwrap();
88 assert_eq!(decoded_message.content, "test content");
89 let n = xmtp_stream.next().await;
91 assert!(n.is_none());
92 }
93
94 #[xmtp_common::test]
95 async fn test_poll_next_error_mapping() {
96 let inner_stream = stream::once(async { Err::<Bytes, TestError>(TestError::StreamError) });
97 let xmtp_stream =
98 XmtpStream::<_, TestMessage>::new(inner_stream, ApiEndpoint::SubscribeGroupMessages);
99 pin_mut!(xmtp_stream);
100
101 let result = xmtp_stream.next().await.unwrap();
102 assert!(result.is_err());
103
104 match result {
105 Err(ApiClientError::ClientWithEndpoint { endpoint, .. }) => {
106 assert_eq!(endpoint, "subscribe_group_messages");
107 }
108 _ => panic!("Expected ClientWithEndpoint error"),
109 }
110 let n = xmtp_stream.next().await;
112 assert!(n.is_none());
113 }
114}