xmtp_api_grpc/streams/
try_from_item.rs

1//! Maps a `TryStream` with a different type that implements `TryFrom<T>` for the streams item
2
3use futures::{Stream, TryStream};
4use pin_project_lite::pin_project;
5use std::{
6    marker::PhantomData,
7    task::{Poll, ready},
8};
9
10pin_project! {
11    pub struct TryFromItem<S, T> {
12        #[pin] inner: S,
13        _marker: PhantomData<T>,
14    }
15}
16
17/// Wrap a `TryStream<T>` such that it converts its 'item' to T
18pub fn try_from_stream<S, T>(s: S) -> TryFromItem<S, T> {
19    TryFromItem::<S, T> {
20        inner: s,
21        _marker: PhantomData,
22    }
23}
24
25impl<S, T> Stream for TryFromItem<S, T>
26where
27    S: TryStream,
28    T: TryFrom<S::Ok>,
29    S::Error: From<<T as TryFrom<S::Ok>>::Error>,
30{
31    type Item = Result<T, S::Error>;
32
33    fn poll_next(
34        mut self: std::pin::Pin<&mut Self>,
35        cx: &mut std::task::Context<'_>,
36    ) -> std::task::Poll<Option<Self::Item>> {
37        let this = self.as_mut().project();
38        let item = ready!(this.inner.try_poll_next(cx));
39        match item {
40            Some(i) => Poll::Ready(Some(i.and_then(|i| i.try_into().map_err(S::Error::from)))),
41            None => Poll::Ready(None),
42        }
43    }
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49    use futures::{StreamExt, stream};
50    use rstest::rstest;
51
52    #[derive(Debug, PartialEq, Clone)]
53    struct TestItem {
54        value: u32,
55    }
56
57    #[derive(Debug, PartialEq)]
58    enum TestError {
59        ConversionError(String),
60    }
61
62    impl std::fmt::Display for TestError {
63        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64            match self {
65                TestError::ConversionError(msg) => write!(f, "Conversion error: {}", msg),
66            }
67        }
68    }
69
70    impl std::error::Error for TestError {}
71
72    impl TryFrom<u32> for TestItem {
73        type Error = TestError;
74
75        fn try_from(value: u32) -> Result<Self, Self::Error> {
76            if value == 999 {
77                Err(TestError::ConversionError("Invalid value 999".to_string()))
78            } else {
79                Ok(TestItem { value })
80            }
81        }
82    }
83
84    #[rstest]
85    #[case(vec![Ok(1), Ok(2), Ok(3)], vec![TestItem { value: 1 }, TestItem { value: 2 }, TestItem { value: 3 }], "happy_path")]
86    #[case(vec![], vec![], "empty_stream")]
87    #[case(vec![Ok(42)], vec![TestItem { value: 42 }], "single_item_stream")]
88    #[case(vec![Ok(1), Ok(2), Ok(3), Ok(4), Ok(5)], vec![TestItem { value: 1 }, TestItem { value: 2 }, TestItem { value: 3 }, TestItem { value: 4 }, TestItem { value: 5 }], "multiple_items_stream")]
89    #[xmtp_common::test]
90    async fn test_successful_conversions(
91        #[case] input: Vec<Result<u32, TestError>>,
92        #[case] expected: Vec<TestItem>,
93        #[case] _description: &str,
94    ) {
95        let stream = stream::iter(input);
96        let try_from_stream = try_from_stream::<_, TestItem>(stream);
97
98        let results: Vec<_> = try_from_stream.map(Result::unwrap).collect().await;
99        assert_eq!(results, expected);
100    }
101
102    #[xmtp_common::test]
103    async fn test_conversion_error_propagation() {
104        let items: Vec<Result<u32, TestError>> = vec![Ok(1), Ok(999), Ok(3)];
105        let stream = stream::iter(items);
106        let try_from_stream = try_from_stream::<_, TestItem>(stream);
107
108        let results: Vec<_> = try_from_stream.collect().await;
109        assert_eq!(results.len(), 3);
110        assert_eq!(results[0].as_ref().unwrap(), &TestItem { value: 1 });
111        assert!(matches!(
112            results[1].as_ref().unwrap_err(),
113            TestError::ConversionError(_)
114        ));
115        assert_eq!(results[2].as_ref().unwrap(), &TestItem { value: 3 });
116    }
117
118    #[xmtp_common::test]
119    fn stream_can_finish() {
120        let items: Vec<Result<u32, TestError>> = vec![Ok(42)];
121        let stream = stream::iter(items);
122        let stream = try_from_stream::<_, TestItem>(stream);
123        futures::pin_mut!(stream);
124
125        let mut cx = futures_test::task::noop_context();
126        assert!(matches!(
127            stream.as_mut().poll_next(&mut cx),
128            Poll::Ready(Some(Ok(_)))
129        ));
130
131        let mut cx = futures_test::task::noop_context();
132        assert!(matches!(stream.poll_next(&mut cx), Poll::Ready(None)));
133    }
134
135    #[xmtp_common::test]
136    fn happy_path() {
137        let items: Vec<Result<u32, TestError>> = vec![Ok(1), Ok(2)];
138        let stream = stream::iter(items);
139        let stream = try_from_stream::<_, TestItem>(stream);
140        futures::pin_mut!(stream);
141
142        let mut cx = futures_test::task::noop_context();
143
144        // Poll first item
145        let first_poll = stream.as_mut().poll_next(&mut cx);
146        assert!(matches!(first_poll, Poll::Ready(Some(Ok(_)))));
147        if let Poll::Ready(Some(Ok(item))) = first_poll {
148            assert_eq!(item, TestItem { value: 1 });
149        }
150
151        // Poll second item
152        let second_poll = stream.as_mut().poll_next(&mut cx);
153        assert!(matches!(second_poll, Poll::Ready(Some(Ok(_)))));
154        if let Poll::Ready(Some(Ok(item))) = second_poll {
155            assert_eq!(item, TestItem { value: 2 });
156        }
157
158        // Poll for end
159        let end_poll = stream.poll_next(&mut cx);
160        assert!(matches!(end_poll, Poll::Ready(None)));
161    }
162}