xmtp_api_d14n/middleware/read_write_client/
client.rs1use derive_builder::Builder;
8use prost::bytes::Bytes;
9use xmtp_proto::api::IsConnectedCheck;
10use xmtp_proto::api::{ApiClientError, Client};
11
12#[derive(Debug, Builder, Default, Clone)]
18#[builder(public)]
19pub struct ReadWriteClient<Read, Write> {
20 #[builder(public)]
21 pub(super) read: Read,
22 #[builder(public)]
23 pub(super) write: Write,
24 #[builder(setter(into), public)]
25 pub(super) filter: String,
26}
27
28impl<Read: Clone, Write: Clone> ReadWriteClient<Read, Write> {
29 pub fn builder() -> ReadWriteClientBuilder<Read, Write> {
30 ReadWriteClientBuilder::default()
31 }
32}
33
34#[xmtp_common::async_trait]
35impl<Read, Write> Client for ReadWriteClient<Read, Write>
36where
37 Read: Client<Error = Write::Error, Stream = Write::Stream>,
38 Write: Client,
39{
40 type Error = <Read as Client>::Error;
41
42 type Stream = <Read as Client>::Stream;
43
44 async fn request(
45 &self,
46 request: http::request::Builder,
47 path: http::uri::PathAndQuery,
48 body: Bytes,
49 ) -> Result<http::Response<Bytes>, ApiClientError<Self::Error>> {
50 if path.path().contains(&self.filter) {
51 self.write.request(request, path, body).await
52 } else {
53 self.read.request(request, path, body).await
54 }
55 }
56
57 async fn stream(
58 &self,
59 request: http::request::Builder,
60 path: http::uri::PathAndQuery,
61 body: Bytes,
62 ) -> Result<http::Response<Self::Stream>, ApiClientError<Self::Error>> {
63 if path.path().contains(&self.filter) {
64 self.write.stream(request, path, body).await
65 } else {
66 self.read.stream(request, path, body).await
67 }
68 }
69
70 fn fake_stream(&self) -> http::Response<Self::Stream> {
71 self.read.fake_stream()
72 }
73}
74
75#[xmtp_common::async_trait]
76impl<R, W> IsConnectedCheck for ReadWriteClient<R, W>
77where
78 R: IsConnectedCheck,
79 W: IsConnectedCheck,
80{
81 async fn is_connected(&self) -> bool {
82 let to_result = |connected: bool| if connected { Ok(()) } else { Err(()) };
84 let read = async { to_result(self.read.is_connected().await) };
85 let write = async { to_result(self.write.is_connected().await) };
86 let result = futures::future::try_join(read, write).await;
87 result.is_ok()
88 }
89}
90
91xmtp_common::if_test! {
92 use derive_builder::UninitializedFieldError;
93 use xmtp_proto::prelude::ApiBuilder;
94 #[allow(clippy::unwrap_used)]
95 impl<R, W> ReadWriteClientBuilder<R, W>
96 where
97 R: ApiBuilder,
98 W: ApiBuilder,
99 {
100 pub(crate) fn build_builder(
101 self,
102 ) -> Result<ReadWriteClient<R::Output, W::Output>, UninitializedFieldError> {
103 Ok(ReadWriteClient {
104 read: <R as ApiBuilder>::build(
105 self.read
106 .ok_or(UninitializedFieldError::new("read"))
107 .unwrap(),
108 )
109 .unwrap(),
110 write: <W as ApiBuilder>::build(
111 self.write
112 .ok_or(UninitializedFieldError::new("write"))
113 .unwrap(),
114 )
115 .unwrap(),
116 filter: self
117 .filter
118 .ok_or(UninitializedFieldError::new("filter"))
119 .unwrap(),
120 })
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use crate::d14n::{PublishClientEnvelopes, QueryEnvelope};
128
129 use super::*;
130 use rstest::*;
131
132 use xmtp_proto::{
133 api::{Query, mock::MockNetworkClient},
134 types::TopicKind,
135 xmtp::xmtpv4::envelopes::ClientEnvelope,
136 };
137 const FILTER: &str = "xmtp.xmtpv4.payer_api.PayerApi";
138 type MockClient = ReadWriteClient<MockNetworkClient, MockNetworkClient>;
139
140 #[fixture]
141 fn rw() -> MockClient {
142 ReadWriteClient {
143 read: MockNetworkClient::default(),
144 write: MockNetworkClient::default(),
145 filter: FILTER.to_string(),
146 }
147 }
148
149 #[rstest]
150 #[xmtp_common::test(unwrap_try = true)]
151 async fn test_writes_when_matches(mut rw: MockClient) {
152 rw.write
153 .expect_request()
154 .times(1)
155 .returning(|_, _, _| Ok(http::Response::new(vec![].into())));
156 let mut e = PublishClientEnvelopes::builder()
157 .envelope(ClientEnvelope::default())
158 .build()?;
159 e.query(&rw).await?;
160 }
161
162 #[rstest]
163 #[xmtp_common::test(unwrap_try = true)]
164 async fn test_reads_when_matches(mut rw: MockClient) {
165 rw.read
166 .expect_request()
167 .times(1)
168 .returning(|_, _, _| Ok(http::Response::new(vec![].into())));
169 let mut e = QueryEnvelope::builder()
170 .topic(TopicKind::GroupMessagesV1.create(vec![]))
171 .last_seen(Default::default())
172 .limit(0)
173 .build()?;
174 e.query(&rw).await?;
175 }
176}