xmtp_db/encrypted_store/
consent_record.rs

1use super::{ConnectionExt, Sqlite, group::StoredGroup};
2use super::{
3    db_connection::DbConnection,
4    schema::{
5        consent_records::{self, dsl},
6        groups::dsl as groups_dsl,
7    },
8};
9use crate::{DbQuery, StorageError, impl_store};
10use diesel::{
11    backend::Backend,
12    deserialize::{self, FromSql, FromSqlRow},
13    expression::AsExpression,
14    prelude::*,
15    serialize::{self, IsNull, Output, ToSql},
16    sql_types::Integer,
17    upsert::excluded,
18};
19use serde::{Deserialize, Serialize};
20use xmtp_common::time::now_ns;
21use xmtp_proto::{
22    ConversionError,
23    xmtp::device_sync::consent_backup::{ConsentSave, ConsentStateSave, ConsentTypeSave},
24};
25mod convert;
26
27/// StoredConsentRecord holds a serialized ConsentRecord
28#[derive(Insertable, Queryable, Debug, Clone, Eq, Deserialize, Serialize)]
29#[diesel(table_name = consent_records)]
30#[diesel(primary_key(entity_type, entity))]
31pub struct StoredConsentRecord {
32    /// Enum, [`ConsentType`] representing the type of consent (conversation_id inbox_id, etc..)
33    pub entity_type: ConsentType,
34    /// Enum, [`ConsentState`] representing the state of consent (allowed, denied, etc..)
35    pub state: ConsentState,
36    /// The entity of what was consented (0x00 etc..)
37    pub entity: String,
38
39    pub consented_at_ns: i64,
40}
41
42impl PartialEq for StoredConsentRecord {
43    fn eq(&self, other: &Self) -> bool {
44        self.entity == other.entity
45            && self.entity_type == other.entity_type
46            && self.state == other.state
47    }
48}
49
50impl StoredConsentRecord {
51    pub fn new(entity_type: ConsentType, state: ConsentState, entity: String) -> Self {
52        Self {
53            entity_type,
54            state,
55            entity,
56            consented_at_ns: now_ns(),
57        }
58    }
59
60    /// This function will perform some logic to see if a new group should be auto-consented
61    /// or auto-denied based on past consent.
62    pub fn stitch_dm_consent(conn: &impl DbQuery, group: &StoredGroup) -> Result<(), StorageError> {
63        if let Some(dm_id) = &group.dm_id {
64            let mut past_consent = conn.find_consent_by_dm_id(dm_id)?;
65            let Some(last_consent) = past_consent.pop() else {
66                return Ok(());
67            };
68
69            let cr = Self::new(
70                ConsentType::ConversationId,
71                last_consent.state,
72                hex::encode(&group.id),
73            );
74            conn.insert_newer_consent_record(cr)?;
75        }
76
77        Ok(())
78    }
79}
80
81impl_store!(StoredConsentRecord, consent_records);
82
83pub trait QueryConsentRecord {
84    /// Returns the consent_records for the given entity up
85    fn get_consent_record(
86        &self,
87        entity: String,
88        entity_type: ConsentType,
89    ) -> Result<Option<StoredConsentRecord>, crate::ConnectionError>;
90
91    fn consent_records(&self) -> Result<Vec<StoredConsentRecord>, crate::ConnectionError>;
92
93    fn consent_records_paged(
94        &self,
95        limit: i64,
96        offset: i64,
97    ) -> Result<Vec<StoredConsentRecord>, crate::ConnectionError>;
98
99    /// Returns true if newer
100    fn insert_newer_consent_record(
101        &self,
102        record: StoredConsentRecord,
103    ) -> Result<bool, crate::ConnectionError>;
104
105    /// Insert consent_records, and replace existing entries, returns records that are new or changed
106    fn insert_or_replace_consent_records(
107        &self,
108        records: &[StoredConsentRecord],
109    ) -> Result<Vec<StoredConsentRecord>, crate::ConnectionError>;
110
111    fn maybe_insert_consent_record_return_existing(
112        &self,
113        record: &StoredConsentRecord,
114    ) -> Result<Option<StoredConsentRecord>, crate::ConnectionError>;
115
116    fn find_consent_by_dm_id(
117        &self,
118        dm_id: &str,
119    ) -> Result<Vec<StoredConsentRecord>, crate::ConnectionError>;
120}
121
122impl<C: ConnectionExt> QueryConsentRecord for DbConnection<C> {
123    /// Returns the consent_records for the given entity up
124    fn get_consent_record(
125        &self,
126        entity: String,
127        entity_type: ConsentType,
128    ) -> Result<Option<StoredConsentRecord>, crate::ConnectionError> {
129        self.raw_query_read(|conn| {
130            dsl::consent_records
131                .filter(dsl::entity.eq(entity))
132                .filter(dsl::entity_type.eq(entity_type))
133                .first(conn)
134                .optional()
135        })
136    }
137
138    fn consent_records(&self) -> Result<Vec<StoredConsentRecord>, crate::ConnectionError> {
139        self.raw_query_read(|conn| super::schema::consent_records::table.load(conn))
140    }
141
142    fn consent_records_paged(
143        &self,
144        limit: i64,
145        offset: i64,
146    ) -> Result<Vec<StoredConsentRecord>, crate::ConnectionError> {
147        let query = consent_records::table
148            .order_by((consent_records::entity_type, consent_records::entity))
149            .limit(limit)
150            .offset(offset);
151
152        self.raw_query_read(|conn| query.load::<StoredConsentRecord>(conn))
153    }
154
155    // returns true if newer
156    fn insert_newer_consent_record(
157        &self,
158        record: StoredConsentRecord,
159    ) -> Result<bool, crate::ConnectionError> {
160        self.raw_query_write(|conn| {
161            let maybe_inserted_consent_record: Option<StoredConsentRecord> =
162                diesel::insert_into(dsl::consent_records)
163                    .values(&record)
164                    .on_conflict_do_nothing()
165                    .get_result(conn)
166                    .optional()?;
167
168            // if record was not inserted...
169            if maybe_inserted_consent_record.is_none() {
170                let old_record = dsl::consent_records
171                    .find((&record.entity_type, &record.entity))
172                    .first::<StoredConsentRecord>(conn)?;
173
174                if old_record.eq(&record) {
175                    return Ok(false);
176                }
177
178                let should_replace = old_record.consented_at_ns < record.consented_at_ns;
179                if should_replace {
180                    diesel::insert_into(dsl::consent_records)
181                        .values(record)
182                        .on_conflict((dsl::entity_type, dsl::entity))
183                        .do_update()
184                        .set(dsl::state.eq(excluded(dsl::state)))
185                        .execute(conn)?;
186                }
187                return Ok(should_replace);
188            }
189
190            Ok(true)
191        })
192    }
193
194    /// Insert consent_records, and replace existing entries, returns records that are new or changed
195    fn insert_or_replace_consent_records(
196        &self,
197        records: &[StoredConsentRecord],
198    ) -> Result<Vec<StoredConsentRecord>, crate::ConnectionError> {
199        let mut query = consent_records::table
200            .into_boxed()
201            .filter(false.into_sql::<diesel::sql_types::Bool>());
202        let primary_keys: Vec<_> = records
203            .iter()
204            .map(|r| (&r.entity, &r.entity_type))
205            .collect();
206        for (entity, entity_type) in primary_keys {
207            query = query.or_filter(
208                consent_records::entity_type
209                    .eq(entity_type)
210                    .and(consent_records::entity.eq(entity)),
211            );
212        }
213
214        let changed = self.raw_query_write(|conn| {
215            let existing: Vec<StoredConsentRecord> = query.load(conn)?;
216            let changed: Vec<_> = records
217                .iter()
218                .filter(|r| !existing.contains(r))
219                .cloned()
220                .collect();
221
222            conn.transaction::<_, diesel::result::Error, _>(|conn| {
223                for record in records.iter() {
224                    diesel::insert_into(dsl::consent_records)
225                        .values(record)
226                        .on_conflict((dsl::entity_type, dsl::entity))
227                        .do_update()
228                        .set(dsl::state.eq(excluded(dsl::state)))
229                        .execute(conn)?;
230                }
231                Ok(())
232            })?;
233
234            Ok(changed)
235        })?;
236
237        Ok(changed)
238    }
239
240    fn maybe_insert_consent_record_return_existing(
241        &self,
242        record: &StoredConsentRecord,
243    ) -> Result<Option<StoredConsentRecord>, crate::ConnectionError> {
244        self.raw_query_write(|conn| {
245            let maybe_inserted_consent_record: Option<StoredConsentRecord> =
246                diesel::insert_into(dsl::consent_records)
247                    .values(record)
248                    .on_conflict_do_nothing()
249                    .get_result(conn)
250                    .optional()?;
251
252            // if record was not inserted...
253            if maybe_inserted_consent_record.is_none() {
254                return dsl::consent_records
255                    .find((&record.entity_type, &record.entity))
256                    .first(conn)
257                    .optional();
258            }
259
260            Ok(None)
261        })
262    }
263
264    fn find_consent_by_dm_id(
265        &self,
266        dm_id: &str,
267    ) -> Result<Vec<StoredConsentRecord>, crate::ConnectionError> {
268        self.raw_query_read(|conn| {
269            // First, get all group IDs for this dm_id
270            let group_ids: Vec<Vec<u8>> = groups_dsl::groups
271                .filter(groups_dsl::dm_id.eq(dm_id))
272                .select(groups_dsl::id)
273                .load::<Vec<u8>>(conn)?;
274
275            // Convert to hex strings
276            let group_id_hexes: Vec<String> = group_ids.iter().map(hex::encode).collect();
277
278            // Query consent records
279            dsl::consent_records
280                .filter(dsl::entity.eq_any(group_id_hexes))
281                .filter(dsl::entity_type.eq(ConsentType::ConversationId))
282                .order(dsl::consented_at_ns.desc())
283                .load::<StoredConsentRecord>(conn)
284        })
285    }
286}
287
288impl<T: QueryConsentRecord + ?Sized> QueryConsentRecord for &T {
289    fn get_consent_record(
290        &self,
291        entity: String,
292        entity_type: ConsentType,
293    ) -> Result<Option<StoredConsentRecord>, crate::ConnectionError> {
294        (**self).get_consent_record(entity, entity_type)
295    }
296
297    fn consent_records(&self) -> Result<Vec<StoredConsentRecord>, crate::ConnectionError> {
298        (**self).consent_records()
299    }
300
301    fn consent_records_paged(
302        &self,
303        limit: i64,
304        offset: i64,
305    ) -> Result<Vec<StoredConsentRecord>, crate::ConnectionError> {
306        (**self).consent_records_paged(limit, offset)
307    }
308
309    fn insert_newer_consent_record(
310        &self,
311        record: StoredConsentRecord,
312    ) -> Result<bool, crate::ConnectionError> {
313        (**self).insert_newer_consent_record(record)
314    }
315
316    fn insert_or_replace_consent_records(
317        &self,
318        records: &[StoredConsentRecord],
319    ) -> Result<Vec<StoredConsentRecord>, crate::ConnectionError> {
320        (**self).insert_or_replace_consent_records(records)
321    }
322
323    fn maybe_insert_consent_record_return_existing(
324        &self,
325        record: &StoredConsentRecord,
326    ) -> Result<Option<StoredConsentRecord>, crate::ConnectionError> {
327        (**self).maybe_insert_consent_record_return_existing(record)
328    }
329
330    fn find_consent_by_dm_id(
331        &self,
332        dm_id: &str,
333    ) -> Result<Vec<StoredConsentRecord>, crate::ConnectionError> {
334        (**self).find_consent_by_dm_id(dm_id)
335    }
336}
337
338#[repr(i32)]
339#[derive(Debug, Copy, Clone, Serialize, Deserialize, Eq, PartialEq, AsExpression, FromSqlRow)]
340#[diesel(sql_type = Integer)]
341/// Type of consent record stored
342pub enum ConsentType {
343    /// Consent is for a conversation
344    ConversationId = 1,
345    /// Consent is for an inbox
346    InboxId = 2,
347}
348
349impl ToSql<Integer, Sqlite> for ConsentType
350where
351    i32: ToSql<Integer, Sqlite>,
352{
353    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result {
354        out.set_value(*self as i32);
355        Ok(IsNull::No)
356    }
357}
358
359impl FromSql<Integer, Sqlite> for ConsentType
360where
361    i32: FromSql<Integer, Sqlite>,
362{
363    fn from_sql(bytes: <Sqlite as Backend>::RawValue<'_>) -> deserialize::Result<Self> {
364        match i32::from_sql(bytes)? {
365            1 => Ok(ConsentType::ConversationId),
366            2 => Ok(ConsentType::InboxId),
367            x => Err(format!("Unrecognized variant {}", x).into()),
368        }
369    }
370}
371
372#[repr(i32)]
373#[derive(Debug, Copy, Clone, Serialize, Deserialize, Eq, PartialEq, AsExpression, FromSqlRow)]
374#[diesel(sql_type = Integer)]
375/// The state of the consent
376pub enum ConsentState {
377    /// Consent is unknown
378    Unknown = 0,
379    /// Consent is allowed
380    Allowed = 1,
381    /// Consent is denied
382    Denied = 2,
383}
384
385impl ToSql<Integer, Sqlite> for ConsentState
386where
387    i32: ToSql<Integer, Sqlite>,
388{
389    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result {
390        out.set_value(*self as i32);
391        Ok(IsNull::No)
392    }
393}
394
395impl FromSql<Integer, Sqlite> for ConsentState
396where
397    i32: FromSql<Integer, Sqlite>,
398{
399    fn from_sql(bytes: <Sqlite as Backend>::RawValue<'_>) -> deserialize::Result<Self> {
400        match i32::from_sql(bytes)? {
401            0 => Ok(ConsentState::Unknown),
402            1 => Ok(ConsentState::Allowed),
403            2 => Ok(ConsentState::Denied),
404            x => Err(format!("Unrecognized variant {}", x).into()),
405        }
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use crate::{Store, group::tests::generate_group, test_utils::with_connection};
412    #[cfg(target_arch = "wasm32")]
413    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
414
415    use super::*;
416
417    fn generate_consent_record(
418        entity_type: ConsentType,
419        state: ConsentState,
420        entity: String,
421    ) -> StoredConsentRecord {
422        StoredConsentRecord {
423            entity_type,
424            state,
425            entity,
426            consented_at_ns: now_ns(),
427        }
428    }
429
430    #[xmtp_common::test(unwrap_try = true)]
431    fn find_consent_by_dm_id() {
432        with_connection(|conn| {
433            let mut g = generate_group(None);
434            g.dm_id = Some("dm:alpha:beta".to_string());
435            g.store(conn)?;
436
437            let cr = generate_consent_record(
438                ConsentType::ConversationId,
439                ConsentState::Allowed,
440                hex::encode(g.id),
441            );
442            cr.store(conn)?;
443
444            let mut records = conn.find_consent_by_dm_id("dm:alpha:beta")?;
445
446            assert_eq!(records.len(), 1);
447            assert_eq!(records.pop()?, cr);
448        })
449    }
450
451    #[xmtp_common::test]
452    fn insert_and_read() {
453        with_connection(|conn| {
454            let inbox_id = "inbox_1";
455            let consent_record = generate_consent_record(
456                ConsentType::InboxId,
457                ConsentState::Allowed,
458                inbox_id.to_string(),
459            );
460            let consent_record_entity = consent_record.entity.clone();
461
462            // Insert the record
463            let result = conn
464                .insert_or_replace_consent_records(std::slice::from_ref(&consent_record))
465                .expect("should store without error");
466            // One record was inserted
467            assert_eq!(result.len(), 1);
468
469            // Insert it again
470            let result = conn
471                .insert_or_replace_consent_records(std::slice::from_ref(&consent_record))
472                .expect("should store without error");
473            // Nothing should change
474            assert_eq!(result.len(), 0);
475
476            // Insert it again, this time with a Denied state
477            let result = conn
478                .insert_or_replace_consent_records(&[StoredConsentRecord {
479                    state: ConsentState::Denied,
480                    ..consent_record
481                }])
482                .expect("should store without error");
483            // Should change
484            assert_eq!(result.len(), 1);
485
486            let consent_record = conn
487                .get_consent_record(inbox_id.to_owned(), ConsentType::InboxId)
488                .expect("query should work");
489
490            assert_eq!(consent_record.unwrap().entity, consent_record_entity);
491
492            let conflict = generate_consent_record(
493                ConsentType::InboxId,
494                ConsentState::Allowed,
495                inbox_id.to_string(),
496            );
497
498            let existing = conn
499                .maybe_insert_consent_record_return_existing(&conflict)
500                .unwrap();
501            assert!(existing.is_some());
502            let existing = existing.unwrap();
503            // we want the old record to be returned.
504            assert_eq!(existing.state, ConsentState::Denied);
505
506            let db_cr = conn
507                .get_consent_record(existing.entity, existing.entity_type)
508                .unwrap()
509                .unwrap();
510            // ensure the db matches the state of what was returned
511            assert_eq!(db_cr.state, existing.state);
512        })
513    }
514}