xmtp_common/test/
traced_test.rs

1/// Tests that can assert on tracing logs in a tokio threaded context
2use std::{io, sync::Arc};
3
4use parking_lot::Mutex;
5use tracing_subscriber::fmt;
6
7thread_local! {
8    pub static LOG_BUFFER: TestWriter = TestWriter::new();
9}
10
11/// Thread local writer which stores logs in memory
12#[derive(Default)]
13pub struct TestWriter(Arc<Mutex<Vec<u8>>>);
14
15impl TestWriter {
16    pub fn new() -> Self {
17        Self(Arc::new(Mutex::new(vec![])))
18    }
19
20    pub fn as_string(&self) -> String {
21        let buf = self.0.lock();
22        String::from_utf8(buf.clone()).expect("Not valid UTF-8")
23    }
24
25    pub fn clear(&self) {
26        let mut buf = self.0.lock();
27        buf.clear();
28    }
29    pub fn flush(&self) {
30        let mut buf = self.0.lock();
31        std::io::Write::flush(&mut *buf).unwrap();
32    }
33}
34
35impl io::Write for TestWriter {
36    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
37        let mut this = self.0.lock();
38        // still print logs for tests
39        print!("{}", String::from_utf8_lossy(buf));
40        Vec::<u8>::write(&mut this, buf)
41    }
42
43    fn flush(&mut self) -> io::Result<()> {
44        let mut this = self.0.lock();
45        Vec::<u8>::flush(&mut this)
46    }
47}
48
49impl Clone for TestWriter {
50    fn clone(&self) -> Self {
51        Self(self.0.clone())
52    }
53}
54
55impl fmt::MakeWriter<'_> for TestWriter {
56    type Writer = TestWriter;
57
58    fn make_writer(&self) -> Self::Writer {
59        self.clone()
60    }
61}
62/*
63/// Only works with current-thread
64#[inline]
65pub fn traced_test<Fut>(f: impl Fn() -> Fut)
66where
67    Fut: futures::Future<Output = ()>,
68{
69    LOG_BUFFER.with(|buf| {
70        let rt = tokio::runtime::Builder::new_current_thread()
71            .thread_name("tracing-test")
72            .enable_time()
73            .enable_io()
74            .build()
75            .unwrap();
76        buf.clear();
77
78        let subscriber = fmt::Subscriber::builder()
79            .with_env_filter(format!("{}=debug", env!("CARGO_PKG_NAME")))
80            .with_writer(buf.clone())
81            .with_level(true)
82            .with_ansi(false)
83            .finish();
84
85        let dispatch = tracing::Dispatch::new(subscriber);
86        tracing::dispatcher::with_default(&dispatch, || {
87            rt.block_on(f());
88        });
89
90        buf.clear();
91    });
92}
93*/
94
95#[macro_export]
96macro_rules! traced_test {
97    ( $f:expr ) => {{
98        use tracing_subscriber::fmt;
99        use $crate::traced_test::TestWriter;
100
101        $crate::traced_test::LOG_BUFFER.with(|buf| {
102            let rt = tokio::runtime::Builder::new_current_thread()
103                .thread_name("tracing-test")
104                .enable_time()
105                .enable_io()
106                .build()
107                .unwrap();
108            buf.clear();
109
110            let subscriber = fmt::Subscriber::builder()
111                .with_env_filter(format!(
112                    "xmtp_db=debug,xmtp_api=debug,xmtp_id=debug,{}=debug",
113                    env!("CARGO_PKG_NAME")
114                ))
115                .with_writer(buf.clone())
116                .with_level(true)
117                .with_ansi(false)
118                .finish();
119
120            let dispatch = tracing::Dispatch::new(subscriber);
121            tracing::dispatcher::with_default(&dispatch, || {
122                rt.block_on($f);
123            });
124
125            buf.clear();
126        });
127    }};
128}
129
130/// macro that can assert logs in tests.
131/// Note: tests that use this must be used in `traced_test` function
132/// and only with tokio's `current` runtime.
133#[macro_export]
134macro_rules! assert_logged {
135    ( $search:expr , $occurrences:expr ) => {
136        $crate::traced_test::LOG_BUFFER.with(|buf| {
137            let lines = {
138                buf.flush();
139                buf.as_string()
140            };
141            let lines = lines.lines();
142            let actual = lines.filter(|line| line.contains($search)).count();
143            if actual != $occurrences {
144                panic!(
145                    "Expected '{}' to be logged {} times, but was logged {} times instead",
146                    $search, $occurrences, actual
147                );
148            }
149        })
150    };
151}