xmtp_proto/traits/combinators/
v3_paged.rs

1use std::marker::PhantomData;
2
3use xmtp_common::{MaybeSend, MaybeSync};
4use xmtp_configuration::MAX_PAGE_SIZE;
5
6use crate::{
7    api::{ApiClientError, Client, Endpoint, Pageable, Query},
8    api_client::Paged,
9};
10
11/// Endpoint that is paged with [`PagingInfo`]
12/// implements the v3 backend paging algorithm for endpoints
13/// which implement the [`Pageable`] trait
14pub struct V3Paged<E, T> {
15    endpoint: E,
16    id_cursor: Option<u64>,
17    _marker: PhantomData<T>,
18}
19
20#[xmtp_common::async_trait]
21impl<E, T, C> Query<C> for V3Paged<E, T>
22where
23    E: Query<C, Output = T> + Pageable,
24    C: Client,
25    C::Error: std::error::Error,
26    T: Default + prost::Message + Paged + 'static,
27{
28    type Output = Vec<<T as Paged>::Message>;
29    async fn query(
30        &mut self,
31        client: &C,
32    ) -> Result<Vec<<T as Paged>::Message>, ApiClientError<C::Error>> {
33        let mut out: Vec<<T as Paged>::Message> = vec![];
34        self.endpoint.set_cursor(self.id_cursor.unwrap_or(0));
35        loop {
36            let result: T = self.endpoint.query(client).await?;
37            let info = *result.info();
38            let mut messages = result.messages();
39            let num_messages = messages.len();
40            out.append(&mut messages);
41
42            if num_messages < MAX_PAGE_SIZE as usize || info.is_none() {
43                break;
44            }
45
46            let paging_info = info.expect("Empty paging info");
47            if paging_info.id_cursor == 0 {
48                break;
49            }
50
51            self.endpoint.set_cursor(paging_info.id_cursor);
52        }
53        Ok(out)
54    }
55}
56
57pub struct V3PagedSpecialized<S> {
58    _marker: PhantomData<S>,
59}
60
61impl<S, E: Endpoint<S>, T: MaybeSend + MaybeSync> Endpoint<V3PagedSpecialized<S>>
62    for V3Paged<E, T>
63{
64    type Output = <E as Endpoint<S>>::Output;
65
66    fn grpc_endpoint(&self) -> std::borrow::Cow<'static, str> {
67        self.endpoint.grpc_endpoint()
68    }
69
70    fn body(&self) -> Result<bytes::Bytes, crate::api::BodyError> {
71        self.endpoint.body()
72    }
73}
74
75/// Set an endpoint to be paged with v3 paging info
76pub fn v3_paged<E, T>(endpoint: E, id_cursor: Option<u64>) -> V3Paged<E, T> {
77    V3Paged {
78        endpoint,
79        id_cursor,
80        _marker: PhantomData,
81    }
82}
83
84#[cfg(test)]
85mod tests {
86
87    use std::borrow::Cow;
88
89    use prost::Message;
90
91    use crate::{
92        api::{self, Endpoint, EndpointExt, mock::MockNetworkClient},
93        mls_v1::{PagingInfo, SortDirection},
94    };
95
96    use super::*;
97    use rstest::*;
98
99    #[derive(prost::Message)]
100    struct TestV3Pageable {
101        #[prost(message, optional, tag = "1")]
102        info: Option<PagingInfo>,
103        #[prost(int32, repeated, tag = "2")]
104        msgs: Vec<i32>,
105    }
106
107    impl Paged for TestV3Pageable {
108        type Message = i32;
109
110        fn info(&self) -> &Option<PagingInfo> {
111            &self.info
112        }
113
114        fn messages(self) -> Vec<Self::Message> {
115            self.msgs
116        }
117    }
118
119    #[derive(Default)]
120    struct PageableTestEndpoint {
121        inner: TestV3Pageable,
122    }
123
124    impl Endpoint for PageableTestEndpoint {
125        type Output = TestV3Pageable;
126
127        fn grpc_endpoint(&self) -> std::borrow::Cow<'static, str> {
128            Cow::Borrowed("")
129        }
130
131        fn body(&self) -> Result<bytes::Bytes, api::BodyError> {
132            Ok(self.inner.encode_to_vec().into())
133        }
134    }
135
136    impl Pageable for PageableTestEndpoint {
137        fn set_cursor(&mut self, cursor: u64) {
138            if let Some(ref mut info) = self.inner.info {
139                info.id_cursor = cursor;
140            }
141        }
142    }
143
144    #[fixture]
145    fn client() -> MockNetworkClient {
146        let mut client = MockNetworkClient::new();
147        client.expect_request().times(1).returning(|_, _, b| {
148            let body = TestV3Pageable::decode(b.clone()).unwrap();
149            assert_eq!(
150                body.info.unwrap().id_cursor,
151                1,
152                "expected 1 got {}",
153                body.info.unwrap().id_cursor
154            );
155            Ok(http::Response::new(
156                TestV3Pageable {
157                    info: Some(PagingInfo {
158                        direction: SortDirection::Ascending as i32,
159                        limit: 100,
160                        id_cursor: 4,
161                    }),
162                    msgs: vec![0; MAX_PAGE_SIZE as usize],
163                }
164                .encode_to_vec()
165                .into(),
166            ))
167        });
168        client.expect_request().times(1).returning(|_, _, b| {
169            let body = TestV3Pageable::decode(b.clone()).unwrap();
170            assert_eq!(
171                body.info.unwrap().id_cursor,
172                4,
173                "expected 4 got {}",
174                body.info.unwrap().id_cursor
175            );
176            Ok(http::Response::new(
177                TestV3Pageable {
178                    info: Some(PagingInfo {
179                        direction: SortDirection::Ascending as i32,
180                        limit: 100,
181                        id_cursor: 6,
182                    }),
183                    msgs: vec![1; MAX_PAGE_SIZE as usize],
184                }
185                .encode_to_vec()
186                .into(),
187            ))
188        });
189        client.expect_request().times(1).returning(|_, _, b| {
190            let body = TestV3Pageable::decode(b.clone()).unwrap();
191            assert_eq!(
192                body.info.unwrap().id_cursor,
193                6,
194                "expected 6 got {}",
195                body.info.unwrap().id_cursor
196            );
197            Ok(http::Response::new(
198                TestV3Pageable {
199                    info: None,
200                    msgs: vec![7],
201                }
202                .encode_to_vec()
203                .into(),
204            ))
205        });
206        client
207    }
208
209    #[rstest]
210    #[xmtp_common::test]
211    async fn pages_endpoint(client: MockNetworkClient) {
212        let endpoint = PageableTestEndpoint {
213            inner: TestV3Pageable {
214                info: Some(PagingInfo {
215                    direction: SortDirection::Ascending as i32,
216                    limit: 100,
217                    id_cursor: 2,
218                }),
219                msgs: vec![],
220            },
221        };
222        // let result = api::v3_paged(endpoint, Some(1)).query(&client).await;
223        let result = endpoint.v3_paged(Some(1)).query(&client).await;
224        assert!(result.is_ok());
225        let result = result.unwrap();
226        let msgs = std::iter::repeat_n(0, MAX_PAGE_SIZE as usize)
227            .chain(std::iter::repeat_n(1, MAX_PAGE_SIZE as usize))
228            .chain(vec![7])
229            .collect::<Vec<_>>();
230        assert_eq!(result, msgs, "{:?}", result);
231    }
232
233    #[rstest]
234    #[xmtp_common::test]
235    async fn pages_endpoint_can_be_retried(client: MockNetworkClient) {
236        let endpoint = PageableTestEndpoint {
237            inner: TestV3Pageable {
238                info: Some(PagingInfo {
239                    direction: SortDirection::Ascending as i32,
240                    limit: 100,
241                    id_cursor: 2,
242                }),
243                msgs: vec![],
244            },
245        };
246        let result = api::v3_paged(api::retry(endpoint), Some(1))
247            .query(&client)
248            .await;
249        assert!(result.is_ok());
250        let result = result.unwrap();
251        let msgs = std::iter::repeat_n(0, MAX_PAGE_SIZE as usize)
252            .chain(std::iter::repeat_n(1, MAX_PAGE_SIZE as usize))
253            .chain(vec![7])
254            .collect::<Vec<_>>();
255        assert_eq!(result, msgs, "{:?}", result);
256    }
257
258    #[xmtp_common::test]
259    fn test_grpc_endpoint_delegates_to_wrapped_endpoint() {
260        let base_endpoint = PageableTestEndpoint::default();
261        let paged_endpoint: V3Paged<PageableTestEndpoint, TestV3Pageable> =
262            v3_paged(base_endpoint, Some(0));
263        assert_eq!(paged_endpoint.grpc_endpoint(), "");
264    }
265
266    #[xmtp_common::test]
267    fn test_body_delegates_to_wrapped_endpoint() {
268        let base_endpoint = PageableTestEndpoint::default();
269        let paged_endpoint: V3Paged<PageableTestEndpoint, TestV3Pageable> =
270            v3_paged(base_endpoint, Some(0));
271        let result = paged_endpoint.body();
272        assert!(result.is_ok());
273        assert_eq!(
274            result.unwrap(),
275            bytes::Bytes::from(TestV3Pageable::default().encode_to_vec())
276        );
277    }
278
279    #[xmtp_common::test]
280    fn test_pageable_test_endpoint_body_encodes_protobuf_message() {
281        let endpoint = PageableTestEndpoint {
282            inner: TestV3Pageable {
283                info: Some(PagingInfo {
284                    direction: SortDirection::Ascending as i32,
285                    limit: 100,
286                    id_cursor: 42,
287                }),
288                msgs: vec![1, 2, 3],
289            },
290        };
291        let result = endpoint.body();
292        assert!(result.is_ok());
293        let expected_bytes = endpoint.inner.encode_to_vec();
294        assert_eq!(result.unwrap(), bytes::Bytes::from(expected_bytes));
295    }
296
297    // this test here to ensure it compiles
298    #[xmtp_common::test]
299    async fn endpoints_can_be_chained() {
300        let client = MockNetworkClient::new();
301        std::mem::drop(
302            PageableTestEndpoint::default()
303                .v3_paged(Some(0))
304                .retry()
305                .query(&client),
306        );
307    }
308}