xmtp_common/
retry.rs

1//! A retry strategy that works with rusts native [`std::error::Error`] type.
2//!
3//! TODO: Could make the impl of `RetryableError` trait into a proc-macro to auto-derive Retryable
4//! on annotated enum variants.
5//! ```ignore
6//! #[derive(Debug, Error)]
7//! enum ErrorFoo {
8//!     #[error("I am retryable")]
9//!     #[retryable]
10//!     Retryable,
11//!     #[error("Nested errors are retryable")]
12//!     #[retryable(inherit)]
13//!     NestedRetryable(AnotherErrorWithRetryableVariants),
14//!     #[error("Always fail")]
15//!     NotRetryable
16//! }
17//! ```
18
19use crate::time::Duration;
20use crate::{MaybeSend, MaybeSync};
21use rand::Rng;
22use std::error::Error;
23use std::sync::Arc;
24
25// Rust 1.86 added Trait upcasting, so we can add these infallible conversions
26// which is useful when getting error messages
27impl From<Box<dyn RetryableError>> for Box<dyn Error> {
28    fn from(retryable: Box<dyn RetryableError>) -> Box<dyn Error> {
29        retryable
30    }
31}
32
33// NOTE: From<> implementation is not possible here b/c of rust orphan rules (relaxed for Box
34// types)
35/// Convert an `Arc<[RetryableError]>` to a Standard Library `Arc<Error>`
36pub fn arc_retryable_to_error(retryable: Arc<dyn RetryableError>) -> Arc<dyn Error> {
37    retryable
38}
39
40pub type BoxedRetry = Retry<Box<dyn Strategy>>;
41
42pub struct NotSpecialized;
43
44/// Specifies which errors are retryable.
45/// All Errors are not retryable by-default.
46pub trait RetryableError<SP = NotSpecialized>: std::error::Error + MaybeSend + MaybeSync {
47    fn is_retryable(&self) -> bool;
48}
49
50impl<T> RetryableError for &'_ T
51where
52    T: RetryableError,
53{
54    fn is_retryable(&self) -> bool {
55        (**self).is_retryable()
56    }
57}
58
59impl<E: RetryableError> RetryableError for Box<E> {
60    fn is_retryable(&self) -> bool {
61        (**self).is_retryable()
62    }
63}
64
65impl RetryableError for core::convert::Infallible {
66    fn is_retryable(&self) -> bool {
67        unreachable!()
68    }
69}
70
71/// Options to specify how to retry a function
72#[derive(Debug, Clone)]
73pub struct Retry<S = ExponentialBackoff> {
74    retries: usize,
75    strategy: S,
76}
77
78impl Default for Retry {
79    fn default() -> Retry {
80        Retry {
81            retries: 5,
82            strategy: ExponentialBackoff::default(),
83        }
84    }
85}
86
87impl<S: Strategy> Retry<S> {
88    /// Get the number of retries this is configured with.
89    pub fn retries(&self) -> usize {
90        self.retries
91    }
92
93    pub fn backoff(&self, attempts: usize, time_spent: crate::time::Instant) -> Option<Duration> {
94        self.strategy.backoff(attempts, time_spent)
95    }
96}
97
98impl<S: Strategy + 'static> Retry<S> {
99    pub fn boxed(self) -> Retry<Box<dyn Strategy>> {
100        Retry {
101            strategy: Box::new(self.strategy),
102            retries: self.retries,
103        }
104    }
105}
106
107/// The strategy interface
108pub trait Strategy: MaybeSend + MaybeSync {
109    /// A time that this retry should backoff
110    /// Returns None when we should no longer backoff,
111    /// despite possibly being below attempts
112    fn backoff(&self, attempts: usize, time_spent: crate::time::Instant) -> Option<Duration>;
113}
114
115impl Strategy for () {
116    fn backoff(&self, _attempts: usize, _time_spent: crate::time::Instant) -> Option<Duration> {
117        Some(Duration::ZERO)
118    }
119}
120
121impl<S: ?Sized + Strategy> Strategy for Box<S> {
122    fn backoff(&self, attempts: usize, time_spent: crate::time::Instant) -> Option<Duration> {
123        (**self).backoff(attempts, time_spent)
124    }
125}
126
127#[derive(Clone, Debug)]
128pub struct ExponentialBackoff {
129    /// The amount to multiply the duration on each subsequent attempt
130    multiplier: u32,
131    /// Duration to be multiplied
132    duration: Duration,
133    /// jitter to add randomness
134    max_jitter: Duration,
135    /// upper limit on time to wait for all retries
136    total_wait_max: Duration,
137    /// upper limit on time to wait between retries
138    individual_wait_max: Duration,
139}
140
141impl ExponentialBackoff {
142    pub fn builder() -> ExponentialBackoffBuilder {
143        ExponentialBackoffBuilder::default()
144    }
145}
146
147impl Default for ExponentialBackoff {
148    fn default() -> Self {
149        Self {
150            // total wait time == two minutes
151            multiplier: 3,
152            duration: Duration::from_millis(50),
153            total_wait_max: Duration::from_secs(120),
154            individual_wait_max: Duration::from_secs(30),
155            max_jitter: Duration::from_millis(25),
156        }
157    }
158}
159
160#[derive(Default)]
161pub struct ExponentialBackoffBuilder {
162    duration: Option<Duration>,
163    max_jitter: Option<Duration>,
164    multiplier: Option<u32>,
165    total_wait_max: Option<Duration>,
166}
167
168impl ExponentialBackoffBuilder {
169    pub fn duration(mut self, duration: Duration) -> Self {
170        self.duration = Some(duration);
171        self
172    }
173
174    pub fn max_jitter(mut self, max_jitter: Duration) -> Self {
175        self.max_jitter = Some(max_jitter);
176        self
177    }
178
179    pub fn multiplier(mut self, multiplier: u32) -> Self {
180        self.multiplier = Some(multiplier);
181        self
182    }
183
184    pub fn total_wait_max(mut self, total_wait_max: Duration) -> Self {
185        self.total_wait_max = Some(total_wait_max);
186        self
187    }
188
189    pub fn build(self) -> ExponentialBackoff {
190        ExponentialBackoff {
191            duration: self.duration.unwrap_or(Duration::from_millis(25)),
192            max_jitter: self.max_jitter.unwrap_or(Duration::from_millis(25)),
193            multiplier: self.multiplier.unwrap_or(3),
194            total_wait_max: self.total_wait_max.unwrap_or(Duration::from_secs(120)),
195            individual_wait_max: Duration::from_secs(30),
196        }
197    }
198}
199
200impl Strategy for ExponentialBackoff {
201    fn backoff(&self, attempts: usize, time_spent: crate::time::Instant) -> Option<Duration> {
202        if time_spent.elapsed() > self.total_wait_max {
203            return None;
204        }
205        let mut duration = self.duration;
206        for _ in 0..(attempts.saturating_sub(1)) {
207            duration *= self.multiplier;
208            if duration > self.individual_wait_max {
209                duration = self.individual_wait_max;
210            }
211        }
212        let distr = rand::distributions::Uniform::new_inclusive(Duration::ZERO, self.max_jitter);
213        let jitter = rand::thread_rng().sample(distr);
214        let wait = duration + jitter;
215        Some(wait)
216    }
217}
218
219/// Builder for [`Retry`]
220#[derive(Default, Debug, Copy, Clone)]
221pub struct RetryBuilder<S> {
222    retries: Option<usize>,
223    strategy: S,
224}
225
226impl RetryBuilder<ExponentialBackoff> {
227    pub fn new() -> Self {
228        Self {
229            retries: Some(5),
230            strategy: ExponentialBackoff::default(),
231        }
232    }
233}
234
235/// Builder for [`Retry`].
236///
237/// # Example
238/// ```ignore
239/// use xmtp_common::retry::RetryBuilder;
240///
241/// RetryBuilder::default()
242///     .retries(5)
243///     .with_strategy(xmtp_common::ExponentialBackoff::default())
244///     .build();
245/// ```
246impl<S: Strategy> RetryBuilder<S> {
247    pub fn build(self) -> Retry<S> {
248        let mut retry = Retry {
249            retries: 5usize,
250            strategy: self.strategy,
251        };
252
253        if let Some(retries) = self.retries {
254            retry.retries = retries;
255        }
256
257        retry
258    }
259
260    /// Specify the  of retries to allow
261    pub fn retries(mut self, retries: usize) -> Self {
262        self.retries = Some(retries);
263        self
264    }
265
266    pub fn with_strategy<St: Strategy>(self, strategy: St) -> RetryBuilder<St> {
267        RetryBuilder {
268            retries: self.retries,
269            strategy,
270        }
271    }
272}
273
274impl Retry {
275    /// Get the builder for [`Retry`]
276    pub fn builder() -> RetryBuilder<ExponentialBackoff> {
277        RetryBuilder::new()
278    }
279}
280
281/// Retry but for an async context
282/// ```
283/// use xmtp_common::{retry_async, retry::{RetryableError, Retry}};
284/// use thiserror::Error;
285/// use tokio::sync::mpsc;
286///
287/// #[derive(Debug, Error)]
288/// enum MyError {
289///     #[error("A retryable error")]
290///     Retryable,
291///     #[error("An error we don't want to retry")]
292///     NotRetryable
293/// }
294///
295/// impl RetryableError for MyError {
296///     fn is_retryable(&self) -> bool {
297///         match self {
298///             Self::Retryable => true,
299///             _=> false,
300///         }
301///     }
302/// }
303///
304/// async fn fallable_fn(rx: &mut mpsc::Receiver<usize>) -> Result<(), MyError> {
305///     if rx.recv().await.unwrap() == 2 {
306///         return Ok(());
307///     }
308///     Err(MyError::Retryable)
309/// }
310///
311/// #[tokio::main(flavor = "current_thread")]
312/// async fn main() -> Result<(), MyError> {
313///
314///     let (tx, mut rx) = mpsc::channel(3);
315///
316///     for i in 0..3 {
317///         tx.send(i).await.unwrap();
318///     }
319///     retry_async!(Retry::default(), (async {
320///         fallable_fn(&mut rx).await
321///     }))
322/// }
323/// ```
324#[macro_export]
325macro_rules! retry_async {
326    ($retry: expr, $code: tt) => {{
327        use tracing::Instrument as _;
328        #[allow(unused)]
329        use $crate::retry::RetryableError;
330        let mut attempts = 0;
331        let time_spent = $crate::time::Instant::now();
332        let span = tracing::trace_span!("retry");
333        loop {
334            let span = span.clone();
335            #[allow(clippy::redundant_closure_call)]
336            let res = $code.instrument(span).await;
337            match res {
338                Ok(v) => break Ok(v),
339                Err(e) => {
340                    if (&e).is_retryable() && attempts < $retry.retries() {
341                        tracing::warn!(
342                            "retrying function that failed with error={}",
343                            e.to_string()
344                        );
345                        if let Some(d) = $retry.backoff(attempts, time_spent) {
346                            attempts += 1;
347                            $crate::time::sleep(d).await;
348                        } else {
349                            tracing::warn!("retry strategy exceeded max wait time");
350                            break Err(e);
351                        }
352                    } else {
353                        tracing::trace!("error is not retryable. {}", e);
354                        break Err(e);
355                    }
356                }
357            }
358        }
359    }};
360}
361
362#[macro_export]
363macro_rules! retryable {
364    ($error: ident) => {{
365        #[allow(unused)]
366        use $crate::retry::RetryableError;
367        $error.is_retryable()
368    }};
369    ($error: expr) => {{
370        use $crate::retry::RetryableError;
371        $error.is_retryable()
372    }};
373}
374
375#[cfg(test)]
376pub(crate) mod tests {
377    use super::*;
378
379    use thiserror::Error;
380    use tokio::sync::mpsc;
381
382    #[derive(Debug, Error)]
383    enum SomeError {
384        #[error("this is a retryable error")]
385        ARetryableError,
386        #[error("Dont retry")]
387        DontRetryThis,
388    }
389
390    impl RetryableError for SomeError {
391        fn is_retryable(&self) -> bool {
392            matches!(self, Self::ARetryableError)
393        }
394    }
395
396    fn retry_error_fn() -> Result<(), SomeError> {
397        Err(SomeError::ARetryableError)
398    }
399
400    fn retryable_with_args(foo: usize, name: String, list: &Vec<String>) -> Result<(), SomeError> {
401        println!("I am {foo} of {name} with items {list:?}");
402        Err(SomeError::ARetryableError)
403    }
404
405    #[xmtp_macro::test]
406    async fn it_retries_twice_and_succeeds() {
407        let mut i = 0;
408        let mut test_fn = || -> Result<(), SomeError> {
409            if i == 2 {
410                return Ok(());
411            }
412            i += 1;
413            retry_error_fn()?;
414            Ok(())
415        };
416
417        retry_async!(Retry::default(), (async { test_fn() })).unwrap();
418    }
419
420    #[xmtp_macro::test]
421    async fn it_works_with_random_args() {
422        let mut i = 0;
423        let list = vec!["String".into(), "Foo".into()];
424        let mut test_fn = || -> Result<(), SomeError> {
425            if i == 2 {
426                return Ok(());
427            }
428            i += 1;
429            retryable_with_args(i, "Hello".to_string(), &list)
430        };
431
432        retry_async!(Retry::default(), (async { test_fn() })).unwrap();
433    }
434
435    #[xmtp_macro::test]
436    async fn it_fails_on_three_retries() {
437        let closure = || -> Result<(), SomeError> {
438            retry_error_fn()?;
439            Ok(())
440        };
441        let result: Result<(), SomeError> = retry_async!(Retry::default(), (async { closure() }));
442
443        assert!(result.is_err())
444    }
445
446    #[xmtp_macro::test]
447    async fn it_only_runs_non_retryable_once() {
448        let mut attempts = 0;
449        let mut test_fn = || -> Result<(), SomeError> {
450            attempts += 1;
451            Err(SomeError::DontRetryThis)
452        };
453
454        let _r = retry_async!(Retry::default(), (async { test_fn() }));
455
456        assert_eq!(attempts, 1);
457    }
458
459    #[xmtp_macro::test]
460    async fn it_works_async() {
461        async fn retryable_async_fn(rx: &mut mpsc::Receiver<usize>) -> Result<(), SomeError> {
462            let val = rx.recv().await.unwrap();
463            if val == 2 {
464                return Ok(());
465            }
466            // do some work
467            crate::time::sleep(core::time::Duration::from_nanos(100)).await;
468            Err(SomeError::ARetryableError)
469        }
470
471        let (tx, mut rx) = mpsc::channel(3);
472
473        for i in 0..3 {
474            tx.send(i).await.unwrap();
475        }
476        retry_async!(
477            Retry::default(),
478            (async { retryable_async_fn(&mut rx).await })
479        )
480        .unwrap();
481        assert!(rx.is_empty());
482    }
483
484    #[xmtp_macro::test]
485    async fn it_works_async_mut() {
486        async fn retryable_async_fn(data: &mut usize) -> Result<(), SomeError> {
487            if *data == 2 {
488                return Ok(());
489            }
490            *data += 1;
491            // do some work
492            crate::time::sleep(core::time::Duration::from_nanos(100)).await;
493            Err(SomeError::ARetryableError)
494        }
495
496        let mut data: usize = 0;
497        retry_async!(
498            Retry::default(),
499            (async { retryable_async_fn(&mut data).await })
500        )
501        .unwrap();
502    }
503
504    #[xmtp_macro::test]
505    fn backoff_retry() {
506        let backoff_retry = Retry::default();
507        let time_spent = crate::time::Instant::now();
508        assert!(backoff_retry.backoff(1, time_spent).unwrap().as_millis() - 50 <= 25);
509        assert!(backoff_retry.backoff(2, time_spent).unwrap().as_millis() - 150 <= 25);
510        assert!(backoff_retry.backoff(3, time_spent).unwrap().as_millis() - 450 <= 25);
511    }
512}