1use crate::time::Duration;
20use crate::{MaybeSend, MaybeSync};
21use rand::Rng;
22use std::error::Error;
23use std::sync::Arc;
24
25impl From<Box<dyn RetryableError>> for Box<dyn Error> {
28 fn from(retryable: Box<dyn RetryableError>) -> Box<dyn Error> {
29 retryable
30 }
31}
32
33pub fn arc_retryable_to_error(retryable: Arc<dyn RetryableError>) -> Arc<dyn Error> {
37 retryable
38}
39
40pub type BoxedRetry = Retry<Box<dyn Strategy>>;
41
42pub struct NotSpecialized;
43
44pub trait RetryableError<SP = NotSpecialized>: std::error::Error + MaybeSend + MaybeSync {
47 fn is_retryable(&self) -> bool;
48}
49
50impl<T> RetryableError for &'_ T
51where
52 T: RetryableError,
53{
54 fn is_retryable(&self) -> bool {
55 (**self).is_retryable()
56 }
57}
58
59impl<E: RetryableError> RetryableError for Box<E> {
60 fn is_retryable(&self) -> bool {
61 (**self).is_retryable()
62 }
63}
64
65impl RetryableError for core::convert::Infallible {
66 fn is_retryable(&self) -> bool {
67 unreachable!()
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct Retry<S = ExponentialBackoff> {
74 retries: usize,
75 strategy: S,
76}
77
78impl Default for Retry {
79 fn default() -> Retry {
80 Retry {
81 retries: 5,
82 strategy: ExponentialBackoff::default(),
83 }
84 }
85}
86
87impl<S: Strategy> Retry<S> {
88 pub fn retries(&self) -> usize {
90 self.retries
91 }
92
93 pub fn backoff(&self, attempts: usize, time_spent: crate::time::Instant) -> Option<Duration> {
94 self.strategy.backoff(attempts, time_spent)
95 }
96}
97
98impl<S: Strategy + 'static> Retry<S> {
99 pub fn boxed(self) -> Retry<Box<dyn Strategy>> {
100 Retry {
101 strategy: Box::new(self.strategy),
102 retries: self.retries,
103 }
104 }
105}
106
107pub trait Strategy: MaybeSend + MaybeSync {
109 fn backoff(&self, attempts: usize, time_spent: crate::time::Instant) -> Option<Duration>;
113}
114
115impl Strategy for () {
116 fn backoff(&self, _attempts: usize, _time_spent: crate::time::Instant) -> Option<Duration> {
117 Some(Duration::ZERO)
118 }
119}
120
121impl<S: ?Sized + Strategy> Strategy for Box<S> {
122 fn backoff(&self, attempts: usize, time_spent: crate::time::Instant) -> Option<Duration> {
123 (**self).backoff(attempts, time_spent)
124 }
125}
126
127#[derive(Clone, Debug)]
128pub struct ExponentialBackoff {
129 multiplier: u32,
131 duration: Duration,
133 max_jitter: Duration,
135 total_wait_max: Duration,
137 individual_wait_max: Duration,
139}
140
141impl ExponentialBackoff {
142 pub fn builder() -> ExponentialBackoffBuilder {
143 ExponentialBackoffBuilder::default()
144 }
145}
146
147impl Default for ExponentialBackoff {
148 fn default() -> Self {
149 Self {
150 multiplier: 3,
152 duration: Duration::from_millis(50),
153 total_wait_max: Duration::from_secs(120),
154 individual_wait_max: Duration::from_secs(30),
155 max_jitter: Duration::from_millis(25),
156 }
157 }
158}
159
160#[derive(Default)]
161pub struct ExponentialBackoffBuilder {
162 duration: Option<Duration>,
163 max_jitter: Option<Duration>,
164 multiplier: Option<u32>,
165 total_wait_max: Option<Duration>,
166}
167
168impl ExponentialBackoffBuilder {
169 pub fn duration(mut self, duration: Duration) -> Self {
170 self.duration = Some(duration);
171 self
172 }
173
174 pub fn max_jitter(mut self, max_jitter: Duration) -> Self {
175 self.max_jitter = Some(max_jitter);
176 self
177 }
178
179 pub fn multiplier(mut self, multiplier: u32) -> Self {
180 self.multiplier = Some(multiplier);
181 self
182 }
183
184 pub fn total_wait_max(mut self, total_wait_max: Duration) -> Self {
185 self.total_wait_max = Some(total_wait_max);
186 self
187 }
188
189 pub fn build(self) -> ExponentialBackoff {
190 ExponentialBackoff {
191 duration: self.duration.unwrap_or(Duration::from_millis(25)),
192 max_jitter: self.max_jitter.unwrap_or(Duration::from_millis(25)),
193 multiplier: self.multiplier.unwrap_or(3),
194 total_wait_max: self.total_wait_max.unwrap_or(Duration::from_secs(120)),
195 individual_wait_max: Duration::from_secs(30),
196 }
197 }
198}
199
200impl Strategy for ExponentialBackoff {
201 fn backoff(&self, attempts: usize, time_spent: crate::time::Instant) -> Option<Duration> {
202 if time_spent.elapsed() > self.total_wait_max {
203 return None;
204 }
205 let mut duration = self.duration;
206 for _ in 0..(attempts.saturating_sub(1)) {
207 duration *= self.multiplier;
208 if duration > self.individual_wait_max {
209 duration = self.individual_wait_max;
210 }
211 }
212 let distr = rand::distributions::Uniform::new_inclusive(Duration::ZERO, self.max_jitter);
213 let jitter = rand::thread_rng().sample(distr);
214 let wait = duration + jitter;
215 Some(wait)
216 }
217}
218
219#[derive(Default, Debug, Copy, Clone)]
221pub struct RetryBuilder<S> {
222 retries: Option<usize>,
223 strategy: S,
224}
225
226impl RetryBuilder<ExponentialBackoff> {
227 pub fn new() -> Self {
228 Self {
229 retries: Some(5),
230 strategy: ExponentialBackoff::default(),
231 }
232 }
233}
234
235impl<S: Strategy> RetryBuilder<S> {
247 pub fn build(self) -> Retry<S> {
248 let mut retry = Retry {
249 retries: 5usize,
250 strategy: self.strategy,
251 };
252
253 if let Some(retries) = self.retries {
254 retry.retries = retries;
255 }
256
257 retry
258 }
259
260 pub fn retries(mut self, retries: usize) -> Self {
262 self.retries = Some(retries);
263 self
264 }
265
266 pub fn with_strategy<St: Strategy>(self, strategy: St) -> RetryBuilder<St> {
267 RetryBuilder {
268 retries: self.retries,
269 strategy,
270 }
271 }
272}
273
274impl Retry {
275 pub fn builder() -> RetryBuilder<ExponentialBackoff> {
277 RetryBuilder::new()
278 }
279}
280
281#[macro_export]
325macro_rules! retry_async {
326 ($retry: expr, $code: tt) => {{
327 use tracing::Instrument as _;
328 #[allow(unused)]
329 use $crate::retry::RetryableError;
330 let mut attempts = 0;
331 let time_spent = $crate::time::Instant::now();
332 let span = tracing::trace_span!("retry");
333 loop {
334 let span = span.clone();
335 #[allow(clippy::redundant_closure_call)]
336 let res = $code.instrument(span).await;
337 match res {
338 Ok(v) => break Ok(v),
339 Err(e) => {
340 if (&e).is_retryable() && attempts < $retry.retries() {
341 tracing::warn!(
342 "retrying function that failed with error={}",
343 e.to_string()
344 );
345 if let Some(d) = $retry.backoff(attempts, time_spent) {
346 attempts += 1;
347 $crate::time::sleep(d).await;
348 } else {
349 tracing::warn!("retry strategy exceeded max wait time");
350 break Err(e);
351 }
352 } else {
353 tracing::trace!("error is not retryable. {}", e);
354 break Err(e);
355 }
356 }
357 }
358 }
359 }};
360}
361
362#[macro_export]
363macro_rules! retryable {
364 ($error: ident) => {{
365 #[allow(unused)]
366 use $crate::retry::RetryableError;
367 $error.is_retryable()
368 }};
369 ($error: expr) => {{
370 use $crate::retry::RetryableError;
371 $error.is_retryable()
372 }};
373}
374
375#[cfg(test)]
376pub(crate) mod tests {
377 use super::*;
378
379 use thiserror::Error;
380 use tokio::sync::mpsc;
381
382 #[derive(Debug, Error)]
383 enum SomeError {
384 #[error("this is a retryable error")]
385 ARetryableError,
386 #[error("Dont retry")]
387 DontRetryThis,
388 }
389
390 impl RetryableError for SomeError {
391 fn is_retryable(&self) -> bool {
392 matches!(self, Self::ARetryableError)
393 }
394 }
395
396 fn retry_error_fn() -> Result<(), SomeError> {
397 Err(SomeError::ARetryableError)
398 }
399
400 fn retryable_with_args(foo: usize, name: String, list: &Vec<String>) -> Result<(), SomeError> {
401 println!("I am {foo} of {name} with items {list:?}");
402 Err(SomeError::ARetryableError)
403 }
404
405 #[xmtp_macro::test]
406 async fn it_retries_twice_and_succeeds() {
407 let mut i = 0;
408 let mut test_fn = || -> Result<(), SomeError> {
409 if i == 2 {
410 return Ok(());
411 }
412 i += 1;
413 retry_error_fn()?;
414 Ok(())
415 };
416
417 retry_async!(Retry::default(), (async { test_fn() })).unwrap();
418 }
419
420 #[xmtp_macro::test]
421 async fn it_works_with_random_args() {
422 let mut i = 0;
423 let list = vec!["String".into(), "Foo".into()];
424 let mut test_fn = || -> Result<(), SomeError> {
425 if i == 2 {
426 return Ok(());
427 }
428 i += 1;
429 retryable_with_args(i, "Hello".to_string(), &list)
430 };
431
432 retry_async!(Retry::default(), (async { test_fn() })).unwrap();
433 }
434
435 #[xmtp_macro::test]
436 async fn it_fails_on_three_retries() {
437 let closure = || -> Result<(), SomeError> {
438 retry_error_fn()?;
439 Ok(())
440 };
441 let result: Result<(), SomeError> = retry_async!(Retry::default(), (async { closure() }));
442
443 assert!(result.is_err())
444 }
445
446 #[xmtp_macro::test]
447 async fn it_only_runs_non_retryable_once() {
448 let mut attempts = 0;
449 let mut test_fn = || -> Result<(), SomeError> {
450 attempts += 1;
451 Err(SomeError::DontRetryThis)
452 };
453
454 let _r = retry_async!(Retry::default(), (async { test_fn() }));
455
456 assert_eq!(attempts, 1);
457 }
458
459 #[xmtp_macro::test]
460 async fn it_works_async() {
461 async fn retryable_async_fn(rx: &mut mpsc::Receiver<usize>) -> Result<(), SomeError> {
462 let val = rx.recv().await.unwrap();
463 if val == 2 {
464 return Ok(());
465 }
466 crate::time::sleep(core::time::Duration::from_nanos(100)).await;
468 Err(SomeError::ARetryableError)
469 }
470
471 let (tx, mut rx) = mpsc::channel(3);
472
473 for i in 0..3 {
474 tx.send(i).await.unwrap();
475 }
476 retry_async!(
477 Retry::default(),
478 (async { retryable_async_fn(&mut rx).await })
479 )
480 .unwrap();
481 assert!(rx.is_empty());
482 }
483
484 #[xmtp_macro::test]
485 async fn it_works_async_mut() {
486 async fn retryable_async_fn(data: &mut usize) -> Result<(), SomeError> {
487 if *data == 2 {
488 return Ok(());
489 }
490 *data += 1;
491 crate::time::sleep(core::time::Duration::from_nanos(100)).await;
493 Err(SomeError::ARetryableError)
494 }
495
496 let mut data: usize = 0;
497 retry_async!(
498 Retry::default(),
499 (async { retryable_async_fn(&mut data).await })
500 )
501 .unwrap();
502 }
503
504 #[xmtp_macro::test]
505 fn backoff_retry() {
506 let backoff_retry = Retry::default();
507 let time_spent = crate::time::Instant::now();
508 assert!(backoff_retry.backoff(1, time_spent).unwrap().as_millis() - 50 <= 25);
509 assert!(backoff_retry.backoff(2, time_spent).unwrap().as_millis() - 150 <= 25);
510 assert!(backoff_retry.backoff(3, time_spent).unwrap().as_millis() - 450 <= 25);
511 }
512}