xmtp_db/
sql_key_store.rs

1use xmtp_common::{RetryableError, retryable};
2
3use self::transactions::MutableTransactionConnection;
4use crate::{ConnectionExt, TransactionalKeyStore, XmtpMlsStorageProvider};
5
6use bincode;
7use diesel::{
8    prelude::*,
9    sql_types::Binary,
10    {RunQueryDsl, sql_query},
11};
12use openmls_traits::storage::*;
13use serde::Serialize;
14
15#[cfg(any(feature = "test-utils", test))]
16pub mod mock;
17mod transactions;
18
19const SELECT_QUERY: &str =
20    "SELECT value_bytes FROM openmls_key_value WHERE key_bytes = ? AND version = ?";
21const REPLACE_QUERY: &str =
22    "REPLACE INTO openmls_key_value (key_bytes, version, value_bytes) VALUES (?, ?, ?)";
23const UPDATE_QUERY: &str =
24    "UPDATE openmls_key_value SET value_bytes = ? WHERE key_bytes = ? AND version = ?";
25const DELETE_QUERY: &str = "DELETE FROM openmls_key_value WHERE key_bytes = ? AND version = ?";
26
27#[cfg(feature = "test-utils")]
28#[derive(
29    Selectable, Queryable, QueryableByName, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash,
30)]
31#[diesel(table_name = crate::schema::openmls_key_value)]
32pub struct OpenMlsKeyValue {
33    pub version: i32,
34    pub key_bytes: Vec<u8>,
35    pub value_bytes: Vec<u8>,
36}
37
38#[cfg(feature = "test-utils")]
39impl OpenMlsKeyValue {
40    pub fn hash_all(conn: &mut SqliteConnection) -> Result<Vec<u8>, diesel::result::Error> {
41        use crate::schema::openmls_key_value;
42        use xmtp_common::Sha2Digest;
43        let values = openmls_key_value::table
44            .order(openmls_key_value::version.asc())
45            .order(openmls_key_value::key_bytes.asc())
46            .load_iter::<OpenMlsKeyValue, _>(conn)?;
47
48        let mut hasher = xmtp_common::Sha256Digest::new();
49        for (i, result) in values.enumerate() {
50            let value = result?;
51            hasher.update(b"version");
52            hasher.update(value.version.to_be_bytes());
53            hasher.update(b"key_bytes");
54            hasher.update(&value.key_bytes);
55            hasher.update(b"value_bytes");
56            hasher.update(&value.value_bytes);
57            hasher.update(b"index");
58            hasher.update(i.to_be_bytes());
59            hasher.update(b"\n");
60        }
61        Ok(hasher.finalize().to_vec())
62    }
63}
64
65#[derive(QueryableByName, Debug, Clone, PartialEq, Eq)]
66#[diesel(table_name = openmls_key_value)]
67struct StorageData {
68    #[diesel(sql_type = Binary)]
69    value_bytes: Vec<u8>,
70}
71
72impl TransactionalKeyStore for diesel::SqliteConnection {
73    type Store<'a>
74        = SqlKeyStore<MutableTransactionConnection<'a>>
75    where
76        Self: 'a;
77
78    fn key_store<'a>(&'a mut self) -> Self::Store<'a> {
79        SqlKeyStore::new_transactional(self)
80    }
81}
82
83#[derive(Clone)]
84pub struct SqlKeyStore<T> {
85    // Directly wrap the DbConnection which is a SqliteConnection in this case
86    conn: T,
87}
88
89impl<A> SqlKeyStore<A> {
90    pub fn new(conn: A) -> Self {
91        Self { conn }
92    }
93}
94
95impl<'a> SqlKeyStore<SqliteConnection> {
96    pub fn new_transactional(
97        conn: &'a mut SqliteConnection,
98    ) -> SqlKeyStore<MutableTransactionConnection<'a>> {
99        SqlKeyStore {
100            conn: MutableTransactionConnection::new(conn),
101        }
102    }
103}
104
105// refactor to use diesel directly
106impl<C> SqlKeyStore<C>
107where
108    C: ConnectionExt,
109{
110    fn select_query<const VERSION: u16>(
111        &self,
112        storage_key: &Vec<u8>,
113    ) -> Result<Vec<StorageData>, crate::ConnectionError> {
114        self.conn.raw_query_read(|conn| {
115            sql_query(SELECT_QUERY)
116                .bind::<diesel::sql_types::Binary, _>(&storage_key)
117                .bind::<diesel::sql_types::Integer, _>(VERSION as i32)
118                .load(conn)
119        })
120    }
121
122    fn replace_query<const VERSION: u16>(
123        &self,
124        storage_key: &Vec<u8>,
125        value: &[u8],
126    ) -> Result<usize, crate::ConnectionError> {
127        self.conn.raw_query_write(|conn| {
128            sql_query(REPLACE_QUERY)
129                .bind::<diesel::sql_types::Binary, _>(&storage_key)
130                .bind::<diesel::sql_types::Integer, _>(VERSION as i32)
131                .bind::<diesel::sql_types::Binary, _>(&value)
132                .execute(conn)
133        })
134    }
135
136    fn update_query<const VERSION: u16>(
137        &self,
138        storage_key: &Vec<u8>,
139        modified_data: &Vec<u8>,
140    ) -> Result<usize, crate::ConnectionError> {
141        self.conn.raw_query_write(|conn| {
142            sql_query(UPDATE_QUERY)
143                .bind::<diesel::sql_types::Binary, _>(&modified_data)
144                .bind::<diesel::sql_types::Binary, _>(&storage_key)
145                .bind::<diesel::sql_types::Integer, _>(VERSION as i32)
146                .execute(conn)
147        })
148    }
149
150    pub fn write<const VERSION: u16>(
151        &self,
152        label: &[u8],
153        key: &[u8],
154        value: &[u8],
155    ) -> Result<(), <Self as StorageProvider<CURRENT_VERSION>>::Error> {
156        let storage_key = build_key_from_vec::<VERSION>(label, key.to_vec());
157        let _ = self.replace_query::<VERSION>(&storage_key, value)?;
158        Ok(())
159    }
160
161    pub fn append<const VERSION: u16>(
162        &self,
163        label: &[u8],
164        key: &[u8],
165        value: &[u8],
166    ) -> Result<(), <Self as StorageProvider<CURRENT_VERSION>>::Error> {
167        tracing::trace!("append {}", String::from_utf8_lossy(label));
168
169        let storage_key = build_key_from_vec::<VERSION>(label, key.to_vec());
170        let data = self.select_query::<VERSION>(&storage_key)?;
171
172        if let Some(entry) = data.into_iter().next() {
173            // The value in the storage is an array of array of bytes
174            match bincode::deserialize::<Vec<Vec<u8>>>(&entry.value_bytes) {
175                Ok(mut deserialized) => {
176                    deserialized.push(value.to_vec());
177                    let modified_data = bincode::serialize(&deserialized)?;
178
179                    let _ = self.update_query::<VERSION>(&storage_key, &modified_data)?;
180                    Ok(())
181                }
182                Err(_e) => Err(SqlKeyStoreError::SerializationError),
183            }
184        } else {
185            // Add a first entry
186            let value_bytes = &bincode::serialize(&vec![value])?;
187            let _ = self.replace_query::<VERSION>(&storage_key, value_bytes)?;
188
189            Ok(())
190        }
191    }
192
193    pub fn remove_item<const VERSION: u16>(
194        &self,
195        label: &[u8],
196        key: &[u8],
197        value: &[u8],
198    ) -> Result<(), <Self as StorageProvider<CURRENT_VERSION>>::Error> {
199        tracing::trace!("remove_item {}", String::from_utf8_lossy(label));
200
201        let storage_key = build_key_from_vec::<VERSION>(label, key.to_vec());
202        let data: Vec<StorageData> = self.select_query::<VERSION>(&storage_key)?;
203
204        if let Some(entry) = data.into_iter().next() {
205            // The value in the storage is an array of array of bytes.
206            let mut deserialized = bincode::deserialize::<Vec<Vec<u8>>>(&entry.value_bytes)
207                .map_err(|_| SqlKeyStoreError::SerializationError)?;
208            let vpos = deserialized.iter().position(|v| v == value);
209
210            if let Some(pos) = vpos {
211                deserialized.remove(pos);
212            }
213            let modified_data = bincode::serialize(&deserialized)
214                .map_err(|_| SqlKeyStoreError::SerializationError)?;
215
216            let _ = self.update_query::<VERSION>(&storage_key, &modified_data)?;
217            Ok(())
218        } else {
219            // Add a first entry
220            let value_bytes =
221                bincode::serialize(&[value]).map_err(|_| SqlKeyStoreError::SerializationError)?;
222            let _ = self.replace_query::<VERSION>(&storage_key, &value_bytes)?;
223            Ok(())
224        }
225    }
226
227    pub fn read<const VERSION: u16, V: Entity<VERSION>>(
228        &self,
229        label: &[u8],
230        key: &[u8],
231    ) -> Result<Option<V>, <Self as StorageProvider<CURRENT_VERSION>>::Error> {
232        let storage_key = build_key_from_vec::<VERSION>(label, key.to_vec());
233
234        let data = self.select_query::<VERSION>(&storage_key)?;
235
236        if let Some(entry) = data.into_iter().next() {
237            let deserialized = bincode::deserialize::<V>(&entry.value_bytes)
238                .map_err(|_| SqlKeyStoreError::SerializationError)?;
239
240            Ok(Some(deserialized))
241        } else {
242            Ok(None)
243        }
244    }
245
246    pub fn read_list<const VERSION: u16, V: Entity<VERSION>>(
247        &self,
248        label: &[u8],
249        key: &[u8],
250    ) -> Result<Vec<V>, <Self as StorageProvider<CURRENT_VERSION>>::Error> {
251        let storage_key = build_key_from_vec::<VERSION>(label, key.to_vec());
252        let results = self.select_query::<VERSION>(&storage_key)?;
253
254        if let Some(entry) = results.into_iter().next() {
255            let list = bincode::deserialize::<Vec<Vec<u8>>>(&entry.value_bytes)?;
256
257            // Read the values from the bytes in the list
258            let mut deserialized_list = Vec::new();
259            for v in list {
260                match bincode::deserialize::<V>(&v) {
261                    Ok(deserialized_value) => deserialized_list.push(deserialized_value),
262                    Err(e) => {
263                        tracing::error!("Error occurred: {}", e);
264                        return Err(SqlKeyStoreError::SerializationError);
265                    }
266                }
267            }
268            Ok(deserialized_list)
269        } else {
270            Ok(vec![])
271        }
272    }
273
274    pub fn delete<const VERSION: u16>(
275        &self,
276        label: &[u8],
277        key: &[u8],
278    ) -> Result<(), <Self as StorageProvider<CURRENT_VERSION>>::Error> {
279        let storage_key = build_key_from_vec::<VERSION>(label, key.to_vec());
280        self.conn.raw_query_write(|conn| {
281            sql_query(DELETE_QUERY)
282                .bind::<diesel::sql_types::Binary, _>(&storage_key)
283                .bind::<diesel::sql_types::Integer, _>(VERSION as i32)
284                .execute(conn)
285        })?;
286        Ok(())
287    }
288}
289
290/// Errors thrown by the key store.
291/// General error type for Mls Storage Trait
292#[derive(thiserror::Error, Debug)]
293pub enum SqlKeyStoreError {
294    #[error("The key store does not allow storing serialized values.")]
295    UnsupportedValueTypeBytes,
296    #[error("Updating is not supported by this key store.")]
297    UnsupportedMethod,
298    #[error("Error serializing value.")]
299    SerializationError,
300    #[error("Value does not exist.")]
301    NotFound,
302    #[error("database error: {0}")]
303    Storage(#[from] diesel::result::Error),
304    #[error("connection {0}")]
305    Connection(#[from] crate::ConnectionError),
306}
307
308impl RetryableError for SqlKeyStoreError {
309    fn is_retryable(&self) -> bool {
310        use SqlKeyStoreError::*;
311        match self {
312            Storage(err) => retryable!(err),
313            SerializationError => false,
314            UnsupportedMethod => false,
315            UnsupportedValueTypeBytes => false,
316            NotFound => false,
317            Connection(c) => retryable!(c),
318        }
319    }
320}
321
322const KEY_PACKAGE_LABEL: &[u8] = b"KeyPackage";
323const ENCRYPTION_KEY_PAIR_LABEL: &[u8] = b"EncryptionKeyPair";
324const SIGNATURE_KEY_PAIR_LABEL: &[u8] = b"SignatureKeyPair";
325const EPOCH_KEY_PAIRS_LABEL: &[u8] = b"EpochKeyPairs";
326pub const KEY_PACKAGE_REFERENCES: &[u8] = b"KeyPackageReferences";
327pub const KEY_PACKAGE_WRAPPER_PRIVATE_KEY: &[u8] = b"KeyPackageWrapperPrivateKey";
328pub const COMMIT_LOG_SIGNER_PRIVATE_KEY: &[u8] = b"CommitLogSignerPrivateKey";
329
330// related to PublicGroup
331const TREE_LABEL: &[u8] = b"Tree";
332const GROUP_CONTEXT_LABEL: &[u8] = b"GroupContext";
333const INTERIM_TRANSCRIPT_HASH_LABEL: &[u8] = b"InterimTranscriptHash";
334const CONFIRMATION_TAG_LABEL: &[u8] = b"ConfirmationTag";
335
336// related to CoreGroup
337const OWN_LEAF_NODE_INDEX_LABEL: &[u8] = b"OwnLeafNodeIndex";
338const EPOCH_SECRETS_LABEL: &[u8] = b"EpochSecrets";
339const MESSAGE_SECRETS_LABEL: &[u8] = b"MessageSecrets";
340
341// related to MlsGroup
342const JOIN_CONFIG_LABEL: &[u8] = b"MlsGroupJoinConfig";
343const OWN_LEAF_NODES_LABEL: &[u8] = b"OwnLeafNodes";
344const GROUP_STATE_LABEL: &[u8] = b"GroupState";
345const QUEUED_PROPOSAL_LABEL: &[u8] = b"QueuedProposal";
346const PROPOSAL_QUEUE_REFS_LABEL: &[u8] = b"ProposalQueueRefs";
347const RESUMPTION_PSK_STORE_LABEL: &[u8] = b"ResumptionPskStore";
348
349impl<C> StorageProvider<CURRENT_VERSION> for SqlKeyStore<C>
350where
351    C: ConnectionExt,
352{
353    type Error = SqlKeyStoreError;
354
355    fn queue_proposal<
356        GroupId: traits::GroupId<CURRENT_VERSION>,
357        ProposalRef: traits::ProposalRef<CURRENT_VERSION>,
358        QueuedProposal: traits::QueuedProposal<CURRENT_VERSION>,
359    >(
360        &self,
361        group_id: &GroupId,
362        proposal_ref: &ProposalRef,
363        proposal: &QueuedProposal,
364    ) -> Result<(), Self::Error> {
365        // write proposal to key (group_id, proposal_ref)
366        let key = bincode::serialize(&(group_id, proposal_ref))?;
367        let value = bincode::serialize(proposal)?;
368        self.write::<CURRENT_VERSION>(QUEUED_PROPOSAL_LABEL, &key, &value)?;
369
370        // update proposal list for group_id
371        let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id)?;
372        let value = bincode::serialize(proposal_ref)?;
373        self.append::<CURRENT_VERSION>(PROPOSAL_QUEUE_REFS_LABEL, &key, &value)?;
374
375        Ok(())
376    }
377
378    fn write_tree<
379        GroupId: traits::GroupId<CURRENT_VERSION>,
380        TreeSync: traits::TreeSync<CURRENT_VERSION>,
381    >(
382        &self,
383        group_id: &GroupId,
384        tree: &TreeSync,
385    ) -> Result<(), Self::Error> {
386        let key = build_key::<CURRENT_VERSION, &GroupId>(TREE_LABEL, group_id)?;
387        let value = bincode::serialize(&tree)?;
388        self.write::<CURRENT_VERSION>(TREE_LABEL, &key, &value)
389    }
390
391    fn write_interim_transcript_hash<
392        GroupId: traits::GroupId<CURRENT_VERSION>,
393        InterimTranscriptHash: traits::InterimTranscriptHash<CURRENT_VERSION>,
394    >(
395        &self,
396        group_id: &GroupId,
397        interim_transcript_hash: &InterimTranscriptHash,
398    ) -> Result<(), Self::Error> {
399        let key = build_key::<CURRENT_VERSION, &GroupId>(INTERIM_TRANSCRIPT_HASH_LABEL, group_id)?;
400        let value = bincode::serialize(&interim_transcript_hash)?;
401        let _ = self.write::<CURRENT_VERSION>(INTERIM_TRANSCRIPT_HASH_LABEL, &key, &value);
402
403        Ok(())
404    }
405
406    fn write_context<
407        GroupId: traits::GroupId<CURRENT_VERSION>,
408        GroupContext: traits::GroupContext<CURRENT_VERSION>,
409    >(
410        &self,
411        group_id: &GroupId,
412        group_context: &GroupContext,
413    ) -> Result<(), Self::Error> {
414        let key = build_key::<CURRENT_VERSION, &GroupId>(GROUP_CONTEXT_LABEL, group_id)?;
415        let value = bincode::serialize(&group_context)?;
416
417        self.write::<CURRENT_VERSION>(GROUP_CONTEXT_LABEL, &key, &value)
418    }
419
420    fn write_confirmation_tag<
421        GroupId: traits::GroupId<CURRENT_VERSION>,
422        ConfirmationTag: traits::ConfirmationTag<CURRENT_VERSION>,
423    >(
424        &self,
425        group_id: &GroupId,
426        confirmation_tag: &ConfirmationTag,
427    ) -> Result<(), Self::Error> {
428        let key = build_key::<CURRENT_VERSION, &GroupId>(CONFIRMATION_TAG_LABEL, group_id)?;
429        let value = bincode::serialize(&confirmation_tag)?;
430
431        self.write::<CURRENT_VERSION>(CONFIRMATION_TAG_LABEL, &key, &value)
432    }
433
434    fn write_signature_key_pair<
435        SignaturePublicKey: traits::SignaturePublicKey<CURRENT_VERSION>,
436        SignatureKeyPair: traits::SignatureKeyPair<CURRENT_VERSION>,
437    >(
438        &self,
439        public_key: &SignaturePublicKey,
440        signature_key_pair: &SignatureKeyPair,
441    ) -> Result<(), Self::Error> {
442        let key = build_key::<CURRENT_VERSION, &SignaturePublicKey>(
443            SIGNATURE_KEY_PAIR_LABEL,
444            public_key,
445        )?;
446        let value = bincode::serialize(&signature_key_pair)?;
447
448        self.write::<CURRENT_VERSION>(SIGNATURE_KEY_PAIR_LABEL, &key, &value)
449    }
450
451    fn queued_proposal_refs<
452        GroupId: traits::GroupId<CURRENT_VERSION>,
453        ProposalRef: traits::ProposalRef<CURRENT_VERSION>,
454    >(
455        &self,
456        group_id: &GroupId,
457    ) -> Result<Vec<ProposalRef>, Self::Error> {
458        let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id)?;
459        self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &key)
460    }
461
462    fn queued_proposals<
463        GroupId: traits::GroupId<CURRENT_VERSION>,
464        ProposalRef: traits::ProposalRef<CURRENT_VERSION>,
465        QueuedProposal: traits::QueuedProposal<CURRENT_VERSION>,
466    >(
467        &self,
468        group_id: &GroupId,
469    ) -> Result<Vec<(ProposalRef, QueuedProposal)>, Self::Error> {
470        let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id)?;
471        let refs: Vec<ProposalRef> = self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &key)?;
472
473        refs.into_iter()
474            .map(|proposal_ref| -> Result<_, _> {
475                let key = bincode::serialize(&(group_id, &proposal_ref))?;
476                match self.read(QUEUED_PROPOSAL_LABEL, &key)? {
477                    Some(proposal) => Ok((proposal_ref, proposal)),
478                    None => Err(SqlKeyStoreError::NotFound),
479                }
480            })
481            .collect::<Result<Vec<_>, _>>()
482    }
483
484    fn tree<
485        GroupId: traits::GroupId<CURRENT_VERSION>,
486        TreeSync: traits::TreeSync<CURRENT_VERSION>,
487    >(
488        &self,
489        group_id: &GroupId,
490    ) -> Result<Option<TreeSync>, Self::Error> {
491        let key = build_key::<CURRENT_VERSION, &GroupId>(TREE_LABEL, group_id)?;
492
493        self.read(TREE_LABEL, &key)
494    }
495
496    fn group_context<
497        GroupId: traits::GroupId<CURRENT_VERSION>,
498        GroupContext: traits::GroupContext<CURRENT_VERSION>,
499    >(
500        &self,
501        group_id: &GroupId,
502    ) -> Result<Option<GroupContext>, Self::Error> {
503        let key = build_key::<CURRENT_VERSION, &GroupId>(GROUP_CONTEXT_LABEL, group_id)?;
504
505        self.read(GROUP_CONTEXT_LABEL, &key)
506    }
507
508    fn interim_transcript_hash<
509        GroupId: traits::GroupId<CURRENT_VERSION>,
510        InterimTranscriptHash: traits::InterimTranscriptHash<CURRENT_VERSION>,
511    >(
512        &self,
513        group_id: &GroupId,
514    ) -> Result<Option<InterimTranscriptHash>, Self::Error> {
515        let key = build_key::<CURRENT_VERSION, &GroupId>(INTERIM_TRANSCRIPT_HASH_LABEL, group_id)?;
516
517        self.read(INTERIM_TRANSCRIPT_HASH_LABEL, &key)
518    }
519
520    fn confirmation_tag<
521        GroupId: traits::GroupId<CURRENT_VERSION>,
522        ConfirmationTag: traits::ConfirmationTag<CURRENT_VERSION>,
523    >(
524        &self,
525        group_id: &GroupId,
526    ) -> Result<Option<ConfirmationTag>, Self::Error> {
527        let key = build_key::<CURRENT_VERSION, &GroupId>(CONFIRMATION_TAG_LABEL, group_id)?;
528
529        self.read(CONFIRMATION_TAG_LABEL, &key)
530    }
531
532    fn signature_key_pair<
533        SignaturePublicKey: traits::SignaturePublicKey<CURRENT_VERSION>,
534        SignatureKeyPair: traits::SignatureKeyPair<CURRENT_VERSION>,
535    >(
536        &self,
537        public_key: &SignaturePublicKey,
538    ) -> Result<Option<SignatureKeyPair>, Self::Error> {
539        let key = build_key::<CURRENT_VERSION, &SignaturePublicKey>(
540            SIGNATURE_KEY_PAIR_LABEL,
541            public_key,
542        )?;
543
544        self.read(SIGNATURE_KEY_PAIR_LABEL, &key)
545    }
546
547    fn write_key_package<
548        HashReference: traits::HashReference<CURRENT_VERSION>,
549        KeyPackage: traits::KeyPackage<CURRENT_VERSION>,
550    >(
551        &self,
552        hash_ref: &HashReference,
553        key_package: &KeyPackage,
554    ) -> Result<(), Self::Error> {
555        let key = build_key::<CURRENT_VERSION, &HashReference>(KEY_PACKAGE_LABEL, hash_ref)?;
556        let value = bincode::serialize(&key_package)?;
557
558        // Store the key package
559        self.write::<CURRENT_VERSION>(KEY_PACKAGE_LABEL, &key, &value)
560    }
561
562    fn write_psk<
563        PskId: traits::PskId<CURRENT_VERSION>,
564        PskBundle: traits::PskBundle<CURRENT_VERSION>,
565    >(
566        &self,
567        _psk_id: &PskId,
568        _psk: &PskBundle,
569    ) -> Result<(), Self::Error> {
570        Ok(())
571    }
572
573    fn write_encryption_key_pair<
574        EncryptionKey: traits::EncryptionKey<CURRENT_VERSION>,
575        HpkeKeyPair: traits::HpkeKeyPair<CURRENT_VERSION>,
576    >(
577        &self,
578        public_key: &EncryptionKey,
579        key_pair: &HpkeKeyPair,
580    ) -> Result<(), Self::Error> {
581        let key =
582            build_key::<CURRENT_VERSION, &EncryptionKey>(ENCRYPTION_KEY_PAIR_LABEL, public_key)?;
583
584        self.write::<CURRENT_VERSION>(
585            ENCRYPTION_KEY_PAIR_LABEL,
586            &key,
587            &bincode::serialize(key_pair)?,
588        )
589    }
590
591    fn key_package<
592        HashReference: traits::HashReference<CURRENT_VERSION>,
593        KeyPackage: traits::KeyPackage<CURRENT_VERSION>,
594    >(
595        &self,
596        hash_ref: &HashReference,
597    ) -> Result<Option<KeyPackage>, Self::Error> {
598        let key = build_key::<CURRENT_VERSION, &HashReference>(KEY_PACKAGE_LABEL, hash_ref)?;
599
600        self.read(KEY_PACKAGE_LABEL, &key)
601    }
602
603    fn psk<PskBundle: traits::PskBundle<CURRENT_VERSION>, PskId: traits::PskId<CURRENT_VERSION>>(
604        &self,
605        _psk_id: &PskId,
606    ) -> Result<Option<PskBundle>, Self::Error> {
607        Ok(None)
608    }
609
610    fn encryption_key_pair<
611        HpkeKeyPair: traits::HpkeKeyPair<CURRENT_VERSION>,
612        EncryptionKey: traits::EncryptionKey<CURRENT_VERSION>,
613    >(
614        &self,
615        public_key: &EncryptionKey,
616    ) -> Result<Option<HpkeKeyPair>, Self::Error> {
617        let key =
618            build_key::<CURRENT_VERSION, &EncryptionKey>(ENCRYPTION_KEY_PAIR_LABEL, public_key)?;
619
620        self.read(ENCRYPTION_KEY_PAIR_LABEL, &key)
621    }
622
623    fn delete_signature_key_pair<
624        SignaturePublicKey: traits::SignaturePublicKey<CURRENT_VERSION>,
625    >(
626        &self,
627        public_key: &SignaturePublicKey,
628    ) -> Result<(), Self::Error> {
629        let key = build_key::<CURRENT_VERSION, &SignaturePublicKey>(
630            SIGNATURE_KEY_PAIR_LABEL,
631            public_key,
632        )?;
633
634        self.delete::<CURRENT_VERSION>(SIGNATURE_KEY_PAIR_LABEL, &key)
635    }
636
637    fn delete_encryption_key_pair<EncryptionKey: traits::EncryptionKey<CURRENT_VERSION>>(
638        &self,
639        public_key: &EncryptionKey,
640    ) -> Result<(), Self::Error> {
641        let key =
642            build_key::<CURRENT_VERSION, &EncryptionKey>(ENCRYPTION_KEY_PAIR_LABEL, public_key)?;
643
644        self.delete::<CURRENT_VERSION>(ENCRYPTION_KEY_PAIR_LABEL, &key)
645    }
646
647    fn delete_key_package<HashReference: traits::HashReference<CURRENT_VERSION>>(
648        &self,
649        hash_ref: &HashReference,
650    ) -> Result<(), Self::Error> {
651        let key = build_key::<CURRENT_VERSION, &HashReference>(KEY_PACKAGE_LABEL, hash_ref)?;
652        self.delete::<CURRENT_VERSION>(KEY_PACKAGE_LABEL, &key)
653    }
654
655    fn delete_psk<PskKey: traits::PskId<CURRENT_VERSION>>(
656        &self,
657        _psk_id: &PskKey,
658    ) -> Result<(), Self::Error> {
659        Err(SqlKeyStoreError::UnsupportedMethod)
660    }
661
662    fn group_state<
663        GroupState: traits::GroupState<CURRENT_VERSION>,
664        GroupId: traits::GroupId<CURRENT_VERSION>,
665    >(
666        &self,
667        group_id: &GroupId,
668    ) -> Result<Option<GroupState>, Self::Error> {
669        let key = build_key::<CURRENT_VERSION, &GroupId>(GROUP_STATE_LABEL, group_id)?;
670
671        self.read(GROUP_STATE_LABEL, &key)
672    }
673
674    fn write_group_state<
675        GroupState: traits::GroupState<CURRENT_VERSION>,
676        GroupId: traits::GroupId<CURRENT_VERSION>,
677    >(
678        &self,
679        group_id: &GroupId,
680        group_state: &GroupState,
681    ) -> Result<(), Self::Error> {
682        let key = build_key::<CURRENT_VERSION, &GroupId>(GROUP_STATE_LABEL, group_id)?;
683
684        self.write::<CURRENT_VERSION>(GROUP_STATE_LABEL, &key, &bincode::serialize(group_state)?)
685    }
686
687    fn delete_group_state<GroupId: traits::GroupId<CURRENT_VERSION>>(
688        &self,
689        group_id: &GroupId,
690    ) -> Result<(), Self::Error> {
691        let key = build_key::<CURRENT_VERSION, &GroupId>(GROUP_STATE_LABEL, group_id)?;
692
693        self.delete::<CURRENT_VERSION>(GROUP_STATE_LABEL, &key)
694    }
695
696    fn message_secrets<
697        GroupId: traits::GroupId<CURRENT_VERSION>,
698        MessageSecrets: traits::MessageSecrets<CURRENT_VERSION>,
699    >(
700        &self,
701        group_id: &GroupId,
702    ) -> Result<Option<MessageSecrets>, Self::Error> {
703        let key = build_key::<CURRENT_VERSION, &GroupId>(MESSAGE_SECRETS_LABEL, group_id)?;
704
705        self.read(MESSAGE_SECRETS_LABEL, &key)
706    }
707
708    fn write_message_secrets<
709        GroupId: traits::GroupId<CURRENT_VERSION>,
710        MessageSecrets: traits::MessageSecrets<CURRENT_VERSION>,
711    >(
712        &self,
713        group_id: &GroupId,
714        message_secrets: &MessageSecrets,
715    ) -> Result<(), Self::Error> {
716        let key = build_key::<CURRENT_VERSION, &GroupId>(MESSAGE_SECRETS_LABEL, group_id)?;
717
718        self.write::<CURRENT_VERSION>(
719            MESSAGE_SECRETS_LABEL,
720            &key,
721            &bincode::serialize(message_secrets)?,
722        )
723    }
724
725    fn delete_message_secrets<GroupId: traits::GroupId<CURRENT_VERSION>>(
726        &self,
727        group_id: &GroupId,
728    ) -> Result<(), Self::Error> {
729        let key = build_key::<CURRENT_VERSION, &GroupId>(MESSAGE_SECRETS_LABEL, group_id)?;
730
731        self.delete::<CURRENT_VERSION>(MESSAGE_SECRETS_LABEL, &key)
732    }
733
734    fn resumption_psk_store<
735        GroupId: traits::GroupId<CURRENT_VERSION>,
736        ResumptionPskStore: traits::ResumptionPskStore<CURRENT_VERSION>,
737    >(
738        &self,
739        group_id: &GroupId,
740    ) -> Result<Option<ResumptionPskStore>, Self::Error> {
741        self.read(RESUMPTION_PSK_STORE_LABEL, &bincode::serialize(group_id)?)
742    }
743
744    fn write_resumption_psk_store<
745        GroupId: traits::GroupId<CURRENT_VERSION>,
746        ResumptionPskStore: traits::ResumptionPskStore<CURRENT_VERSION>,
747    >(
748        &self,
749        group_id: &GroupId,
750        resumption_psk_store: &ResumptionPskStore,
751    ) -> Result<(), Self::Error> {
752        self.write::<CURRENT_VERSION>(
753            RESUMPTION_PSK_STORE_LABEL,
754            &bincode::serialize(group_id)?,
755            &bincode::serialize(resumption_psk_store)?,
756        )
757    }
758
759    fn delete_all_resumption_psk_secrets<GroupId: traits::GroupId<CURRENT_VERSION>>(
760        &self,
761        group_id: &GroupId,
762    ) -> Result<(), Self::Error> {
763        self.delete::<CURRENT_VERSION>(RESUMPTION_PSK_STORE_LABEL, &bincode::serialize(group_id)?)
764    }
765
766    fn own_leaf_index<
767        GroupId: traits::GroupId<CURRENT_VERSION>,
768        LeafNodeIndex: traits::LeafNodeIndex<CURRENT_VERSION>,
769    >(
770        &self,
771        group_id: &GroupId,
772    ) -> Result<Option<LeafNodeIndex>, Self::Error> {
773        let key = build_key::<CURRENT_VERSION, &GroupId>(OWN_LEAF_NODE_INDEX_LABEL, group_id)?;
774        self.read(OWN_LEAF_NODE_INDEX_LABEL, &key)
775    }
776
777    fn write_own_leaf_index<
778        GroupId: traits::GroupId<CURRENT_VERSION>,
779        LeafNodeIndex: traits::LeafNodeIndex<CURRENT_VERSION>,
780    >(
781        &self,
782        group_id: &GroupId,
783        own_leaf_index: &LeafNodeIndex,
784    ) -> Result<(), Self::Error> {
785        let key = build_key::<CURRENT_VERSION, &GroupId>(OWN_LEAF_NODE_INDEX_LABEL, group_id)?;
786        self.write::<CURRENT_VERSION>(
787            OWN_LEAF_NODE_INDEX_LABEL,
788            &key,
789            &bincode::serialize(own_leaf_index)?,
790        )
791    }
792
793    fn delete_own_leaf_index<GroupId: traits::GroupId<CURRENT_VERSION>>(
794        &self,
795        group_id: &GroupId,
796    ) -> Result<(), Self::Error> {
797        let key = build_key::<CURRENT_VERSION, &GroupId>(OWN_LEAF_NODE_INDEX_LABEL, group_id)?;
798        self.delete::<CURRENT_VERSION>(OWN_LEAF_NODE_INDEX_LABEL, &key)
799    }
800
801    fn group_epoch_secrets<
802        GroupId: traits::GroupId<CURRENT_VERSION>,
803        GroupEpochSecrets: traits::GroupEpochSecrets<CURRENT_VERSION>,
804    >(
805        &self,
806        group_id: &GroupId,
807    ) -> Result<Option<GroupEpochSecrets>, Self::Error> {
808        let key = build_key::<CURRENT_VERSION, &GroupId>(EPOCH_SECRETS_LABEL, group_id)?;
809        self.read(EPOCH_SECRETS_LABEL, &key)
810    }
811
812    fn write_group_epoch_secrets<
813        GroupId: traits::GroupId<CURRENT_VERSION>,
814        GroupEpochSecrets: traits::GroupEpochSecrets<CURRENT_VERSION>,
815    >(
816        &self,
817        group_id: &GroupId,
818        group_epoch_secrets: &GroupEpochSecrets,
819    ) -> Result<(), Self::Error> {
820        let key = build_key::<CURRENT_VERSION, &GroupId>(EPOCH_SECRETS_LABEL, group_id)?;
821        self.write::<CURRENT_VERSION>(
822            EPOCH_SECRETS_LABEL,
823            &key,
824            &bincode::serialize(group_epoch_secrets)?,
825        )
826    }
827
828    fn delete_group_epoch_secrets<GroupId: traits::GroupId<CURRENT_VERSION>>(
829        &self,
830        group_id: &GroupId,
831    ) -> Result<(), Self::Error> {
832        let key = build_key::<CURRENT_VERSION, &GroupId>(EPOCH_SECRETS_LABEL, group_id)?;
833        self.delete::<CURRENT_VERSION>(EPOCH_SECRETS_LABEL, &key)
834    }
835
836    fn write_encryption_epoch_key_pairs<
837        GroupId: traits::GroupId<CURRENT_VERSION>,
838        EpochKey: traits::EpochKey<CURRENT_VERSION>,
839        HpkeKeyPair: traits::HpkeKeyPair<CURRENT_VERSION>,
840    >(
841        &self,
842        group_id: &GroupId,
843        epoch: &EpochKey,
844        leaf_index: u32,
845        key_pairs: &[HpkeKeyPair],
846    ) -> Result<(), Self::Error> {
847        let key = epoch_key_pairs_id(group_id, epoch, leaf_index)?;
848        let value = bincode::serialize(key_pairs)?;
849        tracing::trace!("Writing encryption epoch key pairs");
850
851        self.write::<CURRENT_VERSION>(EPOCH_KEY_PAIRS_LABEL, &key, &value)
852    }
853
854    fn encryption_epoch_key_pairs<
855        GroupId: traits::GroupId<CURRENT_VERSION>,
856        EpochKey: traits::EpochKey<CURRENT_VERSION>,
857        HpkeKeyPair: traits::HpkeKeyPair<CURRENT_VERSION>,
858    >(
859        &self,
860        group_id: &GroupId,
861        epoch: &EpochKey,
862        leaf_index: u32,
863    ) -> Result<Vec<HpkeKeyPair>, Self::Error> {
864        tracing::trace!("Reading encryption epoch key pairs");
865
866        let key = epoch_key_pairs_id(group_id, epoch, leaf_index)?;
867        let storage_key = build_key_from_vec::<CURRENT_VERSION>(EPOCH_KEY_PAIRS_LABEL, key);
868        tracing::trace!("  key: {}", hex::encode(&storage_key));
869
870        let query = "SELECT value_bytes FROM openmls_key_value WHERE key_bytes = ? AND version = ?";
871
872        let data: Vec<StorageData> = self.conn.raw_query_read(|conn| {
873            sql_query(query)
874                .bind::<diesel::sql_types::Binary, _>(&storage_key)
875                .bind::<diesel::sql_types::Integer, _>(CURRENT_VERSION as i32)
876                .load(conn)
877        })?;
878
879        if let Some(entry) = data.into_iter().next() {
880            match bincode::deserialize::<Vec<HpkeKeyPair>>(&entry.value_bytes) {
881                Ok(deserialized) => Ok(deserialized),
882                Err(e) => {
883                    eprintln!("Error occurred: {}", e);
884                    Err(SqlKeyStoreError::SerializationError)
885                }
886            }
887        } else {
888            Ok(vec![])
889        }
890    }
891
892    fn delete_encryption_epoch_key_pairs<
893        GroupId: traits::GroupId<CURRENT_VERSION>,
894        EpochKey: traits::EpochKey<CURRENT_VERSION>,
895    >(
896        &self,
897        group_id: &GroupId,
898        epoch: &EpochKey,
899        leaf_index: u32,
900    ) -> Result<(), Self::Error> {
901        let key = epoch_key_pairs_id(group_id, epoch, leaf_index)?;
902
903        self.delete::<CURRENT_VERSION>(EPOCH_KEY_PAIRS_LABEL, &key)
904    }
905
906    fn clear_proposal_queue<
907        GroupId: traits::GroupId<CURRENT_VERSION>,
908        ProposalRef: traits::ProposalRef<CURRENT_VERSION>,
909    >(
910        &self,
911        group_id: &GroupId,
912    ) -> Result<(), Self::Error> {
913        let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id)?;
914        let proposal_refs: Vec<ProposalRef> = self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &key)?;
915
916        for proposal_ref in proposal_refs {
917            let key = bincode::serialize(&(group_id, proposal_ref))?;
918            self.delete::<CURRENT_VERSION>(QUEUED_PROPOSAL_LABEL, &key)?;
919        }
920
921        let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id)?;
922
923        self.delete::<CURRENT_VERSION>(PROPOSAL_QUEUE_REFS_LABEL, &key)
924    }
925
926    fn mls_group_join_config<
927        GroupId: traits::GroupId<CURRENT_VERSION>,
928        MlsGroupJoinConfig: traits::MlsGroupJoinConfig<CURRENT_VERSION>,
929    >(
930        &self,
931        group_id: &GroupId,
932    ) -> Result<Option<MlsGroupJoinConfig>, Self::Error> {
933        let key = build_key::<CURRENT_VERSION, &GroupId>(JOIN_CONFIG_LABEL, group_id)?;
934
935        self.read(JOIN_CONFIG_LABEL, &key)
936    }
937
938    fn write_mls_join_config<
939        GroupId: traits::GroupId<CURRENT_VERSION>,
940        MlsGroupJoinConfig: traits::MlsGroupJoinConfig<CURRENT_VERSION>,
941    >(
942        &self,
943        group_id: &GroupId,
944        config: &MlsGroupJoinConfig,
945    ) -> Result<(), Self::Error> {
946        let key = build_key::<CURRENT_VERSION, &GroupId>(JOIN_CONFIG_LABEL, group_id)?;
947        let value = bincode::serialize(config)?;
948
949        self.write::<CURRENT_VERSION>(JOIN_CONFIG_LABEL, &key, &value)
950    }
951
952    fn own_leaf_nodes<
953        GroupId: traits::GroupId<CURRENT_VERSION>,
954        LeafNode: traits::LeafNode<CURRENT_VERSION>,
955    >(
956        &self,
957        group_id: &GroupId,
958    ) -> Result<Vec<LeafNode>, Self::Error> {
959        tracing::trace!("own_leaf_nodes");
960        let key = build_key::<CURRENT_VERSION, &GroupId>(OWN_LEAF_NODES_LABEL, group_id)?;
961
962        self.read_list(OWN_LEAF_NODES_LABEL, &key)
963    }
964
965    fn append_own_leaf_node<
966        GroupId: traits::GroupId<CURRENT_VERSION>,
967        LeafNode: traits::LeafNode<CURRENT_VERSION>,
968    >(
969        &self,
970        group_id: &GroupId,
971        leaf_node: &LeafNode,
972    ) -> Result<(), Self::Error> {
973        let key = build_key::<CURRENT_VERSION, &GroupId>(OWN_LEAF_NODES_LABEL, group_id)?;
974        let value = bincode::serialize(leaf_node)?;
975
976        self.append::<CURRENT_VERSION>(OWN_LEAF_NODES_LABEL, &key, &value)
977    }
978
979    fn delete_own_leaf_nodes<GroupId: traits::GroupId<CURRENT_VERSION>>(
980        &self,
981        group_id: &GroupId,
982    ) -> Result<(), Self::Error> {
983        let key = build_key::<CURRENT_VERSION, &GroupId>(OWN_LEAF_NODES_LABEL, group_id)?;
984        self.delete::<CURRENT_VERSION>(OWN_LEAF_NODES_LABEL, &key)
985    }
986
987    fn delete_group_config<GroupId: traits::GroupId<CURRENT_VERSION>>(
988        &self,
989        group_id: &GroupId,
990    ) -> Result<(), Self::Error> {
991        let key = build_key::<CURRENT_VERSION, &GroupId>(JOIN_CONFIG_LABEL, group_id)?;
992        self.delete::<CURRENT_VERSION>(JOIN_CONFIG_LABEL, &key)
993    }
994
995    fn delete_tree<GroupId: traits::GroupId<CURRENT_VERSION>>(
996        &self,
997        group_id: &GroupId,
998    ) -> Result<(), Self::Error> {
999        let key = build_key::<CURRENT_VERSION, &GroupId>(TREE_LABEL, group_id)?;
1000
1001        self.delete::<CURRENT_VERSION>(TREE_LABEL, &key)
1002    }
1003
1004    fn delete_confirmation_tag<GroupId: traits::GroupId<CURRENT_VERSION>>(
1005        &self,
1006        group_id: &GroupId,
1007    ) -> Result<(), Self::Error> {
1008        let key = build_key::<CURRENT_VERSION, &GroupId>(CONFIRMATION_TAG_LABEL, group_id)?;
1009
1010        self.delete::<CURRENT_VERSION>(CONFIRMATION_TAG_LABEL, &key)
1011    }
1012
1013    fn delete_context<GroupId: traits::GroupId<CURRENT_VERSION>>(
1014        &self,
1015        group_id: &GroupId,
1016    ) -> Result<(), Self::Error> {
1017        let key = build_key::<CURRENT_VERSION, &GroupId>(GROUP_CONTEXT_LABEL, group_id)?;
1018
1019        self.delete::<CURRENT_VERSION>(GROUP_CONTEXT_LABEL, &key)
1020    }
1021
1022    fn delete_interim_transcript_hash<GroupId: traits::GroupId<CURRENT_VERSION>>(
1023        &self,
1024        group_id: &GroupId,
1025    ) -> Result<(), Self::Error> {
1026        let key = build_key::<CURRENT_VERSION, &GroupId>(INTERIM_TRANSCRIPT_HASH_LABEL, group_id)?;
1027
1028        self.delete::<CURRENT_VERSION>(INTERIM_TRANSCRIPT_HASH_LABEL, &key)
1029    }
1030
1031    fn remove_proposal<
1032        GroupId: traits::GroupId<CURRENT_VERSION>,
1033        ProposalRef: traits::ProposalRef<CURRENT_VERSION>,
1034    >(
1035        &self,
1036        group_id: &GroupId,
1037        proposal_ref: &ProposalRef,
1038    ) -> Result<(), Self::Error> {
1039        // Delete the proposal ref
1040        let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id)?;
1041        let value = bincode::serialize(proposal_ref)?;
1042        self.remove_item::<CURRENT_VERSION>(PROPOSAL_QUEUE_REFS_LABEL, &key, &value)?;
1043
1044        // Delete the proposal
1045        let key = bincode::serialize(&(group_id, proposal_ref))?;
1046        self.delete::<CURRENT_VERSION>(QUEUED_PROPOSAL_LABEL, &key)
1047    }
1048}
1049
1050/// Build a key with version and label.
1051fn build_key_from_vec<const V: u16>(label: &[u8], key: Vec<u8>) -> Vec<u8> {
1052    let mut key_out = label.to_vec();
1053    key_out.extend_from_slice(&key);
1054    key_out.extend_from_slice(&u16::to_be_bytes(V));
1055    key_out
1056}
1057
1058/// Build a key with version and label.
1059fn build_key<const V: u16, K: Serialize>(
1060    label: &[u8],
1061    key: K,
1062) -> Result<Vec<u8>, SqlKeyStoreError> {
1063    let key_vec = bincode::serialize(&key)?;
1064    Ok(build_key_from_vec::<V>(label, key_vec))
1065}
1066
1067fn epoch_key_pairs_id(
1068    group_id: &impl traits::GroupId<CURRENT_VERSION>,
1069    epoch: &impl traits::EpochKey<CURRENT_VERSION>,
1070    leaf_index: u32,
1071) -> Result<Vec<u8>, SqlKeyStoreError> {
1072    let mut key = bincode::serialize(group_id)?;
1073    key.extend_from_slice(&bincode::serialize(epoch)?);
1074    key.extend_from_slice(&bincode::serialize(&leaf_index)?);
1075    Ok(key)
1076}
1077
1078impl From<bincode::Error> for SqlKeyStoreError {
1079    fn from(_: bincode::Error) -> Self {
1080        Self::SerializationError
1081    }
1082}
1083
1084#[cfg(any(test, feature = "test-utils"))]
1085impl SqlKeyStore<crate::test_utils::MemoryStorage> {
1086    pub fn kv_pairs(&self) -> String {
1087        self.conn.key_value_pairs()
1088    }
1089
1090    pub fn kv_pairs_utf8(&self) -> String {
1091        self.conn.key_value_pairs_utf8()
1092    }
1093}
1094
1095#[cfg(test)]
1096pub(crate) mod tests {
1097    #[cfg(target_arch = "wasm32")]
1098    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
1099
1100    use openmls::group::GroupId;
1101    use openmls_basic_credential::{SignatureKeyPair, StorageId};
1102    use openmls_traits::{
1103        OpenMlsProvider,
1104        storage::{
1105            CURRENT_VERSION, Entity, Key, StorageProvider,
1106            traits::{self},
1107        },
1108    };
1109    use serde::{Deserialize, Serialize};
1110
1111    use super::SqlKeyStore;
1112    use crate::encrypted_store::MlsProviderExt;
1113    use crate::{
1114        XmtpTestDb, sql_key_store::SqlKeyStoreError, xmtp_openmls_provider::XmtpOpenMlsProvider,
1115    };
1116    use xmtp_cryptography::configuration::CIPHERSUITE;
1117
1118    #[xmtp_common::test]
1119    async fn store_read_delete() {
1120        let store = crate::TestDb::create_persistent_store(None).await;
1121        let conn = store.conn();
1122        let key_store = SqlKeyStore::new(conn);
1123
1124        let signature_keys = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm()).unwrap();
1125        let public_key = StorageId::from(signature_keys.to_public_vec());
1126        assert!(
1127            key_store
1128                .signature_key_pair::<StorageId, SignatureKeyPair>(&public_key)
1129                .unwrap()
1130                .is_none()
1131        );
1132
1133        key_store
1134            .write_signature_key_pair::<StorageId, SignatureKeyPair>(&public_key, &signature_keys)
1135            .unwrap();
1136
1137        assert!(
1138            key_store
1139                .signature_key_pair::<StorageId, SignatureKeyPair>(&public_key)
1140                .unwrap()
1141                .is_some()
1142        );
1143
1144        key_store
1145            .delete_signature_key_pair::<StorageId>(&public_key)
1146            .unwrap();
1147
1148        assert!(
1149            key_store
1150                .signature_key_pair::<StorageId, SignatureKeyPair>(&public_key)
1151                .unwrap()
1152                .is_none()
1153        );
1154    }
1155
1156    #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
1157    struct Proposal(Vec<u8>);
1158    impl traits::QueuedProposal<CURRENT_VERSION> for Proposal {}
1159    impl Entity<CURRENT_VERSION> for Proposal {}
1160
1161    #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)]
1162    struct ProposalRef(usize);
1163    impl traits::ProposalRef<CURRENT_VERSION> for ProposalRef {}
1164    impl Key<CURRENT_VERSION> for ProposalRef {}
1165    impl Entity<CURRENT_VERSION> for ProposalRef {}
1166
1167    #[xmtp_common::test(unwrap_try = true)]
1168    async fn test_read_write() {
1169        let store = crate::TestDb::create_persistent_store(None).await;
1170        let conn = store.conn();
1171        let mls_store = SqlKeyStore::new(conn);
1172        let provider = XmtpOpenMlsProvider::new(mls_store);
1173        let key_store = provider.key_store();
1174
1175        let raw_value = vec![3u8; 32];
1176        let group_1 = bincode::serialize(&[1u8; 32])?;
1177        let group_2 = bincode::serialize(&[2u8; 32])?;
1178        let value_1 = bincode::serialize(&raw_value)?;
1179
1180        key_store.write::<CURRENT_VERSION>(
1181            crate::sql_key_store::COMMIT_LOG_SIGNER_PRIVATE_KEY,
1182            &group_1,
1183            &value_1,
1184        )?;
1185
1186        // Query on a value that hasn't been written
1187        let result = key_store.read::<CURRENT_VERSION, Vec<u8>>(
1188            crate::sql_key_store::COMMIT_LOG_SIGNER_PRIVATE_KEY,
1189            &group_2,
1190        );
1191        assert!(result.is_ok(), "{}", result.err().unwrap());
1192        assert!(result.unwrap().is_none());
1193
1194        let result = key_store.read::<CURRENT_VERSION, Vec<u8>>(
1195            crate::sql_key_store::COMMIT_LOG_SIGNER_PRIVATE_KEY,
1196            &group_1,
1197        );
1198        assert!(result.is_ok(), "{}", result.err().unwrap());
1199        assert_eq!(result.unwrap(), Some(raw_value));
1200    }
1201
1202    #[xmtp_common::test]
1203    async fn list_append_remove() {
1204        let store = crate::TestDb::create_persistent_store(None).await;
1205        let conn = store.conn();
1206        let mls_store = SqlKeyStore::new(conn);
1207        let provider = XmtpOpenMlsProvider::new(mls_store);
1208        let group_id = GroupId::random(provider.rand());
1209        let proposals = (0..10)
1210            .map(|i| Proposal(format!("TestProposal{i}").as_bytes().to_vec()))
1211            .collect::<Vec<_>>();
1212
1213        // Store proposals
1214        for (i, proposal) in proposals.iter().enumerate() {
1215            provider
1216                .storage()
1217                .queue_proposal::<GroupId, ProposalRef, Proposal>(
1218                    &group_id,
1219                    &ProposalRef(i),
1220                    proposal,
1221                )
1222                .expect("Failed to queue proposal");
1223        }
1224
1225        tracing::trace!("Finished with queued proposals");
1226        // Read proposal refs
1227        let proposal_refs_read: Vec<ProposalRef> = provider
1228            .storage()
1229            .queued_proposal_refs(&group_id)
1230            .expect("Failed to read proposal refs");
1231        assert_eq!(
1232            (0..10).map(ProposalRef).collect::<Vec<_>>(),
1233            proposal_refs_read
1234        );
1235
1236        // Read proposals
1237        let proposals_read: Vec<(ProposalRef, Proposal)> =
1238            provider.storage().queued_proposals(&group_id).unwrap();
1239        let proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10)
1240            .map(ProposalRef)
1241            .zip(proposals.clone().into_iter())
1242            .collect();
1243        assert_eq!(proposals_expected, proposals_read);
1244
1245        // Remove proposal 5
1246        provider
1247            .storage()
1248            .remove_proposal(&group_id, &ProposalRef(5))
1249            .unwrap();
1250
1251        let proposal_refs_read: Vec<ProposalRef> =
1252            provider.storage().queued_proposal_refs(&group_id).unwrap();
1253        let mut expected = (0..10).map(ProposalRef).collect::<Vec<_>>();
1254        expected.remove(5);
1255        assert_eq!(expected, proposal_refs_read);
1256
1257        let proposals_read: Vec<(ProposalRef, Proposal)> =
1258            provider.storage().queued_proposals(&group_id).unwrap();
1259        let mut proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10)
1260            .map(ProposalRef)
1261            .zip(proposals.clone().into_iter())
1262            .collect();
1263        proposals_expected.remove(5);
1264        assert_eq!(proposals_expected, proposals_read);
1265
1266        // Clear all proposals
1267        provider
1268            .storage()
1269            .clear_proposal_queue::<GroupId, ProposalRef>(&group_id)
1270            .unwrap();
1271        let proposal_refs_read: Result<Vec<ProposalRef>, SqlKeyStoreError> =
1272            provider.storage().queued_proposal_refs(&group_id);
1273        assert!(proposal_refs_read.unwrap().is_empty());
1274
1275        let proposals_read: Result<Vec<(ProposalRef, Proposal)>, SqlKeyStoreError> =
1276            provider.storage().queued_proposals(&group_id);
1277        assert!(proposals_read.unwrap().is_empty());
1278    }
1279
1280    #[xmtp_common::test]
1281    async fn group_state() {
1282        let store = crate::TestDb::create_persistent_store(None).await;
1283        let conn = store.conn();
1284        let store = SqlKeyStore::new(conn);
1285        let provider = XmtpOpenMlsProvider::new(store);
1286
1287        #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)]
1288        struct GroupState(usize);
1289        impl traits::GroupState<CURRENT_VERSION> for GroupState {}
1290        impl Entity<CURRENT_VERSION> for GroupState {}
1291
1292        let group_id = GroupId::random(provider.rand());
1293
1294        // Group state
1295        provider
1296            .storage()
1297            .write_group_state(&group_id, &GroupState(77))
1298            .unwrap();
1299
1300        // Read group state
1301        let group_state: Option<GroupState> = provider.storage().group_state(&group_id).unwrap();
1302        assert_eq!(GroupState(77), group_state.unwrap());
1303    }
1304}