xmtp_api_d14n/protocol/extractors/
payloads.rs

1use xmtp_common::RetryableError;
2
3use crate::protocol::traits::EnvelopeVisitor;
4use crate::protocol::{EnvelopeError, ExtractionError};
5use xmtp_proto::xmtp::identity::associations::IdentityUpdate;
6use xmtp_proto::xmtp::mls::api::v1::UploadKeyPackageRequest;
7use xmtp_proto::xmtp::mls::api::v1::{GroupMessageInput, WelcomeMessageInput};
8use xmtp_proto::xmtp::xmtpv4::envelopes::client_envelope::Payload;
9
10/// Extract Topics from Envelopes
11#[derive(Default, Clone, Debug)]
12pub struct PayloadExtractor {
13    payload: Option<Payload>,
14}
15
16impl PayloadExtractor {
17    pub fn new() -> Self {
18        Default::default()
19    }
20
21    pub fn get(self) -> Result<Payload, PayloadExtractionError> {
22        self.payload.ok_or(PayloadExtractionError::Failed)
23    }
24}
25
26#[derive(thiserror::Error, Debug)]
27pub enum PayloadExtractionError {
28    #[error("Failed to extract payload, wrong ProtocolMessage?")]
29    Failed,
30}
31
32impl RetryableError for PayloadExtractionError {
33    fn is_retryable(&self) -> bool {
34        false
35    }
36}
37
38impl From<PayloadExtractionError> for EnvelopeError {
39    fn from(err: PayloadExtractionError) -> EnvelopeError {
40        EnvelopeError::Extraction(ExtractionError::Payload(err))
41    }
42}
43
44// TODO: at some point its possible to figure out how to borrow input
45// from the Envelope and return it, but probably requires an entirely new
46// 'accept_borrowed' path as well as some work to deal with the ::decode
47// returning a newly allocated type. Not worth the effort yet.
48impl EnvelopeVisitor<'_> for PayloadExtractor {
49    type Error = PayloadExtractionError; // mostly is infallible
50    fn visit_group_message_input(
51        &mut self,
52        message: &GroupMessageInput,
53    ) -> Result<(), Self::Error> {
54        self.payload = Some(Payload::GroupMessage(message.clone()));
55        Ok(())
56    }
57
58    fn visit_welcome_message_input(
59        &mut self,
60        message: &WelcomeMessageInput,
61    ) -> Result<(), Self::Error> {
62        self.payload = Some(Payload::WelcomeMessage(message.clone()));
63        Ok(())
64    }
65
66    fn visit_upload_key_package(
67        &mut self,
68        kp: &UploadKeyPackageRequest,
69    ) -> Result<(), Self::Error> {
70        self.payload = Some(Payload::UploadKeyPackage(kp.clone()));
71        Ok(())
72    }
73
74    fn visit_identity_update(&mut self, update: &IdentityUpdate) -> Result<(), Self::Error> {
75        self.payload = Some(Payload::IdentityUpdate(update.clone()));
76        Ok(())
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use crate::protocol::Envelope;
84    use crate::protocol::extractors::test_utils::*;
85    use xmtp_proto::mls_v1::{group_message_input, welcome_message_input};
86    use xmtp_proto::xmtp::xmtpv4::envelopes::client_envelope::Payload;
87
88    #[xmtp_common::test]
89    fn test_extract_group_message_payload() {
90        let envelope = TestEnvelopeBuilder::new()
91            .with_group_message_custom(vec![1, 2, 3], vec![4, 5, 6])
92            .build();
93        let payload = envelope.payload().unwrap();
94
95        match payload {
96            Payload::GroupMessage(msg) => {
97                assert!(msg.version.is_some());
98                let m = msg.version.unwrap();
99                let group_message_input::Version::V1(group_message_input::V1 { data, .. }) = m;
100                assert_eq!(data, vec![1, 2, 3]);
101            }
102            _ => panic!("Expected GroupMessage payload"),
103        }
104    }
105
106    #[xmtp_common::test]
107    fn test_extract_welcome_message_payload() {
108        let envelope = TestEnvelopeBuilder::new()
109            .with_welcome_message(vec![1, 2, 3])
110            .build();
111        let payload = envelope.payload().unwrap();
112
113        // TODO: test this using the WelcomeMessageExtractor
114        match payload {
115            Payload::WelcomeMessage(msg) => {
116                assert!(msg.version.is_some());
117                let m = msg.version.unwrap();
118                let welcome_message_input::Version::V1(welcome_message_input::V1 {
119                    installation_key,
120                    ..
121                }) = m
122                else {
123                    panic!("Expected WelcomeMessageVersion::V1");
124                };
125                assert_eq!(installation_key, vec![1, 2, 3]);
126            }
127            _ => panic!("Expected WelcomeMessage payload"),
128        }
129    }
130
131    #[xmtp_common::test]
132    fn test_extract_welcome_pointer_payload() {
133        let installation_key = xmtp_common::rand_vec::<32>();
134        let welcome_pointer = xmtp_common::rand_vec::<32>();
135        let hpke_public_key = xmtp_common::rand_vec::<32>();
136        let wrapper_algorithm = 2;
137        let envelope = TestEnvelopeBuilder::new()
138            .with_welcome_pointer(
139                installation_key.clone(),
140                welcome_pointer.clone(),
141                hpke_public_key.clone(),
142                wrapper_algorithm,
143            )
144            .build();
145        let payload = envelope.payload().unwrap();
146
147        // TODO: test this using the WelcomeMessageExtractor
148        match payload {
149            Payload::WelcomeMessage(msg) => {
150                assert!(msg.version.is_some());
151                let m = msg.version.unwrap();
152                let welcome_message_input::Version::WelcomePointer(wp) = m else {
153                    panic!("Expected WelcomeMessageVersion::WelcomePointer");
154                };
155                assert_eq!(wp.installation_key, installation_key);
156                assert_eq!(wp.welcome_pointer, welcome_pointer);
157                assert_eq!(wp.hpke_public_key, hpke_public_key);
158                assert_eq!(wp.wrapper_algorithm, wrapper_algorithm);
159            }
160            _ => panic!("Expected WelcomeMessage payload"),
161        }
162    }
163
164    #[xmtp_common::test]
165    fn test_extract_key_package_payload() {
166        let envelope = TestEnvelopeBuilder::new()
167            .with_key_package_custom(vec![1, 2, 3])
168            .build();
169        let payload = envelope.payload().unwrap();
170
171        match payload {
172            Payload::UploadKeyPackage(kp) => {
173                assert!(kp.key_package.is_some());
174                assert!(!kp.is_inbox_id_credential);
175            }
176            _ => panic!("Expected UploadKeyPackage payload"),
177        }
178    }
179
180    #[xmtp_common::test]
181    fn test_extract_identity_update_payload() {
182        let envelope = TestEnvelopeBuilder::new()
183            .with_identity_update_custom("test_inbox".to_string())
184            .build();
185        let payload = envelope.payload().unwrap();
186
187        match payload {
188            Payload::IdentityUpdate(update) => {
189                assert_eq!(update.inbox_id, "test_inbox");
190            }
191            _ => panic!("Expected IdentityUpdate payload"),
192        }
193    }
194
195    #[xmtp_common::test]
196    fn test_extract_no_payload_fails() {
197        let envelope = TestEnvelopeBuilder::new().with_empty_payload().build();
198        let result = envelope.payload();
199        assert!(result.is_err());
200        matches!(
201            result.unwrap_err(),
202            EnvelopeError::Extraction(ExtractionError::Payload(PayloadExtractionError::Failed))
203        );
204    }
205}