xmtp_api_d14n/protocol/extractors/
group_messages.rs1use xmtp_cryptography::hash::sha256_bytes;
2use xmtp_proto::{
3 ConversionError,
4 mls_v1::group_message,
5 types::{Cursor, GlobalCursor, GroupMessage, GroupMessageBuilder},
6};
7
8use crate::protocol::traits::EnvelopeVisitor;
9use crate::protocol::{ExtractionError, Extractor};
10use chrono::{DateTime, Utc};
11use openmls::{
12 framing::MlsMessageIn,
13 prelude::{ContentType, ProtocolMessage, tls_codec::Deserialize},
14};
15use xmtp_proto::xmtp::mls::api::v1::group_message_input;
16use xmtp_proto::xmtp::xmtpv4::envelopes::UnsignedOriginatorEnvelope;
17
18#[derive(Default, Clone, Debug)]
20pub struct GroupMessageExtractor {
21 cursor: Cursor,
22 created_ns: DateTime<Utc>,
23 group_message: Option<GroupMessageBuilder>,
24 depends_on: GlobalCursor,
25}
26
27impl Extractor for GroupMessageExtractor {
28 type Output = Result<GroupMessage, ExtractionError>;
29
30 fn get(self) -> Self::Output {
31 let Self {
32 cursor,
33 created_ns,
34 group_message,
35 depends_on,
36 } = self;
37 if let Some(mut gm) = group_message {
38 gm.cursor(cursor);
39 gm.created_ns(created_ns);
40 gm.depends_on(depends_on);
41 Ok(gm.build()?)
42 } else {
43 Err(ExtractionError::Conversion(ConversionError::Missing {
44 item: "group_message",
45 r#type: std::any::type_name::<GroupMessage>(),
46 }))
47 }
48 }
49}
50
51impl EnvelopeVisitor<'_> for GroupMessageExtractor {
52 type Error = ConversionError;
53
54 fn visit_unsigned_originator(
55 &mut self,
56 envelope: &UnsignedOriginatorEnvelope,
57 ) -> Result<(), Self::Error> {
58 self.cursor = Cursor::new(envelope.originator_sequence_id, envelope.originator_node_id);
59 self.created_ns = DateTime::from_timestamp_nanos(envelope.originator_ns);
60 Ok(())
61 }
62
63 fn visit_client(
64 &mut self,
65 e: &xmtp_proto::xmtp::xmtpv4::envelopes::ClientEnvelope,
66 ) -> Result<(), Self::Error> {
67 if let Some(ref aad) = e.aad {
68 self.depends_on = aad.depends_on.clone().unwrap_or_default().into();
69 }
70 Ok(())
71 }
72
73 fn visit_group_message_v1(
74 &mut self,
75 message: &group_message_input::V1,
76 ) -> Result<(), Self::Error> {
77 let mut gm = GroupMessageBuilder::default();
78 let payload_hash = sha256_bytes(message.data.as_slice());
79 gm.sender_hmac(message.sender_hmac.clone())
80 .should_push(message.should_push)
81 .payload_hash(payload_hash);
82 extract_common_mls(&mut gm, &message.data)?;
83 self.group_message = Some(gm);
84 Ok(())
85 }
86}
87
88#[derive(Default)]
89pub struct V3GroupMessageExtractor {
90 group_message: Option<GroupMessageBuilder>,
91}
92
93impl Extractor for V3GroupMessageExtractor {
94 type Output = Result<Option<GroupMessage>, ConversionError>;
95
96 fn get(self) -> Self::Output {
97 if let Some(gm) = self.group_message {
98 Ok(Some(gm.build()?))
99 } else {
100 Ok(None)
101 }
102 }
103}
104
105impl EnvelopeVisitor<'_> for V3GroupMessageExtractor {
106 type Error = ConversionError;
107
108 fn visit_v3_group_message(&mut self, message: &group_message::V1) -> Result<(), Self::Error> {
109 let mut group_message = GroupMessage::builder();
110 let payload_hash = sha256_bytes(message.data.as_slice());
111 let is_commit = extract_common_mls(&mut group_message, &message.data)?;
115 let originator_node_id = if is_commit {
116 xmtp_configuration::Originators::MLS_COMMITS
117 } else {
118 xmtp_configuration::Originators::APPLICATION_MESSAGES
119 };
120 group_message
121 .cursor(Cursor::new(message.id, originator_node_id))
122 .created_ns(DateTime::from_timestamp_nanos(message.created_ns as i64))
123 .sender_hmac(message.sender_hmac.clone())
124 .should_push(message.should_push)
125 .payload_hash(payload_hash);
126 self.group_message = Some(group_message);
127 Ok(())
128 }
129}
130
131fn extract_common_mls(
134 builder: &mut GroupMessageBuilder,
135 mut data: &[u8],
136) -> Result<bool, ConversionError> {
137 let msg_in = MlsMessageIn::tls_deserialize(&mut data)?;
138 let protocol_message: ProtocolMessage = msg_in.try_into_protocol_message()?;
139 let is_commit = protocol_message.content_type() == ContentType::Commit;
140
141 builder
142 .group_id(protocol_message.group_id().to_vec())
143 .message(protocol_message);
144 Ok(is_commit)
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use crate::protocol::ProtocolEnvelope;
151 use crate::protocol::extractors::test_utils::*;
152
153 #[xmtp_common::test]
154 fn test_extract_group_message_fails_with_mock_data() {
155 let envelope = TestEnvelopeBuilder::new()
156 .with_originator_node_id(123)
157 .with_originator_sequence_id(456)
158 .with_originator_ns(789)
159 .with_application_message(vec![1, 2, 3])
160 .build();
161 let mut extractor = GroupMessageExtractor::default();
162 envelope.accept(&mut extractor).unwrap();
163 let msg = extractor.get().unwrap();
164 assert_eq!(vec![1, 2, 3], *msg.group_id);
165 }
166}