xmtp_api_d14n/middleware/multi_node_client/
client.rs1use crate::middleware::multi_node_client::{errors::MultiNodeClientError, gateway_api::*};
2use prost::bytes::Bytes;
3use tokio::sync::OnceCell;
4use xmtp_api_grpc::{ClientBuilder, GrpcClient, error::GrpcError};
5use xmtp_common::time::Duration;
6use xmtp_proto::api::{ApiClientError, Client, IsConnectedCheck};
7
8#[derive(Clone)]
11pub struct MultiNodeClient {
12 pub gateway_client: GrpcClient,
13 pub inner: OnceCell<GrpcClient>,
14 pub timeout: Duration,
15 pub node_client_template: ClientBuilder,
16}
17
18impl MultiNodeClient {
22 async fn init_inner(&self) -> Result<&GrpcClient, ApiClientError<MultiNodeClientError>> {
23 self.inner
24 .get_or_try_init(|| async {
25 let nodes = get_nodes(&self.gateway_client, &self.node_client_template).await?;
26 let fastest_node = get_fastest_node(nodes, self.timeout).await?;
27 Ok(fastest_node)
28 })
29 .await
30 }
31}
32
33#[xmtp_common::async_trait]
36impl Client for MultiNodeClient {
37 type Error = GrpcError;
38 type Stream = <GrpcClient as Client>::Stream;
39
40 async fn request(
41 &self,
42 request: http::request::Builder,
43 path: http::uri::PathAndQuery,
44 body: Bytes,
45 ) -> Result<http::Response<Bytes>, ApiClientError<Self::Error>> {
46 let inner = self
47 .init_inner()
48 .await
49 .map_err(|e| ApiClientError::<GrpcError>::Other(Box::new(e)))?;
50
51 inner.request(request, path, body).await
52 }
53
54 async fn stream(
55 &self,
56 request: http::request::Builder,
57 path: http::uri::PathAndQuery,
58 body: Bytes,
59 ) -> Result<http::Response<Self::Stream>, ApiClientError<Self::Error>> {
60 let inner = self
61 .init_inner()
62 .await
63 .map_err(|e| ApiClientError::<GrpcError>::Other(Box::new(e)))?;
64
65 inner.stream(request, path, body).await
66 }
67
68 fn fake_stream(&self) -> http::Response<Self::Stream> {
69 self.gateway_client.fake_stream()
70 }
71}
72
73#[xmtp_common::async_trait]
74impl IsConnectedCheck for MultiNodeClient {
75 async fn is_connected(&self) -> bool {
76 self.gateway_client.is_connected().await
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use crate::{
84 ReadWriteClient,
85 middleware::{MiddlewareBuilder, MultiNodeClientBuilder},
86 protocol::{InMemoryCursorStore, NoCursorStore},
87 queries::D14nClient,
88 };
89 use std::sync::Arc;
90 use xmtp_configuration::{GrpcUrls, PAYER_WRITE_FILTER};
91 use xmtp_proto::api::Query;
92 use xmtp_proto::api_client::{ApiBuilder, NetConnectConfig};
93 use xmtp_proto::prelude::XmtpMlsClient;
94 use xmtp_proto::types::GroupId;
95
96 fn is_tls_enabled() -> bool {
97 url::Url::parse(GrpcUrls::GATEWAY)
98 .expect("valid gateway url")
99 .scheme()
100 == "https"
101 }
102
103 fn create_in_memory_cursor_store() -> Arc<InMemoryCursorStore> {
104 Arc::new(InMemoryCursorStore::default())
105 }
106
107 fn create_gateway_builder() -> ClientBuilder {
108 let mut gateway_builder = GrpcClient::builder();
109 gateway_builder.set_host(GrpcUrls::GATEWAY.to_string());
110 gateway_builder.set_tls(is_tls_enabled());
111 gateway_builder
112 }
113
114 fn create_node_builder() -> ClientBuilder {
115 let mut node_builder = GrpcClient::builder();
116 node_builder.set_tls(is_tls_enabled());
117 node_builder
118 }
119
120 fn create_multinode_client_builder() -> MultiNodeClientBuilder {
121 let mut multi_node_builder = MultiNodeClientBuilder::default();
122 multi_node_builder
123 .set_gateway_builder(create_gateway_builder())
124 .unwrap();
125 multi_node_builder
126 .set_node_client_builder(create_node_builder())
127 .unwrap();
128 multi_node_builder
129 .set_timeout(Duration::from_millis(1000))
130 .unwrap();
131 multi_node_builder
132 }
133
134 fn create_multinode_client() -> MultiNodeClient {
135 let multi_node_builder = create_multinode_client_builder();
136 multi_node_builder.build().unwrap()
137 }
138
139 fn create_d14n_client()
140 -> D14nClient<ReadWriteClient<MultiNodeClient, GrpcClient>, NoCursorStore> {
141 let rw = ReadWriteClient::builder()
142 .read(create_multinode_client_builder().build().unwrap())
143 .write(create_gateway_builder().build().unwrap())
144 .filter(PAYER_WRITE_FILTER)
145 .build()
146 .unwrap();
147
148 D14nClient::new(rw, NoCursorStore).unwrap()
149 }
150
151 fn create_node_client_template(tls: bool) -> xmtp_api_grpc::ClientBuilder {
152 let mut client_builder = GrpcClient::builder();
153 client_builder.set_tls(tls);
154 client_builder.set_host("http://placeholder".to_string());
156 client_builder
157 }
158
159 #[xmtp_common::test]
160 fn tls_guard_accepts_matching_https_tls_true() {
161 let t = create_node_client_template(true);
162 validate_tls_guard(&t, "https://example.com:443").expect("should accept");
163 }
164
165 #[xmtp_common::test]
166 fn tls_guard_accepts_matching_http_tls_false() {
167 let t = create_node_client_template(false);
168 validate_tls_guard(&t, "http://example.com:80").expect("should accept");
169 }
170
171 #[xmtp_common::test]
172 fn tls_guard_rejects_https_with_plain_template() {
173 let t = create_node_client_template(false);
174 let err = validate_tls_guard(&t, "https://example.com:443")
175 .err()
176 .unwrap();
177 let msg = format!("{err}");
178 assert!(msg.contains("tls channel"));
179 }
180
181 #[xmtp_common::test]
182 fn tls_guard_rejects_http_with_tls_template() {
183 let t = create_node_client_template(true);
184 let err = validate_tls_guard(&t, "http://example.com:80")
185 .err()
186 .unwrap();
187 let msg = format!("{err}");
188 assert!(msg.contains("tls channel"));
189 }
190
191 #[xmtp_common::test]
193 async fn build_multinode_as_d14n() {
194 use crate::D14nClient;
195 use xmtp_proto::prelude::ApiBuilder;
196
197 let gateway_builder = create_gateway_builder();
199 let node_builder = create_node_builder();
200
201 let mut multi_node_builder = MultiNodeClientBuilder::default();
203
204 multi_node_builder
207 .set_gateway_builder(gateway_builder.clone())
208 .expect("gateway set on multi-node");
209
210 multi_node_builder
211 .set_node_client_builder(node_builder)
212 .expect("node set on multi-node");
213
214 multi_node_builder
217 .set_timeout(xmtp_common::time::Duration::from_millis(1000))
218 .unwrap();
219
220 let cursor_store = create_in_memory_cursor_store();
225 let multi_node_client = multi_node_builder.build().unwrap();
226 let gateway_client = gateway_builder.build().unwrap();
227
228 let rw = ReadWriteClient::builder()
229 .read(multi_node_client)
230 .write(gateway_client)
231 .filter(PAYER_WRITE_FILTER)
232 .build()
233 .unwrap();
234 let _d14n = D14nClient::new(rw, cursor_store).unwrap();
236 }
237
238 #[xmtp_common::test]
240 async fn build_multinode_as_standalone() {
241 let gateway_builder = create_gateway_builder();
242 let node_builder = create_node_builder();
243 let mut multi_node_builder = MultiNodeClientBuilder::default();
244 multi_node_builder
245 .set_gateway_builder(gateway_builder.clone())
246 .expect("gateway set on multi-node");
247
248 multi_node_builder
249 .set_node_client_builder(node_builder)
250 .expect("node set on multi-node");
251
252 multi_node_builder
253 .set_timeout(xmtp_common::time::Duration::from_millis(100))
254 .unwrap();
255
256 let _ = multi_node_builder
257 .build()
258 .expect("failed to build multi-node client");
259 }
260
261 #[xmtp_common::test]
262 async fn d14n_request_latest_group_message() {
263 let client = create_d14n_client();
264 let id: GroupId = GroupId::from(vec![]);
265 let response = client.query_latest_group_message(id).await;
266 match response {
267 Err(e) => {
268 let err_str = e.to_string();
269 assert!(err_str.contains("missing field group_message"));
272 }
273 Ok(_) => panic!("expected error for empty group id"),
274 }
275 }
276
277 #[xmtp_common::test]
278 async fn multinode_request_latest_group_message() {
279 use crate::d14n::GetNewestEnvelopes;
280 let client = create_multinode_client();
281 let mut endpoint = GetNewestEnvelopes::builder().topic(vec![]).build().unwrap();
282 let response = endpoint.query(&client).await.unwrap();
283 assert!(!response.results.is_empty());
284 }
285}