xmtp_common/
stream_handles.rs

1//! Consistent Stream behavior between WebAssembly and Native utilizing `tokio::task::spawn` in native and
2//! `wasm_bindgen_futures::spawn` for web.
3
4use crate::{MaybeSend, MaybeSync, if_native, if_wasm};
5
6pub type GenericStreamHandle<O> = dyn StreamHandle<StreamOutput = O>;
7
8#[derive(thiserror::Error, Debug)]
9pub enum StreamHandleError {
10    #[error("Result Channel closed")]
11    ChannelClosed,
12    #[error("The stream was closed")]
13    StreamClosed,
14    #[error("Stream Cancelled")]
15    Cancelled,
16    #[error("Stream Panicked With {0}")]
17    Panicked(String),
18    #[cfg(not(target_arch = "wasm32"))]
19    #[error(transparent)]
20    JoinHandleError(#[from] tokio::task::JoinError),
21}
22/// A handle to a spawned Stream
23/// the spawned stream can be 'joined` by awaiting its Future implementation.
24/// All spawned tasks are detached, so waiting the handle is not required.
25#[xmtp_macro::async_trait]
26pub trait StreamHandle: MaybeSend + MaybeSync {
27    /// The Output type for the stream
28    type StreamOutput;
29
30    /// Asynchronously waits for the stream to be fully spawned
31    async fn wait_for_ready(&mut self);
32    /// Signal the stream to end
33    /// Does not wait for the stream to end, so will not receive the result of stream.
34    fn end(&self);
35
36    // Its better to:
37    // `StreamHandle: Future<Output = Result<Self::StreamOutput,StreamHandleError>>`
38    // but then crate::spawn` generates `Unused future must be used` since
39    // `async fn` desugars to `fn() -> impl Future`. There's no way
40    // to get rid of that warning, so we separate the future impl to here.
41    // See this rust-playground for an example:
42    // https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=a2a88b144c9459176e8fae41ee569553
43    /// Join the task back to the current thread, waiting until it ends.
44    async fn join(self) -> Result<Self::StreamOutput, StreamHandleError>;
45
46    /// End the stream and asynchronously wait for it to shutdown, getting the result of its
47    /// execution.
48    async fn end_and_wait(&mut self) -> Result<Self::StreamOutput, StreamHandleError>;
49    /// Get an Abort Handle to the stream.
50    /// This handle may be cloned/sent/etc easily
51    /// and many handles may exist at once.
52    fn abort_handle(&self) -> Box<dyn AbortHandle>;
53}
54
55/// A handle that can be moved/cloned/sent, but can only close the stream.
56pub trait AbortHandle: crate::MaybeSend + crate::MaybeSync {
57    /// Send a signal to end the stream, without waiting for a result.
58    fn end(&self);
59    fn is_finished(&self) -> bool;
60}
61
62if_wasm! {
63pub use wasm::*;
64mod wasm {
65    use std::{
66        future::Future,
67        pin::Pin,
68        task::{Context, Poll},
69    };
70
71    use futures::future::Either;
72
73    use super::*;
74
75    pub struct WasmStreamHandle<T> {
76        result: tokio::sync::oneshot::Receiver<T>,
77        // we only send once but oneshot senders aren't cloneable
78        // so we use mpsc here to keep the `&self` on `end`.
79        closer: tokio::sync::mpsc::Sender<()>,
80        ready: Option<tokio::sync::oneshot::Receiver<()>>,
81    }
82
83    impl<T> Future for WasmStreamHandle<Result<T, StreamHandleError>> {
84        type Output = Result<T, StreamHandleError>;
85
86        fn poll(
87            self: std::pin::Pin<&mut Self>,
88            cx: &mut std::task::Context<'_>,
89        ) -> std::task::Poll<Self::Output> {
90            // safe because we consider `result` to be structurally pinned
91            // pinning: https://doc.rust-lang.org/std/pin/#choosing-pinning-to-be-structural-for-field
92            let result = unsafe { self.map_unchecked_mut(|r| &mut r.result) };
93            result.poll(cx).map(|r| match r {
94                Ok(r) => r,
95                Err(_) => Err(StreamHandleError::ChannelClosed),
96            })
97        }
98    }
99
100    #[xmtp_common::async_trait]
101    impl<T> StreamHandle for WasmStreamHandle<Result<T, StreamHandleError>> {
102        type StreamOutput = T;
103
104        async fn wait_for_ready(&mut self) {
105            if let Some(s) = self.ready.take() {
106                let _ = s.await;
107            }
108        }
109
110        async fn end_and_wait(&mut self) -> Result<Self::StreamOutput, StreamHandleError> {
111            self.end();
112            self.await
113        }
114
115        fn end(&self) {
116            let _ = self.closer.try_send(());
117        }
118
119        fn abort_handle(&self) -> Box<dyn AbortHandle> {
120            Box::new(CloseHandle(self.closer.clone()))
121        }
122
123        async fn join(self) -> Result<Self::StreamOutput, StreamHandleError> {
124            self.await
125        }
126    }
127
128    #[derive(Clone)]
129    pub struct CloseHandle(tokio::sync::mpsc::Sender<()>);
130    impl AbortHandle for CloseHandle {
131        fn end(&self) {
132            let _ = self.0.try_send(());
133        }
134
135        fn is_finished(&self) -> bool {
136            self.0.is_closed()
137        }
138    }
139
140    /// Spawn a future on the `wasm-bindgen` local current-thread executer
141    ///  future does not require `Send`.
142    ///  optionally pass in `ready` to signal when stream will be ready.
143    pub fn spawn<F>(
144        ready: Option<tokio::sync::oneshot::Receiver<()>>,
145        future: F,
146    ) -> impl StreamHandle<StreamOutput = F::Output>
147    where
148        F: Future + 'static,
149        F::Output: 'static,
150    {
151        let (res_tx, res_rx) = tokio::sync::oneshot::channel();
152        let (closer_tx, closer_rx) = tokio::sync::mpsc::channel::<()>(1);
153        let closer_handle = CloserHandle::new(closer_rx);
154
155        let handle = WasmStreamHandle {
156            result: res_rx,
157            closer: closer_tx,
158            ready,
159        };
160        tracing::info!("Spawning local task on web executor");
161        wasm_bindgen_futures::spawn_local(async move {
162            futures::pin_mut!(closer_handle);
163            futures::pin_mut!(future);
164            let value = match futures::future::select(closer_handle, future).await {
165                Either::Left((_, _)) => {
166                    tracing::warn!("stream closed");
167                    Err(StreamHandleError::StreamClosed)
168                }
169                Either::Right((v, _)) => {
170                    tracing::debug!("Future ended with value");
171                    Ok(v)
172                }
173            };
174            let _ = res_tx.send(value);
175            tracing::info!("spawned local future closing");
176        });
177
178        handle
179    }
180
181    struct CloserHandle(tokio::sync::mpsc::Receiver<()>);
182
183    impl CloserHandle {
184        fn new(receiver: tokio::sync::mpsc::Receiver<()>) -> Self {
185            Self(receiver)
186        }
187    }
188
189    impl Future for CloserHandle {
190        type Output = ();
191
192        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
193            match self.0.poll_recv(cx) {
194                Poll::Pending => Poll::Pending,
195                Poll::Ready(Some(_)) => Poll::Ready(()),
196                // if the channel is closed, the task has detached and must
197                // be kept alive for the duration for the program
198                Poll::Ready(None) => Poll::Pending,
199            }
200        }
201    }
202}}
203
204if_native! {
205pub use native::*;
206mod native {
207    use super::*;
208    use std::future::Future;
209    use tokio::task::JoinHandle;
210
211    pub struct TokioStreamHandle<T> {
212        inner: JoinHandle<T>,
213        ready: Option<tokio::sync::oneshot::Receiver<()>>,
214    }
215
216    impl<T> Future for TokioStreamHandle<T> {
217        type Output = Result<T, StreamHandleError>;
218
219        fn poll(
220            self: std::pin::Pin<&mut Self>,
221            cx: &mut std::task::Context<'_>,
222        ) -> std::task::Poll<Self::Output> {
223            // safe because we consider `inner` to be structurally pinned
224            // https://doc.rust-lang.org/std/pin/#choosing-pinning-to-be-structural-for-field
225            let inner = unsafe { self.map_unchecked_mut(|v| &mut v.inner) };
226            inner.poll(cx).map_err(StreamHandleError::from)
227        }
228    }
229
230    #[xmtp_common::async_trait]
231    impl<T: Send> StreamHandle for TokioStreamHandle<T> {
232        type StreamOutput = T;
233
234        async fn wait_for_ready(&mut self) {
235            if let Some(s) = self.ready.take() {
236                let _ = s.await;
237            }
238        }
239
240        fn end(&self) {
241            self.inner.abort();
242        }
243
244        async fn end_and_wait(&mut self) -> Result<Self::StreamOutput, StreamHandleError> {
245            use crate::StreamHandleError::*;
246
247            self.end();
248            match self.await {
249                Err(JoinHandleError(e)) if e.is_panic() => Err(Panicked(e.to_string())),
250                Err(JoinHandleError(e)) if e.is_cancelled() => Err(Cancelled),
251                Ok(t) => Ok(t),
252                Err(e) => Err(e),
253            }
254        }
255
256        fn abort_handle(&self) -> Box<dyn AbortHandle> {
257            Box::new(self.inner.abort_handle())
258        }
259
260        async fn join(self) -> Result<Self::StreamOutput, StreamHandleError> {
261            self.await
262        }
263    }
264
265    impl AbortHandle for tokio::task::AbortHandle {
266        fn end(&self) {
267            self.abort()
268        }
269
270        fn is_finished(&self) -> bool {
271            self.is_finished()
272        }
273    }
274
275    pub fn spawn<F>(
276        ready: Option<tokio::sync::oneshot::Receiver<()>>,
277        future: F,
278    ) -> impl StreamHandle<StreamOutput = F::Output>
279    where
280        F: Future + Send + 'static,
281        F::Output: Send + 'static,
282    {
283        TokioStreamHandle {
284            inner: tokio::task::spawn(future),
285            ready,
286        }
287    }
288
289    crate::if_test! {
290        pub fn spawn_instrumented<F>(
291            ready: Option<tokio::sync::oneshot::Receiver<()>>,
292            future: F,
293        ) -> impl StreamHandle<StreamOutput = F::Output>
294        where
295            F: Future + Send + 'static,
296            F::Output: Send + 'static,
297        {
298            TokioStreamHandle {
299                inner: tokio::task::spawn(future),
300                ready,
301            }
302        }
303    }
304}}