xmtp_api_d14n/middleware/
auth.rs1use arc_swap::ArcSwap;
2use prost::bytes::Bytes;
3use std::sync::Arc;
4use tokio::sync::OnceCell;
5use xmtp_common::{BoxDynError, MaybeSend, MaybeSync};
6use xmtp_proto::api::{ApiClientError, Client, IsConnectedCheck};
7
8#[cfg(not(test))]
9use xmtp_common::time::now_secs;
10#[cfg(test)]
12fn now_secs() -> i64 {
13 1_000_000
14}
15
16#[derive(Clone, Debug, PartialEq, Eq)]
17pub struct Credential {
18 name: http::header::HeaderName,
19 value: http::header::HeaderValue,
20 expires_at_seconds: i64,
21}
22
23impl Credential {
24 pub fn new(
25 name: Option<http::header::HeaderName>,
26 value: http::header::HeaderValue,
27 expires_at_seconds: i64,
28 ) -> Self {
29 Self {
30 name: name.unwrap_or(http::header::AUTHORIZATION),
31 value,
32 expires_at_seconds,
33 }
34 }
35}
36
37#[derive(Default)]
38struct AuthInner {
39 handle: OnceCell<ArcSwap<Credential>>,
40 mutex: tokio::sync::Mutex<()>,
41}
42
43#[derive(Default, Clone)]
44pub struct AuthHandle {
45 inner: Arc<AuthInner>,
46}
47
48impl AuthHandle {
49 pub fn new() -> Self {
50 Self::default()
51 }
52 pub async fn set(&self, credential: Credential) {
53 let mut new = Some(credential);
54 let inner = self
55 .inner
56 .handle
57 .get_or_init(|| async {
58 ArcSwap::from_pointee(new.take().expect("Credential not set"))
59 })
60 .await;
61 if let Some(new) = new {
62 inner.store(Arc::new(new));
63 }
64 }
65 pub fn id(&self) -> usize {
66 Arc::as_ptr(&self.inner) as usize
67 }
68}
69
70#[xmtp_common::async_trait]
71pub trait AuthCallback: MaybeSend + MaybeSync {
72 async fn on_auth_required(&self) -> Result<Credential, BoxDynError>;
73}
74
75#[derive(Clone)]
87pub struct AuthMiddleware<C> {
88 inner: C,
89 handle: AuthHandle,
90 callback: Option<Arc<dyn AuthCallback>>,
91}
92
93impl<C> AuthMiddleware<C> {
94 #[track_caller]
95 pub fn new(
96 inner: C,
97 callback: Option<Arc<dyn AuthCallback>>,
98 handle: Option<AuthHandle>,
99 ) -> Self {
100 assert!(
101 callback.is_some() || handle.is_some(),
102 "Either a callback or a handle must be provided"
103 );
104 Self {
105 inner,
106 handle: handle.unwrap_or_default(),
107 callback,
108 }
109 }
110 async fn get_credential(&self) -> Result<Option<&ArcSwap<Credential>>, BoxDynError> {
111 let arc_swap = if let Some(callback) = &self.callback {
112 let arc_swap = self
113 .handle
114 .inner
115 .handle
116 .get_or_try_init(|| async {
117 let credential = callback.on_auth_required().await?;
118 let arc_swap = ArcSwap::from_pointee(credential);
119 Ok::<_, BoxDynError>(arc_swap)
120 })
121 .await?;
122 Some(arc_swap)
123 } else {
124 self.handle.inner.handle.get()
125 };
126
127 let Some(arc_swap) = arc_swap else {
128 return Err("No auth callback provided and no credentials set. Please set credentials by calling `AuthHandle::set`.".into());
129 };
130
131 let needs_refresh = || arc_swap.load().expires_at_seconds <= now_secs();
132
133 if let Some(callback) = &self.callback
134 && needs_refresh()
135 {
136 let _guard = self.handle.inner.mutex.lock().await;
138 if needs_refresh() {
141 let new_header = callback.on_auth_required().await?;
142 arc_swap.store(Arc::new(new_header));
143 }
144 }
145 Ok(Some(arc_swap))
146 }
147 async fn modify_request<E: std::error::Error>(
148 &self,
149 mut request: http::request::Builder,
150 ) -> Result<http::request::Builder, ApiClientError<E>> {
151 let maybe_credential = self
152 .get_credential()
153 .await
154 .map_err(ApiClientError::<E>::OtherUnretryable)?;
155 if let Some(credential) = maybe_credential {
156 let credential = credential.load();
157 request = request.header(credential.name.clone(), credential.value.clone());
158 }
159 Ok(request)
160 }
161}
162
163#[xmtp_common::async_trait]
164impl<C: Client> Client for AuthMiddleware<C> {
165 type Error = C::Error;
166
167 type Stream = C::Stream;
168
169 async fn request(
170 &self,
171 request: http::request::Builder,
172 path: http::uri::PathAndQuery,
173 body: Bytes,
174 ) -> Result<http::Response<Bytes>, ApiClientError<Self::Error>> {
175 let request = self.modify_request(request).await?;
176 self.inner.request(request, path, body).await
177 }
178
179 async fn stream(
180 &self,
181 request: http::request::Builder,
182 path: http::uri::PathAndQuery,
183 body: Bytes,
184 ) -> Result<http::Response<Self::Stream>, ApiClientError<Self::Error>> {
185 let request = self.modify_request(request).await?;
186 self.inner.stream(request, path, body).await
187 }
188
189 fn fake_stream(&self) -> http::Response<Self::Stream> {
190 self.inner.fake_stream()
191 }
192}
193
194#[xmtp_common::async_trait]
195impl<C: IsConnectedCheck> IsConnectedCheck for AuthMiddleware<C> {
196 async fn is_connected(&self) -> bool {
197 self.inner.is_connected().await
198 }
199}
200
201#[cfg(test)]
202mod tests {
203
204 use super::*;
205 use futures::StreamExt;
206
207 fn credential(offset: i64) -> Credential {
208 let random_name = xmtp_common::rand_string::<16>().to_lowercase();
209 let header_name =
210 http::header::HeaderName::try_from(format!("x-test-header-{random_name}")).unwrap();
211 let random = xmtp_common::rand_string::<16>();
212 let header_value = http::header::HeaderValue::try_from(format!("Bearer {random}")).unwrap();
213 let now = now_secs();
214 Credential::new(Some(header_name), header_value.clone(), now + offset)
215 }
216
217 #[xmtp_common::test]
218 async fn test_auth_handle() {
219 let credential = credential(0);
220 let auth_handle = AuthHandle::new();
221 auth_handle.set(credential.clone()).await;
222 let inner = auth_handle
223 .inner
224 .handle
225 .get()
226 .map(|c| c.load_full())
227 .unwrap();
228 assert_eq!(inner.name, credential.name);
229 assert_eq!(inner.value, credential.value);
230 assert_eq!(inner.expires_at_seconds, credential.expires_at_seconds);
231 }
232
233 struct TestClient {
234 expected_credential: Option<Credential>,
235 }
236
237 impl TestClient {
238 pub fn new(expected_credential: Option<Credential>) -> Self {
239 Self {
240 expected_credential,
241 }
242 }
243 }
244
245 #[xmtp_common::async_trait]
246 impl Client for TestClient {
247 type Error = core::convert::Infallible;
248 type Stream = futures::stream::Once<
249 core::pin::Pin<Box<dyn Future<Output = Result<Bytes, Self::Error>> + Send + Sync>>,
250 >;
251
252 async fn request(
253 &self,
254 request: http::request::Builder,
255 _path: http::uri::PathAndQuery,
256 body: Bytes,
257 ) -> Result<http::Response<Bytes>, ApiClientError<Self::Error>> {
258 let headers = request.headers_ref().unwrap();
259 if let Some(expected_credential) = &self.expected_credential {
260 assert_eq!(
261 headers.get(&expected_credential.name).unwrap(),
262 &expected_credential.value
263 );
264 } else {
265 assert!(headers.is_empty());
266 }
267 Ok(http::Response::new(body))
268 }
269
270 async fn stream(
271 &self,
272 request: http::request::Builder,
273 _path: http::uri::PathAndQuery,
274 body: Bytes,
275 ) -> Result<http::Response<Self::Stream>, ApiClientError<Self::Error>> {
276 let headers = request.headers_ref().unwrap();
277 if let Some(expected_credential) = &self.expected_credential {
278 assert_eq!(
279 headers.get(&expected_credential.name).unwrap(),
280 &expected_credential.value
281 );
282 } else {
283 assert!(headers.is_empty());
284 }
285 Ok(http::Response::new(futures::stream::once(Box::pin(
286 async move { Ok::<_, Self::Error>(body) },
287 ))))
288 }
289
290 fn fake_stream(&self) -> http::Response<Self::Stream> {
291 http::Response::new(futures::stream::once(Box::pin(async move {
292 Ok::<_, Self::Error>(Default::default())
293 })))
294 }
295 }
296
297 impl<C: Client> AuthMiddleware<C> {
298 pub async fn make_requests(&self, expected: Result<(), String>) {
299 let request = http::request::Builder::new();
300 let path = http::uri::PathAndQuery::from_static("/");
301 let body = Bytes::new();
302 let result = self.request(request, path.clone(), body.clone()).await;
303 match (&expected, result) {
304 (Ok(()), Ok(response)) => {
305 assert_eq!(response.status(), http::StatusCode::OK);
306 }
307 (Err(e), Ok(response)) => {
308 panic!("Expected error: {e}, got response: {response:?}");
309 }
310 (Ok(()), Err(e)) => {
311 panic!("Expected Ok, got error: {e}");
312 }
313 (Err(e), Err(res)) => {
314 assert_eq!(e, &res.to_string());
315 }
316 }
317
318 let request = http::request::Builder::new();
319 let result = self.stream(request, path, body).await;
320 match (&expected, result) {
321 (Ok(()), Ok(response)) => {
322 assert_eq!(response.status(), http::StatusCode::OK);
323 }
324 (Err(e), Ok(_)) => {
325 panic!("Expected error: {e}, got Ok");
326 }
327 (Ok(()), Err(e)) => {
328 panic!("Expected Ok, got error: {e}");
329 }
330 (Err(e), Err(res)) => {
331 assert_eq!(e, &res.to_string());
332 }
333 }
334 }
335 }
336
337 struct TestCallback {
338 inner: Credential,
339 count: Arc<std::sync::atomic::AtomicI64>,
340 }
341
342 #[xmtp_common::async_trait]
343 impl AuthCallback for TestCallback {
344 async fn on_auth_required(&self) -> Result<Credential, BoxDynError> {
345 xmtp_common::time::sleep(std::time::Duration::from_millis(10)).await;
347 let mut credential = self.inner.clone();
348 xmtp_common::time::sleep(std::time::Duration::from_millis(10)).await;
349 let count = self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
350 xmtp_common::time::sleep(std::time::Duration::from_millis(10)).await;
351 credential.expires_at_seconds += count;
352 xmtp_common::time::sleep(std::time::Duration::from_millis(10)).await;
353 tracing::debug!("credential: {credential:?}, {}, {count}", now_secs());
354 Ok(credential)
355 }
356 }
357
358 impl TestCallback {
359 pub fn new(credential: Credential, count: Arc<std::sync::atomic::AtomicI64>) -> Self {
360 Self {
361 inner: credential,
362 count,
363 }
364 }
365 }
366
367 xmtp_common::if_native! {
370 #[xmtp_common::test]
371 async fn test_auth_middleware_no_callback_or_handle() {
372 std::panic::catch_unwind(|| {
374 AuthMiddleware::new(TestClient::new(None), None, None);
375 })
376 .unwrap_err();
377 }
378 }
379
380 #[xmtp_common::test]
381 async fn test_auth_middleware_with_no_callback_and_handle() {
382 let credential = credential(0);
383 let auth_handle = AuthHandle::new();
384 let mut middleware =
385 AuthMiddleware::new(TestClient::new(None), None, Some(auth_handle.clone()));
386 middleware
387 .make_requests(Err("No auth callback provided and no credentials set. Please set credentials by calling `AuthHandle::set`.".into()))
388 .await;
389
390 auth_handle.set(credential.clone()).await;
391 middleware.inner.expected_credential = Some(credential.clone());
392 middleware.make_requests(Ok(())).await;
393 }
394
395 #[xmtp_common::test]
396 async fn test_auth_middleware_with_callback_and_no_handle() {
397 let credential = credential(-1);
398 let count = Arc::new(std::sync::atomic::AtomicI64::new(0));
399 let callback = TestCallback::new(credential.clone(), count.clone());
400 let middleware = AuthMiddleware::new(
401 TestClient::new(Some(credential.clone())),
402 Some(Arc::new(callback)),
403 None,
404 );
405 middleware.make_requests(Ok(())).await;
406 middleware.make_requests(Ok(())).await;
407 middleware.make_requests(Ok(())).await;
408 assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 3);
412 }
413
414 #[xmtp_common::test]
415 async fn test_auth_middleware_with_callback_and_handle() {
416 let cred = credential(-1);
417 let count = Arc::new(std::sync::atomic::AtomicI64::new(0));
418 let auth_handle = AuthHandle::new();
419 let callback = TestCallback::new(cred.clone(), count.clone());
420 let mut middleware = AuthMiddleware::new(
421 TestClient::new(Some(cred.clone())),
422 Some(Arc::new(callback)),
423 Some(auth_handle.clone()),
424 );
425 middleware.make_requests(Ok(())).await;
426 middleware.make_requests(Ok(())).await;
427 middleware.make_requests(Ok(())).await;
428 assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 3);
429 let handle_credential = credential(1);
430 auth_handle.set(handle_credential.clone()).await;
431 middleware.inner.expected_credential = Some(handle_credential.clone());
432 middleware.make_requests(Ok(())).await;
433 assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 3);
434 auth_handle.set(cred.clone()).await;
435 middleware.inner.expected_credential = Some(cred.clone());
436 middleware.make_requests(Ok(())).await;
437 middleware.make_requests(Ok(())).await;
438 middleware.make_requests(Ok(())).await;
439 assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 4);
440 }
441
442 #[xmtp_common::test]
443 async fn test_auth_middleware_with_callback_and_handle_concurrent_requests() {
444 let cred = credential(-1);
445 let count = Arc::new(std::sync::atomic::AtomicI64::new(0));
446 let auth_handle = AuthHandle::new();
447 let mut middlewares = vec![];
448 for _ in 0..10 {
449 let middleware = AuthMiddleware::new(
450 TestClient::new(Some(cred.clone())),
451 Some(Arc::new(TestCallback::new(cred.clone(), count.clone()))),
452 Some(auth_handle.clone()),
453 );
454 middlewares.push(middleware);
455 }
456
457 let mut tasks = middlewares
458 .iter()
459 .map(|middleware| async {
460 middleware.make_requests(Ok(())).await;
461 middleware.make_requests(Ok(())).await;
462 middleware.make_requests(Ok(())).await;
463 })
464 .collect::<futures::stream::FuturesUnordered<_>>();
465
466 while let Some(task) = tasks.next().await {
467 let () = task;
468 }
469 assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 3);
470 }
471}