1use super::{
2 ConnectionExt, Sqlite,
3 db_connection::DbConnection,
4 group::ConversationType,
5 group_message::StoredGroupMessage,
6 schema::{
7 group_messages::dsl as group_messages_dsl,
8 groups::dsl as groups_dsl,
9 processed_device_sync_messages::{self, dsl},
10 },
11};
12use crate::{StorageError, impl_store, impl_store_or_ignore};
13use diesel::{
14 backend::Backend,
15 deserialize::{self, FromSql, FromSqlRow},
16 expression::AsExpression,
17 prelude::*,
18 serialize::{self, IsNull, Output, ToSql},
19 sql_types::Integer,
20};
21use serde::{Deserialize, Serialize};
22
23#[repr(i32)]
25#[derive(
26 Debug, Default, Copy, Clone, Serialize, Deserialize, Eq, PartialEq, AsExpression, FromSqlRow,
27)]
28#[diesel(sql_type = Integer)]
29pub enum DeviceSyncProcessingState {
30 #[default]
32 Pending = 0,
33 Processed = 1,
35 Failed = 2,
37}
38
39impl ToSql<Integer, Sqlite> for DeviceSyncProcessingState
40where
41 i32: ToSql<Integer, Sqlite>,
42{
43 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result {
44 out.set_value(*self as i32);
45 Ok(IsNull::No)
46 }
47}
48
49impl FromSql<Integer, Sqlite> for DeviceSyncProcessingState
50where
51 i32: FromSql<Integer, Sqlite>,
52{
53 fn from_sql(bytes: <Sqlite as Backend>::RawValue<'_>) -> deserialize::Result<Self> {
54 match i32::from_sql(bytes)? {
55 0 => Ok(DeviceSyncProcessingState::Pending),
56 1 => Ok(DeviceSyncProcessingState::Processed),
57 2 => Ok(DeviceSyncProcessingState::Failed),
58 x => Err(format!("Unrecognized variant {}", x).into()),
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Insertable, Identifiable, Queryable)]
64#[diesel(table_name = processed_device_sync_messages)]
65#[diesel(primary_key(message_id))]
66pub struct StoredProcessedDeviceSyncMessages {
67 pub message_id: Vec<u8>,
68 pub attempts: i32,
70 pub state: DeviceSyncProcessingState,
72}
73
74impl StoredProcessedDeviceSyncMessages {
75 pub fn new(message_id: Vec<u8>) -> Self {
77 Self {
78 message_id,
79 attempts: 0,
80 state: DeviceSyncProcessingState::Pending,
81 }
82 }
83}
84
85impl_store!(
86 StoredProcessedDeviceSyncMessages,
87 processed_device_sync_messages
88);
89impl_store_or_ignore!(
90 StoredProcessedDeviceSyncMessages,
91 processed_device_sync_messages
92);
93
94pub trait QueryDeviceSyncMessages {
95 fn unprocessed_sync_group_messages(&self) -> Result<Vec<StoredGroupMessage>, StorageError>;
96 fn sync_group_messages_paged(
97 &self,
98 offset: i64,
99 limit: i64,
100 ) -> Result<Vec<StoredGroupMessage>, StorageError>;
101 fn mark_device_sync_msg_as_processed(&self, message_id: &[u8]) -> Result<(), StorageError>;
103 fn increment_device_sync_msg_attempt(
107 &self,
108 message_id: &[u8],
109 max_attempts: i32,
110 ) -> Result<i32, StorageError>;
111}
112
113impl<T> QueryDeviceSyncMessages for &T
114where
115 T: QueryDeviceSyncMessages,
116{
117 fn unprocessed_sync_group_messages(&self) -> Result<Vec<StoredGroupMessage>, StorageError> {
118 (**self).unprocessed_sync_group_messages()
119 }
120
121 fn sync_group_messages_paged(
122 &self,
123 offset: i64,
124 limit: i64,
125 ) -> Result<Vec<StoredGroupMessage>, StorageError> {
126 (**self).sync_group_messages_paged(offset, limit)
127 }
128
129 fn mark_device_sync_msg_as_processed(&self, message_id: &[u8]) -> Result<(), StorageError> {
130 (**self).mark_device_sync_msg_as_processed(message_id)
131 }
132
133 fn increment_device_sync_msg_attempt(
134 &self,
135 message_id: &[u8],
136 max_attempts: i32,
137 ) -> Result<i32, StorageError> {
138 (**self).increment_device_sync_msg_attempt(message_id, max_attempts)
139 }
140}
141
142impl<C: ConnectionExt> QueryDeviceSyncMessages for DbConnection<C> {
143 fn unprocessed_sync_group_messages(&self) -> Result<Vec<StoredGroupMessage>, StorageError> {
144 let result = self.raw_query_read(|conn| {
145 group_messages_dsl::group_messages
146 .inner_join(groups_dsl::groups.on(group_messages_dsl::group_id.eq(groups_dsl::id)))
147 .filter(groups_dsl::conversation_type.eq(ConversationType::Sync))
148 .filter(
152 diesel::dsl::not(diesel::dsl::exists(
153 dsl::processed_device_sync_messages
154 .filter(dsl::message_id.eq(group_messages_dsl::id)),
155 ))
156 .or(diesel::dsl::exists(
157 dsl::processed_device_sync_messages
158 .filter(dsl::message_id.eq(group_messages_dsl::id))
159 .filter(dsl::state.eq(DeviceSyncProcessingState::Pending)),
160 )),
161 )
162 .select(group_messages_dsl::group_messages::all_columns())
163 .load::<StoredGroupMessage>(conn)
164 })?;
165 Ok(result)
166 }
167
168 fn sync_group_messages_paged(
169 &self,
170 offset: i64,
171 limit: i64,
172 ) -> Result<Vec<StoredGroupMessage>, StorageError> {
173 let result = self.raw_query_read(|conn| {
174 group_messages_dsl::group_messages
175 .inner_join(groups_dsl::groups.on(group_messages_dsl::group_id.eq(groups_dsl::id)))
176 .filter(groups_dsl::conversation_type.eq(ConversationType::Sync))
177 .select(group_messages_dsl::group_messages::all_columns())
178 .order_by(group_messages_dsl::sent_at_ns.desc())
179 .limit(limit)
180 .offset(offset)
181 .load::<StoredGroupMessage>(conn)
182 })?;
183 Ok(result)
184 }
185
186 fn mark_device_sync_msg_as_processed(&self, message_id: &[u8]) -> Result<(), StorageError> {
187 self.raw_query_write(|conn| {
188 diesel::insert_into(dsl::processed_device_sync_messages)
189 .values(StoredProcessedDeviceSyncMessages {
190 message_id: message_id.to_vec(),
191 attempts: 0,
192 state: DeviceSyncProcessingState::Processed,
193 })
194 .on_conflict(dsl::message_id)
195 .do_update()
196 .set(dsl::state.eq(DeviceSyncProcessingState::Processed))
197 .execute(conn)
198 })?;
199 Ok(())
200 }
201
202 fn increment_device_sync_msg_attempt(
203 &self,
204 message_id: &[u8],
205 max_attempts: i32,
206 ) -> Result<i32, StorageError> {
207 let attempts = self.raw_query_write(|conn| {
208 diesel::insert_into(dsl::processed_device_sync_messages)
210 .values(StoredProcessedDeviceSyncMessages {
211 message_id: message_id.to_vec(),
212 attempts: 1,
213 state: DeviceSyncProcessingState::Pending,
214 })
215 .on_conflict(dsl::message_id)
216 .do_update()
217 .set(dsl::attempts.eq(dsl::attempts + 1))
218 .execute(conn)?;
219
220 let record: StoredProcessedDeviceSyncMessages = dsl::processed_device_sync_messages
222 .find(message_id)
223 .first(conn)?;
224
225 if record.attempts >= max_attempts {
227 diesel::update(dsl::processed_device_sync_messages.find(message_id))
228 .set(dsl::state.eq(DeviceSyncProcessingState::Failed))
229 .execute(conn)?;
230 }
231
232 Ok(record.attempts)
233 })?;
234 Ok(attempts)
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use crate::{
242 Store,
243 group::{ConversationType, tests::generate_group},
244 group_message::tests::generate_message,
245 test_utils::with_connection,
246 };
247
248 #[xmtp_common::test(unwrap_try = true)]
249 fn it_marks_as_processed() {
250 with_connection(|conn| {
251 let mut group = generate_group(None);
252 group.conversation_type = ConversationType::Sync;
253 group.store(conn)?;
254
255 let mut group2 = generate_group(None);
256 group2.conversation_type = ConversationType::Sync;
257 group2.store(conn)?;
258
259 let message1 = generate_message(None, Some(&group.id), None, None, None, None);
260 message1.store(conn)?;
261 let message2 = generate_message(None, Some(&group2.id), None, None, None, None);
262 message2.store(conn)?;
263
264 let unprocessed = conn.unprocessed_sync_group_messages()?;
265 assert_eq!(unprocessed.len(), 2);
266
267 StoredProcessedDeviceSyncMessages::new(message2.id.clone()).store(conn)?;
269 let unprocessed = conn.unprocessed_sync_group_messages()?;
270 assert_eq!(unprocessed.len(), 2);
271
272 conn.mark_device_sync_msg_as_processed(&message2.id)?;
274
275 let unprocessed = conn.unprocessed_sync_group_messages()?;
276 assert_eq!(unprocessed.len(), 1);
277 })
278 }
279
280 #[xmtp_common::test(unwrap_try = true)]
281 fn it_stores_with_attempts_and_state() {
282 with_connection(|conn| {
283 let mut group = generate_group(None);
284 group.conversation_type = ConversationType::Sync;
285 group.store(conn)?;
286
287 let message = generate_message(None, Some(&group.id), None, None, None, None);
288 message.store(conn)?;
289
290 let stored = StoredProcessedDeviceSyncMessages::new(message.id.clone());
292 assert_eq!(stored.attempts, 0);
293 assert_eq!(stored.state, DeviceSyncProcessingState::Pending);
294 stored.store(conn)?;
295
296 let unprocessed = conn.unprocessed_sync_group_messages()?;
298 assert_eq!(unprocessed.len(), 1);
299
300 conn.mark_device_sync_msg_as_processed(&message.id)?;
302
303 let unprocessed = conn.unprocessed_sync_group_messages()?;
305 assert_eq!(unprocessed.len(), 0);
306 })
307 }
308
309 #[xmtp_common::test(unwrap_try = true)]
310 fn it_preserves_attempts_when_marking_as_processed() {
311 with_connection(|conn| {
312 let mut group = generate_group(None);
313 group.conversation_type = ConversationType::Sync;
314 group.store(conn)?;
315
316 let message = generate_message(None, Some(&group.id), None, None, None, None);
317 message.store(conn)?;
318
319 StoredProcessedDeviceSyncMessages::new(message.id.clone()).store(conn)?;
321
322 conn.increment_device_sync_msg_attempt(&message.id, 3)?;
324 conn.increment_device_sync_msg_attempt(&message.id, 3)?;
325
326 conn.mark_device_sync_msg_as_processed(&message.id)?;
328
329 let record: StoredProcessedDeviceSyncMessages = conn.raw_query_read(|c| {
331 dsl::processed_device_sync_messages
332 .find(&message.id)
333 .first(c)
334 })?;
335 assert_eq!(record.attempts, 2);
336 assert_eq!(record.state, DeviceSyncProcessingState::Processed);
337 })
338 }
339
340 #[xmtp_common::test(unwrap_try = true)]
341 fn it_increments_attempts_and_sets_failed_at_max() {
342 with_connection(|conn| {
343 let mut group = generate_group(None);
344 group.conversation_type = ConversationType::Sync;
345 group.store(conn)?;
346
347 let message = generate_message(None, Some(&group.id), None, None, None, None);
348 message.store(conn)?;
349
350 StoredProcessedDeviceSyncMessages::new(message.id.clone()).store(conn)?;
352
353 let attempts = conn.increment_device_sync_msg_attempt(&message.id, 3)?;
355 assert_eq!(attempts, 1);
356 let unprocessed = conn.unprocessed_sync_group_messages()?;
358 assert_eq!(unprocessed.len(), 1);
359
360 let attempts = conn.increment_device_sync_msg_attempt(&message.id, 3)?;
362 assert_eq!(attempts, 2);
363 let unprocessed = conn.unprocessed_sync_group_messages()?;
365 assert_eq!(unprocessed.len(), 1);
366
367 let attempts = conn.increment_device_sync_msg_attempt(&message.id, 3)?;
369 assert_eq!(attempts, 3);
370 let unprocessed = conn.unprocessed_sync_group_messages()?;
372 assert_eq!(unprocessed.len(), 0);
373 })
374 }
375
376 #[xmtp_common::test(unwrap_try = true)]
377 fn it_returns_sync_group_messages_paged() {
378 with_connection(|conn| {
379 let mut sync_group = generate_group(None);
380 sync_group.conversation_type = ConversationType::Sync;
381 sync_group.store(conn)?;
382
383 let mut dm_group = generate_group(None);
385 dm_group.conversation_type = ConversationType::Dm;
386 dm_group.store(conn)?;
387
388 let mut sync_message_ids = Vec::new();
391 for i in 0..5 {
392 let message = generate_message(
393 None,
394 Some(&sync_group.id),
395 Some(((5 - i) * 1000) as i64),
396 None,
397 None,
398 None,
399 );
400 message.store(conn)?;
401 sync_message_ids.push(message.id);
402 }
403
404 let dm_message = generate_message(None, Some(&dm_group.id), None, None, None, None);
406 dm_message.store(conn)?;
407
408 let page1 = conn.sync_group_messages_paged(0, 2)?;
410 assert_eq!(page1.len(), 2);
411 assert_eq!(page1[0].id, sync_message_ids[0]);
412 assert_eq!(page1[1].id, sync_message_ids[1]);
413
414 let page2 = conn.sync_group_messages_paged(2, 2)?;
416 assert_eq!(page2.len(), 2);
417 assert_eq!(page2[0].id, sync_message_ids[2]);
418 assert_eq!(page2[1].id, sync_message_ids[3]);
419
420 let page3 = conn.sync_group_messages_paged(4, 2)?;
422 assert_eq!(page3.len(), 1);
423 assert_eq!(page3[0].id, sync_message_ids[4]);
424
425 let page4 = conn.sync_group_messages_paged(10, 2)?;
427 assert_eq!(page4.len(), 0);
428
429 let all_messages = conn.sync_group_messages_paged(0, 100)?;
431 assert_eq!(all_messages.len(), 5);
432
433 for (i, msg) in all_messages.iter().enumerate() {
435 assert_eq!(msg.id, sync_message_ids[i]);
436 assert_eq!(msg.group_id, sync_group.id);
437 }
438 })
439 }
440}