xmtp_api_d14n/protocol/extractors/
group_messages.rs

1use xmtp_cryptography::hash::sha256_bytes;
2use xmtp_proto::{
3    ConversionError,
4    mls_v1::group_message,
5    types::{Cursor, GlobalCursor, GroupMessage, GroupMessageBuilder},
6};
7
8use crate::protocol::traits::EnvelopeVisitor;
9use crate::protocol::{ExtractionError, Extractor};
10use chrono::{DateTime, Utc};
11use openmls::{
12    framing::MlsMessageIn,
13    prelude::{ContentType, ProtocolMessage, tls_codec::Deserialize},
14};
15use xmtp_proto::xmtp::mls::api::v1::group_message_input;
16use xmtp_proto::xmtp::xmtpv4::envelopes::UnsignedOriginatorEnvelope;
17
18/// Type to extract a Group Message from Originator Envelopes
19#[derive(Default, Clone, Debug)]
20pub struct GroupMessageExtractor {
21    cursor: Cursor,
22    created_ns: DateTime<Utc>,
23    group_message: Option<GroupMessageBuilder>,
24    depends_on: GlobalCursor,
25}
26
27impl Extractor for GroupMessageExtractor {
28    type Output = Result<GroupMessage, ExtractionError>;
29
30    fn get(self) -> Self::Output {
31        let Self {
32            cursor,
33            created_ns,
34            group_message,
35            depends_on,
36        } = self;
37        if let Some(mut gm) = group_message {
38            gm.cursor(cursor);
39            gm.created_ns(created_ns);
40            gm.depends_on(depends_on);
41            Ok(gm.build()?)
42        } else {
43            Err(ExtractionError::Conversion(ConversionError::Missing {
44                item: "group_message",
45                r#type: std::any::type_name::<GroupMessage>(),
46            }))
47        }
48    }
49}
50
51impl EnvelopeVisitor<'_> for GroupMessageExtractor {
52    type Error = ConversionError;
53
54    fn visit_unsigned_originator(
55        &mut self,
56        envelope: &UnsignedOriginatorEnvelope,
57    ) -> Result<(), Self::Error> {
58        self.cursor = Cursor::new(envelope.originator_sequence_id, envelope.originator_node_id);
59        self.created_ns = DateTime::from_timestamp_nanos(envelope.originator_ns);
60        Ok(())
61    }
62
63    fn visit_client(
64        &mut self,
65        e: &xmtp_proto::xmtp::xmtpv4::envelopes::ClientEnvelope,
66    ) -> Result<(), Self::Error> {
67        if let Some(ref aad) = e.aad {
68            self.depends_on = aad.depends_on.clone().unwrap_or_default().into();
69        }
70        Ok(())
71    }
72
73    fn visit_group_message_v1(
74        &mut self,
75        message: &group_message_input::V1,
76    ) -> Result<(), Self::Error> {
77        let mut gm = GroupMessageBuilder::default();
78        let payload_hash = sha256_bytes(message.data.as_slice());
79        gm.sender_hmac(message.sender_hmac.clone())
80            .should_push(message.should_push)
81            .payload_hash(payload_hash);
82        extract_common_mls(&mut gm, &message.data)?;
83        self.group_message = Some(gm);
84        Ok(())
85    }
86}
87
88#[derive(Default)]
89pub struct V3GroupMessageExtractor {
90    group_message: Option<GroupMessageBuilder>,
91}
92
93impl Extractor for V3GroupMessageExtractor {
94    type Output = Result<Option<GroupMessage>, ConversionError>;
95
96    fn get(self) -> Self::Output {
97        if let Some(gm) = self.group_message {
98            Ok(Some(gm.build()?))
99        } else {
100            Ok(None)
101        }
102    }
103}
104
105impl EnvelopeVisitor<'_> for V3GroupMessageExtractor {
106    type Error = ConversionError;
107
108    fn visit_v3_group_message(&mut self, message: &group_message::V1) -> Result<(), Self::Error> {
109        let mut group_message = GroupMessage::builder();
110        let payload_hash = sha256_bytes(message.data.as_slice());
111        // commits are stored inside of messages
112        // MLS commits come from a strongly-ordered backend (like a blockchain)
113        // Application messages come from XMTPD nodes (not strongly ordered)
114        let is_commit = extract_common_mls(&mut group_message, &message.data)?;
115        let originator_node_id = if is_commit {
116            xmtp_configuration::Originators::MLS_COMMITS
117        } else {
118            xmtp_configuration::Originators::APPLICATION_MESSAGES
119        };
120        group_message
121            .cursor(Cursor::new(message.id, originator_node_id))
122            .created_ns(DateTime::from_timestamp_nanos(message.created_ns as i64))
123            .sender_hmac(message.sender_hmac.clone())
124            .should_push(message.should_push)
125            .payload_hash(payload_hash);
126        self.group_message = Some(group_message);
127        Ok(())
128    }
129}
130
131/// extract common mls config
132/// returns true if it is a commit
133fn extract_common_mls(
134    builder: &mut GroupMessageBuilder,
135    mut data: &[u8],
136) -> Result<bool, ConversionError> {
137    let msg_in = MlsMessageIn::tls_deserialize(&mut data)?;
138    let protocol_message: ProtocolMessage = msg_in.try_into_protocol_message()?;
139    let is_commit = protocol_message.content_type() == ContentType::Commit;
140
141    builder
142        .group_id(protocol_message.group_id().to_vec())
143        .message(protocol_message);
144    Ok(is_commit)
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use crate::protocol::ProtocolEnvelope;
151    use crate::protocol::extractors::test_utils::*;
152
153    #[xmtp_common::test]
154    fn test_extract_group_message_fails_with_mock_data() {
155        let envelope = TestEnvelopeBuilder::new()
156            .with_originator_node_id(123)
157            .with_originator_sequence_id(456)
158            .with_originator_ns(789)
159            .with_application_message(vec![1, 2, 3])
160            .build();
161        let mut extractor = GroupMessageExtractor::default();
162        envelope.accept(&mut extractor).unwrap();
163        let msg = extractor.get().unwrap();
164        assert_eq!(vec![1, 2, 3], *msg.group_id);
165    }
166}