xmtp_api_d14n/queries/stream/
extractor.rs

1//! Extracts & flattens items from a `TryStream` whose items implement [`EnvelopeCollection`] with extractor ('T')
2
3use futures::{Stream, TryStream};
4use pin_project_lite::pin_project;
5use std::{
6    collections::VecDeque,
7    marker::PhantomData,
8    task::{Poll, ready},
9};
10
11use crate::protocol::{
12    EnvelopeCollection, EnvelopeError, EnvelopeVisitor, TryEnvelopeCollectionExt, TryExtractor,
13};
14
15pin_project! {
16    pub struct TryExtractorStream<S, E: TryExtractor> {
17        #[pin] inner: S,
18        buffered: VecDeque<<E as TryExtractor>::Ok>,
19        _marker: PhantomData<E>,
20    }
21}
22
23/// Wrap a `TryStream<T>` such that it converts its 'item' to T
24// _NOTE_: extractor accepted as argument to avoid a requirement on
25// specifying `Stream` type
26pub fn try_extractor<S, E>(s: S) -> TryExtractorStream<S, E>
27where
28    E: TryExtractor,
29{
30    TryExtractorStream::<S, E> {
31        inner: s,
32        buffered: Default::default(),
33        _marker: PhantomData,
34    }
35}
36
37impl<S, E> Stream for TryExtractorStream<S, E>
38where
39    S: TryStream,
40    for<'a> S::Ok: EnvelopeCollection<'a> + std::fmt::Debug,
41    for<'a> S::Error: From<EnvelopeError>,
42    for<'a> E: TryExtractor + EnvelopeVisitor<'a> + Default,
43    <E as TryExtractor>::Error: std::fmt::Debug,
44    for<'a> EnvelopeError:
45        From<<E as EnvelopeVisitor<'a>>::Error> + From<<E as TryExtractor>::Error>,
46{
47    type Item = Result<E::Ok, S::Error>;
48
49    fn poll_next(
50        mut self: std::pin::Pin<&mut Self>,
51        cx: &mut std::task::Context<'_>,
52    ) -> std::task::Poll<Option<Self::Item>> {
53        let this = self.as_mut().project();
54        if let Some(item) = this.buffered.pop_front() {
55            return Poll::Ready(Some(Ok(item)));
56        }
57        let envelope = ready!(this.inner.try_poll_next(cx));
58        match envelope {
59            Some(item) => {
60                let item = item?;
61                let (success, _failure) = item.try_consume::<E>()?;
62                let mut consumed = success.into_iter();
63                let ready_item = consumed.next();
64                this.buffered.extend(consumed);
65                if let Some(item) = ready_item {
66                    return Poll::Ready(Some(Ok(item)));
67                }
68                cx.waker().wake_by_ref();
69                Poll::Pending
70            }
71            None => Poll::Ready(None),
72        }
73    }
74}
75
76#[cfg(test)]
77mod test {
78    use super::*;
79    use crate::protocol::{EnvelopeError, EnvelopeVisitor, Extractor};
80    use futures::{StreamExt, stream};
81    use rstest::rstest;
82
83    /// Stream Item type
84    /// ProtocolEnvelope is implemented for u32
85    /// for testing only
86    type StreamItem = u32;
87
88    // Mock extractor for testing
89    #[derive(Default)]
90    struct MockExtractor {
91        value: StreamItem,
92    }
93
94    impl<'a> EnvelopeVisitor<'a> for MockExtractor {
95        type Error = EnvelopeError;
96        fn test_visit_u32(&mut self, n: &u32) -> Result<(), Self::Error> {
97            self.value = *n;
98            Ok(())
99        }
100    }
101
102    impl Extractor for MockExtractor {
103        type Output = Result<StreamItem, EnvelopeError>;
104
105        fn get(self) -> Self::Output {
106            Ok(self.value)
107        }
108    }
109
110    // Error extractor for testing error cases
111    #[derive(Default)]
112    struct ErrorExtractor;
113
114    impl<'a> EnvelopeVisitor<'a> for ErrorExtractor {
115        type Error = EnvelopeError;
116    }
117
118    impl Extractor for ErrorExtractor {
119        type Output = Result<String, EnvelopeError>;
120
121        fn get(self) -> Self::Output {
122            Err(EnvelopeError::NotFound("extractor error"))
123        }
124    }
125
126    #[rstest]
127    #[case(vec![Ok(vec![1]), Ok(vec![2])], vec![1, 2], "happy_path")]
128    #[case(vec![], vec![], "empty_stream")]
129    #[case(vec![Ok(vec![1])], vec![1], "single_item_stream")]
130    #[case(vec![Ok(vec![1]), Ok(vec![2]), Ok(vec![3])], vec![1, 2, 3], "multiple_items_stream")]
131    #[case(vec![Ok(vec![]), Ok(vec![1])], vec![1], "empty_collection")]
132    #[case(vec![Ok(vec![1, 2]), Ok(vec![3])], vec![1, 2, 3], "buffering")]
133    #[xmtp_common::test]
134    async fn test_content_scenarios(
135        #[case] input: Vec<Result<Vec<u32>, EnvelopeError>>,
136        #[case] expected: Vec<u32>,
137        #[case] _description: &str,
138    ) {
139        let stream = stream::iter(input);
140        let extractor_stream = try_extractor::<_, MockExtractor>(stream);
141
142        let results: Vec<_> = extractor_stream.map(Result::unwrap).collect().await;
143        assert_eq!(results, expected);
144    }
145
146    #[xmtp_common::test]
147    async fn test_stream_error_propagation() {
148        let items: Vec<Result<Vec<u32>, EnvelopeError>> =
149            vec![Ok(vec![1]), Err(EnvelopeError::NotFound("test error"))];
150        let stream = stream::iter(items);
151        let extractor_stream = try_extractor::<_, MockExtractor>(stream);
152
153        let results: Vec<_> = extractor_stream.collect().await;
154        assert_eq!(results.len(), 2);
155        assert_eq!(*results[0].as_ref().unwrap(), 1);
156        assert!(results[1].is_err());
157    }
158
159    //TODO ignored until https://github.com/xmtp/libxmtp/issues/2604
160    #[ignore]
161    #[xmtp_common::test]
162    async fn test_extraction_error_propagation() {
163        let items: Vec<Result<Vec<u32>, EnvelopeError>> = vec![Ok(vec![1])];
164        let stream = stream::iter(items);
165        let extractor_stream = try_extractor::<_, ErrorExtractor>(stream);
166
167        let results: Vec<_> = extractor_stream.collect().await;
168        assert_eq!(results.len(), 1);
169        assert!(results[0].is_err());
170    }
171
172    #[xmtp_common::test]
173    fn stream_can_finish() {
174        let items: Vec<Result<Vec<u32>, EnvelopeError>> = vec![Ok(vec![1])];
175        let stream = stream::iter(items);
176        let stream = try_extractor::<_, MockExtractor>(stream);
177        futures::pin_mut!(stream);
178        let mut cx = futures_test::task::noop_context();
179        assert!(matches!(
180            stream.as_mut().poll_next(&mut cx),
181            Poll::Ready(Some(_))
182        ));
183
184        let mut cx = futures_test::task::noop_context();
185        assert!(matches!(stream.poll_next(&mut cx), Poll::Ready(None)));
186    }
187}