xmtp_api_d14n/protocol/extractors/
payloads.rs1use xmtp_common::RetryableError;
2
3use crate::protocol::traits::EnvelopeVisitor;
4use crate::protocol::{EnvelopeError, ExtractionError};
5use xmtp_proto::xmtp::identity::associations::IdentityUpdate;
6use xmtp_proto::xmtp::mls::api::v1::UploadKeyPackageRequest;
7use xmtp_proto::xmtp::mls::api::v1::{GroupMessageInput, WelcomeMessageInput};
8use xmtp_proto::xmtp::xmtpv4::envelopes::client_envelope::Payload;
9
10#[derive(Default, Clone, Debug)]
12pub struct PayloadExtractor {
13 payload: Option<Payload>,
14}
15
16impl PayloadExtractor {
17 pub fn new() -> Self {
18 Default::default()
19 }
20
21 pub fn get(self) -> Result<Payload, PayloadExtractionError> {
22 self.payload.ok_or(PayloadExtractionError::Failed)
23 }
24}
25
26#[derive(thiserror::Error, Debug)]
27pub enum PayloadExtractionError {
28 #[error("Failed to extract payload, wrong ProtocolMessage?")]
29 Failed,
30}
31
32impl RetryableError for PayloadExtractionError {
33 fn is_retryable(&self) -> bool {
34 false
35 }
36}
37
38impl From<PayloadExtractionError> for EnvelopeError {
39 fn from(err: PayloadExtractionError) -> EnvelopeError {
40 EnvelopeError::Extraction(ExtractionError::Payload(err))
41 }
42}
43
44impl EnvelopeVisitor<'_> for PayloadExtractor {
49 type Error = PayloadExtractionError; fn visit_group_message_input(
51 &mut self,
52 message: &GroupMessageInput,
53 ) -> Result<(), Self::Error> {
54 self.payload = Some(Payload::GroupMessage(message.clone()));
55 Ok(())
56 }
57
58 fn visit_welcome_message_input(
59 &mut self,
60 message: &WelcomeMessageInput,
61 ) -> Result<(), Self::Error> {
62 self.payload = Some(Payload::WelcomeMessage(message.clone()));
63 Ok(())
64 }
65
66 fn visit_upload_key_package(
67 &mut self,
68 kp: &UploadKeyPackageRequest,
69 ) -> Result<(), Self::Error> {
70 self.payload = Some(Payload::UploadKeyPackage(kp.clone()));
71 Ok(())
72 }
73
74 fn visit_identity_update(&mut self, update: &IdentityUpdate) -> Result<(), Self::Error> {
75 self.payload = Some(Payload::IdentityUpdate(update.clone()));
76 Ok(())
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use crate::protocol::Envelope;
84 use crate::protocol::extractors::test_utils::*;
85 use xmtp_proto::mls_v1::{group_message_input, welcome_message_input};
86 use xmtp_proto::xmtp::xmtpv4::envelopes::client_envelope::Payload;
87
88 #[xmtp_common::test]
89 fn test_extract_group_message_payload() {
90 let envelope = TestEnvelopeBuilder::new()
91 .with_group_message_custom(vec![1, 2, 3], vec![4, 5, 6])
92 .build();
93 let payload = envelope.payload().unwrap();
94
95 match payload {
96 Payload::GroupMessage(msg) => {
97 assert!(msg.version.is_some());
98 let m = msg.version.unwrap();
99 let group_message_input::Version::V1(group_message_input::V1 { data, .. }) = m;
100 assert_eq!(data, vec![1, 2, 3]);
101 }
102 _ => panic!("Expected GroupMessage payload"),
103 }
104 }
105
106 #[xmtp_common::test]
107 fn test_extract_welcome_message_payload() {
108 let envelope = TestEnvelopeBuilder::new()
109 .with_welcome_message(vec![1, 2, 3])
110 .build();
111 let payload = envelope.payload().unwrap();
112
113 match payload {
115 Payload::WelcomeMessage(msg) => {
116 assert!(msg.version.is_some());
117 let m = msg.version.unwrap();
118 let welcome_message_input::Version::V1(welcome_message_input::V1 {
119 installation_key,
120 ..
121 }) = m
122 else {
123 panic!("Expected WelcomeMessageVersion::V1");
124 };
125 assert_eq!(installation_key, vec![1, 2, 3]);
126 }
127 _ => panic!("Expected WelcomeMessage payload"),
128 }
129 }
130
131 #[xmtp_common::test]
132 fn test_extract_welcome_pointer_payload() {
133 let installation_key = xmtp_common::rand_vec::<32>();
134 let welcome_pointer = xmtp_common::rand_vec::<32>();
135 let hpke_public_key = xmtp_common::rand_vec::<32>();
136 let wrapper_algorithm = 2;
137 let envelope = TestEnvelopeBuilder::new()
138 .with_welcome_pointer(
139 installation_key.clone(),
140 welcome_pointer.clone(),
141 hpke_public_key.clone(),
142 wrapper_algorithm,
143 )
144 .build();
145 let payload = envelope.payload().unwrap();
146
147 match payload {
149 Payload::WelcomeMessage(msg) => {
150 assert!(msg.version.is_some());
151 let m = msg.version.unwrap();
152 let welcome_message_input::Version::WelcomePointer(wp) = m else {
153 panic!("Expected WelcomeMessageVersion::WelcomePointer");
154 };
155 assert_eq!(wp.installation_key, installation_key);
156 assert_eq!(wp.welcome_pointer, welcome_pointer);
157 assert_eq!(wp.hpke_public_key, hpke_public_key);
158 assert_eq!(wp.wrapper_algorithm, wrapper_algorithm);
159 }
160 _ => panic!("Expected WelcomeMessage payload"),
161 }
162 }
163
164 #[xmtp_common::test]
165 fn test_extract_key_package_payload() {
166 let envelope = TestEnvelopeBuilder::new()
167 .with_key_package_custom(vec![1, 2, 3])
168 .build();
169 let payload = envelope.payload().unwrap();
170
171 match payload {
172 Payload::UploadKeyPackage(kp) => {
173 assert!(kp.key_package.is_some());
174 assert!(!kp.is_inbox_id_credential);
175 }
176 _ => panic!("Expected UploadKeyPackage payload"),
177 }
178 }
179
180 #[xmtp_common::test]
181 fn test_extract_identity_update_payload() {
182 let envelope = TestEnvelopeBuilder::new()
183 .with_identity_update_custom("test_inbox".to_string())
184 .build();
185 let payload = envelope.payload().unwrap();
186
187 match payload {
188 Payload::IdentityUpdate(update) => {
189 assert_eq!(update.inbox_id, "test_inbox");
190 }
191 _ => panic!("Expected IdentityUpdate payload"),
192 }
193 }
194
195 #[xmtp_common::test]
196 fn test_extract_no_payload_fails() {
197 let envelope = TestEnvelopeBuilder::new().with_empty_payload().build();
198 let result = envelope.payload();
199 assert!(result.is_err());
200 matches!(
201 result.unwrap_err(),
202 EnvelopeError::Extraction(ExtractionError::Payload(PayloadExtractionError::Failed))
203 );
204 }
205}