xmtp_common/
stream_handles.rs1use crate::{MaybeSend, MaybeSync, if_native, if_wasm};
5
6pub type GenericStreamHandle<O> = dyn StreamHandle<StreamOutput = O>;
7
8#[derive(thiserror::Error, Debug)]
9pub enum StreamHandleError {
10 #[error("Result Channel closed")]
11 ChannelClosed,
12 #[error("The stream was closed")]
13 StreamClosed,
14 #[error("Stream Cancelled")]
15 Cancelled,
16 #[error("Stream Panicked With {0}")]
17 Panicked(String),
18 #[cfg(not(target_arch = "wasm32"))]
19 #[error(transparent)]
20 JoinHandleError(#[from] tokio::task::JoinError),
21}
22#[xmtp_macro::async_trait]
26pub trait StreamHandle: MaybeSend + MaybeSync {
27 type StreamOutput;
29
30 async fn wait_for_ready(&mut self);
32 fn end(&self);
35
36 async fn join(self) -> Result<Self::StreamOutput, StreamHandleError>;
45
46 async fn end_and_wait(&mut self) -> Result<Self::StreamOutput, StreamHandleError>;
49 fn abort_handle(&self) -> Box<dyn AbortHandle>;
53}
54
55pub trait AbortHandle: crate::MaybeSend + crate::MaybeSync {
57 fn end(&self);
59 fn is_finished(&self) -> bool;
60}
61
62if_wasm! {
63pub use wasm::*;
64mod wasm {
65 use std::{
66 future::Future,
67 pin::Pin,
68 task::{Context, Poll},
69 };
70
71 use futures::future::Either;
72
73 use super::*;
74
75 pub struct WasmStreamHandle<T> {
76 result: tokio::sync::oneshot::Receiver<T>,
77 closer: tokio::sync::mpsc::Sender<()>,
80 ready: Option<tokio::sync::oneshot::Receiver<()>>,
81 }
82
83 impl<T> Future for WasmStreamHandle<Result<T, StreamHandleError>> {
84 type Output = Result<T, StreamHandleError>;
85
86 fn poll(
87 self: std::pin::Pin<&mut Self>,
88 cx: &mut std::task::Context<'_>,
89 ) -> std::task::Poll<Self::Output> {
90 let result = unsafe { self.map_unchecked_mut(|r| &mut r.result) };
93 result.poll(cx).map(|r| match r {
94 Ok(r) => r,
95 Err(_) => Err(StreamHandleError::ChannelClosed),
96 })
97 }
98 }
99
100 #[xmtp_common::async_trait]
101 impl<T> StreamHandle for WasmStreamHandle<Result<T, StreamHandleError>> {
102 type StreamOutput = T;
103
104 async fn wait_for_ready(&mut self) {
105 if let Some(s) = self.ready.take() {
106 let _ = s.await;
107 }
108 }
109
110 async fn end_and_wait(&mut self) -> Result<Self::StreamOutput, StreamHandleError> {
111 self.end();
112 self.await
113 }
114
115 fn end(&self) {
116 let _ = self.closer.try_send(());
117 }
118
119 fn abort_handle(&self) -> Box<dyn AbortHandle> {
120 Box::new(CloseHandle(self.closer.clone()))
121 }
122
123 async fn join(self) -> Result<Self::StreamOutput, StreamHandleError> {
124 self.await
125 }
126 }
127
128 #[derive(Clone)]
129 pub struct CloseHandle(tokio::sync::mpsc::Sender<()>);
130 impl AbortHandle for CloseHandle {
131 fn end(&self) {
132 let _ = self.0.try_send(());
133 }
134
135 fn is_finished(&self) -> bool {
136 self.0.is_closed()
137 }
138 }
139
140 pub fn spawn<F>(
144 ready: Option<tokio::sync::oneshot::Receiver<()>>,
145 future: F,
146 ) -> impl StreamHandle<StreamOutput = F::Output>
147 where
148 F: Future + 'static,
149 F::Output: 'static,
150 {
151 let (res_tx, res_rx) = tokio::sync::oneshot::channel();
152 let (closer_tx, closer_rx) = tokio::sync::mpsc::channel::<()>(1);
153 let closer_handle = CloserHandle::new(closer_rx);
154
155 let handle = WasmStreamHandle {
156 result: res_rx,
157 closer: closer_tx,
158 ready,
159 };
160 tracing::info!("Spawning local task on web executor");
161 wasm_bindgen_futures::spawn_local(async move {
162 futures::pin_mut!(closer_handle);
163 futures::pin_mut!(future);
164 let value = match futures::future::select(closer_handle, future).await {
165 Either::Left((_, _)) => {
166 tracing::warn!("stream closed");
167 Err(StreamHandleError::StreamClosed)
168 }
169 Either::Right((v, _)) => {
170 tracing::debug!("Future ended with value");
171 Ok(v)
172 }
173 };
174 let _ = res_tx.send(value);
175 tracing::info!("spawned local future closing");
176 });
177
178 handle
179 }
180
181 struct CloserHandle(tokio::sync::mpsc::Receiver<()>);
182
183 impl CloserHandle {
184 fn new(receiver: tokio::sync::mpsc::Receiver<()>) -> Self {
185 Self(receiver)
186 }
187 }
188
189 impl Future for CloserHandle {
190 type Output = ();
191
192 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
193 match self.0.poll_recv(cx) {
194 Poll::Pending => Poll::Pending,
195 Poll::Ready(Some(_)) => Poll::Ready(()),
196 Poll::Ready(None) => Poll::Pending,
199 }
200 }
201 }
202}}
203
204if_native! {
205pub use native::*;
206mod native {
207 use super::*;
208 use std::future::Future;
209 use tokio::task::JoinHandle;
210
211 pub struct TokioStreamHandle<T> {
212 inner: JoinHandle<T>,
213 ready: Option<tokio::sync::oneshot::Receiver<()>>,
214 }
215
216 impl<T> Future for TokioStreamHandle<T> {
217 type Output = Result<T, StreamHandleError>;
218
219 fn poll(
220 self: std::pin::Pin<&mut Self>,
221 cx: &mut std::task::Context<'_>,
222 ) -> std::task::Poll<Self::Output> {
223 let inner = unsafe { self.map_unchecked_mut(|v| &mut v.inner) };
226 inner.poll(cx).map_err(StreamHandleError::from)
227 }
228 }
229
230 #[xmtp_common::async_trait]
231 impl<T: Send> StreamHandle for TokioStreamHandle<T> {
232 type StreamOutput = T;
233
234 async fn wait_for_ready(&mut self) {
235 if let Some(s) = self.ready.take() {
236 let _ = s.await;
237 }
238 }
239
240 fn end(&self) {
241 self.inner.abort();
242 }
243
244 async fn end_and_wait(&mut self) -> Result<Self::StreamOutput, StreamHandleError> {
245 use crate::StreamHandleError::*;
246
247 self.end();
248 match self.await {
249 Err(JoinHandleError(e)) if e.is_panic() => Err(Panicked(e.to_string())),
250 Err(JoinHandleError(e)) if e.is_cancelled() => Err(Cancelled),
251 Ok(t) => Ok(t),
252 Err(e) => Err(e),
253 }
254 }
255
256 fn abort_handle(&self) -> Box<dyn AbortHandle> {
257 Box::new(self.inner.abort_handle())
258 }
259
260 async fn join(self) -> Result<Self::StreamOutput, StreamHandleError> {
261 self.await
262 }
263 }
264
265 impl AbortHandle for tokio::task::AbortHandle {
266 fn end(&self) {
267 self.abort()
268 }
269
270 fn is_finished(&self) -> bool {
271 self.is_finished()
272 }
273 }
274
275 pub fn spawn<F>(
276 ready: Option<tokio::sync::oneshot::Receiver<()>>,
277 future: F,
278 ) -> impl StreamHandle<StreamOutput = F::Output>
279 where
280 F: Future + Send + 'static,
281 F::Output: Send + 'static,
282 {
283 TokioStreamHandle {
284 inner: tokio::task::spawn(future),
285 ready,
286 }
287 }
288
289 crate::if_test! {
290 pub fn spawn_instrumented<F>(
291 ready: Option<tokio::sync::oneshot::Receiver<()>>,
292 future: F,
293 ) -> impl StreamHandle<StreamOutput = F::Output>
294 where
295 F: Future + Send + 'static,
296 F::Output: Send + 'static,
297 {
298 TokioStreamHandle {
299 inner: tokio::task::spawn(future),
300 ready,
301 }
302 }
303 }
304}}