xmtp_api_d14n/protocol/extractors/
topics.rs

1use hex::FromHexError;
2use openmls::framing::errors::ProtocolMessageError;
3use xmtp_common::RetryableError;
4use xmtp_proto::ConversionError;
5
6use crate::protocol::ExtractionError;
7
8use super::{EnvelopeError, Extractor};
9use crate::protocol::traits::EnvelopeVisitor;
10use openmls::prelude::KeyPackageVerifyError;
11use openmls::{
12    framing::MlsMessageIn,
13    prelude::{KeyPackageIn, ProtocolMessage, tls_codec::Deserialize},
14};
15use openmls_rust_crypto::RustCrypto;
16use xmtp_proto::types::{Topic, TopicKind};
17use xmtp_proto::xmtp::identity::api::v1::get_identity_updates_request;
18use xmtp_proto::xmtp::identity::associations::IdentityUpdate;
19use xmtp_proto::xmtp::mls::api::v1::KeyPackageUpload;
20use xmtp_proto::xmtp::mls::api::v1::UploadKeyPackageRequest;
21use xmtp_proto::xmtp::mls::api::v1::{
22    group_message_input::V1 as GroupMessageV1,
23    welcome_message_input::{
24        V1 as WelcomeMessageV1, WelcomePointer as WelcomeMessageWelcomePointer,
25    },
26};
27
28/// Extract Topics from Envelopes
29#[derive(Default, Clone, Debug)]
30pub struct TopicExtractor {
31    topic: Option<Topic>,
32}
33
34impl TopicExtractor {
35    pub fn new() -> Self {
36        Default::default()
37    }
38}
39impl Extractor for TopicExtractor {
40    type Output = Result<Topic, TopicExtractionError>;
41
42    fn get(self) -> Self::Output {
43        self.topic.ok_or(TopicExtractionError::Failed)
44    }
45}
46
47impl TopicExtractor {
48    pub fn get(self) -> Result<Topic, TopicExtractionError> {
49        self.topic.ok_or(TopicExtractionError::Failed)
50    }
51}
52
53#[derive(thiserror::Error, Debug)]
54pub enum TopicExtractionError {
55    #[error("Topic extraction failed, no topic available")]
56    Failed,
57    #[error(transparent)]
58    KeyPackageVerify(#[from] KeyPackageVerifyError),
59    #[error(transparent)]
60    Mls(#[from] openmls::prelude::Error),
61    #[error(transparent)]
62    Protocol(#[from] ProtocolMessageError),
63    #[error(transparent)]
64    FromHex(#[from] FromHexError),
65    #[error(transparent)]
66    Conversion(#[from] ConversionError),
67}
68
69impl RetryableError for TopicExtractionError {
70    fn is_retryable(&self) -> bool {
71        false
72    }
73}
74
75impl From<TopicExtractionError> for EnvelopeError {
76    fn from(err: TopicExtractionError) -> EnvelopeError {
77        EnvelopeError::Extraction(ExtractionError::Topic(err))
78    }
79}
80
81impl EnvelopeVisitor<'_> for TopicExtractor {
82    type Error = TopicExtractionError;
83
84    fn visit_group_message_v1(&mut self, message: &GroupMessageV1) -> Result<(), Self::Error> {
85        let msg_result = MlsMessageIn::tls_deserialize(&mut message.data.as_slice())?;
86        let protocol_message: ProtocolMessage = msg_result.try_into_protocol_message()?;
87        self.topic =
88            Some(TopicKind::GroupMessagesV1.create(protocol_message.group_id().as_slice()));
89        Ok(())
90    }
91
92    fn visit_welcome_message_version(
93        &mut self,
94        version: &xmtp_proto::mls_v1::welcome_message_input::Version,
95    ) -> Result<(), Self::Error> {
96        match version {
97            xmtp_proto::mls_v1::welcome_message_input::Version::V1(v1) => {
98                self.visit_welcome_message_v1(v1)
99            }
100            xmtp_proto::mls_v1::welcome_message_input::Version::WelcomePointer(wp) => {
101                self.visit_welcome_pointer(wp)
102            }
103        }
104    }
105
106    fn visit_welcome_message_v1(&mut self, message: &WelcomeMessageV1) -> Result<(), Self::Error> {
107        self.topic = Some(TopicKind::WelcomeMessagesV1.create(message.installation_key.as_slice()));
108        Ok(())
109    }
110
111    fn visit_welcome_pointer(
112        &mut self,
113        message: &WelcomeMessageWelcomePointer,
114    ) -> Result<(), Self::Error> {
115        self.topic = Some(TopicKind::WelcomeMessagesV1.create(message.installation_key.as_slice()));
116        Ok(())
117    }
118
119    fn visit_upload_key_package(
120        &mut self,
121        kp: &UploadKeyPackageRequest,
122    ) -> Result<(), Self::Error> {
123        let upload = kp.key_package.as_ref().ok_or(ConversionError::Missing {
124            item: "key_package",
125            r#type: std::any::type_name::<KeyPackageUpload>(),
126        })?;
127        let kp_in: KeyPackageIn =
128            KeyPackageIn::tls_deserialize_exact(upload.key_package_tls_serialized.as_slice())?;
129        let rust_crypto = RustCrypto::default();
130        let kp = kp_in.validate(
131            &rust_crypto,
132            xmtp_configuration::MLS_PROTOCOL_VERSION,
133            openmls::prelude::LeafNodeLifetimePolicy::Verify,
134        )?;
135        let installation_key = kp.leaf_node().signature_key().as_slice();
136        self.topic = Some(TopicKind::KeyPackagesV1.create(installation_key));
137        Ok(())
138    }
139
140    fn visit_identity_update(&mut self, update: &IdentityUpdate) -> Result<(), Self::Error> {
141        let decoded_id = hex::decode(&update.inbox_id)?;
142        self.topic = Some(TopicKind::IdentityUpdatesV1.create(&decoded_id));
143        Ok(())
144    }
145
146    fn visit_identity_updates_request(
147        &mut self,
148        update: &get_identity_updates_request::Request,
149    ) -> Result<(), Self::Error> {
150        let decoded_id = hex::decode(&update.inbox_id)?;
151        self.topic = Some(TopicKind::IdentityUpdatesV1.create(&decoded_id));
152        Ok(())
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use xmtp_cryptography::XmtpInstallationCredential;
159
160    use super::*;
161    use crate::protocol::Envelope;
162    use crate::protocol::extractors::test_utils::*;
163
164    #[xmtp_common::test]
165    fn test_extract_group_message_topic() {
166        let envelope = TestEnvelopeBuilder::new()
167            .with_application_message(vec![1, 2, 3])
168            .build();
169        assert_eq!(
170            envelope.topic().unwrap(),
171            TopicKind::GroupMessagesV1.create([1, 2, 3])
172        );
173    }
174
175    #[xmtp_common::test]
176    fn test_extract_welcome_message_topic() {
177        let envelope = TestEnvelopeBuilder::new()
178            .with_welcome_message(vec![5, 6, 7, 8])
179            .build();
180        let topic = envelope.topic().unwrap();
181
182        let expected_topic = TopicKind::WelcomeMessagesV1.create([5, 6, 7, 8]);
183        assert_eq!(topic, expected_topic);
184    }
185
186    #[xmtp_common::test]
187    fn test_extract_key_package_topic() {
188        let installation = XmtpInstallationCredential::default();
189        let envelope = TestEnvelopeBuilder::new()
190            .with_key_package("test".to_string(), installation.clone())
191            .build();
192        assert_eq!(
193            envelope.topic().unwrap(),
194            TopicKind::KeyPackagesV1.create(installation.public_slice())
195        );
196    }
197
198    #[xmtp_common::test]
199    fn test_extract_identity_update_topic() {
200        let envelope = TestEnvelopeBuilder::new().with_identity_update().build();
201
202        let expected_decoded_id = hex::decode("abcd1234").unwrap();
203        let expected_topic = TopicKind::IdentityUpdatesV1.create(&expected_decoded_id);
204        assert_eq!(envelope.topic().unwrap(), expected_topic);
205    }
206
207    #[xmtp_common::test]
208    fn test_extract_missing_key_package_fails() {
209        let envelope = TestEnvelopeBuilder::new()
210            .with_invalid_key_package()
211            .build();
212
213        assert!(envelope.topic().is_err());
214    }
215
216    #[xmtp_common::test]
217    fn test_extract_invalid_hex_identity_fails() {
218        let envelope = TestEnvelopeBuilder::new()
219            .with_invalid_identity_update()
220            .build();
221        assert!(envelope.topic().is_err());
222    }
223
224    #[xmtp_common::test]
225    fn test_extract_no_topic_fails() {
226        let extractor = TopicExtractor::new();
227        let result = extractor.get();
228
229        assert!(result.is_err());
230        assert!(matches!(result.unwrap_err(), TopicExtractionError::Failed));
231    }
232
233    #[xmtp_common::test]
234    fn test_extraction_from_identity_update_req() {
235        let req = get_identity_updates_request::Request {
236            inbox_id: hex::encode(b"test_id"),
237            sequence_id: 0,
238        };
239        assert_eq!(
240            req.topic().unwrap(),
241            TopicKind::IdentityUpdatesV1.create(b"test_id")
242        );
243    }
244}