xmtp_api_d14n/queries/stream/
extractor.rs1use futures::{Stream, TryStream};
4use pin_project_lite::pin_project;
5use std::{
6 collections::VecDeque,
7 marker::PhantomData,
8 task::{Poll, ready},
9};
10
11use crate::protocol::{
12 EnvelopeCollection, EnvelopeError, EnvelopeVisitor, TryEnvelopeCollectionExt, TryExtractor,
13};
14
15pin_project! {
16 pub struct TryExtractorStream<S, E: TryExtractor> {
17 #[pin] inner: S,
18 buffered: VecDeque<<E as TryExtractor>::Ok>,
19 _marker: PhantomData<E>,
20 }
21}
22
23pub fn try_extractor<S, E>(s: S) -> TryExtractorStream<S, E>
27where
28 E: TryExtractor,
29{
30 TryExtractorStream::<S, E> {
31 inner: s,
32 buffered: Default::default(),
33 _marker: PhantomData,
34 }
35}
36
37impl<S, E> Stream for TryExtractorStream<S, E>
38where
39 S: TryStream,
40 for<'a> S::Ok: EnvelopeCollection<'a> + std::fmt::Debug,
41 for<'a> S::Error: From<EnvelopeError>,
42 for<'a> E: TryExtractor + EnvelopeVisitor<'a> + Default,
43 <E as TryExtractor>::Error: std::fmt::Debug,
44 for<'a> EnvelopeError:
45 From<<E as EnvelopeVisitor<'a>>::Error> + From<<E as TryExtractor>::Error>,
46{
47 type Item = Result<E::Ok, S::Error>;
48
49 fn poll_next(
50 mut self: std::pin::Pin<&mut Self>,
51 cx: &mut std::task::Context<'_>,
52 ) -> std::task::Poll<Option<Self::Item>> {
53 let this = self.as_mut().project();
54 if let Some(item) = this.buffered.pop_front() {
55 return Poll::Ready(Some(Ok(item)));
56 }
57 let envelope = ready!(this.inner.try_poll_next(cx));
58 match envelope {
59 Some(item) => {
60 let item = item?;
61 let (success, _failure) = item.try_consume::<E>()?;
62 let mut consumed = success.into_iter();
63 let ready_item = consumed.next();
64 this.buffered.extend(consumed);
65 if let Some(item) = ready_item {
66 return Poll::Ready(Some(Ok(item)));
67 }
68 cx.waker().wake_by_ref();
69 Poll::Pending
70 }
71 None => Poll::Ready(None),
72 }
73 }
74}
75
76#[cfg(test)]
77mod test {
78 use super::*;
79 use crate::protocol::{EnvelopeError, EnvelopeVisitor, Extractor};
80 use futures::{StreamExt, stream};
81 use rstest::rstest;
82
83 type StreamItem = u32;
87
88 #[derive(Default)]
90 struct MockExtractor {
91 value: StreamItem,
92 }
93
94 impl<'a> EnvelopeVisitor<'a> for MockExtractor {
95 type Error = EnvelopeError;
96 fn test_visit_u32(&mut self, n: &u32) -> Result<(), Self::Error> {
97 self.value = *n;
98 Ok(())
99 }
100 }
101
102 impl Extractor for MockExtractor {
103 type Output = Result<StreamItem, EnvelopeError>;
104
105 fn get(self) -> Self::Output {
106 Ok(self.value)
107 }
108 }
109
110 #[derive(Default)]
112 struct ErrorExtractor;
113
114 impl<'a> EnvelopeVisitor<'a> for ErrorExtractor {
115 type Error = EnvelopeError;
116 }
117
118 impl Extractor for ErrorExtractor {
119 type Output = Result<String, EnvelopeError>;
120
121 fn get(self) -> Self::Output {
122 Err(EnvelopeError::NotFound("extractor error"))
123 }
124 }
125
126 #[rstest]
127 #[case(vec![Ok(vec![1]), Ok(vec![2])], vec![1, 2], "happy_path")]
128 #[case(vec![], vec![], "empty_stream")]
129 #[case(vec![Ok(vec![1])], vec![1], "single_item_stream")]
130 #[case(vec![Ok(vec![1]), Ok(vec![2]), Ok(vec![3])], vec![1, 2, 3], "multiple_items_stream")]
131 #[case(vec![Ok(vec![]), Ok(vec![1])], vec![1], "empty_collection")]
132 #[case(vec![Ok(vec![1, 2]), Ok(vec![3])], vec![1, 2, 3], "buffering")]
133 #[xmtp_common::test]
134 async fn test_content_scenarios(
135 #[case] input: Vec<Result<Vec<u32>, EnvelopeError>>,
136 #[case] expected: Vec<u32>,
137 #[case] _description: &str,
138 ) {
139 let stream = stream::iter(input);
140 let extractor_stream = try_extractor::<_, MockExtractor>(stream);
141
142 let results: Vec<_> = extractor_stream.map(Result::unwrap).collect().await;
143 assert_eq!(results, expected);
144 }
145
146 #[xmtp_common::test]
147 async fn test_stream_error_propagation() {
148 let items: Vec<Result<Vec<u32>, EnvelopeError>> =
149 vec![Ok(vec![1]), Err(EnvelopeError::NotFound("test error"))];
150 let stream = stream::iter(items);
151 let extractor_stream = try_extractor::<_, MockExtractor>(stream);
152
153 let results: Vec<_> = extractor_stream.collect().await;
154 assert_eq!(results.len(), 2);
155 assert_eq!(*results[0].as_ref().unwrap(), 1);
156 assert!(results[1].is_err());
157 }
158
159 #[ignore]
161 #[xmtp_common::test]
162 async fn test_extraction_error_propagation() {
163 let items: Vec<Result<Vec<u32>, EnvelopeError>> = vec![Ok(vec![1])];
164 let stream = stream::iter(items);
165 let extractor_stream = try_extractor::<_, ErrorExtractor>(stream);
166
167 let results: Vec<_> = extractor_stream.collect().await;
168 assert_eq!(results.len(), 1);
169 assert!(results[0].is_err());
170 }
171
172 #[xmtp_common::test]
173 fn stream_can_finish() {
174 let items: Vec<Result<Vec<u32>, EnvelopeError>> = vec![Ok(vec![1])];
175 let stream = stream::iter(items);
176 let stream = try_extractor::<_, MockExtractor>(stream);
177 futures::pin_mut!(stream);
178 let mut cx = futures_test::task::noop_context();
179 assert!(matches!(
180 stream.as_mut().poll_next(&mut cx),
181 Poll::Ready(Some(_))
182 ));
183
184 let mut cx = futures_test::task::noop_context();
185 assert!(matches!(stream.poll_next(&mut cx), Poll::Ready(None)));
186 }
187}