xmtp_macro/
lib.rs

1extern crate proc_macro;
2
3mod logging;
4
5use proc_macro2::*;
6use quote::{quote, quote_spanned};
7use syn::{Data, DeriveInput, parse_macro_input};
8
9use crate::logging::{LogEventInput, get_context_fields, get_doc_comment};
10
11/// A proc macro attribute that wraps the input in an `async_trait` implementation,
12/// delegating to the appropriate `async_trait` implementation based on the target architecture.
13///
14/// On wasm32 architecture, it delegates to `async_trait::async_trait(?Send)`.
15/// On all other architectures, it delegates to `async_trait::async_trait`.
16#[proc_macro_attribute]
17pub fn async_trait(
18    _attr: proc_macro::TokenStream,
19    input: proc_macro::TokenStream,
20) -> proc_macro::TokenStream {
21    let input = syn::parse_macro_input!(input as syn::Item);
22    quote! {
23        #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
24        #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
25        #input
26    }
27    .into()
28}
29
30// This needs to be configurable here, because we can't look at env variables in wasm
31static DISABLE_LOGGING: std::sync::LazyLock<bool> = std::sync::LazyLock::new(|| {
32    std::env::var("CI").is_ok_and(|v| v == "true")
33        || std::env::var("XMTP_TEST_LOGGING").is_ok_and(|v| v == "false")
34});
35
36/// A test macro that delegates to the appropriate test framework based on the target architecture.
37///
38/// On wasm32 architecture, it delegates to `wasm_bindgen_test::wasm_bindgen_test`.
39/// On all other architectures, it delegates to `tokio::test`.
40///
41/// When using with 'rstest', ensure any other test invocations come after rstest invocation.
42/// # Example
43///
44/// ```ignore
45/// #[test]
46/// async fn test_something() {
47///     assert_eq!(2 + 2, 4);
48/// }
49/// ```
50#[proc_macro_attribute]
51pub fn test(
52    attr: proc_macro::TokenStream,
53    body: proc_macro::TokenStream,
54) -> proc_macro::TokenStream {
55    // Parse the input function attributes
56    let mut attributes = Attributes::default();
57    let attribute_parser = syn::meta::parser(|meta| attributes.parse(&meta));
58    syn::parse_macro_input!(attr with attribute_parser);
59
60    // Parse the function as an ItemFn
61    let mut input_fn = syn::parse_macro_input!(body as syn::ItemFn);
62    let is_async = input_fn.sig.asyncness.is_some();
63
64    // Generate the appropriate test attributes
65    let test_attrs = if is_async {
66        let flavor = attributes.flavor();
67
68        if &flavor.value() != "current_thread" {
69            let workers = attributes.worker_threads();
70            quote! {
71                #[cfg_attr(not(all(target_arch = "wasm32", any(target_os = "unknown", target_os = "none"))), tokio::test(flavor = #flavor, worker_threads = #workers))]
72                #[cfg_attr(all(target_arch = "wasm32", any(target_os = "unknown", target_os = "none")), wasm_bindgen_test::wasm_bindgen_test)]
73            }
74        } else {
75            quote! {
76                #[cfg_attr(not(all(target_arch = "wasm32", any(target_os = "unknown", target_os = "none"))), tokio::test(flavor = #flavor))]
77                #[cfg_attr(all(target_arch = "wasm32", any(target_os = "unknown", target_os = "none")), wasm_bindgen_test::wasm_bindgen_test)]
78            }
79        }
80    } else {
81        quote! {
82            #[cfg_attr(not(all(target_arch = "wasm32", any(target_os = "unknown", target_os = "none"))), test)]
83            #[cfg_attr(all(target_arch = "wasm32", any(target_os = "unknown", target_os = "none")), wasm_bindgen_test::wasm_bindgen_test)]
84        }
85    };
86
87    // Transform ? to .unwrap() on functions that return ()
88    let should_transform = attributes.unwrap_try() && returns_unit(&input_fn.sig.output);
89    if should_transform {
90        let input_fn_tokens = quote!(#input_fn);
91        let transformed_tokens = transform_question_marks(input_fn_tokens.into());
92        input_fn = syn::parse_macro_input!(transformed_tokens as syn::ItemFn);
93    }
94
95    let disable_logging = attributes.disable_logging || *DISABLE_LOGGING;
96    if !disable_logging {
97        let init = syn::parse_quote!(xmtp_common::logger(););
98        input_fn.block.stmts.insert(0, init);
99    }
100
101    proc_macro::TokenStream::from(quote! {
102        #test_attrs
103        #input_fn
104    })
105}
106
107#[proc_macro_attribute]
108pub fn build_logging_metadata(
109    _attr: proc_macro::TokenStream,
110    item: proc_macro::TokenStream,
111) -> proc_macro::TokenStream {
112    let input = parse_macro_input!(item as DeriveInput);
113
114    let enum_name = &input.ident;
115    let visibility = &input.vis;
116    let attrs = &input.attrs;
117
118    let Data::Enum(data_enum) = &input.data else {
119        return syn::Error::new_spanned(&input, "log_event_macro can only be used on enums")
120            .to_compile_error()
121            .into();
122    };
123
124    let mut display_arms = Vec::new();
125    let mut metadata_entries = Vec::new();
126    let mut cleaned_variants = Vec::new();
127    let mut metadata_match_arms = Vec::new();
128
129    for variant in &data_enum.variants {
130        let variant_name = &variant.ident;
131        let variant_name_str = variant_name.to_string();
132        let doc_comment = match get_doc_comment(variant) {
133            Ok(dc) => dc,
134            Err(err) => return err.to_compile_error().into(),
135        };
136        let context_fields = get_context_fields(&variant.attrs);
137
138        // Filter out #[context(...)] attributes for the output enum
139        let filtered_attrs: Vec<_> = variant
140            .attrs
141            .iter()
142            .filter(|a| !a.path().is_ident("context"))
143            .collect();
144
145        // Rebuild variant without context attribute
146        let variant_fields = &variant.fields;
147        let variant_discriminant = variant
148            .discriminant
149            .as_ref()
150            .map(|(eq, expr)| quote! { #eq #expr });
151
152        cleaned_variants.push(quote! {
153            #(#filtered_attrs)*
154            #variant_name #variant_fields #variant_discriminant
155        });
156
157        // Display impl arm
158        display_arms.push(quote! {
159            #enum_name::#variant_name => write!(f, #doc_comment),
160        });
161
162        // Metadata entry for the const array
163        let context_fields_tokens: Vec<_> = context_fields.iter().map(|f| quote! { #f }).collect();
164        metadata_entries.push(quote! {
165            crate::EventMetadata {
166                name: #variant_name_str,
167                event: #enum_name::#variant_name,
168                doc: #doc_comment,
169                context_fields: &[#(#context_fields_tokens),*],
170            }
171        });
172
173        // Match arm for the metadata() method
174        metadata_match_arms.push(quote! {
175            #enum_name::#variant_name => &Self::METADATA[#enum_name::#variant_name as usize],
176        });
177    }
178
179    let variant_count = cleaned_variants.len();
180
181    let expanded = quote! {
182        #(#attrs)*
183        #[repr(usize)]
184        #[derive(Clone, Copy, Debug, PartialEq, Eq)]
185        #visibility enum #enum_name {
186            #(#cleaned_variants),*
187        }
188
189        impl #enum_name {
190            /// Metadata for all variants of this enum, indexed by variant discriminant.
191            pub const METADATA: [crate::EventMetadata; #variant_count] = [
192                #(#metadata_entries),*
193            ];
194
195            /// Returns the metadata for this event variant.
196            pub const fn metadata(&self) -> &'static crate::EventMetadata {
197                match self {
198                    #(#metadata_match_arms)*
199                }
200            }
201        }
202
203        impl ::core::fmt::Display for #enum_name {
204            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
205                match self {
206                    #(#display_arms)*
207                }
208            }
209        }
210    };
211
212    expanded.into()
213}
214
215#[proc_macro]
216pub fn log_event(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
217    let input = parse_macro_input!(input as LogEventInput);
218    let event = &input.event;
219    let installation_id = &input.installation_id;
220
221    let provided_names: Vec<String> = input.fields.iter().map(|f| f.name.to_string()).collect();
222    let tracing_fields: Vec<TokenStream> =
223        input.fields.iter().map(|f| f.to_tracing_tokens()).collect();
224
225    // Generate match arms for building context string (non-structured logging only)
226    let context_match_arms: Vec<TokenStream> = input
227        .fields
228        .iter()
229        .enumerate()
230        .map(|(i, f)| {
231            let name_str = &provided_names[i];
232            let value = f.value_tokens();
233            if matches!(f.sigil, Some('%')) {
234                quote! {
235                    #name_str => Some(format!("{}: {}", #name_str, #value))
236                }
237            } else {
238                quote! {
239                    #name_str => Some(format!("{}: {:?}", #name_str, #value))
240                }
241            }
242        })
243        .collect();
244
245    let provided_names_tokens = provided_names.into_iter().map(|n| quote! { #n });
246
247    // Generate the appropriate tracing level
248    let level = match input.level {
249        logging::LogLevel::Info => quote! { ::tracing::Level::INFO },
250        logging::LogLevel::Warn => quote! { ::tracing::Level::WARN },
251        logging::LogLevel::Error => quote! { ::tracing::Level::ERROR },
252    };
253
254    let tracing_call = quote! {
255        ::tracing::event!(
256            #level,
257            #(#tracing_fields,)*
258            "{}",
259            __message
260        );
261    };
262
263    quote! {
264        {
265            const PROVIDED: &[&str] = &[#(#provided_names_tokens),*];
266
267            // Compile-time validation: ensure all required context fields are provided
268            const _: () = #event.metadata().validate_fields(PROVIDED);
269
270            let __meta = #event.metadata();
271
272            // Bind installation_id to a variable to extend its lifetime
273            let __installation_id = #installation_id;
274            // Hex encode last 4 bytes of installation_id
275            let __installation_bytes: &[u8] = __installation_id.as_ref();
276            let __installation_len = __installation_bytes.len();
277            let __installation_last_4 = if __installation_len >= 4 { &__installation_bytes[__installation_len - 4..] } else { __installation_bytes };
278            let __installation_truncated = hex::encode(__installation_last_4);
279
280            // Build message with context for non-structured logging
281            let __message = if ::xmtp_common::is_structured_logging() {
282                // Structured logging: include installation_id and timestamp in message
283                format!("➣ {} {{installation_id: {}, timestamp: {}}}", __meta.doc, __installation_truncated, xmtp_common::time::now_ns())
284            } else {
285                // Non-structured logging: embed context in message for readability
286                let __context_parts: ::std::vec::Vec<String> = __meta.context_fields
287                    .iter()
288                    .filter_map(|&field_name| {
289                        match field_name {
290                            #(#context_match_arms,)*
291                            _ => None,
292                        }
293                    })
294                    .collect();
295
296                let __context_str = __context_parts.join(", ");
297                if __context_str.is_empty() {
298                    format!("➣ {} {{installation_id: {}, timestamp: {}}}", __meta.doc, __installation_truncated, xmtp_common::time::now_ns())
299                } else {
300                    format!("➣ {} {{{__context_str}, installation_id: {}, timestamp: {}}}", __meta.doc, __installation_truncated, xmtp_common::time::now_ns())
301                }
302            };
303
304            #tracing_call
305        }
306    }
307    .into()
308}
309
310// Check if a function's return type is () (unit)
311fn returns_unit(return_type: &syn::ReturnType) -> bool {
312    match return_type {
313        // No explicit return type means it returns ()
314        syn::ReturnType::Default => true,
315
316        // Explicit return type, check if it's ()
317        syn::ReturnType::Type(_, ty) => {
318            if let syn::Type::Tuple(tuple) = &**ty {
319                // Empty tuple () is the unit type
320                tuple.elems.is_empty()
321            } else {
322                false
323            }
324        }
325    }
326}
327
328// Transform ? operators to .unwrap() calls at the token level
329fn transform_question_marks(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
330    let mut result = proc_macro2::TokenStream::new();
331    let tokens = proc_macro2::TokenStream::from(tokens)
332        .into_iter()
333        .peekable();
334
335    for token in tokens {
336        match &token {
337            proc_macro2::TokenTree::Punct(p) if p.as_char() == '?' => {
338                // Get the span from the question mark token
339                let span = p.span();
340
341                // Use quote_spanned! to generate .unwrap() with the original span
342                let unwrap_tokens = quote_spanned! {span=>
343                    .unwrap()
344                };
345
346                result.extend(unwrap_tokens);
347            }
348            proc_macro2::TokenTree::Group(g) => {
349                // Recursively transform tokens in groups
350                let transformed_stream = transform_question_marks(g.stream().into());
351
352                let mut transformed_group = proc_macro2::Group::new(
353                    g.delimiter(),
354                    proc_macro2::TokenStream::from(transformed_stream),
355                );
356
357                // Preserve the span
358                let span = g.span();
359                transformed_group.set_span(span);
360                result.extend(quote!(#transformed_group));
361            }
362            _ => {
363                // Keep other tokens as is
364                result.extend([token]);
365            }
366        }
367    }
368
369    result.into()
370}
371
372#[derive(Default)]
373struct Attributes {
374    flavor: Option<syn::LitStr>,
375    worker_threads: Option<syn::LitInt>,
376    unwrap_try: Option<bool>,
377    disable_logging: bool,
378}
379
380impl Attributes {
381    fn flavor(&self) -> syn::LitStr {
382        self.flavor
383            .as_ref()
384            .cloned()
385            .unwrap_or(syn::LitStr::new("current_thread", Span::call_site()))
386    }
387
388    fn unwrap_try(&self) -> bool {
389        self.unwrap_try.as_ref().is_some_and(|v| *v)
390    }
391
392    fn worker_threads(&self) -> syn::LitInt {
393        self.worker_threads
394            .as_ref()
395            .cloned()
396            .unwrap_or(syn::LitInt::new(
397                &num_cpus::get().to_string(),
398                Span::call_site(),
399            ))
400    }
401}
402
403impl Attributes {
404    fn parse(&mut self, meta: &syn::meta::ParseNestedMeta) -> syn::Result<()> {
405        if meta.path.is_ident("flavor") {
406            self.flavor = Some(meta.value()?.parse()?);
407            return Ok(());
408        } else if meta.path.is_ident("worker_threads") {
409            self.worker_threads = Some(meta.value()?.parse()?);
410            return Ok(());
411        } else if meta.path.is_ident("unwrap_try") {
412            self.unwrap_try = Some(meta.value()?.parse::<syn::LitBool>()?.value());
413            return Ok(());
414        } else if meta.path.is_ident("disable_logging") {
415            self.disable_logging = meta.value()?.parse::<syn::LitBool>()?.value();
416            return Ok(());
417        }
418
419        Err(meta.error("unknown attribute"))
420    }
421}