xmtp_api_d14n/middleware/read_write_client/
client.rs

1//! We define a very simple strategy for separating reads/writes for different
2//! grpc calls.
3//! If more control is required we could extend or modify this client implementation
4//! to filter with regex, or let the consumer pass in a closure instead of a static
5//! string filter.
6
7use derive_builder::Builder;
8use prost::bytes::Bytes;
9use xmtp_proto::api::IsConnectedCheck;
10use xmtp_proto::api::{ApiClientError, Client};
11
12/// A client which holds two clients
13/// and decides on a read/write strategy based on a given service str
14/// if the query path contains a match for the given filter,
15/// the client will write with the write client.
16/// For all other queries it does a read.
17#[derive(Debug, Builder, Default, Clone)]
18#[builder(public)]
19pub struct ReadWriteClient<Read, Write> {
20    #[builder(public)]
21    pub(super) read: Read,
22    #[builder(public)]
23    pub(super) write: Write,
24    #[builder(setter(into), public)]
25    pub(super) filter: String,
26}
27
28impl<Read: Clone, Write: Clone> ReadWriteClient<Read, Write> {
29    pub fn builder() -> ReadWriteClientBuilder<Read, Write> {
30        ReadWriteClientBuilder::default()
31    }
32}
33
34#[xmtp_common::async_trait]
35impl<Read, Write> Client for ReadWriteClient<Read, Write>
36where
37    Read: Client<Error = Write::Error, Stream = Write::Stream>,
38    Write: Client,
39{
40    type Error = <Read as Client>::Error;
41
42    type Stream = <Read as Client>::Stream;
43
44    async fn request(
45        &self,
46        request: http::request::Builder,
47        path: http::uri::PathAndQuery,
48        body: Bytes,
49    ) -> Result<http::Response<Bytes>, ApiClientError<Self::Error>> {
50        if path.path().contains(&self.filter) {
51            self.write.request(request, path, body).await
52        } else {
53            self.read.request(request, path, body).await
54        }
55    }
56
57    async fn stream(
58        &self,
59        request: http::request::Builder,
60        path: http::uri::PathAndQuery,
61        body: Bytes,
62    ) -> Result<http::Response<Self::Stream>, ApiClientError<Self::Error>> {
63        if path.path().contains(&self.filter) {
64            self.write.stream(request, path, body).await
65        } else {
66            self.read.stream(request, path, body).await
67        }
68    }
69
70    fn fake_stream(&self) -> http::Response<Self::Stream> {
71        self.read.fake_stream()
72    }
73}
74
75#[xmtp_common::async_trait]
76impl<R, W> IsConnectedCheck for ReadWriteClient<R, W>
77where
78    R: IsConnectedCheck,
79    W: IsConnectedCheck,
80{
81    async fn is_connected(&self) -> bool {
82        // This implementation gives concurrent execution with early return.
83        let to_result = |connected: bool| if connected { Ok(()) } else { Err(()) };
84        let read = async { to_result(self.read.is_connected().await) };
85        let write = async { to_result(self.write.is_connected().await) };
86        let result = futures::future::try_join(read, write).await;
87        result.is_ok()
88    }
89}
90
91xmtp_common::if_test! {
92    use derive_builder::UninitializedFieldError;
93    use xmtp_proto::prelude::ApiBuilder;
94    #[allow(clippy::unwrap_used)]
95    impl<R, W> ReadWriteClientBuilder<R, W>
96    where
97        R: ApiBuilder,
98        W: ApiBuilder,
99    {
100        pub(crate) fn build_builder(
101            self,
102        ) -> Result<ReadWriteClient<R::Output, W::Output>, UninitializedFieldError> {
103            Ok(ReadWriteClient {
104                read: <R as ApiBuilder>::build(
105                    self.read
106                        .ok_or(UninitializedFieldError::new("read"))
107                        .unwrap(),
108                )
109                .unwrap(),
110                write: <W as ApiBuilder>::build(
111                    self.write
112                        .ok_or(UninitializedFieldError::new("write"))
113                        .unwrap(),
114                )
115                .unwrap(),
116                filter: self
117                    .filter
118                    .ok_or(UninitializedFieldError::new("filter"))
119                    .unwrap(),
120            })
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use crate::d14n::{PublishClientEnvelopes, QueryEnvelope};
128
129    use super::*;
130    use rstest::*;
131
132    use xmtp_proto::{
133        api::{Query, mock::MockNetworkClient},
134        types::TopicKind,
135        xmtp::xmtpv4::envelopes::ClientEnvelope,
136    };
137    const FILTER: &str = "xmtp.xmtpv4.payer_api.PayerApi";
138    type MockClient = ReadWriteClient<MockNetworkClient, MockNetworkClient>;
139
140    #[fixture]
141    fn rw() -> MockClient {
142        ReadWriteClient {
143            read: MockNetworkClient::default(),
144            write: MockNetworkClient::default(),
145            filter: FILTER.to_string(),
146        }
147    }
148
149    #[rstest]
150    #[xmtp_common::test(unwrap_try = true)]
151    async fn test_writes_when_matches(mut rw: MockClient) {
152        rw.write
153            .expect_request()
154            .times(1)
155            .returning(|_, _, _| Ok(http::Response::new(vec![].into())));
156        let mut e = PublishClientEnvelopes::builder()
157            .envelope(ClientEnvelope::default())
158            .build()?;
159        e.query(&rw).await?;
160    }
161
162    #[rstest]
163    #[xmtp_common::test(unwrap_try = true)]
164    async fn test_reads_when_matches(mut rw: MockClient) {
165        rw.read
166            .expect_request()
167            .times(1)
168            .returning(|_, _, _| Ok(http::Response::new(vec![].into())));
169        let mut e = QueryEnvelope::builder()
170            .topic(TopicKind::GroupMessagesV1.create(vec![]))
171            .last_seen(Default::default())
172            .limit(0)
173            .build()?;
174        e.query(&rw).await?;
175    }
176}