xmtp_proto/traits/combinators/
v3_paged.rs1use std::marker::PhantomData;
2
3use xmtp_common::{MaybeSend, MaybeSync};
4use xmtp_configuration::MAX_PAGE_SIZE;
5
6use crate::{
7 api::{ApiClientError, Client, Endpoint, Pageable, Query},
8 api_client::Paged,
9};
10
11pub struct V3Paged<E, T> {
15 endpoint: E,
16 id_cursor: Option<u64>,
17 _marker: PhantomData<T>,
18}
19
20#[xmtp_common::async_trait]
21impl<E, T, C> Query<C> for V3Paged<E, T>
22where
23 E: Query<C, Output = T> + Pageable,
24 C: Client,
25 C::Error: std::error::Error,
26 T: Default + prost::Message + Paged + 'static,
27{
28 type Output = Vec<<T as Paged>::Message>;
29 async fn query(
30 &mut self,
31 client: &C,
32 ) -> Result<Vec<<T as Paged>::Message>, ApiClientError<C::Error>> {
33 let mut out: Vec<<T as Paged>::Message> = vec![];
34 self.endpoint.set_cursor(self.id_cursor.unwrap_or(0));
35 loop {
36 let result: T = self.endpoint.query(client).await?;
37 let info = *result.info();
38 let mut messages = result.messages();
39 let num_messages = messages.len();
40 out.append(&mut messages);
41
42 if num_messages < MAX_PAGE_SIZE as usize || info.is_none() {
43 break;
44 }
45
46 let paging_info = info.expect("Empty paging info");
47 if paging_info.id_cursor == 0 {
48 break;
49 }
50
51 self.endpoint.set_cursor(paging_info.id_cursor);
52 }
53 Ok(out)
54 }
55}
56
57pub struct V3PagedSpecialized<S> {
58 _marker: PhantomData<S>,
59}
60
61impl<S, E: Endpoint<S>, T: MaybeSend + MaybeSync> Endpoint<V3PagedSpecialized<S>>
62 for V3Paged<E, T>
63{
64 type Output = <E as Endpoint<S>>::Output;
65
66 fn grpc_endpoint(&self) -> std::borrow::Cow<'static, str> {
67 self.endpoint.grpc_endpoint()
68 }
69
70 fn body(&self) -> Result<bytes::Bytes, crate::api::BodyError> {
71 self.endpoint.body()
72 }
73}
74
75pub fn v3_paged<E, T>(endpoint: E, id_cursor: Option<u64>) -> V3Paged<E, T> {
77 V3Paged {
78 endpoint,
79 id_cursor,
80 _marker: PhantomData,
81 }
82}
83
84#[cfg(test)]
85mod tests {
86
87 use std::borrow::Cow;
88
89 use prost::Message;
90
91 use crate::{
92 api::{self, Endpoint, EndpointExt, mock::MockNetworkClient},
93 mls_v1::{PagingInfo, SortDirection},
94 };
95
96 use super::*;
97 use rstest::*;
98
99 #[derive(prost::Message)]
100 struct TestV3Pageable {
101 #[prost(message, optional, tag = "1")]
102 info: Option<PagingInfo>,
103 #[prost(int32, repeated, tag = "2")]
104 msgs: Vec<i32>,
105 }
106
107 impl Paged for TestV3Pageable {
108 type Message = i32;
109
110 fn info(&self) -> &Option<PagingInfo> {
111 &self.info
112 }
113
114 fn messages(self) -> Vec<Self::Message> {
115 self.msgs
116 }
117 }
118
119 #[derive(Default)]
120 struct PageableTestEndpoint {
121 inner: TestV3Pageable,
122 }
123
124 impl Endpoint for PageableTestEndpoint {
125 type Output = TestV3Pageable;
126
127 fn grpc_endpoint(&self) -> std::borrow::Cow<'static, str> {
128 Cow::Borrowed("")
129 }
130
131 fn body(&self) -> Result<bytes::Bytes, api::BodyError> {
132 Ok(self.inner.encode_to_vec().into())
133 }
134 }
135
136 impl Pageable for PageableTestEndpoint {
137 fn set_cursor(&mut self, cursor: u64) {
138 if let Some(ref mut info) = self.inner.info {
139 info.id_cursor = cursor;
140 }
141 }
142 }
143
144 #[fixture]
145 fn client() -> MockNetworkClient {
146 let mut client = MockNetworkClient::new();
147 client.expect_request().times(1).returning(|_, _, b| {
148 let body = TestV3Pageable::decode(b.clone()).unwrap();
149 assert_eq!(
150 body.info.unwrap().id_cursor,
151 1,
152 "expected 1 got {}",
153 body.info.unwrap().id_cursor
154 );
155 Ok(http::Response::new(
156 TestV3Pageable {
157 info: Some(PagingInfo {
158 direction: SortDirection::Ascending as i32,
159 limit: 100,
160 id_cursor: 4,
161 }),
162 msgs: vec![0; MAX_PAGE_SIZE as usize],
163 }
164 .encode_to_vec()
165 .into(),
166 ))
167 });
168 client.expect_request().times(1).returning(|_, _, b| {
169 let body = TestV3Pageable::decode(b.clone()).unwrap();
170 assert_eq!(
171 body.info.unwrap().id_cursor,
172 4,
173 "expected 4 got {}",
174 body.info.unwrap().id_cursor
175 );
176 Ok(http::Response::new(
177 TestV3Pageable {
178 info: Some(PagingInfo {
179 direction: SortDirection::Ascending as i32,
180 limit: 100,
181 id_cursor: 6,
182 }),
183 msgs: vec![1; MAX_PAGE_SIZE as usize],
184 }
185 .encode_to_vec()
186 .into(),
187 ))
188 });
189 client.expect_request().times(1).returning(|_, _, b| {
190 let body = TestV3Pageable::decode(b.clone()).unwrap();
191 assert_eq!(
192 body.info.unwrap().id_cursor,
193 6,
194 "expected 6 got {}",
195 body.info.unwrap().id_cursor
196 );
197 Ok(http::Response::new(
198 TestV3Pageable {
199 info: None,
200 msgs: vec![7],
201 }
202 .encode_to_vec()
203 .into(),
204 ))
205 });
206 client
207 }
208
209 #[rstest]
210 #[xmtp_common::test]
211 async fn pages_endpoint(client: MockNetworkClient) {
212 let endpoint = PageableTestEndpoint {
213 inner: TestV3Pageable {
214 info: Some(PagingInfo {
215 direction: SortDirection::Ascending as i32,
216 limit: 100,
217 id_cursor: 2,
218 }),
219 msgs: vec![],
220 },
221 };
222 let result = endpoint.v3_paged(Some(1)).query(&client).await;
224 assert!(result.is_ok());
225 let result = result.unwrap();
226 let msgs = std::iter::repeat_n(0, MAX_PAGE_SIZE as usize)
227 .chain(std::iter::repeat_n(1, MAX_PAGE_SIZE as usize))
228 .chain(vec![7])
229 .collect::<Vec<_>>();
230 assert_eq!(result, msgs, "{:?}", result);
231 }
232
233 #[rstest]
234 #[xmtp_common::test]
235 async fn pages_endpoint_can_be_retried(client: MockNetworkClient) {
236 let endpoint = PageableTestEndpoint {
237 inner: TestV3Pageable {
238 info: Some(PagingInfo {
239 direction: SortDirection::Ascending as i32,
240 limit: 100,
241 id_cursor: 2,
242 }),
243 msgs: vec![],
244 },
245 };
246 let result = api::v3_paged(api::retry(endpoint), Some(1))
247 .query(&client)
248 .await;
249 assert!(result.is_ok());
250 let result = result.unwrap();
251 let msgs = std::iter::repeat_n(0, MAX_PAGE_SIZE as usize)
252 .chain(std::iter::repeat_n(1, MAX_PAGE_SIZE as usize))
253 .chain(vec![7])
254 .collect::<Vec<_>>();
255 assert_eq!(result, msgs, "{:?}", result);
256 }
257
258 #[xmtp_common::test]
259 fn test_grpc_endpoint_delegates_to_wrapped_endpoint() {
260 let base_endpoint = PageableTestEndpoint::default();
261 let paged_endpoint: V3Paged<PageableTestEndpoint, TestV3Pageable> =
262 v3_paged(base_endpoint, Some(0));
263 assert_eq!(paged_endpoint.grpc_endpoint(), "");
264 }
265
266 #[xmtp_common::test]
267 fn test_body_delegates_to_wrapped_endpoint() {
268 let base_endpoint = PageableTestEndpoint::default();
269 let paged_endpoint: V3Paged<PageableTestEndpoint, TestV3Pageable> =
270 v3_paged(base_endpoint, Some(0));
271 let result = paged_endpoint.body();
272 assert!(result.is_ok());
273 assert_eq!(
274 result.unwrap(),
275 bytes::Bytes::from(TestV3Pageable::default().encode_to_vec())
276 );
277 }
278
279 #[xmtp_common::test]
280 fn test_pageable_test_endpoint_body_encodes_protobuf_message() {
281 let endpoint = PageableTestEndpoint {
282 inner: TestV3Pageable {
283 info: Some(PagingInfo {
284 direction: SortDirection::Ascending as i32,
285 limit: 100,
286 id_cursor: 42,
287 }),
288 msgs: vec![1, 2, 3],
289 },
290 };
291 let result = endpoint.body();
292 assert!(result.is_ok());
293 let expected_bytes = endpoint.inner.encode_to_vec();
294 assert_eq!(result.unwrap(), bytes::Bytes::from(expected_bytes));
295 }
296
297 #[xmtp_common::test]
299 async fn endpoints_can_be_chained() {
300 let client = MockNetworkClient::new();
301 std::mem::drop(
302 PageableTestEndpoint::default()
303 .v3_paged(Some(0))
304 .retry()
305 .query(&client),
306 );
307 }
308}