xmtp_api_d14n/middleware/
readonly_client.rs

1//! We define a very simple strategy for disabling writes on certain clients.
2
3xmtp_common::if_test! {
4    mod test;
5}
6
7use derive_builder::Builder;
8use prost::bytes::Bytes;
9use xmtp_proto::api::IsConnectedCheck;
10use xmtp_proto::api::{ApiClientError, Client};
11
12const DENY: &[&str] = &[
13    "UploadKeyPackage",
14    "RevokeInstallation",
15    "BatchPublishCommitLog",
16    "SendWelcomeMessages",
17    "RegisterInstallation",
18    "PublishIdentityUpdate",
19];
20
21/// A client that will error on requests that write to the network.
22#[derive(Debug, Builder, Default, Clone)]
23#[builder(public)]
24pub struct ReadonlyClient<Client> {
25    #[builder(public)]
26    pub(super) inner: Client,
27}
28
29impl<C: Clone> ReadonlyClient<C> {
30    pub fn builder() -> ReadonlyClientBuilder<C> {
31        ReadonlyClientBuilder::default()
32    }
33}
34
35#[xmtp_common::async_trait]
36impl<C> Client for ReadonlyClient<C>
37where
38    C: Client,
39{
40    type Error = <C as Client>::Error;
41    type Stream = <C as Client>::Stream;
42
43    async fn request(
44        &self,
45        request: http::request::Builder,
46        path: http::uri::PathAndQuery,
47        body: Bytes,
48    ) -> Result<http::Response<Bytes>, ApiClientError<Self::Error>> {
49        let p = path.path();
50        if DENY.iter().any(|d| p.contains(d)) {
51            return Err(ApiClientError::WritesDisabled);
52        }
53
54        self.inner.request(request, path, body).await
55    }
56
57    async fn stream(
58        &self,
59        request: http::request::Builder,
60        path: http::uri::PathAndQuery,
61        body: Bytes,
62    ) -> Result<http::Response<Self::Stream>, ApiClientError<Self::Error>> {
63        let p = path.path();
64        if DENY.iter().any(|d| p.contains(d)) {
65            return Err(ApiClientError::WritesDisabled);
66        }
67
68        self.inner.stream(request, path, body).await
69    }
70
71    fn fake_stream(&self) -> http::Response<Self::Stream> {
72        self.inner.fake_stream()
73    }
74}
75
76#[xmtp_common::async_trait]
77impl<C> IsConnectedCheck for ReadonlyClient<C>
78where
79    C: IsConnectedCheck,
80{
81    async fn is_connected(&self) -> bool {
82        self.inner.is_connected().await
83    }
84}
85
86xmtp_common::if_test! {
87    use derive_builder::UninitializedFieldError;
88    use xmtp_proto::prelude::ApiBuilder;
89    #[allow(clippy::unwrap_used)]
90    impl<C> ReadonlyClientBuilder<C>
91    where
92        C: ApiBuilder,
93    {
94        pub(crate) fn build_builder(
95            self,
96        ) -> Result<ReadonlyClient<C::Output>, UninitializedFieldError> {
97            Ok(ReadonlyClient {
98                inner: <C as ApiBuilder>::build(
99                    self.inner
100                        .ok_or(UninitializedFieldError::new("read"))
101                        .unwrap(),
102                )
103                .unwrap(),
104            })
105        }
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use crate::{d14n::PublishClientEnvelopes, v3::PublishIdentityUpdate};
112
113    use super::*;
114    use rstest::*;
115
116    use xmtp_proto::{
117        api::{Query, mock::MockNetworkClient},
118        xmtp::xmtpv4::envelopes::ClientEnvelope,
119    };
120    type MockClient = ReadonlyClient<MockNetworkClient>;
121
122    #[fixture]
123    fn ro() -> MockClient {
124        ReadonlyClient {
125            inner: MockNetworkClient::default(),
126        }
127    }
128
129    #[rstest]
130    #[xmtp_common::test(unwrap_try = true)]
131    async fn test_forwards_to_inner(mut ro: MockClient) {
132        ro.inner
133            .expect_request()
134            .times(1)
135            .returning(|_, _, _| Ok(http::Response::new(vec![].into())));
136        let mut e = PublishClientEnvelopes::builder()
137            .envelope(ClientEnvelope::default())
138            .build()?;
139        e.query(&ro).await?;
140    }
141
142    #[rstest]
143    #[xmtp_common::test(unwrap_try = true)]
144    async fn test_errors_on_write(ro: MockClient) {
145        let mut e = PublishIdentityUpdate::builder().build()?;
146        let result = e.query(&ro).await;
147        assert!(matches!(result, Err(ApiClientError::WritesDisabled)));
148    }
149}