xmtp_api_grpc/streams/
default.rs1use prost::bytes::Bytes;
4use std::{
5 marker::PhantomData,
6 pin::Pin,
7 task::{Context, Poll, ready},
8};
9
10use crate::error::GrpcError;
11use futures::{Stream, TryStream};
12use pin_project_lite::pin_project;
13use xmtp_proto::{
14 ApiEndpoint,
15 api::{ApiClientError, Client},
16};
17
18pin_project! {
19 pub struct XmtpTonicStream<S, T> {
21 #[pin] inner: S,
22 endpoint: ApiEndpoint,
23 _marker: PhantomData<T>,
24 }
25}
26
27impl<S, T> XmtpTonicStream<S, T> {
28 pub fn new(inner: S, endpoint: ApiEndpoint) -> Self {
29 Self {
30 inner,
31 endpoint,
32 _marker: PhantomData,
33 }
34 }
35}
36
37impl<T> XmtpTonicStream<crate::GrpcStream, T> {
38 pub async fn from_body<B: prost::Name>(
41 body: B,
42 client: crate::GrpcClient,
43 endpoint: ApiEndpoint,
44 ) -> Result<Self, ApiClientError<GrpcError>> {
45 let pnq = xmtp_proto::path_and_query::<B>();
46 let request = http::Request::builder();
47 let path = http::uri::PathAndQuery::try_from(pnq.as_ref())?;
48 let s = client
49 .stream(request, path, body.encode_to_vec().into())
50 .await?;
51 Ok(Self::new(s.into_body(), endpoint))
52 }
53}
54
55impl<S, T> Stream for XmtpTonicStream<S, T>
56where
57 S: TryStream<Ok = Bytes, Error = GrpcError>,
58 GrpcError: From<<S as TryStream>::Error>,
59 T: prost::Message + Default,
60{
61 type Item = Result<T, ApiClientError<GrpcError>>;
62
63 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
64 let this = self.as_mut().project();
65 if let Some(item) = ready!(this.inner.try_poll_next(cx)) {
66 let res = item
67 .map_err(|e| ApiClientError::new(self.endpoint.clone(), e))
68 .and_then(|i| T::decode(i).map_err(GrpcError::from).map_err(Into::into));
69 Poll::Ready(Some(res))
70 } else {
71 Poll::Ready(None)
72 }
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use futures::{StreamExt, stream};
80 use prost::Message;
81 use rstest::rstest;
82
83 #[derive(Clone, PartialEq, Message)]
84 struct TestMessage {
85 #[prost(string, tag = "1")]
86 pub content: String,
87 }
88
89 impl prost::Name for TestMessage {
90 const NAME: &'static str = "TestMessage";
91 const PACKAGE: &'static str = "test";
92 fn full_name() -> String {
93 format!("{}.{}", Self::PACKAGE, Self::NAME)
94 }
95 }
96
97 fn create_test_message_bytes(content: &str) -> Bytes {
98 let msg = TestMessage {
99 content: content.to_string(),
100 };
101 Bytes::from(msg.encode_to_vec())
102 }
103
104 #[rstest]
105 #[case::empty_stream(vec![], vec![])]
106 #[case::single_message(
107 vec![Ok(create_test_message_bytes("test1"))],
108 vec![TestMessage { content: "test1".to_string() }],
109 )]
110 #[case::multiple_messages(
111 vec![
112 Ok(create_test_message_bytes("msg1")),
113 Ok(create_test_message_bytes("msg2")),
114 Ok(create_test_message_bytes("msg3"))
115 ],
116 vec![
117 TestMessage { content: "msg1".to_string() },
118 TestMessage { content: "msg2".to_string() },
119 TestMessage { content: "msg3".to_string() }
120 ],
121 )]
122 #[xmtp_common::test]
123 async fn test_successful_message_decoding(
124 #[case] input: Vec<Result<Bytes, GrpcError>>,
125 #[case] expected: Vec<TestMessage>,
126 ) {
127 let stream = stream::iter(input);
128 let endpoint = ApiEndpoint::SubscribeGroupMessages;
129 let stream = XmtpTonicStream::<_, TestMessage>::new(stream, endpoint);
130
131 let results: Vec<_> = stream.map(Result::unwrap).collect().await;
132 assert_eq!(results, expected);
133 }
134
135 #[xmtp_common::test]
136 async fn test_error_propagation() {
137 let grpc_error = GrpcError::Status(tonic::Status::unavailable("Connection failed"));
138 let input = vec![
139 Ok(create_test_message_bytes("msg1")),
140 Err(grpc_error),
141 Ok(create_test_message_bytes("msg3")),
142 ];
143
144 let stream = stream::iter(input);
145 let endpoint = ApiEndpoint::QueryGroupMessages;
146 let stream = XmtpTonicStream::<_, TestMessage>::new(stream, endpoint.clone());
147
148 let results: Vec<_> = stream.collect().await;
149 assert_eq!(results.len(), 3);
150
151 assert_eq!(
152 results[0].as_ref().unwrap(),
153 &TestMessage {
154 content: "msg1".to_string()
155 }
156 );
157
158 let api_error = results[1].as_ref().unwrap_err();
159 if let xmtp_proto::api::ApiClientError::ClientWithEndpoint {
160 endpoint: err_endpoint,
161 ..
162 } = api_error
163 {
164 assert_eq!(*err_endpoint, endpoint.to_string());
165 } else {
166 panic!("Expected ClientWithEndpoint error variant");
167 }
168
169 assert_eq!(
170 results[2].as_ref().unwrap(),
171 &TestMessage {
172 content: "msg3".to_string()
173 }
174 );
175 }
176
177 #[xmtp_common::test]
178 fn stream_ends() {
179 let input = vec![Ok(create_test_message_bytes("test"))];
180 let stream = stream::iter(input);
181 let endpoint = ApiEndpoint::SendGroupMessages;
182 let stream = XmtpTonicStream::<_, TestMessage>::new(stream, endpoint);
183
184 futures::pin_mut!(stream);
185 let mut cx = futures_test::task::noop_context();
186
187 let first_poll = stream.as_mut().poll_next(&mut cx);
188 assert!(matches!(first_poll, Poll::Ready(Some(Ok(_)))));
189 if let Poll::Ready(Some(Ok(msg))) = first_poll {
190 assert_eq!(msg.content, "test");
191 }
192
193 let end_poll = stream.poll_next(&mut cx);
194 assert!(matches!(end_poll, Poll::Ready(None)));
195 }
196}