xmtp_api_d14n/middleware/
auth.rs

1use arc_swap::ArcSwap;
2use prost::bytes::Bytes;
3use std::sync::Arc;
4use tokio::sync::OnceCell;
5use xmtp_common::{BoxDynError, MaybeSend, MaybeSync};
6use xmtp_proto::api::{ApiClientError, Client, IsConnectedCheck};
7
8#[cfg(not(test))]
9use xmtp_common::time::now_secs;
10// override now_secs so we don't have flaky tests
11#[cfg(test)]
12fn now_secs() -> i64 {
13    1_000_000
14}
15
16#[derive(Clone, Debug, PartialEq, Eq)]
17pub struct Credential {
18    name: http::header::HeaderName,
19    value: http::header::HeaderValue,
20    expires_at_seconds: i64,
21}
22
23impl Credential {
24    pub fn new(
25        name: Option<http::header::HeaderName>,
26        value: http::header::HeaderValue,
27        expires_at_seconds: i64,
28    ) -> Self {
29        Self {
30            name: name.unwrap_or(http::header::AUTHORIZATION),
31            value,
32            expires_at_seconds,
33        }
34    }
35}
36
37#[derive(Default)]
38struct AuthInner {
39    handle: OnceCell<ArcSwap<Credential>>,
40    mutex: tokio::sync::Mutex<()>,
41}
42
43#[derive(Default, Clone)]
44pub struct AuthHandle {
45    inner: Arc<AuthInner>,
46}
47
48impl AuthHandle {
49    pub fn new() -> Self {
50        Self::default()
51    }
52    pub async fn set(&self, credential: Credential) {
53        let mut new = Some(credential);
54        let inner = self
55            .inner
56            .handle
57            .get_or_init(|| async {
58                ArcSwap::from_pointee(new.take().expect("Credential not set"))
59            })
60            .await;
61        if let Some(new) = new {
62            inner.store(Arc::new(new));
63        }
64    }
65    pub fn id(&self) -> usize {
66        Arc::as_ptr(&self.inner) as usize
67    }
68}
69
70#[xmtp_common::async_trait]
71pub trait AuthCallback: MaybeSend + MaybeSync {
72    async fn on_auth_required(&self) -> Result<Credential, BoxDynError>;
73}
74
75/// Middleware for adding authentication headers to requests.
76///
77/// This middleware will add authentication headers to requests if a callback or handle is provided.
78///
79/// If a callback is provided, it will be called to get the credential when it is expired.
80/// If a handle is provided, it can be used to set the credential.
81///
82/// If only providing a handle, then expired credentials will still be used until the credential is set.
83///
84/// If creating multiple clients, if they share the same handle, then the credential will be shared between them
85/// resulting in less auth callbacks. Auth callbacks are debounced internally to prevent excessive calls.
86#[derive(Clone)]
87pub struct AuthMiddleware<C> {
88    inner: C,
89    handle: AuthHandle,
90    callback: Option<Arc<dyn AuthCallback>>,
91}
92
93impl<C> AuthMiddleware<C> {
94    #[track_caller]
95    pub fn new(
96        inner: C,
97        callback: Option<Arc<dyn AuthCallback>>,
98        handle: Option<AuthHandle>,
99    ) -> Self {
100        assert!(
101            callback.is_some() || handle.is_some(),
102            "Either a callback or a handle must be provided"
103        );
104        Self {
105            inner,
106            handle: handle.unwrap_or_default(),
107            callback,
108        }
109    }
110    async fn get_credential(&self) -> Result<Option<&ArcSwap<Credential>>, BoxDynError> {
111        let arc_swap = if let Some(callback) = &self.callback {
112            let arc_swap = self
113                .handle
114                .inner
115                .handle
116                .get_or_try_init(|| async {
117                    let credential = callback.on_auth_required().await?;
118                    let arc_swap = ArcSwap::from_pointee(credential);
119                    Ok::<_, BoxDynError>(arc_swap)
120                })
121                .await?;
122            Some(arc_swap)
123        } else {
124            self.handle.inner.handle.get()
125        };
126
127        let Some(arc_swap) = arc_swap else {
128            return Err("No auth callback provided and no credentials set. Please set credentials by calling `AuthHandle::set`.".into());
129        };
130
131        let needs_refresh = || arc_swap.load().expires_at_seconds <= now_secs();
132
133        if let Some(callback) = &self.callback
134            && needs_refresh()
135        {
136            // Multiple threads may be racing to run this, so this may require a lock in the future.
137            let _guard = self.handle.inner.mutex.lock().await;
138            // after acquiring the lock, we need to check again if the credential needs to be refreshed so that
139            // if another thread has already refreshed the credential, we don't need to do it again.
140            if needs_refresh() {
141                let new_header = callback.on_auth_required().await?;
142                arc_swap.store(Arc::new(new_header));
143            }
144        }
145        Ok(Some(arc_swap))
146    }
147    async fn modify_request<E: std::error::Error>(
148        &self,
149        mut request: http::request::Builder,
150    ) -> Result<http::request::Builder, ApiClientError<E>> {
151        let maybe_credential = self
152            .get_credential()
153            .await
154            .map_err(ApiClientError::<E>::OtherUnretryable)?;
155        if let Some(credential) = maybe_credential {
156            let credential = credential.load();
157            request = request.header(credential.name.clone(), credential.value.clone());
158        }
159        Ok(request)
160    }
161}
162
163#[xmtp_common::async_trait]
164impl<C: Client> Client for AuthMiddleware<C> {
165    type Error = C::Error;
166
167    type Stream = C::Stream;
168
169    async fn request(
170        &self,
171        request: http::request::Builder,
172        path: http::uri::PathAndQuery,
173        body: Bytes,
174    ) -> Result<http::Response<Bytes>, ApiClientError<Self::Error>> {
175        let request = self.modify_request(request).await?;
176        self.inner.request(request, path, body).await
177    }
178
179    async fn stream(
180        &self,
181        request: http::request::Builder,
182        path: http::uri::PathAndQuery,
183        body: Bytes,
184    ) -> Result<http::Response<Self::Stream>, ApiClientError<Self::Error>> {
185        let request = self.modify_request(request).await?;
186        self.inner.stream(request, path, body).await
187    }
188
189    fn fake_stream(&self) -> http::Response<Self::Stream> {
190        self.inner.fake_stream()
191    }
192}
193
194#[xmtp_common::async_trait]
195impl<C: IsConnectedCheck> IsConnectedCheck for AuthMiddleware<C> {
196    async fn is_connected(&self) -> bool {
197        self.inner.is_connected().await
198    }
199}
200
201#[cfg(test)]
202mod tests {
203
204    use super::*;
205    use futures::StreamExt;
206
207    fn credential(offset: i64) -> Credential {
208        let random_name = xmtp_common::rand_string::<16>().to_lowercase();
209        let header_name =
210            http::header::HeaderName::try_from(format!("x-test-header-{random_name}")).unwrap();
211        let random = xmtp_common::rand_string::<16>();
212        let header_value = http::header::HeaderValue::try_from(format!("Bearer {random}")).unwrap();
213        let now = now_secs();
214        Credential::new(Some(header_name), header_value.clone(), now + offset)
215    }
216
217    #[xmtp_common::test]
218    async fn test_auth_handle() {
219        let credential = credential(0);
220        let auth_handle = AuthHandle::new();
221        auth_handle.set(credential.clone()).await;
222        let inner = auth_handle
223            .inner
224            .handle
225            .get()
226            .map(|c| c.load_full())
227            .unwrap();
228        assert_eq!(inner.name, credential.name);
229        assert_eq!(inner.value, credential.value);
230        assert_eq!(inner.expires_at_seconds, credential.expires_at_seconds);
231    }
232
233    struct TestClient {
234        expected_credential: Option<Credential>,
235    }
236
237    impl TestClient {
238        pub fn new(expected_credential: Option<Credential>) -> Self {
239            Self {
240                expected_credential,
241            }
242        }
243    }
244
245    #[xmtp_common::async_trait]
246    impl Client for TestClient {
247        type Error = core::convert::Infallible;
248        type Stream = futures::stream::Once<
249            core::pin::Pin<Box<dyn Future<Output = Result<Bytes, Self::Error>> + Send + Sync>>,
250        >;
251
252        async fn request(
253            &self,
254            request: http::request::Builder,
255            _path: http::uri::PathAndQuery,
256            body: Bytes,
257        ) -> Result<http::Response<Bytes>, ApiClientError<Self::Error>> {
258            let headers = request.headers_ref().unwrap();
259            if let Some(expected_credential) = &self.expected_credential {
260                assert_eq!(
261                    headers.get(&expected_credential.name).unwrap(),
262                    &expected_credential.value
263                );
264            } else {
265                assert!(headers.is_empty());
266            }
267            Ok(http::Response::new(body))
268        }
269
270        async fn stream(
271            &self,
272            request: http::request::Builder,
273            _path: http::uri::PathAndQuery,
274            body: Bytes,
275        ) -> Result<http::Response<Self::Stream>, ApiClientError<Self::Error>> {
276            let headers = request.headers_ref().unwrap();
277            if let Some(expected_credential) = &self.expected_credential {
278                assert_eq!(
279                    headers.get(&expected_credential.name).unwrap(),
280                    &expected_credential.value
281                );
282            } else {
283                assert!(headers.is_empty());
284            }
285            Ok(http::Response::new(futures::stream::once(Box::pin(
286                async move { Ok::<_, Self::Error>(body) },
287            ))))
288        }
289
290        fn fake_stream(&self) -> http::Response<Self::Stream> {
291            http::Response::new(futures::stream::once(Box::pin(async move {
292                Ok::<_, Self::Error>(Default::default())
293            })))
294        }
295    }
296
297    impl<C: Client> AuthMiddleware<C> {
298        pub async fn make_requests(&self, expected: Result<(), String>) {
299            let request = http::request::Builder::new();
300            let path = http::uri::PathAndQuery::from_static("/");
301            let body = Bytes::new();
302            let result = self.request(request, path.clone(), body.clone()).await;
303            match (&expected, result) {
304                (Ok(()), Ok(response)) => {
305                    assert_eq!(response.status(), http::StatusCode::OK);
306                }
307                (Err(e), Ok(response)) => {
308                    panic!("Expected error: {e}, got response: {response:?}");
309                }
310                (Ok(()), Err(e)) => {
311                    panic!("Expected Ok, got error: {e}");
312                }
313                (Err(e), Err(res)) => {
314                    assert_eq!(e, &res.to_string());
315                }
316            }
317
318            let request = http::request::Builder::new();
319            let result = self.stream(request, path, body).await;
320            match (&expected, result) {
321                (Ok(()), Ok(response)) => {
322                    assert_eq!(response.status(), http::StatusCode::OK);
323                }
324                (Err(e), Ok(_)) => {
325                    panic!("Expected error: {e}, got Ok");
326                }
327                (Ok(()), Err(e)) => {
328                    panic!("Expected Ok, got error: {e}");
329                }
330                (Err(e), Err(res)) => {
331                    assert_eq!(e, &res.to_string());
332                }
333            }
334        }
335    }
336
337    struct TestCallback {
338        inner: Credential,
339        count: Arc<std::sync::atomic::AtomicI64>,
340    }
341
342    #[xmtp_common::async_trait]
343    impl AuthCallback for TestCallback {
344        async fn on_auth_required(&self) -> Result<Credential, BoxDynError> {
345            // Add sleeps so we can test concurrent requests
346            xmtp_common::time::sleep(std::time::Duration::from_millis(10)).await;
347            let mut credential = self.inner.clone();
348            xmtp_common::time::sleep(std::time::Duration::from_millis(10)).await;
349            let count = self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
350            xmtp_common::time::sleep(std::time::Duration::from_millis(10)).await;
351            credential.expires_at_seconds += count;
352            xmtp_common::time::sleep(std::time::Duration::from_millis(10)).await;
353            tracing::debug!("credential: {credential:?}, {}, {count}", now_secs());
354            Ok(credential)
355        }
356    }
357
358    impl TestCallback {
359        pub fn new(credential: Credential, count: Arc<std::sync::atomic::AtomicI64>) -> Self {
360            Self {
361                inner: credential,
362                count,
363            }
364        }
365    }
366
367    // Only run this test on native where we can catch the panic
368    // This should never panic in practice because we only create auth middleware if there is a callback or handle.
369    xmtp_common::if_native! {
370        #[xmtp_common::test]
371        async fn test_auth_middleware_no_callback_or_handle() {
372            // expect a panic when creating the middleware without a callback or handle
373            std::panic::catch_unwind(|| {
374                AuthMiddleware::new(TestClient::new(None), None, None);
375            })
376            .unwrap_err();
377        }
378    }
379
380    #[xmtp_common::test]
381    async fn test_auth_middleware_with_no_callback_and_handle() {
382        let credential = credential(0);
383        let auth_handle = AuthHandle::new();
384        let mut middleware =
385            AuthMiddleware::new(TestClient::new(None), None, Some(auth_handle.clone()));
386        middleware
387            .make_requests(Err("No auth callback provided and no credentials set. Please set credentials by calling `AuthHandle::set`.".into()))
388            .await;
389
390        auth_handle.set(credential.clone()).await;
391        middleware.inner.expected_credential = Some(credential.clone());
392        middleware.make_requests(Ok(())).await;
393    }
394
395    #[xmtp_common::test]
396    async fn test_auth_middleware_with_callback_and_no_handle() {
397        let credential = credential(-1);
398        let count = Arc::new(std::sync::atomic::AtomicI64::new(0));
399        let callback = TestCallback::new(credential.clone(), count.clone());
400        let middleware = AuthMiddleware::new(
401            TestClient::new(Some(credential.clone())),
402            Some(Arc::new(callback)),
403            None,
404        );
405        middleware.make_requests(Ok(())).await;
406        middleware.make_requests(Ok(())).await;
407        middleware.make_requests(Ok(())).await;
408        // 3 calls are expected because the credential starts out being one
409        // second past expiry, then the second of expiry, then has one
410        // second until expiry, so it doesn't need to be refreshed.
411        assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 3);
412    }
413
414    #[xmtp_common::test]
415    async fn test_auth_middleware_with_callback_and_handle() {
416        let cred = credential(-1);
417        let count = Arc::new(std::sync::atomic::AtomicI64::new(0));
418        let auth_handle = AuthHandle::new();
419        let callback = TestCallback::new(cred.clone(), count.clone());
420        let mut middleware = AuthMiddleware::new(
421            TestClient::new(Some(cred.clone())),
422            Some(Arc::new(callback)),
423            Some(auth_handle.clone()),
424        );
425        middleware.make_requests(Ok(())).await;
426        middleware.make_requests(Ok(())).await;
427        middleware.make_requests(Ok(())).await;
428        assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 3);
429        let handle_credential = credential(1);
430        auth_handle.set(handle_credential.clone()).await;
431        middleware.inner.expected_credential = Some(handle_credential.clone());
432        middleware.make_requests(Ok(())).await;
433        assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 3);
434        auth_handle.set(cred.clone()).await;
435        middleware.inner.expected_credential = Some(cred.clone());
436        middleware.make_requests(Ok(())).await;
437        middleware.make_requests(Ok(())).await;
438        middleware.make_requests(Ok(())).await;
439        assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 4);
440    }
441
442    #[xmtp_common::test]
443    async fn test_auth_middleware_with_callback_and_handle_concurrent_requests() {
444        let cred = credential(-1);
445        let count = Arc::new(std::sync::atomic::AtomicI64::new(0));
446        let auth_handle = AuthHandle::new();
447        let mut middlewares = vec![];
448        for _ in 0..10 {
449            let middleware = AuthMiddleware::new(
450                TestClient::new(Some(cred.clone())),
451                Some(Arc::new(TestCallback::new(cred.clone(), count.clone()))),
452                Some(auth_handle.clone()),
453            );
454            middlewares.push(middleware);
455        }
456
457        let mut tasks = middlewares
458            .iter()
459            .map(|middleware| async {
460                middleware.make_requests(Ok(())).await;
461                middleware.make_requests(Ok(())).await;
462                middleware.make_requests(Ok(())).await;
463            })
464            .collect::<futures::stream::FuturesUnordered<_>>();
465
466        while let Some(task) = tasks.next().await {
467            let () = task;
468        }
469        assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 3);
470    }
471}