xmtp_db/encrypted_store/group/
dms.rs

1use crate::ConnectionExt;
2
3use super::*;
4use crate::ConnectionError;
5
6pub trait QueryDms {
7    /// Same behavior as fetched, but will stitch DM groups
8    fn fetch_stitched(&self, key: &[u8]) -> Result<Option<StoredGroup>, ConnectionError>;
9
10    fn find_active_dm_group<M>(&self, members: M) -> Result<Option<StoredGroup>, ConnectionError>
11    where
12        M: std::fmt::Display;
13
14    /// Load the other DMs that are stitched into this group
15    fn other_dms(&self, group_id: &[u8]) -> Result<Vec<StoredGroup>, ConnectionError>;
16}
17
18impl<T> QueryDms for &T
19where
20    T: QueryDms,
21{
22    fn fetch_stitched(&self, key: &[u8]) -> Result<Option<StoredGroup>, ConnectionError> {
23        (**self).fetch_stitched(key)
24    }
25
26    fn find_active_dm_group<M>(&self, members: M) -> Result<Option<StoredGroup>, ConnectionError>
27    where
28        M: std::fmt::Display,
29    {
30        (**self).find_active_dm_group(members)
31    }
32
33    fn other_dms(&self, group_id: &[u8]) -> Result<Vec<StoredGroup>, ConnectionError> {
34        (**self).other_dms(group_id)
35    }
36}
37
38impl<C: ConnectionExt> QueryDms for DbConnection<C> {
39    /// Same behavior as fetched, but will stitch DM groups
40    fn fetch_stitched(&self, key: &[u8]) -> Result<Option<StoredGroup>, ConnectionError> {
41        let group = self.raw_query_read(|conn| {
42            groups::table
43                .filter(groups::id.eq(key))
44                .first::<StoredGroup>(conn)
45                .optional()
46        })?;
47
48        // Is this group a DM?
49        let Some(StoredGroup {
50            dm_id: Some(dm_id), ..
51        }) = group
52        else {
53            // If not, return the group
54            return Ok(group);
55        };
56
57        // Otherwise, return the stitched DM
58        self.raw_query_read(|conn| {
59            groups::table
60                .filter(groups::dm_id.eq(dm_id))
61                .order_by(groups::last_message_ns.desc())
62                .first::<StoredGroup>(conn)
63                .optional()
64        })
65    }
66
67    fn find_active_dm_group<M>(&self, members: M) -> Result<Option<StoredGroup>, ConnectionError>
68    where
69        M: std::fmt::Display,
70    {
71        let query = dsl::groups
72            .filter(dsl::dm_id.eq(Some(members.to_string())))
73            .filter(dsl::membership_state.ne(GroupMembershipState::Restored))
74            .order_by(dsl::last_message_ns.desc());
75
76        self.raw_query_read(|conn| query.first(conn).optional())
77    }
78
79    /// Load the other DMs that are stitched into this group
80    fn other_dms(&self, group_id: &[u8]) -> Result<Vec<StoredGroup>, ConnectionError> {
81        let query = dsl::groups.filter(dsl::id.eq(group_id));
82
83        let groups: Vec<StoredGroup> = self.raw_query_read(|conn| query.load(conn))?;
84
85        // Grab the dm_id of the group
86        let Some(StoredGroup {
87            id,
88            dm_id: Some(dm_id),
89            ..
90        }) = groups.into_iter().next()
91        else {
92            return Ok(vec![]);
93        };
94
95        let query = dsl::groups
96            .filter(dsl::dm_id.eq(dm_id))
97            .filter(dsl::id.ne(id));
98
99        let other_dms: Vec<StoredGroup> = self.raw_query_read(|conn| query.load(conn))?;
100        Ok(other_dms)
101    }
102}
103
104#[cfg(test)]
105pub(super) mod tests {
106    #[cfg(target_arch = "wasm32")]
107    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
108    use super::*;
109    use crate::{Store, test_utils::with_connection};
110    use std::sync::atomic::{AtomicU16, Ordering};
111    use xmtp_common::{rand_vec, time::now_ns};
112
113    static TARGET_INBOX_ID: AtomicU16 = AtomicU16::new(2);
114
115    /// Generate a test dm group
116    pub fn generate_dm(state: Option<GroupMembershipState>) -> StoredGroup {
117        let target = TARGET_INBOX_ID.fetch_add(1, Ordering::SeqCst).to_string();
118        StoredGroup::builder()
119            .id(rand_vec::<24>())
120            .created_at_ns(now_ns())
121            .membership_state(state.unwrap_or(GroupMembershipState::Allowed))
122            .added_by_inbox_id("placeholder_address")
123            .dm_id(format!(
124                "dm:placeholder_inbox_id_1:placeholder_inbox_id_{target}",
125            ))
126            .build()
127            .unwrap()
128    }
129
130    #[xmtp_common::test]
131    fn test_dm_stitching() {
132        with_connection(|conn| {
133            StoredGroup::builder()
134                .id(rand_vec::<24>())
135                .created_at_ns(now_ns())
136                .membership_state(GroupMembershipState::Allowed)
137                .added_by_inbox_id("placeholder_address")
138                .dm_id(Some("dm:some_wise_guy:thats_me".to_string()))
139                .build()
140                .unwrap()
141                .store(conn)
142                .unwrap();
143
144            StoredGroup::builder()
145                .id(rand_vec::<24>())
146                .created_at_ns(now_ns())
147                .membership_state(GroupMembershipState::Allowed)
148                .added_by_inbox_id("placeholder_address")
149                .dm_id(Some("dm:some_wise_guy:thats_me".to_string()))
150                .build()
151                .unwrap()
152                .store(conn)
153                .unwrap();
154            let all_groups = conn.find_groups(GroupQueryArgs::default()).unwrap();
155
156            assert_eq!(all_groups.len(), 1);
157        })
158    }
159
160    #[xmtp_common::test]
161    fn test_dm_deduplication() {
162        with_connection(|conn| {
163            let now = now_ns();
164            let base_time = now - 1_000_000_000; // 1 second ago
165
166            // Create DM groups with same dm_id but different timestamps
167            let dm_id = "dm:alice:bob";
168
169            // Oldest DM (should be filtered out)
170            let oldest_dm = StoredGroup::builder()
171                .id(rand_vec::<24>())
172                .created_at_ns(base_time)
173                .last_message_ns(base_time)
174                .membership_state(GroupMembershipState::Allowed)
175                .added_by_inbox_id("alice")
176                .dm_id(Some(dm_id.to_string()))
177                .build()
178                .unwrap();
179            oldest_dm.store(conn).unwrap();
180
181            // Middle DM (should be filtered out)
182            let middle_dm = StoredGroup::builder()
183                .id(rand_vec::<24>())
184                .created_at_ns(base_time + 1_000_000)
185                .last_message_ns(base_time + 1_000_000)
186                .membership_state(GroupMembershipState::Allowed)
187                .added_by_inbox_id("bob")
188                .dm_id(Some(dm_id.to_string()))
189                .build()
190                .unwrap();
191            middle_dm.store(conn).unwrap();
192
193            // Latest DM (should be kept)
194            let latest_dm = StoredGroup::builder()
195                .id(rand_vec::<24>())
196                .created_at_ns(base_time + 2_000_000)
197                .last_message_ns(base_time + 2_000_000)
198                .membership_state(GroupMembershipState::Allowed)
199                .added_by_inbox_id("alice")
200                .dm_id(Some(dm_id.to_string()))
201                .build()
202                .unwrap();
203            latest_dm.store(conn).unwrap();
204
205            // Create another DM with different dm_id (should always be kept)
206            let different_dm = StoredGroup::builder()
207                .id(rand_vec::<24>())
208                .created_at_ns(base_time + 500_000)
209                .last_message_ns(base_time + 500_000)
210                .membership_state(GroupMembershipState::Allowed)
211                .added_by_inbox_id("charlie")
212                .dm_id(Some("dm:charlie:dave".to_string()))
213                .build()
214                .unwrap();
215            different_dm.store(conn).unwrap();
216
217            // Create a regular group (non-DM, should always be kept)
218            let regular_group = StoredGroup::builder()
219                .id(rand_vec::<24>())
220                .created_at_ns(base_time + 1_500_000)
221                .last_message_ns(base_time + 1_500_000)
222                .membership_state(GroupMembershipState::Allowed)
223                .added_by_inbox_id("alice")
224                .dm_id(None) // No dm_id = regular group
225                .build()
226                .unwrap();
227            regular_group.store(conn).unwrap();
228
229            // Test with include_duplicate_dms = false (default deduplication)
230            let deduplicated_groups = conn
231                .find_groups(GroupQueryArgs {
232                    include_duplicate_dms: false,
233                    ..Default::default()
234                })
235                .unwrap();
236
237            // Should have 3 groups: latest DM, different DM, and regular group
238            assert_eq!(deduplicated_groups.len(), 3);
239
240            // Verify the latest DM is kept (highest last_message_ns for dm_id)
241            let kept_dm = deduplicated_groups
242                .iter()
243                .find(|g| g.dm_id.as_deref() == Some(dm_id))
244                .expect("Should find the DM group");
245            assert_eq!(kept_dm.id, latest_dm.id);
246            assert_eq!(kept_dm.last_message_ns, Some(base_time + 2_000_000));
247
248            // Verify different DM is kept
249            let kept_different_dm = deduplicated_groups
250                .iter()
251                .find(|g| g.dm_id.as_deref() == Some("dm:charlie:dave"))
252                .expect("Should find the different DM group");
253            assert_eq!(kept_different_dm.id, different_dm.id);
254
255            // Verify regular group is kept
256            let kept_regular = deduplicated_groups
257                .iter()
258                .find(|g| g.dm_id.is_none())
259                .expect("Should find the regular group");
260            assert_eq!(kept_regular.id, regular_group.id);
261
262            // Test with include_duplicate_dms = true (no deduplication)
263            let all_groups = conn
264                .find_groups(GroupQueryArgs {
265                    include_duplicate_dms: true,
266                    ..Default::default()
267                })
268                .unwrap();
269
270            // Should have all 5 groups
271            assert_eq!(all_groups.len(), 5);
272        })
273    }
274}