xmtp_proto/traits/combinators/
retry.rs

1use std::marker::PhantomData;
2
3use xmtp_common::{
4    ExponentialBackoff, MaybeSend, MaybeSync, Retry, RetryableError, Strategy as RetryStrategy,
5    retry_async,
6};
7
8use crate::api::{ApiClientError, Client, Endpoint, Pageable, Query, QueryRaw};
9
10/// The concrete type of a [`crate::api::retry`] Combinators.
11/// Generally using the concrete type can be avoided with type inference
12/// or impl Trait.
13pub struct RetryQuery<E, S = ExponentialBackoff> {
14    endpoint: E,
15    pub(crate) retry: Retry<S>,
16}
17
18impl<E> RetryQuery<E> {
19    pub fn new(endpoint: E) -> Self {
20        Self {
21            endpoint,
22            retry: Default::default(),
23        }
24    }
25}
26
27impl<E> Pageable for RetryQuery<E>
28where
29    E: Pageable,
30{
31    fn set_cursor(&mut self, cursor: u64) {
32        self.endpoint.set_cursor(cursor)
33    }
34}
35
36#[xmtp_common::async_trait]
37impl<E, C, S> Query<C> for RetryQuery<E, S>
38where
39    E: Query<C>,
40    C: Client,
41    C::Error: RetryableError,
42    S: RetryStrategy,
43{
44    type Output = E::Output;
45    async fn query(&mut self, client: &C) -> Result<Self::Output, ApiClientError<C::Error>> {
46        retry_async!(
47            self.retry,
48            (async { Query::<C>::query(&mut self.endpoint, client).await })
49        )
50    }
51}
52
53#[xmtp_common::async_trait]
54impl<E, C, S> QueryRaw<C> for RetryQuery<E, S>
55where
56    E: Endpoint,
57    C: Client,
58    C::Error: RetryableError,
59    S: RetryStrategy,
60{
61    async fn query_raw(&mut self, client: &C) -> Result<bytes::Bytes, ApiClientError<C::Error>> {
62        retry_async!(
63            self.retry,
64            (async { QueryRaw::<C>::query_raw(&mut self.endpoint, client).await })
65        )
66    }
67}
68
69pub struct RetrySpecialized<Spec> {
70    _marker: PhantomData<Spec>,
71}
72
73impl<E, Spec> Endpoint<RetrySpecialized<Spec>> for RetryQuery<E>
74where
75    E: Endpoint<Spec>,
76    Spec: MaybeSend + MaybeSync,
77{
78    type Output = <E as Endpoint<Spec>>::Output;
79
80    fn grpc_endpoint(&self) -> std::borrow::Cow<'static, str> {
81        self.endpoint.grpc_endpoint()
82    }
83
84    fn body(&self) -> Result<bytes::Bytes, crate::api::BodyError> {
85        self.endpoint.body()
86    }
87}
88
89/// retry with the default retry strategy (ExponentialBackoff)
90pub fn retry<E>(endpoint: E) -> RetryQuery<E, ExponentialBackoff> {
91    RetryQuery::<E, _> {
92        endpoint,
93        retry: Retry::default(),
94    }
95}
96
97/// Retry the endpoint, indicating a specific strategy to retry with
98pub fn retry_with_strategy<E, S>(endpoint: E, retry: Retry<S>) -> RetryQuery<E, S> {
99    RetryQuery::<E, S> { endpoint, retry }
100}
101
102#[cfg(test)]
103mod tests {
104
105    use crate::api::{
106        EndpointExt,
107        mock::{MockError, MockNetworkClient, TestEndpoint},
108    };
109
110    use super::*;
111
112    #[xmtp_common::test]
113    async fn retries_endpoint_three_times() {
114        let mut client = MockNetworkClient::new();
115        client.expect_request().times(3).returning(|_, _, _| {
116            tracing::info!("error");
117            Err(ApiClientError::Client {
118                source: MockError::ARetryableError,
119            })
120        });
121        client
122            .expect_request()
123            .times(1)
124            .returning(|_, _, _| Ok(http::Response::new(vec![].into())));
125
126        let result: Result<(), _> = retry(TestEndpoint).query(&client).await;
127        assert!(result.is_ok());
128    }
129
130    #[xmtp_common::test]
131    async fn does_not_retry_non_retryable() {
132        let mut client = MockNetworkClient::new();
133        client.expect_request().times(1).returning(|_, _, _| {
134            Err(ApiClientError::Client {
135                source: MockError::ANonRetryableError,
136            })
137        });
138
139        let result: Result<(), _> = retry(TestEndpoint).query(&client).await;
140        assert!(result.is_err());
141        assert!(
142            matches!(
143                result,
144                Err(ApiClientError::ClientWithEndpoint {
145                    source: MockError::ANonRetryableError,
146                    ..
147                })
148            ),
149            "{:?}",
150            result.unwrap_err()
151        );
152    }
153
154    #[xmtp_common::test]
155    fn test_grpc_endpoint_delegates_to_wrapped_endpoint() {
156        let retry_endpoint = retry(TestEndpoint);
157        assert_eq!(retry_endpoint.grpc_endpoint(), "");
158    }
159
160    #[xmtp_common::test]
161    fn test_body_delegates_to_wrapped_endpoint() {
162        let retry_endpoint = retry(TestEndpoint);
163        let result = retry_endpoint.body();
164        assert!(result.is_ok());
165        assert_eq!(result.unwrap(), bytes::Bytes::from(vec![]));
166    }
167
168    #[xmtp_common::test]
169    async fn retries_with_strategy() {
170        let mut client = MockNetworkClient::new();
171        client.expect_request().times(2).returning(|_, _, _| {
172            Err(ApiClientError::Client {
173                source: MockError::ARetryableError,
174            })
175        });
176        client
177            .expect_request()
178            .times(1)
179            .returning(|_, _, _| Ok(http::Response::new(vec![1].into())));
180
181        let result: Result<(), _> = TestEndpoint
182            .ignore_response() // ignore b/c invalid protobuf bytes
183            .retry_with_strategy(Retry::builder().retries(2).build())
184            .query(&client)
185            .await;
186        assert!(result.is_ok(), "{:?}", result.unwrap_err());
187    }
188}