xmtp_common/
test.rs

1//! Common Test Utilities
2use crate::time::Expired;
3use rand::distributions::DistString;
4use rand::{Rng, distributions::Alphanumeric, seq::IteratorRandom};
5use std::collections::HashMap;
6use std::sync::LazyLock;
7use std::{future::Future, sync::OnceLock};
8use tokio::sync;
9
10mod macros;
11
12mod openmls;
13pub use openmls::*;
14
15crate::if_native! {
16    use parking_lot::Mutex;
17    pub mod traced_test;
18    pub use traced_test::TestWriter;
19    mod logger;
20
21    use once_cell::sync::Lazy;
22    static REPLACE_IDS: Lazy<Mutex<HashMap<String, String>>> = Lazy::new(|| Mutex::new(HashMap::new()));
23}
24
25static INIT: OnceLock<()> = OnceLock::new();
26
27use toxiproxy_rust::TOXIPROXY;
28
29static TOXIPROXY_TEST_LOCK: LazyLock<sync::Mutex<()>> = LazyLock::new(|| sync::Mutex::new(()));
30
31// TODO: can add this to the macro
32pub async fn toxiproxy_test<T, F: AsyncFn() -> T>(f: F) -> T {
33    let _g = TOXIPROXY_TEST_LOCK.lock().await;
34    TOXIPROXY.reset().await.unwrap();
35    f().await
36}
37
38pub trait Generate {
39    /// generate a struct containing random data
40    fn generate() -> Self;
41}
42
43/// Replace inbox id in Contextual output with a name (i.e Alix, Bo, etc.)
44#[derive(Default)]
45pub struct TestLogReplace {
46    #[allow(unused)]
47    ids: HashMap<String, String>,
48}
49
50impl TestLogReplace {
51    pub fn add(&mut self, id: &str, name: &str) {
52        crate::wasm_or_native! {
53            wasm => { let _ = (id, name); },
54            native => {
55                self.ids.insert(id.to_string(), name.to_string());
56                let mut ids = REPLACE_IDS.lock();
57                ids.insert(id.to_string(), name.to_string());
58            },
59        }
60    }
61}
62
63// remove ids for replacement from map on drop
64impl Drop for TestLogReplace {
65    fn drop(&mut self) {
66        crate::wasm_or_native! {
67            wasm => {},
68            native => {
69                let mut ids = REPLACE_IDS.lock();
70                for id in self.ids.keys() {
71                    let _ = ids.remove(id.as_str());
72                }
73            },
74        }
75    }
76}
77
78#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
79pub fn logger_layer<S>() -> impl tracing_subscriber::Layer<S>
80where
81    S: tracing::Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>,
82{
83    use tracing_subscriber::{
84        EnvFilter, Layer,
85        fmt::{self, format},
86    };
87    let structured = std::env::var("STRUCTURED");
88    let contextual = std::env::var("CONTEXTUAL");
89    let show_spans = std::env::var("SHOW_SPAN_FIELDS");
90
91    let is_structured = matches!(structured, Ok(s) if s == "true" || s == "1");
92    let is_contextual = matches!(contextual, Ok(c) if c == "true" || c == "1");
93    let show_spans = matches!(show_spans, Ok(c) if c == "true" || c == "1");
94    let filter = || {
95        EnvFilter::builder()
96            .with_default_directive(tracing::metadata::LevelFilter::INFO.into())
97            .from_env()
98            .expect("invalid environment log filter")
99    };
100
101    vec![
102        is_structured
103            .then(|| {
104                tracing_subscriber::fmt::layer()
105                    .json()
106                    .with_filter(filter())
107            })
108            .boxed(),
109        is_contextual
110            .then(|| {
111                let processor =
112                    tracing_forest::printer::Printer::new().formatter(logger::Contextual);
113                tracing_forest::ForestLayer::new(processor, tracing_forest::tag::NoTag)
114                    .with_filter(filter())
115            })
116            .boxed(),
117        // default logger
118        (!is_structured && !is_contextual)
119            .then(|| {
120                fmt::layer()
121                    .compact()
122                    .with_ansi(true)
123                    .without_time()
124                    .with_test_writer()
125                    .fmt_fields({
126                        format::debug_fn(move |writer, field, value| {
127                            if show_spans && (field.name() != "message") {
128                                write!(writer, ", {}={:?}", field.name(), value)?;
129                            } else if field.name() == "message" {
130                                let mut message = format!("{value:?}");
131                                let ids = REPLACE_IDS.lock();
132                                for (id, name) in ids.iter() {
133                                    message = message.replace(id, name);
134                                    message = message.replace(&crate::fmt::truncate_hex(id), name);
135                                }
136                                write!(writer, "{message}")?;
137                            }
138                            Ok(())
139                        })
140                    })
141                    .with_filter(filter())
142            })
143            .boxed(),
144    ]
145}
146
147/// A simple test logger that defaults to the INFO level
148pub fn logger() {
149    use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
150
151    crate::wasm_or_native! {
152        wasm => {
153            use tracing_wasm::{ConsoleConfig, WASMLayerConfigBuilder, WASMLayer};
154            INIT.get_or_init(|| {
155                let filter = tracing_subscriber::EnvFilter::builder().parse("debug").unwrap();
156
157                // this makes error logs in CI a little easier to read
158                let config = if cfg!(feature = "test-utils") {
159                    WASMLayerConfigBuilder::new()
160                        .set_console_config(ConsoleConfig::ReportWithoutConsoleColor)
161                        .build()
162                } else {
163                    WASMLayerConfigBuilder::new()
164                        .set_console_config(ConsoleConfig::ReportWithConsoleColor)
165                        .build()
166                };
167                tracing_subscriber::registry()
168                    .with(WASMLayer::new(config))
169                    .with(filter)
170                    .init();
171
172                console_error_panic_hook::set_once();
173                wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
174            });
175        },
176        native => {
177            INIT.get_or_init(|| {
178                let _ = tracing_subscriber::registry()
179                    .with(logger_layer())
180                    .try_init();
181            });
182        },
183    }
184}
185
186// Execute once before any tests are run
187#[cfg(all(test, not(target_arch = "wasm32"), feature = "test-utils"))]
188#[ctor::ctor]
189fn ctor_logging_setup() {
190    crate::logger();
191    let _ = fdlimit::raise_fd_limit();
192}
193
194// must be in an arc so we only ever have one subscriber
195crate::if_native! {
196    static SCOPED_SUBSCRIBER: LazyLock<std::sync::Arc<Box<dyn tracing::Subscriber + Send + Sync>>> =
197        LazyLock::new(|| {
198            use tracing_subscriber::layer::SubscriberExt;
199
200            std::sync::Arc::new(Box::new(
201                tracing_subscriber::registry().with(logger_layer()),
202            ))
203        });
204
205    pub fn subscriber() -> impl tracing::Subscriber {
206        (*SCOPED_SUBSCRIBER).clone()
207    }
208}
209
210pub fn rand_hexstring() -> String {
211    let mut rng = crate::rng();
212    let hex_chars = "0123456789abcdef";
213    let v: String = (0..40)
214        .map(|_| hex_chars.chars().choose(&mut rng).unwrap())
215        .collect();
216
217    format!("0x{v}")
218}
219
220pub fn rand_account_address() -> String {
221    Alphanumeric.sample_string(&mut crate::rng(), 42)
222}
223
224pub fn rand_u64() -> u64 {
225    crate::rng().r#gen()
226}
227
228pub fn rand_i64() -> i64 {
229    crate::rng().r#gen()
230}
231
232pub fn tmp_path() -> String {
233    let db_name = crate::rand_string::<24>();
234    crate::wasm_or_native_expr! {
235        native => format!("{}/{db_name}.db3", std::env::temp_dir().to_str().unwrap()),
236        wasm => format!("test_db/{db_name}.db3"),
237    }
238}
239
240pub fn rand_time() -> i64 {
241    let mut rng = rand::thread_rng();
242    rng.gen_range(0..1_000_000_000)
243}
244
245pub async fn wait_for_some<F, Fut, T>(f: F) -> Option<T>
246where
247    F: Fn() -> Fut,
248    Fut: Future<Output = Option<T>>,
249{
250    crate::time::timeout(crate::time::Duration::from_secs(20), async {
251        loop {
252            if let Some(r) = f().await {
253                return r;
254            } else {
255                crate::task::yield_now().await;
256            }
257        }
258    })
259    .await
260    .ok()
261}
262
263pub async fn wait_for_ok<F, Fut, T, E>(f: F) -> Result<T, Expired>
264where
265    F: Fn() -> Fut,
266    Fut: Future<Output = Result<T, E>>,
267{
268    crate::time::timeout(crate::time::Duration::from_secs(20), async {
269        loop {
270            if let Ok(r) = f().await {
271                return r;
272            } else {
273                crate::task::yield_now().await;
274            }
275        }
276    })
277    .await
278}
279
280pub async fn wait_for_eq<F, Fut, T>(f: F, expected: T) -> Result<(), Expired>
281where
282    F: Fn() -> Fut,
283    Fut: Future<Output = T>,
284    T: std::fmt::Debug + PartialEq,
285{
286    let result = crate::time::timeout(crate::time::Duration::from_secs(20), async {
287        loop {
288            let result = f().await;
289            if expected == result {
290                return result;
291            } else {
292                crate::task::yield_now().await;
293            }
294        }
295    })
296    .await?;
297
298    assert_eq!(expected, result);
299    Ok(())
300}
301
302pub async fn wait_for_ge<F, Fut, T>(f: F, expected: T) -> Result<(), Expired>
303where
304    F: Fn() -> Fut,
305    Fut: Future<Output = T>,
306    T: std::fmt::Debug + PartialEq + PartialOrd,
307{
308    crate::time::timeout(crate::time::Duration::from_secs(20), async {
309        loop {
310            let result = f().await;
311            if result >= expected {
312                return result;
313            } else {
314                crate::task::yield_now().await;
315            }
316        }
317    })
318    .await?;
319
320    Ok(())
321}
322
323/// Extension trait for formatting collections of Debug items in tests
324pub trait DebugDisplay {
325    /// Format items as debug output, one per line
326    fn format_list(&self) -> String;
327
328    /// Format items with enumeration (index -- item)
329    fn format_enumerated(&self) -> String;
330}
331
332impl<T: std::fmt::Debug> DebugDisplay for [T] {
333    fn format_list(&self) -> String {
334        self.iter()
335            .map(|item| format!("{:?}", item))
336            .collect::<Vec<_>>()
337            .join("\n")
338    }
339
340    fn format_enumerated(&self) -> String {
341        self.iter()
342            .enumerate()
343            .map(|(i, item)| format!("{} -- {:?}", i, item))
344            .collect::<Vec<_>>()
345            .join("\n")
346    }
347}