xmtp_api_d14n/protocol/extractors/
welcomes.rs

1use chrono::{DateTime, Utc};
2use xmtp_proto::ConversionError;
3use xmtp_proto::types::{
4    Cursor, WelcomeMessage, WelcomeMessageBuilder, WelcomeMessageV1, WelcomePointer,
5};
6
7use crate::protocol::traits::EnvelopeVisitor;
8use crate::protocol::{ExtractionError, Extractor};
9use xmtp_proto::mls_v1::welcome_message::WelcomePointer as V3ProtoWelcomePointer;
10use xmtp_proto::mls_v1::welcome_message_input::{
11    V1 as ProtoWelcomeMessageV1, WelcomePointer as WelcomeMessageWelcomePointer,
12};
13use xmtp_proto::xmtp::xmtpv4::envelopes::UnsignedOriginatorEnvelope;
14
15/// Type to extract a Welcome Message from Originator Envelopes
16#[derive(Default)]
17pub struct WelcomeMessageExtractor {
18    cursor: Cursor,
19    created_ns: DateTime<Utc>,
20    welcome_message: Option<WelcomeMessageBuilder>,
21}
22
23impl Extractor for WelcomeMessageExtractor {
24    type Output = Result<WelcomeMessage, ExtractionError>;
25
26    fn get(self) -> Self::Output {
27        let Self {
28            cursor,
29            created_ns,
30            welcome_message,
31        } = self;
32        if let Some(mut gm) = welcome_message {
33            gm.cursor(cursor);
34            gm.created_ns(created_ns);
35            Ok(gm.build()?)
36        } else {
37            Err(ExtractionError::Conversion(ConversionError::Missing {
38                item: "welcome_message",
39                r#type: std::any::type_name::<WelcomeMessage>(),
40            }))
41        }
42    }
43}
44
45impl EnvelopeVisitor<'_> for WelcomeMessageExtractor {
46    type Error = ConversionError;
47
48    fn visit_unsigned_originator(
49        &mut self,
50        envelope: &UnsignedOriginatorEnvelope,
51    ) -> Result<(), Self::Error> {
52        self.cursor = Cursor::new(envelope.originator_sequence_id, envelope.originator_node_id);
53        self.created_ns = DateTime::from_timestamp_nanos(envelope.originator_ns);
54        Ok(())
55    }
56
57    fn visit_welcome_message_v1(
58        &mut self,
59        message: &ProtoWelcomeMessageV1,
60    ) -> Result<(), Self::Error> {
61        let mut builder = WelcomeMessage::builder();
62        builder.variant(WelcomeMessageV1 {
63            installation_key: message.installation_key.as_slice().try_into()?,
64            data: message.data.clone(),
65            hpke_public_key: message.hpke_public_key.clone(),
66            wrapper_algorithm: message.wrapper_algorithm.try_into()?,
67            welcome_metadata: message.welcome_metadata.clone(),
68        });
69        self.welcome_message = Some(builder);
70        Ok(())
71    }
72
73    fn visit_welcome_pointer(
74        &mut self,
75        message: &WelcomeMessageWelcomePointer,
76    ) -> Result<(), Self::Error> {
77        let mut builder = WelcomeMessage::builder();
78        builder.variant(WelcomePointer {
79            installation_key: message.installation_key.as_slice().try_into()?,
80            welcome_pointer: message.welcome_pointer.clone(),
81            hpke_public_key: message.hpke_public_key.clone(),
82            wrapper_algorithm: message.wrapper_algorithm.try_into()?,
83        });
84        self.welcome_message = Some(builder);
85        Ok(())
86    }
87}
88
89#[derive(Default)]
90pub struct V3WelcomeMessageExtractor {
91    welcome_message: WelcomeMessageBuilder,
92}
93
94impl Extractor for V3WelcomeMessageExtractor {
95    type Output = Result<WelcomeMessage, ConversionError>;
96
97    fn get(self) -> Self::Output {
98        self.welcome_message.build()
99    }
100}
101
102impl EnvelopeVisitor<'_> for V3WelcomeMessageExtractor {
103    type Error = ConversionError;
104
105    fn visit_v3_welcome_message(
106        &mut self,
107        message: &xmtp_proto::mls_v1::welcome_message::V1,
108    ) -> Result<(), Self::Error> {
109        let originator_node_id = xmtp_configuration::Originators::WELCOME_MESSAGES;
110
111        self.welcome_message
112            .cursor(Cursor::new(message.id, originator_node_id))
113            .created_ns(DateTime::from_timestamp_nanos(message.created_ns as i64))
114            .variant(
115                WelcomeMessageV1::builder()
116                    .installation_key(message.installation_key.as_slice().try_into()?)
117                    .data(message.data.clone())
118                    .hpke_public_key(message.hpke_public_key.clone())
119                    .wrapper_algorithm(message.wrapper_algorithm.try_into()?)
120                    .welcome_metadata(message.welcome_metadata.clone())
121                    .build()?,
122            );
123        Ok(())
124    }
125
126    fn visit_v3_welcome_pointer(
127        &mut self,
128        message: &V3ProtoWelcomePointer,
129    ) -> Result<(), Self::Error> {
130        let originator_node_id = xmtp_configuration::Originators::WELCOME_MESSAGES;
131        self.welcome_message
132            .cursor(Cursor::new(message.id, originator_node_id))
133            .created_ns(DateTime::from_timestamp_nanos(message.created_ns as i64))
134            .variant(
135                WelcomePointer::builder()
136                    .installation_key(message.installation_key.as_slice().try_into()?)
137                    .welcome_pointer(message.welcome_pointer.clone())
138                    .hpke_public_key(message.hpke_public_key.clone())
139                    .wrapper_algorithm(message.wrapper_algorithm.try_into()?)
140                    .build()?,
141            );
142        Ok(())
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use xmtp_proto::xmtp::mls::message_contents::WelcomeWrapperAlgorithm;
149
150    use super::*;
151    use crate::protocol::ProtocolEnvelope;
152    use crate::protocol::extractors::test_utils::*;
153
154    #[xmtp_common::test]
155    fn test_extract_welcome_message() {
156        let installation_key = xmtp_common::rand_vec::<32>();
157        let data = xmtp_common::rand_vec::<64>();
158        let hpke_public_key = xmtp_common::rand_vec::<32>();
159
160        let envelope = TestEnvelopeBuilder::new()
161            .with_originator_node_id(123)
162            .with_originator_sequence_id(456)
163            .with_originator_ns(789)
164            .with_welcome_message_full(
165                installation_key.clone(),
166                data.clone(),
167                hpke_public_key.clone(),
168                WelcomeWrapperAlgorithm::XwingMlkem768Draft6.into(),
169                vec![1, 2, 3],
170            )
171            .build();
172
173        let mut extractor = WelcomeMessageExtractor::default();
174        envelope.accept(&mut extractor).unwrap();
175        let welcome_message = extractor.get();
176
177        let msg = welcome_message.unwrap();
178        assert_eq!(msg.cursor, Cursor::new(456u64, 123u32));
179        assert_eq!(msg.created_ns.timestamp_nanos_opt().unwrap(), 789);
180        let v1 = msg.as_v1().unwrap();
181        assert_eq!(v1.installation_key, installation_key);
182        assert_eq!(v1.data, data);
183        assert_eq!(v1.hpke_public_key, hpke_public_key);
184        assert_eq!(
185            v1.wrapper_algorithm,
186            WelcomeWrapperAlgorithm::XwingMlkem768Draft6
187        );
188        assert_eq!(v1.welcome_metadata, vec![1, 2, 3]);
189    }
190}