xmtp_db/encrypted_store/
processed_device_sync_messages.rs

1use super::{
2    ConnectionExt, Sqlite,
3    db_connection::DbConnection,
4    group::ConversationType,
5    group_message::StoredGroupMessage,
6    schema::{
7        group_messages::dsl as group_messages_dsl,
8        groups::dsl as groups_dsl,
9        processed_device_sync_messages::{self, dsl},
10    },
11};
12use crate::{StorageError, impl_store, impl_store_or_ignore};
13use diesel::{
14    backend::Backend,
15    deserialize::{self, FromSql, FromSqlRow},
16    expression::AsExpression,
17    prelude::*,
18    serialize::{self, IsNull, Output, ToSql},
19    sql_types::Integer,
20};
21use serde::{Deserialize, Serialize};
22
23/// The state of a device sync message processing
24#[repr(i32)]
25#[derive(
26    Debug, Default, Copy, Clone, Serialize, Deserialize, Eq, PartialEq, AsExpression, FromSqlRow,
27)]
28#[diesel(sql_type = Integer)]
29pub enum DeviceSyncProcessingState {
30    /// Message is pending processing
31    #[default]
32    Pending = 0,
33    /// Message has been successfully processed
34    Processed = 1,
35    /// Message processing failed permanently
36    Failed = 2,
37}
38
39impl ToSql<Integer, Sqlite> for DeviceSyncProcessingState
40where
41    i32: ToSql<Integer, Sqlite>,
42{
43    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result {
44        out.set_value(*self as i32);
45        Ok(IsNull::No)
46    }
47}
48
49impl FromSql<Integer, Sqlite> for DeviceSyncProcessingState
50where
51    i32: FromSql<Integer, Sqlite>,
52{
53    fn from_sql(bytes: <Sqlite as Backend>::RawValue<'_>) -> deserialize::Result<Self> {
54        match i32::from_sql(bytes)? {
55            0 => Ok(DeviceSyncProcessingState::Pending),
56            1 => Ok(DeviceSyncProcessingState::Processed),
57            2 => Ok(DeviceSyncProcessingState::Failed),
58            x => Err(format!("Unrecognized variant {}", x).into()),
59        }
60    }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Insertable, Identifiable, Queryable)]
64#[diesel(table_name = processed_device_sync_messages)]
65#[diesel(primary_key(message_id))]
66pub struct StoredProcessedDeviceSyncMessages {
67    pub message_id: Vec<u8>,
68    /// Number of processing attempts remaining
69    pub attempts: i32,
70    /// Current processing state
71    pub state: DeviceSyncProcessingState,
72}
73
74impl StoredProcessedDeviceSyncMessages {
75    /// Create a new stored processed device sync message with default values
76    pub fn new(message_id: Vec<u8>) -> Self {
77        Self {
78            message_id,
79            attempts: 0,
80            state: DeviceSyncProcessingState::Pending,
81        }
82    }
83}
84
85impl_store!(
86    StoredProcessedDeviceSyncMessages,
87    processed_device_sync_messages
88);
89impl_store_or_ignore!(
90    StoredProcessedDeviceSyncMessages,
91    processed_device_sync_messages
92);
93
94pub trait QueryDeviceSyncMessages {
95    fn unprocessed_sync_group_messages(&self) -> Result<Vec<StoredGroupMessage>, StorageError>;
96    fn sync_group_messages_paged(
97        &self,
98        offset: i64,
99        limit: i64,
100    ) -> Result<Vec<StoredGroupMessage>, StorageError>;
101    /// Marks a device sync message as processed.
102    fn mark_device_sync_msg_as_processed(&self, message_id: &[u8]) -> Result<(), StorageError>;
103    /// Increments the attempt count for a device sync message.
104    /// If the attempt count reaches max_attempts, the state is set to Failed.
105    /// Returns the new attempt count.
106    fn increment_device_sync_msg_attempt(
107        &self,
108        message_id: &[u8],
109        max_attempts: i32,
110    ) -> Result<i32, StorageError>;
111}
112
113impl<T> QueryDeviceSyncMessages for &T
114where
115    T: QueryDeviceSyncMessages,
116{
117    fn unprocessed_sync_group_messages(&self) -> Result<Vec<StoredGroupMessage>, StorageError> {
118        (**self).unprocessed_sync_group_messages()
119    }
120
121    fn sync_group_messages_paged(
122        &self,
123        offset: i64,
124        limit: i64,
125    ) -> Result<Vec<StoredGroupMessage>, StorageError> {
126        (**self).sync_group_messages_paged(offset, limit)
127    }
128
129    fn mark_device_sync_msg_as_processed(&self, message_id: &[u8]) -> Result<(), StorageError> {
130        (**self).mark_device_sync_msg_as_processed(message_id)
131    }
132
133    fn increment_device_sync_msg_attempt(
134        &self,
135        message_id: &[u8],
136        max_attempts: i32,
137    ) -> Result<i32, StorageError> {
138        (**self).increment_device_sync_msg_attempt(message_id, max_attempts)
139    }
140}
141
142impl<C: ConnectionExt> QueryDeviceSyncMessages for DbConnection<C> {
143    fn unprocessed_sync_group_messages(&self) -> Result<Vec<StoredGroupMessage>, StorageError> {
144        let result = self.raw_query_read(|conn| {
145            group_messages_dsl::group_messages
146                .inner_join(groups_dsl::groups.on(group_messages_dsl::group_id.eq(groups_dsl::id)))
147                .filter(groups_dsl::conversation_type.eq(ConversationType::Sync))
148                // Include messages that either:
149                // 1. Don't have an entry in processed_device_sync_messages, OR
150                // 2. Have an entry with state = Pending
151                .filter(
152                    diesel::dsl::not(diesel::dsl::exists(
153                        dsl::processed_device_sync_messages
154                            .filter(dsl::message_id.eq(group_messages_dsl::id)),
155                    ))
156                    .or(diesel::dsl::exists(
157                        dsl::processed_device_sync_messages
158                            .filter(dsl::message_id.eq(group_messages_dsl::id))
159                            .filter(dsl::state.eq(DeviceSyncProcessingState::Pending)),
160                    )),
161                )
162                .select(group_messages_dsl::group_messages::all_columns())
163                .load::<StoredGroupMessage>(conn)
164        })?;
165        Ok(result)
166    }
167
168    fn sync_group_messages_paged(
169        &self,
170        offset: i64,
171        limit: i64,
172    ) -> Result<Vec<StoredGroupMessage>, StorageError> {
173        let result = self.raw_query_read(|conn| {
174            group_messages_dsl::group_messages
175                .inner_join(groups_dsl::groups.on(group_messages_dsl::group_id.eq(groups_dsl::id)))
176                .filter(groups_dsl::conversation_type.eq(ConversationType::Sync))
177                .select(group_messages_dsl::group_messages::all_columns())
178                .order_by(group_messages_dsl::sent_at_ns.desc())
179                .limit(limit)
180                .offset(offset)
181                .load::<StoredGroupMessage>(conn)
182        })?;
183        Ok(result)
184    }
185
186    fn mark_device_sync_msg_as_processed(&self, message_id: &[u8]) -> Result<(), StorageError> {
187        self.raw_query_write(|conn| {
188            diesel::insert_into(dsl::processed_device_sync_messages)
189                .values(StoredProcessedDeviceSyncMessages {
190                    message_id: message_id.to_vec(),
191                    attempts: 0,
192                    state: DeviceSyncProcessingState::Processed,
193                })
194                .on_conflict(dsl::message_id)
195                .do_update()
196                .set(dsl::state.eq(DeviceSyncProcessingState::Processed))
197                .execute(conn)
198        })?;
199        Ok(())
200    }
201
202    fn increment_device_sync_msg_attempt(
203        &self,
204        message_id: &[u8],
205        max_attempts: i32,
206    ) -> Result<i32, StorageError> {
207        let attempts = self.raw_query_write(|conn| {
208            // Upsert: insert with attempts=1 if no record exists, or increment attempts if it does
209            diesel::insert_into(dsl::processed_device_sync_messages)
210                .values(StoredProcessedDeviceSyncMessages {
211                    message_id: message_id.to_vec(),
212                    attempts: 1,
213                    state: DeviceSyncProcessingState::Pending,
214                })
215                .on_conflict(dsl::message_id)
216                .do_update()
217                .set(dsl::attempts.eq(dsl::attempts + 1))
218                .execute(conn)?;
219
220            // Get the updated record
221            let record: StoredProcessedDeviceSyncMessages = dsl::processed_device_sync_messages
222                .find(message_id)
223                .first(conn)?;
224
225            // If we've reached max attempts, set state to Failed
226            if record.attempts >= max_attempts {
227                diesel::update(dsl::processed_device_sync_messages.find(message_id))
228                    .set(dsl::state.eq(DeviceSyncProcessingState::Failed))
229                    .execute(conn)?;
230            }
231
232            Ok(record.attempts)
233        })?;
234        Ok(attempts)
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::{
242        Store,
243        group::{ConversationType, tests::generate_group},
244        group_message::tests::generate_message,
245        test_utils::with_connection,
246    };
247
248    #[xmtp_common::test(unwrap_try = true)]
249    fn it_marks_as_processed() {
250        with_connection(|conn| {
251            let mut group = generate_group(None);
252            group.conversation_type = ConversationType::Sync;
253            group.store(conn)?;
254
255            let mut group2 = generate_group(None);
256            group2.conversation_type = ConversationType::Sync;
257            group2.store(conn)?;
258
259            let message1 = generate_message(None, Some(&group.id), None, None, None, None);
260            message1.store(conn)?;
261            let message2 = generate_message(None, Some(&group2.id), None, None, None, None);
262            message2.store(conn)?;
263
264            let unprocessed = conn.unprocessed_sync_group_messages()?;
265            assert_eq!(unprocessed.len(), 2);
266
267            // Storing with Pending state still counts as unprocessed
268            StoredProcessedDeviceSyncMessages::new(message2.id.clone()).store(conn)?;
269            let unprocessed = conn.unprocessed_sync_group_messages()?;
270            assert_eq!(unprocessed.len(), 2);
271
272            // Setting state to Processed marks it as processed
273            conn.mark_device_sync_msg_as_processed(&message2.id)?;
274
275            let unprocessed = conn.unprocessed_sync_group_messages()?;
276            assert_eq!(unprocessed.len(), 1);
277        })
278    }
279
280    #[xmtp_common::test(unwrap_try = true)]
281    fn it_stores_with_attempts_and_state() {
282        with_connection(|conn| {
283            let mut group = generate_group(None);
284            group.conversation_type = ConversationType::Sync;
285            group.store(conn)?;
286
287            let message = generate_message(None, Some(&group.id), None, None, None, None);
288            message.store(conn)?;
289
290            // Store with default values (Pending state)
291            let stored = StoredProcessedDeviceSyncMessages::new(message.id.clone());
292            assert_eq!(stored.attempts, 0);
293            assert_eq!(stored.state, DeviceSyncProcessingState::Pending);
294            stored.store(conn)?;
295
296            // Pending state is still considered unprocessed
297            let unprocessed = conn.unprocessed_sync_group_messages()?;
298            assert_eq!(unprocessed.len(), 1);
299
300            // Update to Processed state using mark_device_sync_msg_as_processed
301            conn.mark_device_sync_msg_as_processed(&message.id)?;
302
303            // Now it's no longer in unprocessed
304            let unprocessed = conn.unprocessed_sync_group_messages()?;
305            assert_eq!(unprocessed.len(), 0);
306        })
307    }
308
309    #[xmtp_common::test(unwrap_try = true)]
310    fn it_preserves_attempts_when_marking_as_processed() {
311        with_connection(|conn| {
312            let mut group = generate_group(None);
313            group.conversation_type = ConversationType::Sync;
314            group.store(conn)?;
315
316            let message = generate_message(None, Some(&group.id), None, None, None, None);
317            message.store(conn)?;
318
319            // Store with Pending state
320            StoredProcessedDeviceSyncMessages::new(message.id.clone()).store(conn)?;
321
322            // Increment attempts a couple times
323            conn.increment_device_sync_msg_attempt(&message.id, 3)?;
324            conn.increment_device_sync_msg_attempt(&message.id, 3)?;
325
326            // Now mark as processed
327            conn.mark_device_sync_msg_as_processed(&message.id)?;
328
329            // Verify attempts are preserved (should be 2)
330            let record: StoredProcessedDeviceSyncMessages = conn.raw_query_read(|c| {
331                dsl::processed_device_sync_messages
332                    .find(&message.id)
333                    .first(c)
334            })?;
335            assert_eq!(record.attempts, 2);
336            assert_eq!(record.state, DeviceSyncProcessingState::Processed);
337        })
338    }
339
340    #[xmtp_common::test(unwrap_try = true)]
341    fn it_increments_attempts_and_sets_failed_at_max() {
342        with_connection(|conn| {
343            let mut group = generate_group(None);
344            group.conversation_type = ConversationType::Sync;
345            group.store(conn)?;
346
347            let message = generate_message(None, Some(&group.id), None, None, None, None);
348            message.store(conn)?;
349
350            // Store with default values (attempts = 0)
351            StoredProcessedDeviceSyncMessages::new(message.id.clone()).store(conn)?;
352
353            // Increment attempt 1
354            let attempts = conn.increment_device_sync_msg_attempt(&message.id, 3)?;
355            assert_eq!(attempts, 1);
356            // Still pending (below max)
357            let unprocessed = conn.unprocessed_sync_group_messages()?;
358            assert_eq!(unprocessed.len(), 1);
359
360            // Increment attempt 2
361            let attempts = conn.increment_device_sync_msg_attempt(&message.id, 3)?;
362            assert_eq!(attempts, 2);
363            // Still pending (below max)
364            let unprocessed = conn.unprocessed_sync_group_messages()?;
365            assert_eq!(unprocessed.len(), 1);
366
367            // Increment attempt 3 (reaches MAX_ATTEMPTS)
368            let attempts = conn.increment_device_sync_msg_attempt(&message.id, 3)?;
369            assert_eq!(attempts, 3);
370            // Should now be Failed and no longer in unprocessed
371            let unprocessed = conn.unprocessed_sync_group_messages()?;
372            assert_eq!(unprocessed.len(), 0);
373        })
374    }
375
376    #[xmtp_common::test(unwrap_try = true)]
377    fn it_returns_sync_group_messages_paged() {
378        with_connection(|conn| {
379            let mut sync_group = generate_group(None);
380            sync_group.conversation_type = ConversationType::Sync;
381            sync_group.store(conn)?;
382
383            // Create a non-sync group to verify filtering works
384            let mut dm_group = generate_group(None);
385            dm_group.conversation_type = ConversationType::Dm;
386            dm_group.store(conn)?;
387
388            // Create 5 messages in the sync group with specific sent_at_ns values
389            // Messages are ordered by sent_at_ns DESC, so we store IDs in reverse order
390            let mut sync_message_ids = Vec::new();
391            for i in 0..5 {
392                let message = generate_message(
393                    None,
394                    Some(&sync_group.id),
395                    Some(((5 - i) * 1000) as i64),
396                    None,
397                    None,
398                    None,
399                );
400                message.store(conn)?;
401                sync_message_ids.push(message.id);
402            }
403
404            // Create a message in the non-sync group (should be filtered out)
405            let dm_message = generate_message(None, Some(&dm_group.id), None, None, None, None);
406            dm_message.store(conn)?;
407
408            // Test pagination: get first 2 messages
409            let page1 = conn.sync_group_messages_paged(0, 2)?;
410            assert_eq!(page1.len(), 2);
411            assert_eq!(page1[0].id, sync_message_ids[0]);
412            assert_eq!(page1[1].id, sync_message_ids[1]);
413
414            // Test pagination: get next 2 messages
415            let page2 = conn.sync_group_messages_paged(2, 2)?;
416            assert_eq!(page2.len(), 2);
417            assert_eq!(page2[0].id, sync_message_ids[2]);
418            assert_eq!(page2[1].id, sync_message_ids[3]);
419
420            // Test pagination: get last message
421            let page3 = conn.sync_group_messages_paged(4, 2)?;
422            assert_eq!(page3.len(), 1);
423            assert_eq!(page3[0].id, sync_message_ids[4]);
424
425            // Test pagination: offset beyond available messages
426            let page4 = conn.sync_group_messages_paged(10, 2)?;
427            assert_eq!(page4.len(), 0);
428
429            // Test getting all messages at once
430            let all_messages = conn.sync_group_messages_paged(0, 100)?;
431            assert_eq!(all_messages.len(), 5);
432
433            // Verify all returned messages are in order and belong to the sync group
434            for (i, msg) in all_messages.iter().enumerate() {
435                assert_eq!(msg.id, sync_message_ids[i]);
436                assert_eq!(msg.group_id, sync_group.id);
437            }
438        })
439    }
440}