xmtp_proto/traits/combinators/
retry.rs1use std::marker::PhantomData;
2
3use xmtp_common::{
4 ExponentialBackoff, MaybeSend, MaybeSync, Retry, RetryableError, Strategy as RetryStrategy,
5 retry_async,
6};
7
8use crate::api::{ApiClientError, Client, Endpoint, Pageable, Query, QueryRaw};
9
10pub struct RetryQuery<E, S = ExponentialBackoff> {
14 endpoint: E,
15 pub(crate) retry: Retry<S>,
16}
17
18impl<E> RetryQuery<E> {
19 pub fn new(endpoint: E) -> Self {
20 Self {
21 endpoint,
22 retry: Default::default(),
23 }
24 }
25}
26
27impl<E> Pageable for RetryQuery<E>
28where
29 E: Pageable,
30{
31 fn set_cursor(&mut self, cursor: u64) {
32 self.endpoint.set_cursor(cursor)
33 }
34}
35
36#[xmtp_common::async_trait]
37impl<E, C, S> Query<C> for RetryQuery<E, S>
38where
39 E: Query<C>,
40 C: Client,
41 C::Error: RetryableError,
42 S: RetryStrategy,
43{
44 type Output = E::Output;
45 async fn query(&mut self, client: &C) -> Result<Self::Output, ApiClientError<C::Error>> {
46 retry_async!(
47 self.retry,
48 (async { Query::<C>::query(&mut self.endpoint, client).await })
49 )
50 }
51}
52
53#[xmtp_common::async_trait]
54impl<E, C, S> QueryRaw<C> for RetryQuery<E, S>
55where
56 E: Endpoint,
57 C: Client,
58 C::Error: RetryableError,
59 S: RetryStrategy,
60{
61 async fn query_raw(&mut self, client: &C) -> Result<bytes::Bytes, ApiClientError<C::Error>> {
62 retry_async!(
63 self.retry,
64 (async { QueryRaw::<C>::query_raw(&mut self.endpoint, client).await })
65 )
66 }
67}
68
69pub struct RetrySpecialized<Spec> {
70 _marker: PhantomData<Spec>,
71}
72
73impl<E, Spec> Endpoint<RetrySpecialized<Spec>> for RetryQuery<E>
74where
75 E: Endpoint<Spec>,
76 Spec: MaybeSend + MaybeSync,
77{
78 type Output = <E as Endpoint<Spec>>::Output;
79
80 fn grpc_endpoint(&self) -> std::borrow::Cow<'static, str> {
81 self.endpoint.grpc_endpoint()
82 }
83
84 fn body(&self) -> Result<bytes::Bytes, crate::api::BodyError> {
85 self.endpoint.body()
86 }
87}
88
89pub fn retry<E>(endpoint: E) -> RetryQuery<E, ExponentialBackoff> {
91 RetryQuery::<E, _> {
92 endpoint,
93 retry: Retry::default(),
94 }
95}
96
97pub fn retry_with_strategy<E, S>(endpoint: E, retry: Retry<S>) -> RetryQuery<E, S> {
99 RetryQuery::<E, S> { endpoint, retry }
100}
101
102#[cfg(test)]
103mod tests {
104
105 use crate::api::{
106 EndpointExt,
107 mock::{MockError, MockNetworkClient, TestEndpoint},
108 };
109
110 use super::*;
111
112 #[xmtp_common::test]
113 async fn retries_endpoint_three_times() {
114 let mut client = MockNetworkClient::new();
115 client.expect_request().times(3).returning(|_, _, _| {
116 tracing::info!("error");
117 Err(ApiClientError::Client {
118 source: MockError::ARetryableError,
119 })
120 });
121 client
122 .expect_request()
123 .times(1)
124 .returning(|_, _, _| Ok(http::Response::new(vec![].into())));
125
126 let result: Result<(), _> = retry(TestEndpoint).query(&client).await;
127 assert!(result.is_ok());
128 }
129
130 #[xmtp_common::test]
131 async fn does_not_retry_non_retryable() {
132 let mut client = MockNetworkClient::new();
133 client.expect_request().times(1).returning(|_, _, _| {
134 Err(ApiClientError::Client {
135 source: MockError::ANonRetryableError,
136 })
137 });
138
139 let result: Result<(), _> = retry(TestEndpoint).query(&client).await;
140 assert!(result.is_err());
141 assert!(
142 matches!(
143 result,
144 Err(ApiClientError::ClientWithEndpoint {
145 source: MockError::ANonRetryableError,
146 ..
147 })
148 ),
149 "{:?}",
150 result.unwrap_err()
151 );
152 }
153
154 #[xmtp_common::test]
155 fn test_grpc_endpoint_delegates_to_wrapped_endpoint() {
156 let retry_endpoint = retry(TestEndpoint);
157 assert_eq!(retry_endpoint.grpc_endpoint(), "");
158 }
159
160 #[xmtp_common::test]
161 fn test_body_delegates_to_wrapped_endpoint() {
162 let retry_endpoint = retry(TestEndpoint);
163 let result = retry_endpoint.body();
164 assert!(result.is_ok());
165 assert_eq!(result.unwrap(), bytes::Bytes::from(vec![]));
166 }
167
168 #[xmtp_common::test]
169 async fn retries_with_strategy() {
170 let mut client = MockNetworkClient::new();
171 client.expect_request().times(2).returning(|_, _, _| {
172 Err(ApiClientError::Client {
173 source: MockError::ARetryableError,
174 })
175 });
176 client
177 .expect_request()
178 .times(1)
179 .returning(|_, _, _| Ok(http::Response::new(vec![1].into())));
180
181 let result: Result<(), _> = TestEndpoint
182 .ignore_response() .retry_with_strategy(Retry::builder().retries(2).build())
184 .query(&client)
185 .await;
186 assert!(result.is_ok(), "{:?}", result.unwrap_err());
187 }
188}