xmtp_db/encrypted_store/group/
dms.rs1use crate::ConnectionExt;
2
3use super::*;
4use crate::ConnectionError;
5
6pub trait QueryDms {
7 fn fetch_stitched(&self, key: &[u8]) -> Result<Option<StoredGroup>, ConnectionError>;
9
10 fn find_active_dm_group<M>(&self, members: M) -> Result<Option<StoredGroup>, ConnectionError>
11 where
12 M: std::fmt::Display;
13
14 fn other_dms(&self, group_id: &[u8]) -> Result<Vec<StoredGroup>, ConnectionError>;
16}
17
18impl<T> QueryDms for &T
19where
20 T: QueryDms,
21{
22 fn fetch_stitched(&self, key: &[u8]) -> Result<Option<StoredGroup>, ConnectionError> {
23 (**self).fetch_stitched(key)
24 }
25
26 fn find_active_dm_group<M>(&self, members: M) -> Result<Option<StoredGroup>, ConnectionError>
27 where
28 M: std::fmt::Display,
29 {
30 (**self).find_active_dm_group(members)
31 }
32
33 fn other_dms(&self, group_id: &[u8]) -> Result<Vec<StoredGroup>, ConnectionError> {
34 (**self).other_dms(group_id)
35 }
36}
37
38impl<C: ConnectionExt> QueryDms for DbConnection<C> {
39 fn fetch_stitched(&self, key: &[u8]) -> Result<Option<StoredGroup>, ConnectionError> {
41 let group = self.raw_query_read(|conn| {
42 groups::table
43 .filter(groups::id.eq(key))
44 .first::<StoredGroup>(conn)
45 .optional()
46 })?;
47
48 let Some(StoredGroup {
50 dm_id: Some(dm_id), ..
51 }) = group
52 else {
53 return Ok(group);
55 };
56
57 self.raw_query_read(|conn| {
59 groups::table
60 .filter(groups::dm_id.eq(dm_id))
61 .order_by(groups::last_message_ns.desc())
62 .first::<StoredGroup>(conn)
63 .optional()
64 })
65 }
66
67 fn find_active_dm_group<M>(&self, members: M) -> Result<Option<StoredGroup>, ConnectionError>
68 where
69 M: std::fmt::Display,
70 {
71 let query = dsl::groups
72 .filter(dsl::dm_id.eq(Some(members.to_string())))
73 .filter(dsl::membership_state.ne(GroupMembershipState::Restored))
74 .order_by(dsl::last_message_ns.desc());
75
76 self.raw_query_read(|conn| query.first(conn).optional())
77 }
78
79 fn other_dms(&self, group_id: &[u8]) -> Result<Vec<StoredGroup>, ConnectionError> {
81 let query = dsl::groups.filter(dsl::id.eq(group_id));
82
83 let groups: Vec<StoredGroup> = self.raw_query_read(|conn| query.load(conn))?;
84
85 let Some(StoredGroup {
87 id,
88 dm_id: Some(dm_id),
89 ..
90 }) = groups.into_iter().next()
91 else {
92 return Ok(vec![]);
93 };
94
95 let query = dsl::groups
96 .filter(dsl::dm_id.eq(dm_id))
97 .filter(dsl::id.ne(id));
98
99 let other_dms: Vec<StoredGroup> = self.raw_query_read(|conn| query.load(conn))?;
100 Ok(other_dms)
101 }
102}
103
104#[cfg(test)]
105pub(super) mod tests {
106 #[cfg(target_arch = "wasm32")]
107 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
108 use super::*;
109 use crate::{Store, test_utils::with_connection};
110 use std::sync::atomic::{AtomicU16, Ordering};
111 use xmtp_common::{rand_vec, time::now_ns};
112
113 static TARGET_INBOX_ID: AtomicU16 = AtomicU16::new(2);
114
115 pub fn generate_dm(state: Option<GroupMembershipState>) -> StoredGroup {
117 let target = TARGET_INBOX_ID.fetch_add(1, Ordering::SeqCst).to_string();
118 StoredGroup::builder()
119 .id(rand_vec::<24>())
120 .created_at_ns(now_ns())
121 .membership_state(state.unwrap_or(GroupMembershipState::Allowed))
122 .added_by_inbox_id("placeholder_address")
123 .dm_id(format!(
124 "dm:placeholder_inbox_id_1:placeholder_inbox_id_{target}",
125 ))
126 .build()
127 .unwrap()
128 }
129
130 #[xmtp_common::test]
131 fn test_dm_stitching() {
132 with_connection(|conn| {
133 StoredGroup::builder()
134 .id(rand_vec::<24>())
135 .created_at_ns(now_ns())
136 .membership_state(GroupMembershipState::Allowed)
137 .added_by_inbox_id("placeholder_address")
138 .dm_id(Some("dm:some_wise_guy:thats_me".to_string()))
139 .build()
140 .unwrap()
141 .store(conn)
142 .unwrap();
143
144 StoredGroup::builder()
145 .id(rand_vec::<24>())
146 .created_at_ns(now_ns())
147 .membership_state(GroupMembershipState::Allowed)
148 .added_by_inbox_id("placeholder_address")
149 .dm_id(Some("dm:some_wise_guy:thats_me".to_string()))
150 .build()
151 .unwrap()
152 .store(conn)
153 .unwrap();
154 let all_groups = conn.find_groups(GroupQueryArgs::default()).unwrap();
155
156 assert_eq!(all_groups.len(), 1);
157 })
158 }
159
160 #[xmtp_common::test]
161 fn test_dm_deduplication() {
162 with_connection(|conn| {
163 let now = now_ns();
164 let base_time = now - 1_000_000_000; let dm_id = "dm:alice:bob";
168
169 let oldest_dm = StoredGroup::builder()
171 .id(rand_vec::<24>())
172 .created_at_ns(base_time)
173 .last_message_ns(base_time)
174 .membership_state(GroupMembershipState::Allowed)
175 .added_by_inbox_id("alice")
176 .dm_id(Some(dm_id.to_string()))
177 .build()
178 .unwrap();
179 oldest_dm.store(conn).unwrap();
180
181 let middle_dm = StoredGroup::builder()
183 .id(rand_vec::<24>())
184 .created_at_ns(base_time + 1_000_000)
185 .last_message_ns(base_time + 1_000_000)
186 .membership_state(GroupMembershipState::Allowed)
187 .added_by_inbox_id("bob")
188 .dm_id(Some(dm_id.to_string()))
189 .build()
190 .unwrap();
191 middle_dm.store(conn).unwrap();
192
193 let latest_dm = StoredGroup::builder()
195 .id(rand_vec::<24>())
196 .created_at_ns(base_time + 2_000_000)
197 .last_message_ns(base_time + 2_000_000)
198 .membership_state(GroupMembershipState::Allowed)
199 .added_by_inbox_id("alice")
200 .dm_id(Some(dm_id.to_string()))
201 .build()
202 .unwrap();
203 latest_dm.store(conn).unwrap();
204
205 let different_dm = StoredGroup::builder()
207 .id(rand_vec::<24>())
208 .created_at_ns(base_time + 500_000)
209 .last_message_ns(base_time + 500_000)
210 .membership_state(GroupMembershipState::Allowed)
211 .added_by_inbox_id("charlie")
212 .dm_id(Some("dm:charlie:dave".to_string()))
213 .build()
214 .unwrap();
215 different_dm.store(conn).unwrap();
216
217 let regular_group = StoredGroup::builder()
219 .id(rand_vec::<24>())
220 .created_at_ns(base_time + 1_500_000)
221 .last_message_ns(base_time + 1_500_000)
222 .membership_state(GroupMembershipState::Allowed)
223 .added_by_inbox_id("alice")
224 .dm_id(None) .build()
226 .unwrap();
227 regular_group.store(conn).unwrap();
228
229 let deduplicated_groups = conn
231 .find_groups(GroupQueryArgs {
232 include_duplicate_dms: false,
233 ..Default::default()
234 })
235 .unwrap();
236
237 assert_eq!(deduplicated_groups.len(), 3);
239
240 let kept_dm = deduplicated_groups
242 .iter()
243 .find(|g| g.dm_id.as_deref() == Some(dm_id))
244 .expect("Should find the DM group");
245 assert_eq!(kept_dm.id, latest_dm.id);
246 assert_eq!(kept_dm.last_message_ns, Some(base_time + 2_000_000));
247
248 let kept_different_dm = deduplicated_groups
250 .iter()
251 .find(|g| g.dm_id.as_deref() == Some("dm:charlie:dave"))
252 .expect("Should find the different DM group");
253 assert_eq!(kept_different_dm.id, different_dm.id);
254
255 let kept_regular = deduplicated_groups
257 .iter()
258 .find(|g| g.dm_id.is_none())
259 .expect("Should find the regular group");
260 assert_eq!(kept_regular.id, regular_group.id);
261
262 let all_groups = conn
264 .find_groups(GroupQueryArgs {
265 include_duplicate_dms: true,
266 ..Default::default()
267 })
268 .unwrap();
269
270 assert_eq!(all_groups.len(), 5);
272 })
273 }
274}