xmtp_db/encrypted_store/
group_intent.rs

1use std::collections::HashMap;
2
3use derive_builder::Builder;
4use diesel::{
5    backend::Backend,
6    connection::DefaultLoadingMode,
7    deserialize::{self, FromSql, FromSqlRow},
8    expression::AsExpression,
9    prelude::*,
10    serialize::{self, IsNull, Output, ToSql},
11    sql_types::Integer,
12};
13use itertools::Itertools;
14use serde::{Deserialize, Serialize};
15use xmtp_common::fmt;
16use xmtp_proto::types::Cursor;
17
18use super::{
19    ConnectionExt, Sqlite,
20    db_connection::DbConnection,
21    group,
22    schema::group_intents::{self, dsl},
23};
24use crate::{
25    Delete, NotFound, StorageError, group_message::QueryGroupMessage, impl_fetch, impl_store,
26};
27
28mod error;
29mod types;
30pub use error::*;
31pub use types::*;
32
33pub type ID = i32;
34
35#[repr(i32)]
36#[derive(Debug, Clone, Copy, PartialEq, Eq, AsExpression, FromSqlRow, Serialize, Deserialize)]
37#[diesel(sql_type = Integer)]
38pub enum IntentKind {
39    SendMessage = 1,
40    KeyUpdate = 2,
41    MetadataUpdate = 3,
42    UpdateGroupMembership = 4,
43    UpdateAdminList = 5,
44    UpdatePermission = 6,
45    ReaddInstallations = 7,
46}
47
48impl std::fmt::Display for IntentKind {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        let description = match self {
51            IntentKind::SendMessage => "SendMessage",
52            IntentKind::KeyUpdate => "KeyUpdate",
53            IntentKind::MetadataUpdate => "MetadataUpdate",
54            IntentKind::UpdateGroupMembership => "UpdateGroupMembership",
55            IntentKind::UpdateAdminList => "UpdateAdminList",
56            IntentKind::UpdatePermission => "UpdatePermission",
57            IntentKind::ReaddInstallations => "ReaddInstallations",
58        };
59        write!(f, "{}", description)
60    }
61}
62
63#[repr(i32)]
64#[derive(Debug, Clone, Copy, PartialEq, Eq, AsExpression, FromSqlRow)]
65#[diesel(sql_type = Integer)]
66pub enum IntentState {
67    ToPublish = 1,
68    Published = 2,
69    Committed = 3,
70    Error = 4,
71    Processed = 5,
72}
73
74#[derive(Queryable, Identifiable, PartialEq, Clone)]
75#[diesel(table_name = group_intents)]
76#[diesel(primary_key(id))]
77pub struct StoredGroupIntent {
78    pub id: ID,
79    pub kind: IntentKind,
80    pub group_id: group::ID,
81    pub data: Vec<u8>,
82    pub state: IntentState,
83    pub payload_hash: Option<Vec<u8>>,
84    pub post_commit_data: Option<Vec<u8>>,
85    pub publish_attempts: i32,
86    pub staged_commit: Option<Vec<u8>>,
87    pub published_in_epoch: Option<i64>,
88    pub should_push: bool,
89    pub sequence_id: Option<i64>,
90    pub originator_id: Option<i64>,
91}
92
93impl std::fmt::Debug for StoredGroupIntent {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(f, "StoredGroupIntent {{ ")?;
96        write!(f, "id: {}, ", self.id)?;
97        write!(f, "kind: {}, ", self.kind)?;
98        write!(
99            f,
100            "group_id: {}, ",
101            fmt::truncate_hex(hex::encode(&self.group_id))
102        )?;
103        write!(f, "data: {}, ", fmt::truncate_hex(hex::encode(&self.data)))?;
104        write!(f, "state: {:?}, ", self.state)?;
105        write!(
106            f,
107            "payload_hash: {:?}, ",
108            self.payload_hash
109                .as_ref()
110                .map(|h| fmt::truncate_hex(hex::encode(h)))
111        )?;
112        write!(
113            f,
114            "post_commit_data: {:?}, ",
115            self.post_commit_data
116                .as_ref()
117                .map(|d| fmt::truncate_hex(hex::encode(d)))
118        )?;
119        write!(f, "publish_attempts: {:?}, ", self.publish_attempts)?;
120        write!(
121            f,
122            "staged_commit: {:?}, ",
123            self.staged_commit
124                .as_ref()
125                .map(|c| fmt::truncate_hex(hex::encode(c)))
126        )?;
127        write!(f, "published_in_epoch: {:?} ", self.published_in_epoch)?;
128        write!(f, " }}")?;
129        Ok(())
130    }
131}
132
133impl_fetch!(StoredGroupIntent, group_intents, ID);
134
135impl<C: ConnectionExt> Delete<StoredGroupIntent> for DbConnection<C> {
136    type Key = ID;
137    fn delete(&self, key: ID) -> Result<usize, StorageError> {
138        Ok(self.raw_query_write(|raw_conn| {
139            diesel::delete(dsl::group_intents.find(key)).execute(raw_conn)
140        })?)
141    }
142}
143
144/// NewGroupIntent is the data needed to create a new group intent.
145/// Do not use this struct directly outside of the storage module.
146/// Use the `queue_intent` method on `MlsGroup` instead.
147#[derive(Insertable, Debug, PartialEq, Clone, Builder)]
148#[diesel(table_name = group_intents)]
149#[builder(setter(into), build_fn(error = "StorageError"))]
150pub struct NewGroupIntent {
151    pub kind: IntentKind,
152    pub group_id: Vec<u8>,
153    pub data: Vec<u8>,
154    pub should_push: bool,
155    #[builder(default = "IntentState::ToPublish")]
156    pub state: IntentState,
157}
158
159impl_store!(NewGroupIntent, group_intents);
160
161impl NewGroupIntent {
162    pub fn builder() -> NewGroupIntentBuilder {
163        NewGroupIntentBuilder::default()
164    }
165
166    pub fn new(kind: IntentKind, group_id: Vec<u8>, data: Vec<u8>, should_push: bool) -> Self {
167        Self {
168            kind,
169            group_id,
170            data,
171            state: IntentState::ToPublish,
172            should_push,
173        }
174    }
175}
176
177pub trait QueryGroupIntent {
178    fn insert_group_intent(
179        &self,
180        to_save: NewGroupIntent,
181    ) -> Result<StoredGroupIntent, crate::ConnectionError>;
182
183    // Query for group_intents by group_id, optionally filtering by state and kind
184    fn find_group_intents<Id: AsRef<[u8]>>(
185        &self,
186        group_id: Id,
187        allowed_states: Option<Vec<IntentState>>,
188        allowed_kinds: Option<Vec<IntentKind>>,
189    ) -> Result<Vec<StoredGroupIntent>, crate::ConnectionError>;
190
191    // Set the intent with the given ID to `Published` and set the payload hash. Optionally add
192    // `post_commit_data`
193    fn set_group_intent_published(
194        &self,
195        intent_id: ID,
196        payload_hash: &[u8],
197        post_commit_data: Option<Vec<u8>>,
198        staged_commit: Option<Vec<u8>>,
199        published_in_epoch: i64,
200    ) -> Result<(), StorageError>;
201
202    // Set the intent with the given ID to `Committed`
203    fn set_group_intent_committed(&self, intent_id: ID, cursor: Cursor)
204    -> Result<(), StorageError>;
205
206    // Set the intent with the given ID to `Committed`
207    fn set_group_intent_processed(&self, intent_id: ID) -> Result<(), StorageError>;
208
209    // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and
210    // `post_commit_data`
211    fn set_group_intent_to_publish(&self, intent_id: ID) -> Result<(), StorageError>;
212
213    /// Set the intent with the given ID to `Error`
214    fn set_group_intent_error(&self, intent_id: ID) -> Result<(), StorageError>;
215
216    // Simple lookup of intents by payload hash, meant to be used when processing messages off the
217    // network
218    fn find_group_intent_by_payload_hash(
219        &self,
220        payload_hash: &[u8],
221    ) -> Result<Option<StoredGroupIntent>, StorageError>;
222
223    /// find the commit message refresh state for each intent payload hash
224    fn find_dependant_commits<P: AsRef<[u8]>>(
225        &self,
226        payload_hashes: &[P],
227    ) -> Result<HashMap<PayloadHash, IntentDependency>, StorageError>;
228
229    fn increment_intent_publish_attempt_count(&self, intent_id: ID) -> Result<(), StorageError>;
230
231    fn set_group_intent_error_and_fail_msg(
232        &self,
233        intent: &StoredGroupIntent,
234        msg_id: Option<Vec<u8>>,
235    ) -> Result<(), StorageError>;
236}
237
238impl<T> QueryGroupIntent for &T
239where
240    T: QueryGroupIntent,
241{
242    fn insert_group_intent(
243        &self,
244        to_save: NewGroupIntent,
245    ) -> Result<StoredGroupIntent, crate::ConnectionError> {
246        (**self).insert_group_intent(to_save)
247    }
248
249    fn find_group_intents<Id: AsRef<[u8]>>(
250        &self,
251        group_id: Id,
252        allowed_states: Option<Vec<IntentState>>,
253        allowed_kinds: Option<Vec<IntentKind>>,
254    ) -> Result<Vec<StoredGroupIntent>, crate::ConnectionError> {
255        (**self).find_group_intents(group_id, allowed_states, allowed_kinds)
256    }
257
258    fn set_group_intent_published(
259        &self,
260        intent_id: ID,
261        payload_hash: &[u8],
262        post_commit_data: Option<Vec<u8>>,
263        staged_commit: Option<Vec<u8>>,
264        published_in_epoch: i64,
265    ) -> Result<(), StorageError> {
266        (**self).set_group_intent_published(
267            intent_id,
268            payload_hash,
269            post_commit_data,
270            staged_commit,
271            published_in_epoch,
272        )
273    }
274
275    fn set_group_intent_committed(
276        &self,
277        intent_id: ID,
278        cursor: Cursor,
279    ) -> Result<(), StorageError> {
280        (**self).set_group_intent_committed(intent_id, cursor)
281    }
282
283    fn set_group_intent_processed(&self, intent_id: ID) -> Result<(), StorageError> {
284        (**self).set_group_intent_processed(intent_id)
285    }
286
287    fn set_group_intent_to_publish(&self, intent_id: ID) -> Result<(), StorageError> {
288        (**self).set_group_intent_to_publish(intent_id)
289    }
290
291    fn set_group_intent_error(&self, intent_id: ID) -> Result<(), StorageError> {
292        (**self).set_group_intent_error(intent_id)
293    }
294
295    fn find_group_intent_by_payload_hash(
296        &self,
297        payload_hash: &[u8],
298    ) -> Result<Option<StoredGroupIntent>, StorageError> {
299        (**self).find_group_intent_by_payload_hash(payload_hash)
300    }
301
302    fn find_dependant_commits<P: AsRef<[u8]>>(
303        &self,
304        payload_hashes: &[P],
305    ) -> Result<HashMap<PayloadHash, IntentDependency>, StorageError> {
306        (**self).find_dependant_commits(payload_hashes)
307    }
308
309    fn increment_intent_publish_attempt_count(&self, intent_id: ID) -> Result<(), StorageError> {
310        (**self).increment_intent_publish_attempt_count(intent_id)
311    }
312
313    fn set_group_intent_error_and_fail_msg(
314        &self,
315        intent: &StoredGroupIntent,
316        msg_id: Option<Vec<u8>>,
317    ) -> Result<(), StorageError> {
318        (**self).set_group_intent_error_and_fail_msg(intent, msg_id)
319    }
320}
321
322impl<C: ConnectionExt> QueryGroupIntent for DbConnection<C> {
323    #[tracing::instrument(level = "trace", skip(self))]
324    fn insert_group_intent(
325        &self,
326        to_save: NewGroupIntent,
327    ) -> Result<StoredGroupIntent, crate::ConnectionError> {
328        self.raw_query_write(|conn| {
329            diesel::insert_into(dsl::group_intents)
330                .values(to_save)
331                .get_result(conn)
332        })
333    }
334
335    // Query for group_intents by group_id, optionally filtering by state and kind
336    #[tracing::instrument(level = "trace", skip(self), fields(group_id = hex::encode(group_id.as_ref())))]
337    fn find_group_intents<Id: AsRef<[u8]>>(
338        &self,
339        group_id: Id,
340        allowed_states: Option<Vec<IntentState>>,
341        allowed_kinds: Option<Vec<IntentKind>>,
342    ) -> Result<Vec<StoredGroupIntent>, crate::ConnectionError> {
343        let group_id = group_id.as_ref();
344        let mut query = dsl::group_intents
345            .into_boxed()
346            .filter(dsl::group_id.eq(group_id));
347
348        if let Some(allowed_states) = allowed_states {
349            query = query.filter(dsl::state.eq_any(allowed_states));
350        }
351
352        if let Some(allowed_kinds) = allowed_kinds {
353            query = query.filter(dsl::kind.eq_any(allowed_kinds));
354        }
355
356        query = query.order(dsl::id.asc());
357
358        self.raw_query_read(|conn| query.load::<StoredGroupIntent>(conn))
359    }
360
361    // Set the intent with the given ID to `Published` and set the payload hash. Optionally add
362    // `post_commit_data`
363    fn set_group_intent_published(
364        &self,
365        intent_id: ID,
366        payload_hash: &[u8],
367        post_commit_data: Option<Vec<u8>>,
368        staged_commit: Option<Vec<u8>>,
369        published_in_epoch: i64,
370    ) -> Result<(), StorageError> {
371        let rows_changed = self.raw_query_write(|conn| {
372            diesel::update(dsl::group_intents)
373                .filter(dsl::id.eq(intent_id))
374                // State machine requires that the only valid state transition to Published is from
375                // ToPublish
376                .filter(dsl::state.eq(IntentState::ToPublish))
377                .set((
378                    dsl::state.eq(IntentState::Published),
379                    dsl::payload_hash.eq(payload_hash),
380                    dsl::post_commit_data.eq(post_commit_data),
381                    dsl::staged_commit.eq(staged_commit),
382                    dsl::published_in_epoch.eq(published_in_epoch),
383                ))
384                .execute(conn)
385        })?;
386
387        if rows_changed == 0 {
388            let already_published = self.raw_query_read(|conn| {
389                dsl::group_intents
390                    .filter(dsl::id.eq(intent_id))
391                    .first::<StoredGroupIntent>(conn)
392            });
393
394            if already_published.is_ok() {
395                return Ok(());
396            } else {
397                return Err(NotFound::IntentForToPublish(intent_id).into());
398            }
399        }
400        Ok(())
401    }
402
403    // Set the intent with the given ID to `Committed`
404    fn set_group_intent_committed(
405        &self,
406        intent_id: ID,
407        cursor: Cursor,
408    ) -> Result<(), StorageError> {
409        let rows_changed: usize = self.raw_query_write(|conn| {
410            diesel::update(dsl::group_intents)
411                .filter(dsl::id.eq(intent_id))
412                // State machine requires that the only valid state transition to Committed is from
413                // Published
414                .filter(dsl::state.eq(IntentState::Published))
415                .set((
416                    dsl::state.eq(IntentState::Committed),
417                    dsl::sequence_id.eq(cursor.sequence_id as i64),
418                    dsl::originator_id.eq(cursor.originator_id as i64),
419                ))
420                .execute(conn)
421        })?;
422
423        // If nothing matched the query, return an error. Either ID or state was wrong
424        if rows_changed == 0 {
425            return Err(NotFound::IntentForCommitted(intent_id).into());
426        }
427
428        Ok(())
429    }
430
431    // Set the intent with the given ID to `Committed`
432    fn set_group_intent_processed(&self, intent_id: ID) -> Result<(), StorageError> {
433        let rows_changed = self.raw_query_write(|conn| {
434            diesel::update(dsl::group_intents)
435                .filter(dsl::id.eq(intent_id))
436                .set(dsl::state.eq(IntentState::Processed))
437                .execute(conn)
438        })?;
439
440        // If nothing matched the query, return an error. Either ID or state was wrong
441        if rows_changed == 0 {
442            return Err(NotFound::IntentById(intent_id).into());
443        }
444
445        Ok(())
446    }
447
448    // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and
449    // `post_commit_data`
450    fn set_group_intent_to_publish(&self, intent_id: ID) -> Result<(), StorageError> {
451        let rows_changed = self.raw_query_write(|conn| {
452            diesel::update(dsl::group_intents)
453                .filter(dsl::id.eq(intent_id))
454                // State machine requires that the only valid state transition to ToPublish is from
455                // Published
456                .filter(dsl::state.eq(IntentState::Published))
457                .set((
458                    dsl::state.eq(IntentState::ToPublish),
459                    // When moving to ToPublish, clear the payload hash and post commit data
460                    dsl::payload_hash.eq(None::<Vec<u8>>),
461                    dsl::post_commit_data.eq(None::<Vec<u8>>),
462                    dsl::published_in_epoch.eq(None::<i64>),
463                    dsl::staged_commit.eq(None::<Vec<u8>>),
464                ))
465                .execute(conn)
466        })?;
467
468        if rows_changed == 0 {
469            return Err(NotFound::IntentForPublish(intent_id).into());
470        }
471        Ok(())
472    }
473
474    /// Set the intent with the given ID to `Error`
475    #[tracing::instrument(level = "trace", skip(self))]
476    fn set_group_intent_error(&self, intent_id: ID) -> Result<(), StorageError> {
477        let rows_changed = self.raw_query_write(|conn| {
478            diesel::update(dsl::group_intents)
479                .filter(dsl::id.eq(intent_id))
480                .set(dsl::state.eq(IntentState::Error))
481                .execute(conn)
482        })?;
483
484        if rows_changed == 0 {
485            return Err(NotFound::IntentById(intent_id).into());
486        }
487
488        Ok(())
489    }
490
491    // Simple lookup of intents by payload hash, meant to be used when processing messages off the
492    // network
493    #[tracing::instrument(
494        level = "trace",
495        skip_all,
496        fields(payload_hash = hex::encode(payload_hash))
497    )]
498    fn find_group_intent_by_payload_hash(
499        &self,
500        payload_hash: &[u8],
501    ) -> Result<Option<StoredGroupIntent>, StorageError> {
502        let result = self.raw_query_read(|conn| {
503            dsl::group_intents
504                .filter(dsl::payload_hash.eq(payload_hash))
505                .first::<StoredGroupIntent>(conn)
506                .optional()
507        })?;
508
509        Ok(result)
510    }
511
512    /// Find the commit message refresh state for each intent by payload hash.
513    /// Returns a map from payload hash to a vector of dependencies (one per originator).
514    fn find_dependant_commits<P: AsRef<[u8]>>(
515        &self,
516        payload_hashes: &[P],
517    ) -> Result<HashMap<PayloadHash, IntentDependency>, StorageError> {
518        use super::schema::refresh_state;
519        use crate::encrypted_store::refresh_state::EntityKind;
520
521        let hashes = payload_hashes
522            .iter()
523            .map(|h| PayloadHashRef::from(h.as_ref()));
524
525        // Query all dependencies in a single database call
526        let map: HashMap<PayloadHash, Vec<IntentDependency>> = self.raw_query_read(|conn| {
527            dsl::group_intents
528                .filter(dsl::payload_hash.eq_any(hashes))
529                .inner_join(
530                    refresh_state::table.on(refresh_state::entity_id
531                        .eq(dsl::group_id)
532                        .and(refresh_state::entity_kind.eq(EntityKind::CommitMessage))),
533                )
534                .select((
535                    dsl::payload_hash.assume_not_null(),
536                    refresh_state::sequence_id,
537                    refresh_state::originator_id,
538                    dsl::group_id,
539                ))
540                .load_iter::<(Vec<u8>, i64, i32, Vec<u8>), DefaultLoadingMode>(conn)?
541                .map_ok(|(hash, sequence_id, originator_id, group_id)| {
542                    (
543                        PayloadHash::from(hash),
544                        IntentDependency {
545                            cursor: Cursor::new(sequence_id as u64, originator_id as u32),
546                            group_id: group_id.into(),
547                        },
548                    )
549                })
550                .process_results(|iter| iter.into_grouping_map().collect())
551        })?;
552
553        let map = map
554            .into_iter()
555            .map(|(hash, mut d)| {
556                if d.len() > 1 {
557                    return Err(GroupIntentError::MoreThanOneDependency {
558                        payload_hash: hash.clone(),
559                        cursors: d.iter().map(|d| d.cursor).collect(),
560                        group_id: d[0].group_id.clone(),
561                    }
562                    .into());
563                }
564
565                // this should be impossible since the sql query wouldnt return anything for
566                // an empty payload hash.
567                let dep = d
568                    .pop()
569                    .ok_or_else(|| GroupIntentError::NoDependencyFound { hash: hash.clone() })
570                    .map_err(StorageError::from)?;
571                Ok::<_, StorageError>((hash, dep))
572            })
573            .try_collect()?;
574
575        Ok(map)
576    }
577
578    fn increment_intent_publish_attempt_count(&self, intent_id: ID) -> Result<(), StorageError> {
579        self.raw_query_write(|conn| {
580            diesel::update(dsl::group_intents)
581                .filter(dsl::id.eq(intent_id))
582                .set(dsl::publish_attempts.eq(dsl::publish_attempts + 1))
583                .execute(conn)
584        })?;
585
586        Ok(())
587    }
588
589    fn set_group_intent_error_and_fail_msg(
590        &self,
591        intent: &StoredGroupIntent,
592        msg_id: Option<Vec<u8>>,
593    ) -> Result<(), StorageError> {
594        self.set_group_intent_error(intent.id)?;
595        if let Some(id) = msg_id {
596            self.set_delivery_status_to_failed(&id)?;
597        }
598        Ok(())
599    }
600}
601
602impl ToSql<Integer, Sqlite> for IntentKind
603where
604    i32: ToSql<Integer, Sqlite>,
605{
606    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result {
607        out.set_value(*self as i32);
608        Ok(IsNull::No)
609    }
610}
611
612impl FromSql<Integer, Sqlite> for IntentKind
613where
614    i32: FromSql<Integer, Sqlite>,
615{
616    fn from_sql(bytes: <Sqlite as Backend>::RawValue<'_>) -> deserialize::Result<Self> {
617        match i32::from_sql(bytes)? {
618            1 => Ok(IntentKind::SendMessage),
619            2 => Ok(IntentKind::KeyUpdate),
620            3 => Ok(IntentKind::MetadataUpdate),
621            4 => Ok(IntentKind::UpdateGroupMembership),
622            5 => Ok(IntentKind::UpdateAdminList),
623            6 => Ok(IntentKind::UpdatePermission),
624            7 => Ok(IntentKind::ReaddInstallations),
625            x => Err(format!("Unrecognized variant {}", x).into()),
626        }
627    }
628}
629
630impl ToSql<Integer, Sqlite> for IntentState
631where
632    i32: ToSql<Integer, Sqlite>,
633{
634    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result {
635        out.set_value(*self as i32);
636        Ok(IsNull::No)
637    }
638}
639
640impl FromSql<Integer, Sqlite> for IntentState
641where
642    i32: FromSql<Integer, Sqlite>,
643{
644    fn from_sql(bytes: <Sqlite as Backend>::RawValue<'_>) -> deserialize::Result<Self> {
645        match i32::from_sql(bytes)? {
646            1 => Ok(IntentState::ToPublish),
647            2 => Ok(IntentState::Published),
648            3 => Ok(IntentState::Committed),
649            4 => Ok(IntentState::Error),
650            5 => Ok(IntentState::Processed),
651            x => Err(format!("Unrecognized variant {}", x).into()),
652        }
653    }
654}
655
656#[cfg(test)]
657pub(crate) mod tests {
658    #[cfg(target_arch = "wasm32")]
659    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
660
661    use super::*;
662    use crate::{
663        Fetch, Store,
664        group::{GroupMembershipState, StoredGroup},
665        test_utils::with_connection,
666    };
667    use xmtp_common::rand_vec;
668
669    fn insert_group<C: ConnectionExt>(conn: &DbConnection<C>, group_id: Vec<u8>) {
670        StoredGroup::builder()
671            .id(group_id)
672            .created_at_ns(100)
673            .membership_state(GroupMembershipState::Allowed)
674            .added_by_inbox_id("placeholder_address")
675            .build()
676            .unwrap()
677            .store(conn)
678            .unwrap();
679    }
680
681    impl NewGroupIntent {
682        // Real group intents must always start as ToPublish. But for tests we allow forcing the
683        // state
684        pub fn new_test(
685            kind: IntentKind,
686            group_id: Vec<u8>,
687            data: Vec<u8>,
688            state: IntentState,
689        ) -> Self {
690            Self {
691                kind,
692                group_id,
693                data,
694                state,
695                should_push: false,
696            }
697        }
698    }
699
700    fn find_first_intent<C: ConnectionExt>(
701        conn: &DbConnection<C>,
702        group_id: group::ID,
703    ) -> StoredGroupIntent {
704        conn.raw_query_read(|raw_conn| {
705            dsl::group_intents
706                .filter(dsl::group_id.eq(group_id))
707                .first(raw_conn)
708        })
709        .unwrap()
710    }
711
712    #[xmtp_common::test]
713    fn test_store_and_fetch() {
714        let group_id = rand_vec::<24>();
715        let data = rand_vec::<24>();
716        let kind = IntentKind::UpdateGroupMembership;
717        let state = IntentState::ToPublish;
718
719        let to_insert = NewGroupIntent::new_test(kind, group_id.clone(), data.clone(), state);
720
721        with_connection(|conn| {
722            // Group needs to exist or FK constraint will fail
723            insert_group(conn, group_id.clone());
724
725            to_insert.store(conn).unwrap();
726
727            let results = conn
728                .find_group_intents(group_id.clone(), Some(vec![IntentState::ToPublish]), None)
729                .unwrap();
730
731            assert_eq!(results.len(), 1);
732            assert_eq!(results[0].kind, kind);
733            assert_eq!(results[0].data, data);
734            assert_eq!(results[0].group_id, group_id);
735
736            let id = results[0].id;
737
738            let fetched: StoredGroupIntent = conn.fetch(&id).unwrap().unwrap();
739
740            assert_eq!(fetched.id, id);
741        })
742    }
743
744    #[xmtp_common::test]
745    fn test_query() {
746        let group_id = rand_vec::<24>();
747
748        let test_intents: Vec<NewGroupIntent> = vec![
749            NewGroupIntent::new_test(
750                IntentKind::UpdateGroupMembership,
751                group_id.clone(),
752                rand_vec::<24>(),
753                IntentState::ToPublish,
754            ),
755            NewGroupIntent::new_test(
756                IntentKind::KeyUpdate,
757                group_id.clone(),
758                rand_vec::<24>(),
759                IntentState::Published,
760            ),
761            NewGroupIntent::new_test(
762                IntentKind::KeyUpdate,
763                group_id.clone(),
764                rand_vec::<24>(),
765                IntentState::Committed,
766            ),
767        ];
768
769        with_connection(|conn| {
770            // Group needs to exist or FK constraint will fail
771            insert_group(conn, group_id.clone());
772
773            for case in test_intents {
774                case.store(conn).unwrap();
775            }
776
777            // Can query for multiple states
778            let mut results = conn
779                .find_group_intents(
780                    group_id.clone(),
781                    Some(vec![IntentState::ToPublish, IntentState::Published]),
782                    None,
783                )
784                .unwrap();
785
786            assert_eq!(results.len(), 2);
787
788            // Can query by kind
789            results = conn
790                .find_group_intents(group_id.clone(), None, Some(vec![IntentKind::KeyUpdate]))
791                .unwrap();
792            assert_eq!(results.len(), 2);
793
794            // Can query by kind and state
795            results = conn
796                .find_group_intents(
797                    group_id.clone(),
798                    Some(vec![IntentState::Committed]),
799                    Some(vec![IntentKind::KeyUpdate]),
800                )
801                .unwrap();
802
803            assert_eq!(results.len(), 1);
804
805            // Can get no results
806            results = conn
807                .find_group_intents(
808                    group_id.clone(),
809                    Some(vec![IntentState::Committed]),
810                    Some(vec![IntentKind::SendMessage]),
811                )
812                .unwrap();
813
814            assert_eq!(results.len(), 0);
815
816            // Can get all intents
817            results = conn.find_group_intents(group_id, None, None).unwrap();
818            assert_eq!(results.len(), 3);
819        })
820    }
821
822    #[xmtp_common::test]
823    fn find_by_payload_hash() {
824        let group_id = rand_vec::<24>();
825
826        with_connection(|conn| {
827            insert_group(conn, group_id.clone());
828
829            // Store the intent
830            NewGroupIntent::new(
831                IntentKind::UpdateGroupMembership,
832                group_id.clone(),
833                rand_vec::<24>(),
834                false,
835            )
836            .store(conn)
837            .unwrap();
838
839            // Find the intent with the ID populated
840            let intent = find_first_intent(conn, group_id.clone());
841
842            // Set the payload hash
843            let payload_hash = rand_vec::<24>();
844            let post_commit_data = rand_vec::<24>();
845            conn.set_group_intent_published(
846                intent.id,
847                &payload_hash,
848                Some(post_commit_data.clone()),
849                None,
850                1,
851            )
852            .unwrap();
853
854            let find_result = conn
855                .find_group_intent_by_payload_hash(&payload_hash)
856                .unwrap()
857                .unwrap();
858
859            assert_eq!(find_result.id, intent.id);
860            assert_eq!(find_result.published_in_epoch, Some(1));
861        })
862    }
863
864    #[xmtp_common::test]
865    fn test_happy_path_state_transitions() {
866        let group_id = rand_vec::<24>();
867
868        with_connection(|conn| {
869            insert_group(conn, group_id.clone());
870
871            // Store the intent
872            NewGroupIntent::new(
873                IntentKind::UpdateGroupMembership,
874                group_id.clone(),
875                rand_vec::<24>(),
876                false,
877            )
878            .store(conn)
879            .unwrap();
880
881            let mut intent = find_first_intent(conn, group_id.clone());
882
883            // Set to published
884            let payload_hash = rand_vec::<24>();
885            let post_commit_data = rand_vec::<24>();
886            conn.set_group_intent_published(
887                intent.id,
888                &payload_hash,
889                Some(post_commit_data.clone()),
890                None,
891                1,
892            )
893            .unwrap();
894
895            intent = conn.fetch(&intent.id).unwrap().unwrap();
896            assert_eq!(intent.state, IntentState::Published);
897            assert_eq!(intent.payload_hash, Some(payload_hash.clone()));
898            assert_eq!(intent.post_commit_data, Some(post_commit_data.clone()));
899
900            conn.set_group_intent_committed(intent.id, Cursor::default())
901                .unwrap();
902            // Refresh from the DB
903            intent = conn.fetch(&intent.id).unwrap().unwrap();
904            assert_eq!(intent.state, IntentState::Committed);
905            // Make sure we haven't lost the payload hash
906            assert_eq!(intent.payload_hash, Some(payload_hash.clone()));
907        })
908    }
909
910    #[xmtp_common::test]
911    fn test_republish_state_transition() {
912        let group_id = rand_vec::<24>();
913
914        with_connection(|conn| {
915            insert_group(conn, group_id.clone());
916
917            // Store the intent
918            NewGroupIntent::new(
919                IntentKind::UpdateGroupMembership,
920                group_id.clone(),
921                rand_vec::<24>(),
922                false,
923            )
924            .store(conn)
925            .unwrap();
926
927            let mut intent = find_first_intent(conn, group_id.clone());
928
929            // Set to published
930            let payload_hash = rand_vec::<24>();
931            let post_commit_data = rand_vec::<24>();
932            conn.set_group_intent_published(
933                intent.id,
934                &payload_hash,
935                Some(post_commit_data.clone()),
936                None,
937                1,
938            )
939            .unwrap();
940
941            intent = conn.fetch(&intent.id).unwrap().unwrap();
942            assert_eq!(intent.state, IntentState::Published);
943            assert_eq!(intent.payload_hash, Some(payload_hash.clone()));
944
945            // Now revert back to ToPublish
946            conn.set_group_intent_to_publish(intent.id).unwrap();
947            intent = conn.fetch(&intent.id).unwrap().unwrap();
948            assert_eq!(intent.state, IntentState::ToPublish);
949            assert!(intent.payload_hash.is_none());
950            assert!(intent.post_commit_data.is_none());
951        })
952    }
953
954    #[xmtp_common::test]
955    fn test_invalid_state_transition() {
956        let group_id = rand_vec::<24>();
957
958        with_connection(|conn| {
959            insert_group(conn, group_id.clone());
960
961            // Store the intent
962            NewGroupIntent::new(
963                IntentKind::UpdateGroupMembership,
964                group_id.clone(),
965                rand_vec::<24>(),
966                false,
967            )
968            .store(conn)
969            .unwrap();
970
971            let intent = find_first_intent(conn, group_id.clone());
972
973            let commit_result = conn.set_group_intent_committed(intent.id, Cursor::default());
974            assert!(commit_result.is_err());
975            assert!(matches!(
976                commit_result.err().unwrap(),
977                StorageError::NotFound(_)
978            ));
979
980            let to_publish_result = conn.set_group_intent_to_publish(intent.id);
981            assert!(to_publish_result.is_err());
982            assert!(matches!(
983                to_publish_result.err().unwrap(),
984                StorageError::NotFound(_)
985            ));
986        })
987    }
988
989    #[xmtp_common::test]
990    fn test_increment_publish_attempts() {
991        let group_id = rand_vec::<24>();
992        with_connection(|conn| {
993            insert_group(conn, group_id.clone());
994            NewGroupIntent::new(
995                IntentKind::UpdateGroupMembership,
996                group_id.clone(),
997                rand_vec::<24>(),
998                false,
999            )
1000            .store(conn)
1001            .unwrap();
1002
1003            let mut intent = find_first_intent(conn, group_id.clone());
1004            assert_eq!(intent.publish_attempts, 0);
1005            conn.increment_intent_publish_attempt_count(intent.id)
1006                .unwrap();
1007            intent = find_first_intent(conn, group_id.clone());
1008            assert_eq!(intent.publish_attempts, 1);
1009            conn.increment_intent_publish_attempt_count(intent.id)
1010                .unwrap();
1011            intent = find_first_intent(conn, group_id.clone());
1012            assert_eq!(intent.publish_attempts, 2);
1013        })
1014    }
1015    #[xmtp_common::test]
1016    fn test_find_dependant_commits() {
1017        use crate::encrypted_store::refresh_state::{EntityKind, QueryRefreshState};
1018
1019        let group_id = rand_vec::<24>();
1020        let payload_hash1 = rand_vec::<24>();
1021        let payload_hash2 = rand_vec::<24>();
1022
1023        with_connection(|conn| {
1024            insert_group(conn, group_id.clone());
1025            NewGroupIntent::new(
1026                IntentKind::SendMessage,
1027                group_id.clone(),
1028                rand_vec::<24>(),
1029                false,
1030            )
1031            .store(conn)
1032            .unwrap();
1033
1034            let intent1 = find_first_intent(conn, group_id.clone());
1035            conn.set_group_intent_published(intent1.id, &payload_hash1, None, None, 1)
1036                .unwrap();
1037
1038            NewGroupIntent::new(
1039                IntentKind::KeyUpdate,
1040                group_id.clone(),
1041                rand_vec::<24>(),
1042                false,
1043            )
1044            .store(conn)
1045            .unwrap();
1046            let intents = conn
1047                .find_group_intents(group_id.clone(), None, None)
1048                .unwrap();
1049            let intent2 = intents.iter().find(|i| i.id != intent1.id).unwrap();
1050            conn.set_group_intent_published(intent2.id, &payload_hash2, None, None, 1)
1051                .unwrap();
1052
1053            conn.update_cursor(
1054                group_id.clone(),
1055                EntityKind::CommitMessage,
1056                Cursor::new(100, 42u32),
1057            )
1058            .unwrap();
1059
1060            let result = conn
1061                .find_dependant_commits(&[&payload_hash1, &payload_hash2])
1062                .unwrap();
1063
1064            assert_eq!(result.len(), 2);
1065            let dep1 = result
1066                .get(&PayloadHash::from(payload_hash1.clone()))
1067                .unwrap();
1068            assert_eq!(dep1.cursor.sequence_id, 100);
1069            assert_eq!(dep1.cursor.originator_id, 42);
1070            assert_eq!(dep1.group_id.as_ref(), &group_id);
1071
1072            let dep2 = result
1073                .get(&PayloadHash::from(payload_hash2.clone()))
1074                .unwrap();
1075            assert_eq!(dep2.cursor.sequence_id, 100);
1076            assert_eq!(dep2.cursor.originator_id, 42);
1077            assert_eq!(dep2.group_id.as_ref(), &group_id);
1078        })
1079    }
1080}