xmtp_db/encrypted_store/
tasks.rs

1use super::{ConnectionExt, db_connection::DbConnection, schema::tasks};
2use crate::StorageError;
3use derive_builder::Builder;
4use diesel::prelude::*;
5use prost::Message;
6use xmtp_common::{NS_IN_DAY, NS_IN_SEC, time::now_ns};
7use xmtp_proto::xmtp::mls::database::Task as TaskProto;
8
9#[derive(Queryable, Identifiable, Debug, Clone)]
10#[diesel(table_name = tasks)]
11#[diesel(primary_key(id))]
12pub struct Task {
13    pub id: i32,
14    pub originating_message_sequence_id: i64,
15    pub originating_message_originator_id: i32,
16    pub created_at_ns: i64,
17    pub expires_at_ns: i64,
18    pub attempts: i32,
19    pub max_attempts: i32,
20    pub last_attempted_at_ns: i64,
21    pub backoff_scaling_factor: f32,
22    pub max_backoff_duration_ns: i64,
23    pub initial_backoff_duration_ns: i64,
24    pub next_attempt_at_ns: i64,
25    pub data_hash: Vec<u8>,
26    pub data: Vec<u8>,
27}
28
29#[derive(Insertable, Debug, PartialEq, Clone, Builder)]
30#[diesel(table_name = tasks)]
31#[builder(build_fn(skip))]
32pub struct NewTask {
33    pub originating_message_sequence_id: i64,
34    pub originating_message_originator_id: i32,
35    pub created_at_ns: i64,
36    pub expires_at_ns: i64,
37    pub attempts: i32,
38    pub max_attempts: i32,
39    pub last_attempted_at_ns: i64,
40    pub backoff_scaling_factor: f32,
41    pub max_backoff_duration_ns: i64,
42    pub initial_backoff_duration_ns: i64,
43    pub next_attempt_at_ns: i64,
44    #[builder(setter(skip))]
45    pub data_hash: Vec<u8>,
46    #[builder(setter(skip))]
47    pub data: Vec<u8>,
48}
49
50impl NewTask {
51    pub fn builder() -> NewTaskBuilder {
52        NewTaskBuilder::default()
53    }
54}
55
56impl NewTaskBuilder {
57    pub fn build(&mut self, task: TaskProto) -> Result<NewTask, StorageError> {
58        use derive_builder::UninitializedFieldError;
59        let err = |s: &'static str| UninitializedFieldError::new(s);
60        let data = task.encode_to_vec();
61        let data_hash = xmtp_common::sha256_bytes(&data);
62        let new_task = NewTask {
63            originating_message_sequence_id: self
64                .originating_message_sequence_id
65                .ok_or_else(|| err("originating_message_sequence_id"))?,
66            originating_message_originator_id: self
67                .originating_message_originator_id
68                .ok_or_else(|| err("originating_message_originator_id"))?,
69            created_at_ns: self.created_at_ns.unwrap_or_else(now_ns),
70            expires_at_ns: self
71                .expires_at_ns
72                .unwrap_or_else(|| now_ns() + NS_IN_DAY * 3),
73            attempts: self.attempts.unwrap_or(0),
74            max_attempts: self.max_attempts.unwrap_or(20),
75            last_attempted_at_ns: self.last_attempted_at_ns.unwrap_or_else(now_ns),
76            backoff_scaling_factor: self.backoff_scaling_factor.unwrap_or(1.5),
77            max_backoff_duration_ns: self.max_backoff_duration_ns.unwrap_or(60 * NS_IN_SEC),
78            initial_backoff_duration_ns: self.initial_backoff_duration_ns.unwrap_or(2 * NS_IN_SEC),
79            next_attempt_at_ns: self.next_attempt_at_ns.unwrap_or_else(now_ns),
80            data_hash,
81            data,
82        };
83        Ok(new_task)
84    }
85}
86
87// impl_store_or_ignore!(Task, tasks);
88
89pub trait QueryTasks {
90    fn create_task(&self, task: NewTask) -> Result<Task, StorageError>;
91
92    fn get_tasks(&self) -> Result<Vec<Task>, StorageError>;
93
94    fn get_next_task(&self) -> Result<Option<Task>, StorageError>;
95
96    fn update_task(
97        &self,
98        id: i32,
99        attempts: i32,
100        last_attempted_at_ns: i64,
101        next_attempt_at_ns: i64,
102    ) -> Result<Task, StorageError>;
103
104    fn delete_task(&self, id: i32) -> Result<bool, StorageError>;
105}
106
107impl<T: QueryTasks> QueryTasks for &'_ T {
108    fn create_task(&self, task: NewTask) -> Result<Task, StorageError> {
109        (**self).create_task(task)
110    }
111
112    fn get_tasks(&self) -> Result<Vec<Task>, StorageError> {
113        (**self).get_tasks()
114    }
115
116    fn get_next_task(&self) -> Result<Option<Task>, StorageError> {
117        (**self).get_next_task()
118    }
119
120    fn update_task(
121        &self,
122        id: i32,
123        attempts: i32,
124        last_attempted_at_ns: i64,
125        next_attempt_at_ns: i64,
126    ) -> Result<Task, StorageError> {
127        (**self).update_task(id, attempts, last_attempted_at_ns, next_attempt_at_ns)
128    }
129
130    fn delete_task(&self, id: i32) -> Result<bool, StorageError> {
131        (**self).delete_task(id)
132    }
133}
134
135impl<C: ConnectionExt> QueryTasks for DbConnection<C> {
136    fn create_task(&self, task: NewTask) -> Result<Task, StorageError> {
137        self.raw_query_write(|conn| {
138            diesel::insert_into(tasks::table)
139                .values(task)
140                .get_result::<Task>(conn)
141        })
142        .map_err(Into::into)
143    }
144
145    fn get_tasks(&self) -> Result<Vec<Task>, StorageError> {
146        self.raw_query_read(|conn| tasks::table.load::<Task>(conn))
147            .map_err(Into::into)
148    }
149
150    fn get_next_task(&self) -> Result<Option<Task>, StorageError> {
151        self.raw_query_read(|conn| {
152            tasks::table
153                .order(tasks::next_attempt_at_ns)
154                .first::<Task>(conn)
155                .optional()
156        })
157        .map_err(Into::into)
158    }
159
160    fn update_task(
161        &self,
162        id: i32,
163        attempts: i32,
164        last_attempted_at_ns: i64,
165        next_attempt_at_ns: i64,
166    ) -> Result<Task, StorageError> {
167        self.raw_query_write(|conn| {
168            diesel::update(tasks::table.filter(tasks::id.eq(id)))
169                .set((
170                    tasks::attempts.eq(attempts),
171                    tasks::last_attempted_at_ns.eq(last_attempted_at_ns),
172                    tasks::next_attempt_at_ns.eq(next_attempt_at_ns),
173                ))
174                .get_result::<Task>(conn)
175        })
176        .map_err(Into::into)
177    }
178
179    fn delete_task(&self, id: i32) -> Result<bool, StorageError> {
180        let num_deleted = self.raw_query_write(|conn| {
181            diesel::delete(tasks::table.filter(tasks::id.eq(id))).execute(conn)
182        })?;
183        Ok(num_deleted == 1)
184    }
185}
186
187#[cfg(test)]
188pub(crate) mod tests {
189    #[cfg(target_arch = "wasm32")]
190    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
191
192    use super::*;
193    use crate::test_utils::with_connection;
194
195    #[xmtp_common::test]
196    fn get_tasks_returns_empty_list_initially() {
197        with_connection(|conn| {
198            let tasks = conn.get_tasks().unwrap();
199            assert!(tasks.is_empty());
200        })
201    }
202
203    #[xmtp_common::test]
204    fn update_task_returns_error_when_not_found() {
205        with_connection(|conn| {
206            // Try to update a task that doesn't exist
207            let result = conn.update_task(999, 5, 1000, 2000);
208            // The update should fail when the task doesn't exist
209            assert!(result.is_err());
210        })
211    }
212
213    #[xmtp_common::test]
214    fn delete_task_returns_false_when_not_found() {
215        with_connection(|conn| {
216            let deleted = conn.delete_task(999).unwrap();
217            assert!(!deleted);
218        })
219    }
220
221    // Generate a random task data for testing to ensure that the hashes are unique
222    fn gen_task_data() -> TaskProto {
223        TaskProto {
224            task: Some(
225                xmtp_proto::xmtp::mls::database::task::Task::ProcessWelcomePointer(
226                    xmtp_proto::xmtp::mls::message_contents::WelcomePointer {
227                        version: Some(xmtp_proto::xmtp::mls::message_contents::welcome_pointer::Version::WelcomeV1Pointer(xmtp_proto::xmtp::mls::message_contents::welcome_pointer::WelcomeV1Pointer {
228                            destination: xmtp_common::rand_vec::<32>(),
229                            aead_type: xmtp_proto::xmtp::mls::message_contents::WelcomePointeeEncryptionAeadType::Chacha20Poly1305.into(),
230                            encryption_key: xmtp_common::rand_vec::<32>(),
231                            data_nonce: xmtp_common::rand_vec::<12>(),
232                            welcome_metadata_nonce: xmtp_common::rand_vec::<12>(),
233                        })),
234                    },
235                ),
236            ),
237        }
238    }
239
240    #[xmtp_common::test]
241    fn all_task_operations_work_together() {
242        with_connection(|conn| {
243            let now = xmtp_common::time::now_ns();
244
245            // 1. Create first task (should be next to run)
246            let task1 = NewTaskBuilder::default()
247                .originating_message_sequence_id(1)
248                .originating_message_originator_id(1)
249                .created_at_ns(now)
250                .expires_at_ns(now + 3_600_000_000_000)
251                .attempts(0)
252                .max_attempts(5)
253                .last_attempted_at_ns(0)
254                .backoff_scaling_factor(1.5)
255                .max_backoff_duration_ns(600_000_000_000)
256                .initial_backoff_duration_ns(2_000_000_000)
257                .next_attempt_at_ns(now + 1000) // Later attempt time
258                .build(gen_task_data())
259                .unwrap();
260
261            // 2. Create second task (should be first to run)
262            let task2 = NewTaskBuilder::default()
263                .originating_message_sequence_id(2)
264                .originating_message_originator_id(1)
265                .created_at_ns(now)
266                .expires_at_ns(now + 7_200_000_000_000) // 2 hours from now
267                .attempts(0)
268                .max_attempts(3)
269                .last_attempted_at_ns(0)
270                .backoff_scaling_factor(2.0)
271                .max_backoff_duration_ns(300_000_000_000)
272                .initial_backoff_duration_ns(1_000_000_000)
273                .next_attempt_at_ns(now + 500) // Earlier attempt time - should be next
274                .build(gen_task_data())
275                .unwrap();
276
277            // 3. Verify no tasks initially
278            assert!(conn.get_next_task().unwrap().is_none());
279            assert!(conn.get_tasks().unwrap().is_empty());
280
281            // 4. Create both tasks
282            let created_task1 = conn.create_task(task1).unwrap();
283            let created_task2 = conn.create_task(task2).unwrap();
284
285            let task1_id = created_task1.id;
286            let task2_id = created_task2.id;
287            assert!(task1_id >= 0, "task1_id: {task1_id}");
288            assert!(task2_id >= 0, "task2_id: {task2_id}");
289            assert_ne!(task1_id, task2_id);
290
291            // 5. Verify both tasks appear in get_tasks
292            let all_tasks = conn.get_tasks().unwrap();
293            assert_eq!(all_tasks.len(), 2);
294
295            // 6. Verify get_next_task returns the task with earlier next_attempt_at_ns (task2)
296            let next_task = conn.get_next_task().unwrap();
297            assert!(next_task.is_some());
298            let next_task = next_task.unwrap();
299            assert_eq!(next_task.id, task2_id);
300            assert_eq!(next_task.next_attempt_at_ns, now + 500);
301
302            // 7. Update task1 to have an even earlier next_attempt_at_ns
303            let updated_task1 = conn
304                .update_task(
305                    task1_id,
306                    1,          // attempts
307                    now + 2000, // last_attempted_at_ns
308                    now + 200,  // next_attempt_at_ns - now earliest
309                )
310                .unwrap();
311
312            // Verify the update
313            assert_eq!(updated_task1.id, task1_id);
314            assert_eq!(updated_task1.attempts, 1);
315            assert_eq!(updated_task1.next_attempt_at_ns, now + 200);
316
317            // 8. Verify get_next_task now returns task1 (earliest next_attempt_at_ns)
318            let next_task = conn.get_next_task().unwrap();
319            assert!(next_task.is_some());
320            let next_task = next_task.unwrap();
321            assert_eq!(next_task.id, task1_id);
322            assert_eq!(next_task.next_attempt_at_ns, now + 200);
323
324            // 9. Verify both tasks appear in get_tasks with correct data
325            let all_tasks_after_update = conn.get_tasks().unwrap();
326            assert_eq!(all_tasks_after_update.len(), 2);
327
328            // Find each task by ID
329            let updated_task1_in_list = all_tasks_after_update
330                .iter()
331                .find(|t| t.id == task1_id)
332                .unwrap();
333            let task2_in_list = all_tasks_after_update
334                .iter()
335                .find(|t| t.id == task2_id)
336                .unwrap();
337
338            assert_eq!(updated_task1_in_list.attempts, 1);
339            assert_eq!(updated_task1_in_list.next_attempt_at_ns, now + 200);
340            assert_eq!(task2_in_list.attempts, 0);
341            assert_eq!(task2_in_list.next_attempt_at_ns, now + 500);
342
343            // 10. Delete task1
344            let deleted = conn.delete_task(task1_id).unwrap();
345            assert!(deleted);
346
347            // 11. Verify get_next_task now returns task2
348            let next_task = conn.get_next_task().unwrap();
349            assert!(next_task.is_some());
350            let next_task = next_task.unwrap();
351            assert_eq!(next_task.id, task2_id);
352
353            // 12. Verify only task2 remains in get_tasks
354            let remaining_tasks = conn.get_tasks().unwrap();
355            assert_eq!(remaining_tasks.len(), 1);
356            assert_eq!(remaining_tasks[0].id, task2_id);
357
358            // 13. Delete task2
359            let deleted = conn.delete_task(task2_id).unwrap();
360            assert!(deleted);
361
362            // 14. Verify no tasks remain
363            let all_tasks_after_delete = conn.get_tasks().unwrap();
364            assert!(all_tasks_after_delete.is_empty());
365            assert!(conn.get_next_task().unwrap().is_none());
366
367            // 15. Verify delete returns false for non-existent task
368            let deleted_again = conn.delete_task(task1_id).unwrap();
369            assert!(!deleted_again);
370        })
371    }
372}