1use super::{ConnectionExt, db_connection::DbConnection, schema::tasks};
2use crate::StorageError;
3use derive_builder::Builder;
4use diesel::prelude::*;
5use prost::Message;
6use xmtp_common::{NS_IN_DAY, NS_IN_SEC, time::now_ns};
7use xmtp_proto::xmtp::mls::database::Task as TaskProto;
8
9#[derive(Queryable, Identifiable, Debug, Clone)]
10#[diesel(table_name = tasks)]
11#[diesel(primary_key(id))]
12pub struct Task {
13 pub id: i32,
14 pub originating_message_sequence_id: i64,
15 pub originating_message_originator_id: i32,
16 pub created_at_ns: i64,
17 pub expires_at_ns: i64,
18 pub attempts: i32,
19 pub max_attempts: i32,
20 pub last_attempted_at_ns: i64,
21 pub backoff_scaling_factor: f32,
22 pub max_backoff_duration_ns: i64,
23 pub initial_backoff_duration_ns: i64,
24 pub next_attempt_at_ns: i64,
25 pub data_hash: Vec<u8>,
26 pub data: Vec<u8>,
27}
28
29#[derive(Insertable, Debug, PartialEq, Clone, Builder)]
30#[diesel(table_name = tasks)]
31#[builder(build_fn(skip))]
32pub struct NewTask {
33 pub originating_message_sequence_id: i64,
34 pub originating_message_originator_id: i32,
35 pub created_at_ns: i64,
36 pub expires_at_ns: i64,
37 pub attempts: i32,
38 pub max_attempts: i32,
39 pub last_attempted_at_ns: i64,
40 pub backoff_scaling_factor: f32,
41 pub max_backoff_duration_ns: i64,
42 pub initial_backoff_duration_ns: i64,
43 pub next_attempt_at_ns: i64,
44 #[builder(setter(skip))]
45 pub data_hash: Vec<u8>,
46 #[builder(setter(skip))]
47 pub data: Vec<u8>,
48}
49
50impl NewTask {
51 pub fn builder() -> NewTaskBuilder {
52 NewTaskBuilder::default()
53 }
54}
55
56impl NewTaskBuilder {
57 pub fn build(&mut self, task: TaskProto) -> Result<NewTask, StorageError> {
58 use derive_builder::UninitializedFieldError;
59 let err = |s: &'static str| UninitializedFieldError::new(s);
60 let data = task.encode_to_vec();
61 let data_hash = xmtp_common::sha256_bytes(&data);
62 let new_task = NewTask {
63 originating_message_sequence_id: self
64 .originating_message_sequence_id
65 .ok_or_else(|| err("originating_message_sequence_id"))?,
66 originating_message_originator_id: self
67 .originating_message_originator_id
68 .ok_or_else(|| err("originating_message_originator_id"))?,
69 created_at_ns: self.created_at_ns.unwrap_or_else(now_ns),
70 expires_at_ns: self
71 .expires_at_ns
72 .unwrap_or_else(|| now_ns() + NS_IN_DAY * 3),
73 attempts: self.attempts.unwrap_or(0),
74 max_attempts: self.max_attempts.unwrap_or(20),
75 last_attempted_at_ns: self.last_attempted_at_ns.unwrap_or_else(now_ns),
76 backoff_scaling_factor: self.backoff_scaling_factor.unwrap_or(1.5),
77 max_backoff_duration_ns: self.max_backoff_duration_ns.unwrap_or(60 * NS_IN_SEC),
78 initial_backoff_duration_ns: self.initial_backoff_duration_ns.unwrap_or(2 * NS_IN_SEC),
79 next_attempt_at_ns: self.next_attempt_at_ns.unwrap_or_else(now_ns),
80 data_hash,
81 data,
82 };
83 Ok(new_task)
84 }
85}
86
87pub trait QueryTasks {
90 fn create_task(&self, task: NewTask) -> Result<Task, StorageError>;
91
92 fn get_tasks(&self) -> Result<Vec<Task>, StorageError>;
93
94 fn get_next_task(&self) -> Result<Option<Task>, StorageError>;
95
96 fn update_task(
97 &self,
98 id: i32,
99 attempts: i32,
100 last_attempted_at_ns: i64,
101 next_attempt_at_ns: i64,
102 ) -> Result<Task, StorageError>;
103
104 fn delete_task(&self, id: i32) -> Result<bool, StorageError>;
105}
106
107impl<T: QueryTasks> QueryTasks for &'_ T {
108 fn create_task(&self, task: NewTask) -> Result<Task, StorageError> {
109 (**self).create_task(task)
110 }
111
112 fn get_tasks(&self) -> Result<Vec<Task>, StorageError> {
113 (**self).get_tasks()
114 }
115
116 fn get_next_task(&self) -> Result<Option<Task>, StorageError> {
117 (**self).get_next_task()
118 }
119
120 fn update_task(
121 &self,
122 id: i32,
123 attempts: i32,
124 last_attempted_at_ns: i64,
125 next_attempt_at_ns: i64,
126 ) -> Result<Task, StorageError> {
127 (**self).update_task(id, attempts, last_attempted_at_ns, next_attempt_at_ns)
128 }
129
130 fn delete_task(&self, id: i32) -> Result<bool, StorageError> {
131 (**self).delete_task(id)
132 }
133}
134
135impl<C: ConnectionExt> QueryTasks for DbConnection<C> {
136 fn create_task(&self, task: NewTask) -> Result<Task, StorageError> {
137 self.raw_query_write(|conn| {
138 diesel::insert_into(tasks::table)
139 .values(task)
140 .get_result::<Task>(conn)
141 })
142 .map_err(Into::into)
143 }
144
145 fn get_tasks(&self) -> Result<Vec<Task>, StorageError> {
146 self.raw_query_read(|conn| tasks::table.load::<Task>(conn))
147 .map_err(Into::into)
148 }
149
150 fn get_next_task(&self) -> Result<Option<Task>, StorageError> {
151 self.raw_query_read(|conn| {
152 tasks::table
153 .order(tasks::next_attempt_at_ns)
154 .first::<Task>(conn)
155 .optional()
156 })
157 .map_err(Into::into)
158 }
159
160 fn update_task(
161 &self,
162 id: i32,
163 attempts: i32,
164 last_attempted_at_ns: i64,
165 next_attempt_at_ns: i64,
166 ) -> Result<Task, StorageError> {
167 self.raw_query_write(|conn| {
168 diesel::update(tasks::table.filter(tasks::id.eq(id)))
169 .set((
170 tasks::attempts.eq(attempts),
171 tasks::last_attempted_at_ns.eq(last_attempted_at_ns),
172 tasks::next_attempt_at_ns.eq(next_attempt_at_ns),
173 ))
174 .get_result::<Task>(conn)
175 })
176 .map_err(Into::into)
177 }
178
179 fn delete_task(&self, id: i32) -> Result<bool, StorageError> {
180 let num_deleted = self.raw_query_write(|conn| {
181 diesel::delete(tasks::table.filter(tasks::id.eq(id))).execute(conn)
182 })?;
183 Ok(num_deleted == 1)
184 }
185}
186
187#[cfg(test)]
188pub(crate) mod tests {
189 #[cfg(target_arch = "wasm32")]
190 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
191
192 use super::*;
193 use crate::test_utils::with_connection;
194
195 #[xmtp_common::test]
196 fn get_tasks_returns_empty_list_initially() {
197 with_connection(|conn| {
198 let tasks = conn.get_tasks().unwrap();
199 assert!(tasks.is_empty());
200 })
201 }
202
203 #[xmtp_common::test]
204 fn update_task_returns_error_when_not_found() {
205 with_connection(|conn| {
206 let result = conn.update_task(999, 5, 1000, 2000);
208 assert!(result.is_err());
210 })
211 }
212
213 #[xmtp_common::test]
214 fn delete_task_returns_false_when_not_found() {
215 with_connection(|conn| {
216 let deleted = conn.delete_task(999).unwrap();
217 assert!(!deleted);
218 })
219 }
220
221 fn gen_task_data() -> TaskProto {
223 TaskProto {
224 task: Some(
225 xmtp_proto::xmtp::mls::database::task::Task::ProcessWelcomePointer(
226 xmtp_proto::xmtp::mls::message_contents::WelcomePointer {
227 version: Some(xmtp_proto::xmtp::mls::message_contents::welcome_pointer::Version::WelcomeV1Pointer(xmtp_proto::xmtp::mls::message_contents::welcome_pointer::WelcomeV1Pointer {
228 destination: xmtp_common::rand_vec::<32>(),
229 aead_type: xmtp_proto::xmtp::mls::message_contents::WelcomePointeeEncryptionAeadType::Chacha20Poly1305.into(),
230 encryption_key: xmtp_common::rand_vec::<32>(),
231 data_nonce: xmtp_common::rand_vec::<12>(),
232 welcome_metadata_nonce: xmtp_common::rand_vec::<12>(),
233 })),
234 },
235 ),
236 ),
237 }
238 }
239
240 #[xmtp_common::test]
241 fn all_task_operations_work_together() {
242 with_connection(|conn| {
243 let now = xmtp_common::time::now_ns();
244
245 let task1 = NewTaskBuilder::default()
247 .originating_message_sequence_id(1)
248 .originating_message_originator_id(1)
249 .created_at_ns(now)
250 .expires_at_ns(now + 3_600_000_000_000)
251 .attempts(0)
252 .max_attempts(5)
253 .last_attempted_at_ns(0)
254 .backoff_scaling_factor(1.5)
255 .max_backoff_duration_ns(600_000_000_000)
256 .initial_backoff_duration_ns(2_000_000_000)
257 .next_attempt_at_ns(now + 1000) .build(gen_task_data())
259 .unwrap();
260
261 let task2 = NewTaskBuilder::default()
263 .originating_message_sequence_id(2)
264 .originating_message_originator_id(1)
265 .created_at_ns(now)
266 .expires_at_ns(now + 7_200_000_000_000) .attempts(0)
268 .max_attempts(3)
269 .last_attempted_at_ns(0)
270 .backoff_scaling_factor(2.0)
271 .max_backoff_duration_ns(300_000_000_000)
272 .initial_backoff_duration_ns(1_000_000_000)
273 .next_attempt_at_ns(now + 500) .build(gen_task_data())
275 .unwrap();
276
277 assert!(conn.get_next_task().unwrap().is_none());
279 assert!(conn.get_tasks().unwrap().is_empty());
280
281 let created_task1 = conn.create_task(task1).unwrap();
283 let created_task2 = conn.create_task(task2).unwrap();
284
285 let task1_id = created_task1.id;
286 let task2_id = created_task2.id;
287 assert!(task1_id >= 0, "task1_id: {task1_id}");
288 assert!(task2_id >= 0, "task2_id: {task2_id}");
289 assert_ne!(task1_id, task2_id);
290
291 let all_tasks = conn.get_tasks().unwrap();
293 assert_eq!(all_tasks.len(), 2);
294
295 let next_task = conn.get_next_task().unwrap();
297 assert!(next_task.is_some());
298 let next_task = next_task.unwrap();
299 assert_eq!(next_task.id, task2_id);
300 assert_eq!(next_task.next_attempt_at_ns, now + 500);
301
302 let updated_task1 = conn
304 .update_task(
305 task1_id,
306 1, now + 2000, now + 200, )
310 .unwrap();
311
312 assert_eq!(updated_task1.id, task1_id);
314 assert_eq!(updated_task1.attempts, 1);
315 assert_eq!(updated_task1.next_attempt_at_ns, now + 200);
316
317 let next_task = conn.get_next_task().unwrap();
319 assert!(next_task.is_some());
320 let next_task = next_task.unwrap();
321 assert_eq!(next_task.id, task1_id);
322 assert_eq!(next_task.next_attempt_at_ns, now + 200);
323
324 let all_tasks_after_update = conn.get_tasks().unwrap();
326 assert_eq!(all_tasks_after_update.len(), 2);
327
328 let updated_task1_in_list = all_tasks_after_update
330 .iter()
331 .find(|t| t.id == task1_id)
332 .unwrap();
333 let task2_in_list = all_tasks_after_update
334 .iter()
335 .find(|t| t.id == task2_id)
336 .unwrap();
337
338 assert_eq!(updated_task1_in_list.attempts, 1);
339 assert_eq!(updated_task1_in_list.next_attempt_at_ns, now + 200);
340 assert_eq!(task2_in_list.attempts, 0);
341 assert_eq!(task2_in_list.next_attempt_at_ns, now + 500);
342
343 let deleted = conn.delete_task(task1_id).unwrap();
345 assert!(deleted);
346
347 let next_task = conn.get_next_task().unwrap();
349 assert!(next_task.is_some());
350 let next_task = next_task.unwrap();
351 assert_eq!(next_task.id, task2_id);
352
353 let remaining_tasks = conn.get_tasks().unwrap();
355 assert_eq!(remaining_tasks.len(), 1);
356 assert_eq!(remaining_tasks[0].id, task2_id);
357
358 let deleted = conn.delete_task(task2_id).unwrap();
360 assert!(deleted);
361
362 let all_tasks_after_delete = conn.get_tasks().unwrap();
364 assert!(all_tasks_after_delete.is_empty());
365 assert!(conn.get_next_task().unwrap().is_none());
366
367 let deleted_again = conn.delete_task(task1_id).unwrap();
369 assert!(!deleted_again);
370 })
371 }
372}