1use xmtp_common::{RetryableError, retryable};
2
3use self::transactions::MutableTransactionConnection;
4use crate::{ConnectionExt, TransactionalKeyStore, XmtpMlsStorageProvider};
5
6use bincode;
7use diesel::{
8 prelude::*,
9 sql_types::Binary,
10 {RunQueryDsl, sql_query},
11};
12use openmls_traits::storage::*;
13use serde::Serialize;
14
15#[cfg(any(feature = "test-utils", test))]
16pub mod mock;
17mod transactions;
18
19const SELECT_QUERY: &str =
20 "SELECT value_bytes FROM openmls_key_value WHERE key_bytes = ? AND version = ?";
21const REPLACE_QUERY: &str =
22 "REPLACE INTO openmls_key_value (key_bytes, version, value_bytes) VALUES (?, ?, ?)";
23const UPDATE_QUERY: &str =
24 "UPDATE openmls_key_value SET value_bytes = ? WHERE key_bytes = ? AND version = ?";
25const DELETE_QUERY: &str = "DELETE FROM openmls_key_value WHERE key_bytes = ? AND version = ?";
26
27#[cfg(feature = "test-utils")]
28#[derive(
29 Selectable, Queryable, QueryableByName, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash,
30)]
31#[diesel(table_name = crate::schema::openmls_key_value)]
32pub struct OpenMlsKeyValue {
33 pub version: i32,
34 pub key_bytes: Vec<u8>,
35 pub value_bytes: Vec<u8>,
36}
37
38#[cfg(feature = "test-utils")]
39impl OpenMlsKeyValue {
40 pub fn hash_all(conn: &mut SqliteConnection) -> Result<Vec<u8>, diesel::result::Error> {
41 use crate::schema::openmls_key_value;
42 use xmtp_common::Sha2Digest;
43 let values = openmls_key_value::table
44 .order(openmls_key_value::version.asc())
45 .order(openmls_key_value::key_bytes.asc())
46 .load_iter::<OpenMlsKeyValue, _>(conn)?;
47
48 let mut hasher = xmtp_common::Sha256Digest::new();
49 for (i, result) in values.enumerate() {
50 let value = result?;
51 hasher.update(b"version");
52 hasher.update(value.version.to_be_bytes());
53 hasher.update(b"key_bytes");
54 hasher.update(&value.key_bytes);
55 hasher.update(b"value_bytes");
56 hasher.update(&value.value_bytes);
57 hasher.update(b"index");
58 hasher.update(i.to_be_bytes());
59 hasher.update(b"\n");
60 }
61 Ok(hasher.finalize().to_vec())
62 }
63}
64
65#[derive(QueryableByName, Debug, Clone, PartialEq, Eq)]
66#[diesel(table_name = openmls_key_value)]
67struct StorageData {
68 #[diesel(sql_type = Binary)]
69 value_bytes: Vec<u8>,
70}
71
72impl TransactionalKeyStore for diesel::SqliteConnection {
73 type Store<'a>
74 = SqlKeyStore<MutableTransactionConnection<'a>>
75 where
76 Self: 'a;
77
78 fn key_store<'a>(&'a mut self) -> Self::Store<'a> {
79 SqlKeyStore::new_transactional(self)
80 }
81}
82
83#[derive(Clone)]
84pub struct SqlKeyStore<T> {
85 conn: T,
87}
88
89impl<A> SqlKeyStore<A> {
90 pub fn new(conn: A) -> Self {
91 Self { conn }
92 }
93}
94
95impl<'a> SqlKeyStore<SqliteConnection> {
96 pub fn new_transactional(
97 conn: &'a mut SqliteConnection,
98 ) -> SqlKeyStore<MutableTransactionConnection<'a>> {
99 SqlKeyStore {
100 conn: MutableTransactionConnection::new(conn),
101 }
102 }
103}
104
105impl<C> SqlKeyStore<C>
107where
108 C: ConnectionExt,
109{
110 fn select_query<const VERSION: u16>(
111 &self,
112 storage_key: &Vec<u8>,
113 ) -> Result<Vec<StorageData>, crate::ConnectionError> {
114 self.conn.raw_query_read(|conn| {
115 sql_query(SELECT_QUERY)
116 .bind::<diesel::sql_types::Binary, _>(&storage_key)
117 .bind::<diesel::sql_types::Integer, _>(VERSION as i32)
118 .load(conn)
119 })
120 }
121
122 fn replace_query<const VERSION: u16>(
123 &self,
124 storage_key: &Vec<u8>,
125 value: &[u8],
126 ) -> Result<usize, crate::ConnectionError> {
127 self.conn.raw_query_write(|conn| {
128 sql_query(REPLACE_QUERY)
129 .bind::<diesel::sql_types::Binary, _>(&storage_key)
130 .bind::<diesel::sql_types::Integer, _>(VERSION as i32)
131 .bind::<diesel::sql_types::Binary, _>(&value)
132 .execute(conn)
133 })
134 }
135
136 fn update_query<const VERSION: u16>(
137 &self,
138 storage_key: &Vec<u8>,
139 modified_data: &Vec<u8>,
140 ) -> Result<usize, crate::ConnectionError> {
141 self.conn.raw_query_write(|conn| {
142 sql_query(UPDATE_QUERY)
143 .bind::<diesel::sql_types::Binary, _>(&modified_data)
144 .bind::<diesel::sql_types::Binary, _>(&storage_key)
145 .bind::<diesel::sql_types::Integer, _>(VERSION as i32)
146 .execute(conn)
147 })
148 }
149
150 pub fn write<const VERSION: u16>(
151 &self,
152 label: &[u8],
153 key: &[u8],
154 value: &[u8],
155 ) -> Result<(), <Self as StorageProvider<CURRENT_VERSION>>::Error> {
156 let storage_key = build_key_from_vec::<VERSION>(label, key.to_vec());
157 let _ = self.replace_query::<VERSION>(&storage_key, value)?;
158 Ok(())
159 }
160
161 pub fn append<const VERSION: u16>(
162 &self,
163 label: &[u8],
164 key: &[u8],
165 value: &[u8],
166 ) -> Result<(), <Self as StorageProvider<CURRENT_VERSION>>::Error> {
167 tracing::trace!("append {}", String::from_utf8_lossy(label));
168
169 let storage_key = build_key_from_vec::<VERSION>(label, key.to_vec());
170 let data = self.select_query::<VERSION>(&storage_key)?;
171
172 if let Some(entry) = data.into_iter().next() {
173 match bincode::deserialize::<Vec<Vec<u8>>>(&entry.value_bytes) {
175 Ok(mut deserialized) => {
176 deserialized.push(value.to_vec());
177 let modified_data = bincode::serialize(&deserialized)?;
178
179 let _ = self.update_query::<VERSION>(&storage_key, &modified_data)?;
180 Ok(())
181 }
182 Err(_e) => Err(SqlKeyStoreError::SerializationError),
183 }
184 } else {
185 let value_bytes = &bincode::serialize(&vec![value])?;
187 let _ = self.replace_query::<VERSION>(&storage_key, value_bytes)?;
188
189 Ok(())
190 }
191 }
192
193 pub fn remove_item<const VERSION: u16>(
194 &self,
195 label: &[u8],
196 key: &[u8],
197 value: &[u8],
198 ) -> Result<(), <Self as StorageProvider<CURRENT_VERSION>>::Error> {
199 tracing::trace!("remove_item {}", String::from_utf8_lossy(label));
200
201 let storage_key = build_key_from_vec::<VERSION>(label, key.to_vec());
202 let data: Vec<StorageData> = self.select_query::<VERSION>(&storage_key)?;
203
204 if let Some(entry) = data.into_iter().next() {
205 let mut deserialized = bincode::deserialize::<Vec<Vec<u8>>>(&entry.value_bytes)
207 .map_err(|_| SqlKeyStoreError::SerializationError)?;
208 let vpos = deserialized.iter().position(|v| v == value);
209
210 if let Some(pos) = vpos {
211 deserialized.remove(pos);
212 }
213 let modified_data = bincode::serialize(&deserialized)
214 .map_err(|_| SqlKeyStoreError::SerializationError)?;
215
216 let _ = self.update_query::<VERSION>(&storage_key, &modified_data)?;
217 Ok(())
218 } else {
219 let value_bytes =
221 bincode::serialize(&[value]).map_err(|_| SqlKeyStoreError::SerializationError)?;
222 let _ = self.replace_query::<VERSION>(&storage_key, &value_bytes)?;
223 Ok(())
224 }
225 }
226
227 pub fn read<const VERSION: u16, V: Entity<VERSION>>(
228 &self,
229 label: &[u8],
230 key: &[u8],
231 ) -> Result<Option<V>, <Self as StorageProvider<CURRENT_VERSION>>::Error> {
232 let storage_key = build_key_from_vec::<VERSION>(label, key.to_vec());
233
234 let data = self.select_query::<VERSION>(&storage_key)?;
235
236 if let Some(entry) = data.into_iter().next() {
237 let deserialized = bincode::deserialize::<V>(&entry.value_bytes)
238 .map_err(|_| SqlKeyStoreError::SerializationError)?;
239
240 Ok(Some(deserialized))
241 } else {
242 Ok(None)
243 }
244 }
245
246 pub fn read_list<const VERSION: u16, V: Entity<VERSION>>(
247 &self,
248 label: &[u8],
249 key: &[u8],
250 ) -> Result<Vec<V>, <Self as StorageProvider<CURRENT_VERSION>>::Error> {
251 let storage_key = build_key_from_vec::<VERSION>(label, key.to_vec());
252 let results = self.select_query::<VERSION>(&storage_key)?;
253
254 if let Some(entry) = results.into_iter().next() {
255 let list = bincode::deserialize::<Vec<Vec<u8>>>(&entry.value_bytes)?;
256
257 let mut deserialized_list = Vec::new();
259 for v in list {
260 match bincode::deserialize::<V>(&v) {
261 Ok(deserialized_value) => deserialized_list.push(deserialized_value),
262 Err(e) => {
263 tracing::error!("Error occurred: {}", e);
264 return Err(SqlKeyStoreError::SerializationError);
265 }
266 }
267 }
268 Ok(deserialized_list)
269 } else {
270 Ok(vec![])
271 }
272 }
273
274 pub fn delete<const VERSION: u16>(
275 &self,
276 label: &[u8],
277 key: &[u8],
278 ) -> Result<(), <Self as StorageProvider<CURRENT_VERSION>>::Error> {
279 let storage_key = build_key_from_vec::<VERSION>(label, key.to_vec());
280 self.conn.raw_query_write(|conn| {
281 sql_query(DELETE_QUERY)
282 .bind::<diesel::sql_types::Binary, _>(&storage_key)
283 .bind::<diesel::sql_types::Integer, _>(VERSION as i32)
284 .execute(conn)
285 })?;
286 Ok(())
287 }
288}
289
290#[derive(thiserror::Error, Debug)]
293pub enum SqlKeyStoreError {
294 #[error("The key store does not allow storing serialized values.")]
295 UnsupportedValueTypeBytes,
296 #[error("Updating is not supported by this key store.")]
297 UnsupportedMethod,
298 #[error("Error serializing value.")]
299 SerializationError,
300 #[error("Value does not exist.")]
301 NotFound,
302 #[error("database error: {0}")]
303 Storage(#[from] diesel::result::Error),
304 #[error("connection {0}")]
305 Connection(#[from] crate::ConnectionError),
306}
307
308impl RetryableError for SqlKeyStoreError {
309 fn is_retryable(&self) -> bool {
310 use SqlKeyStoreError::*;
311 match self {
312 Storage(err) => retryable!(err),
313 SerializationError => false,
314 UnsupportedMethod => false,
315 UnsupportedValueTypeBytes => false,
316 NotFound => false,
317 Connection(c) => retryable!(c),
318 }
319 }
320}
321
322const KEY_PACKAGE_LABEL: &[u8] = b"KeyPackage";
323const ENCRYPTION_KEY_PAIR_LABEL: &[u8] = b"EncryptionKeyPair";
324const SIGNATURE_KEY_PAIR_LABEL: &[u8] = b"SignatureKeyPair";
325const EPOCH_KEY_PAIRS_LABEL: &[u8] = b"EpochKeyPairs";
326pub const KEY_PACKAGE_REFERENCES: &[u8] = b"KeyPackageReferences";
327pub const KEY_PACKAGE_WRAPPER_PRIVATE_KEY: &[u8] = b"KeyPackageWrapperPrivateKey";
328pub const COMMIT_LOG_SIGNER_PRIVATE_KEY: &[u8] = b"CommitLogSignerPrivateKey";
329
330const TREE_LABEL: &[u8] = b"Tree";
332const GROUP_CONTEXT_LABEL: &[u8] = b"GroupContext";
333const INTERIM_TRANSCRIPT_HASH_LABEL: &[u8] = b"InterimTranscriptHash";
334const CONFIRMATION_TAG_LABEL: &[u8] = b"ConfirmationTag";
335
336const OWN_LEAF_NODE_INDEX_LABEL: &[u8] = b"OwnLeafNodeIndex";
338const EPOCH_SECRETS_LABEL: &[u8] = b"EpochSecrets";
339const MESSAGE_SECRETS_LABEL: &[u8] = b"MessageSecrets";
340
341const JOIN_CONFIG_LABEL: &[u8] = b"MlsGroupJoinConfig";
343const OWN_LEAF_NODES_LABEL: &[u8] = b"OwnLeafNodes";
344const GROUP_STATE_LABEL: &[u8] = b"GroupState";
345const QUEUED_PROPOSAL_LABEL: &[u8] = b"QueuedProposal";
346const PROPOSAL_QUEUE_REFS_LABEL: &[u8] = b"ProposalQueueRefs";
347const RESUMPTION_PSK_STORE_LABEL: &[u8] = b"ResumptionPskStore";
348
349impl<C> StorageProvider<CURRENT_VERSION> for SqlKeyStore<C>
350where
351 C: ConnectionExt,
352{
353 type Error = SqlKeyStoreError;
354
355 fn queue_proposal<
356 GroupId: traits::GroupId<CURRENT_VERSION>,
357 ProposalRef: traits::ProposalRef<CURRENT_VERSION>,
358 QueuedProposal: traits::QueuedProposal<CURRENT_VERSION>,
359 >(
360 &self,
361 group_id: &GroupId,
362 proposal_ref: &ProposalRef,
363 proposal: &QueuedProposal,
364 ) -> Result<(), Self::Error> {
365 let key = bincode::serialize(&(group_id, proposal_ref))?;
367 let value = bincode::serialize(proposal)?;
368 self.write::<CURRENT_VERSION>(QUEUED_PROPOSAL_LABEL, &key, &value)?;
369
370 let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id)?;
372 let value = bincode::serialize(proposal_ref)?;
373 self.append::<CURRENT_VERSION>(PROPOSAL_QUEUE_REFS_LABEL, &key, &value)?;
374
375 Ok(())
376 }
377
378 fn write_tree<
379 GroupId: traits::GroupId<CURRENT_VERSION>,
380 TreeSync: traits::TreeSync<CURRENT_VERSION>,
381 >(
382 &self,
383 group_id: &GroupId,
384 tree: &TreeSync,
385 ) -> Result<(), Self::Error> {
386 let key = build_key::<CURRENT_VERSION, &GroupId>(TREE_LABEL, group_id)?;
387 let value = bincode::serialize(&tree)?;
388 self.write::<CURRENT_VERSION>(TREE_LABEL, &key, &value)
389 }
390
391 fn write_interim_transcript_hash<
392 GroupId: traits::GroupId<CURRENT_VERSION>,
393 InterimTranscriptHash: traits::InterimTranscriptHash<CURRENT_VERSION>,
394 >(
395 &self,
396 group_id: &GroupId,
397 interim_transcript_hash: &InterimTranscriptHash,
398 ) -> Result<(), Self::Error> {
399 let key = build_key::<CURRENT_VERSION, &GroupId>(INTERIM_TRANSCRIPT_HASH_LABEL, group_id)?;
400 let value = bincode::serialize(&interim_transcript_hash)?;
401 let _ = self.write::<CURRENT_VERSION>(INTERIM_TRANSCRIPT_HASH_LABEL, &key, &value);
402
403 Ok(())
404 }
405
406 fn write_context<
407 GroupId: traits::GroupId<CURRENT_VERSION>,
408 GroupContext: traits::GroupContext<CURRENT_VERSION>,
409 >(
410 &self,
411 group_id: &GroupId,
412 group_context: &GroupContext,
413 ) -> Result<(), Self::Error> {
414 let key = build_key::<CURRENT_VERSION, &GroupId>(GROUP_CONTEXT_LABEL, group_id)?;
415 let value = bincode::serialize(&group_context)?;
416
417 self.write::<CURRENT_VERSION>(GROUP_CONTEXT_LABEL, &key, &value)
418 }
419
420 fn write_confirmation_tag<
421 GroupId: traits::GroupId<CURRENT_VERSION>,
422 ConfirmationTag: traits::ConfirmationTag<CURRENT_VERSION>,
423 >(
424 &self,
425 group_id: &GroupId,
426 confirmation_tag: &ConfirmationTag,
427 ) -> Result<(), Self::Error> {
428 let key = build_key::<CURRENT_VERSION, &GroupId>(CONFIRMATION_TAG_LABEL, group_id)?;
429 let value = bincode::serialize(&confirmation_tag)?;
430
431 self.write::<CURRENT_VERSION>(CONFIRMATION_TAG_LABEL, &key, &value)
432 }
433
434 fn write_signature_key_pair<
435 SignaturePublicKey: traits::SignaturePublicKey<CURRENT_VERSION>,
436 SignatureKeyPair: traits::SignatureKeyPair<CURRENT_VERSION>,
437 >(
438 &self,
439 public_key: &SignaturePublicKey,
440 signature_key_pair: &SignatureKeyPair,
441 ) -> Result<(), Self::Error> {
442 let key = build_key::<CURRENT_VERSION, &SignaturePublicKey>(
443 SIGNATURE_KEY_PAIR_LABEL,
444 public_key,
445 )?;
446 let value = bincode::serialize(&signature_key_pair)?;
447
448 self.write::<CURRENT_VERSION>(SIGNATURE_KEY_PAIR_LABEL, &key, &value)
449 }
450
451 fn queued_proposal_refs<
452 GroupId: traits::GroupId<CURRENT_VERSION>,
453 ProposalRef: traits::ProposalRef<CURRENT_VERSION>,
454 >(
455 &self,
456 group_id: &GroupId,
457 ) -> Result<Vec<ProposalRef>, Self::Error> {
458 let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id)?;
459 self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &key)
460 }
461
462 fn queued_proposals<
463 GroupId: traits::GroupId<CURRENT_VERSION>,
464 ProposalRef: traits::ProposalRef<CURRENT_VERSION>,
465 QueuedProposal: traits::QueuedProposal<CURRENT_VERSION>,
466 >(
467 &self,
468 group_id: &GroupId,
469 ) -> Result<Vec<(ProposalRef, QueuedProposal)>, Self::Error> {
470 let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id)?;
471 let refs: Vec<ProposalRef> = self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &key)?;
472
473 refs.into_iter()
474 .map(|proposal_ref| -> Result<_, _> {
475 let key = bincode::serialize(&(group_id, &proposal_ref))?;
476 match self.read(QUEUED_PROPOSAL_LABEL, &key)? {
477 Some(proposal) => Ok((proposal_ref, proposal)),
478 None => Err(SqlKeyStoreError::NotFound),
479 }
480 })
481 .collect::<Result<Vec<_>, _>>()
482 }
483
484 fn tree<
485 GroupId: traits::GroupId<CURRENT_VERSION>,
486 TreeSync: traits::TreeSync<CURRENT_VERSION>,
487 >(
488 &self,
489 group_id: &GroupId,
490 ) -> Result<Option<TreeSync>, Self::Error> {
491 let key = build_key::<CURRENT_VERSION, &GroupId>(TREE_LABEL, group_id)?;
492
493 self.read(TREE_LABEL, &key)
494 }
495
496 fn group_context<
497 GroupId: traits::GroupId<CURRENT_VERSION>,
498 GroupContext: traits::GroupContext<CURRENT_VERSION>,
499 >(
500 &self,
501 group_id: &GroupId,
502 ) -> Result<Option<GroupContext>, Self::Error> {
503 let key = build_key::<CURRENT_VERSION, &GroupId>(GROUP_CONTEXT_LABEL, group_id)?;
504
505 self.read(GROUP_CONTEXT_LABEL, &key)
506 }
507
508 fn interim_transcript_hash<
509 GroupId: traits::GroupId<CURRENT_VERSION>,
510 InterimTranscriptHash: traits::InterimTranscriptHash<CURRENT_VERSION>,
511 >(
512 &self,
513 group_id: &GroupId,
514 ) -> Result<Option<InterimTranscriptHash>, Self::Error> {
515 let key = build_key::<CURRENT_VERSION, &GroupId>(INTERIM_TRANSCRIPT_HASH_LABEL, group_id)?;
516
517 self.read(INTERIM_TRANSCRIPT_HASH_LABEL, &key)
518 }
519
520 fn confirmation_tag<
521 GroupId: traits::GroupId<CURRENT_VERSION>,
522 ConfirmationTag: traits::ConfirmationTag<CURRENT_VERSION>,
523 >(
524 &self,
525 group_id: &GroupId,
526 ) -> Result<Option<ConfirmationTag>, Self::Error> {
527 let key = build_key::<CURRENT_VERSION, &GroupId>(CONFIRMATION_TAG_LABEL, group_id)?;
528
529 self.read(CONFIRMATION_TAG_LABEL, &key)
530 }
531
532 fn signature_key_pair<
533 SignaturePublicKey: traits::SignaturePublicKey<CURRENT_VERSION>,
534 SignatureKeyPair: traits::SignatureKeyPair<CURRENT_VERSION>,
535 >(
536 &self,
537 public_key: &SignaturePublicKey,
538 ) -> Result<Option<SignatureKeyPair>, Self::Error> {
539 let key = build_key::<CURRENT_VERSION, &SignaturePublicKey>(
540 SIGNATURE_KEY_PAIR_LABEL,
541 public_key,
542 )?;
543
544 self.read(SIGNATURE_KEY_PAIR_LABEL, &key)
545 }
546
547 fn write_key_package<
548 HashReference: traits::HashReference<CURRENT_VERSION>,
549 KeyPackage: traits::KeyPackage<CURRENT_VERSION>,
550 >(
551 &self,
552 hash_ref: &HashReference,
553 key_package: &KeyPackage,
554 ) -> Result<(), Self::Error> {
555 let key = build_key::<CURRENT_VERSION, &HashReference>(KEY_PACKAGE_LABEL, hash_ref)?;
556 let value = bincode::serialize(&key_package)?;
557
558 self.write::<CURRENT_VERSION>(KEY_PACKAGE_LABEL, &key, &value)
560 }
561
562 fn write_psk<
563 PskId: traits::PskId<CURRENT_VERSION>,
564 PskBundle: traits::PskBundle<CURRENT_VERSION>,
565 >(
566 &self,
567 _psk_id: &PskId,
568 _psk: &PskBundle,
569 ) -> Result<(), Self::Error> {
570 Ok(())
571 }
572
573 fn write_encryption_key_pair<
574 EncryptionKey: traits::EncryptionKey<CURRENT_VERSION>,
575 HpkeKeyPair: traits::HpkeKeyPair<CURRENT_VERSION>,
576 >(
577 &self,
578 public_key: &EncryptionKey,
579 key_pair: &HpkeKeyPair,
580 ) -> Result<(), Self::Error> {
581 let key =
582 build_key::<CURRENT_VERSION, &EncryptionKey>(ENCRYPTION_KEY_PAIR_LABEL, public_key)?;
583
584 self.write::<CURRENT_VERSION>(
585 ENCRYPTION_KEY_PAIR_LABEL,
586 &key,
587 &bincode::serialize(key_pair)?,
588 )
589 }
590
591 fn key_package<
592 HashReference: traits::HashReference<CURRENT_VERSION>,
593 KeyPackage: traits::KeyPackage<CURRENT_VERSION>,
594 >(
595 &self,
596 hash_ref: &HashReference,
597 ) -> Result<Option<KeyPackage>, Self::Error> {
598 let key = build_key::<CURRENT_VERSION, &HashReference>(KEY_PACKAGE_LABEL, hash_ref)?;
599
600 self.read(KEY_PACKAGE_LABEL, &key)
601 }
602
603 fn psk<PskBundle: traits::PskBundle<CURRENT_VERSION>, PskId: traits::PskId<CURRENT_VERSION>>(
604 &self,
605 _psk_id: &PskId,
606 ) -> Result<Option<PskBundle>, Self::Error> {
607 Ok(None)
608 }
609
610 fn encryption_key_pair<
611 HpkeKeyPair: traits::HpkeKeyPair<CURRENT_VERSION>,
612 EncryptionKey: traits::EncryptionKey<CURRENT_VERSION>,
613 >(
614 &self,
615 public_key: &EncryptionKey,
616 ) -> Result<Option<HpkeKeyPair>, Self::Error> {
617 let key =
618 build_key::<CURRENT_VERSION, &EncryptionKey>(ENCRYPTION_KEY_PAIR_LABEL, public_key)?;
619
620 self.read(ENCRYPTION_KEY_PAIR_LABEL, &key)
621 }
622
623 fn delete_signature_key_pair<
624 SignaturePublicKey: traits::SignaturePublicKey<CURRENT_VERSION>,
625 >(
626 &self,
627 public_key: &SignaturePublicKey,
628 ) -> Result<(), Self::Error> {
629 let key = build_key::<CURRENT_VERSION, &SignaturePublicKey>(
630 SIGNATURE_KEY_PAIR_LABEL,
631 public_key,
632 )?;
633
634 self.delete::<CURRENT_VERSION>(SIGNATURE_KEY_PAIR_LABEL, &key)
635 }
636
637 fn delete_encryption_key_pair<EncryptionKey: traits::EncryptionKey<CURRENT_VERSION>>(
638 &self,
639 public_key: &EncryptionKey,
640 ) -> Result<(), Self::Error> {
641 let key =
642 build_key::<CURRENT_VERSION, &EncryptionKey>(ENCRYPTION_KEY_PAIR_LABEL, public_key)?;
643
644 self.delete::<CURRENT_VERSION>(ENCRYPTION_KEY_PAIR_LABEL, &key)
645 }
646
647 fn delete_key_package<HashReference: traits::HashReference<CURRENT_VERSION>>(
648 &self,
649 hash_ref: &HashReference,
650 ) -> Result<(), Self::Error> {
651 let key = build_key::<CURRENT_VERSION, &HashReference>(KEY_PACKAGE_LABEL, hash_ref)?;
652 self.delete::<CURRENT_VERSION>(KEY_PACKAGE_LABEL, &key)
653 }
654
655 fn delete_psk<PskKey: traits::PskId<CURRENT_VERSION>>(
656 &self,
657 _psk_id: &PskKey,
658 ) -> Result<(), Self::Error> {
659 Err(SqlKeyStoreError::UnsupportedMethod)
660 }
661
662 fn group_state<
663 GroupState: traits::GroupState<CURRENT_VERSION>,
664 GroupId: traits::GroupId<CURRENT_VERSION>,
665 >(
666 &self,
667 group_id: &GroupId,
668 ) -> Result<Option<GroupState>, Self::Error> {
669 let key = build_key::<CURRENT_VERSION, &GroupId>(GROUP_STATE_LABEL, group_id)?;
670
671 self.read(GROUP_STATE_LABEL, &key)
672 }
673
674 fn write_group_state<
675 GroupState: traits::GroupState<CURRENT_VERSION>,
676 GroupId: traits::GroupId<CURRENT_VERSION>,
677 >(
678 &self,
679 group_id: &GroupId,
680 group_state: &GroupState,
681 ) -> Result<(), Self::Error> {
682 let key = build_key::<CURRENT_VERSION, &GroupId>(GROUP_STATE_LABEL, group_id)?;
683
684 self.write::<CURRENT_VERSION>(GROUP_STATE_LABEL, &key, &bincode::serialize(group_state)?)
685 }
686
687 fn delete_group_state<GroupId: traits::GroupId<CURRENT_VERSION>>(
688 &self,
689 group_id: &GroupId,
690 ) -> Result<(), Self::Error> {
691 let key = build_key::<CURRENT_VERSION, &GroupId>(GROUP_STATE_LABEL, group_id)?;
692
693 self.delete::<CURRENT_VERSION>(GROUP_STATE_LABEL, &key)
694 }
695
696 fn message_secrets<
697 GroupId: traits::GroupId<CURRENT_VERSION>,
698 MessageSecrets: traits::MessageSecrets<CURRENT_VERSION>,
699 >(
700 &self,
701 group_id: &GroupId,
702 ) -> Result<Option<MessageSecrets>, Self::Error> {
703 let key = build_key::<CURRENT_VERSION, &GroupId>(MESSAGE_SECRETS_LABEL, group_id)?;
704
705 self.read(MESSAGE_SECRETS_LABEL, &key)
706 }
707
708 fn write_message_secrets<
709 GroupId: traits::GroupId<CURRENT_VERSION>,
710 MessageSecrets: traits::MessageSecrets<CURRENT_VERSION>,
711 >(
712 &self,
713 group_id: &GroupId,
714 message_secrets: &MessageSecrets,
715 ) -> Result<(), Self::Error> {
716 let key = build_key::<CURRENT_VERSION, &GroupId>(MESSAGE_SECRETS_LABEL, group_id)?;
717
718 self.write::<CURRENT_VERSION>(
719 MESSAGE_SECRETS_LABEL,
720 &key,
721 &bincode::serialize(message_secrets)?,
722 )
723 }
724
725 fn delete_message_secrets<GroupId: traits::GroupId<CURRENT_VERSION>>(
726 &self,
727 group_id: &GroupId,
728 ) -> Result<(), Self::Error> {
729 let key = build_key::<CURRENT_VERSION, &GroupId>(MESSAGE_SECRETS_LABEL, group_id)?;
730
731 self.delete::<CURRENT_VERSION>(MESSAGE_SECRETS_LABEL, &key)
732 }
733
734 fn resumption_psk_store<
735 GroupId: traits::GroupId<CURRENT_VERSION>,
736 ResumptionPskStore: traits::ResumptionPskStore<CURRENT_VERSION>,
737 >(
738 &self,
739 group_id: &GroupId,
740 ) -> Result<Option<ResumptionPskStore>, Self::Error> {
741 self.read(RESUMPTION_PSK_STORE_LABEL, &bincode::serialize(group_id)?)
742 }
743
744 fn write_resumption_psk_store<
745 GroupId: traits::GroupId<CURRENT_VERSION>,
746 ResumptionPskStore: traits::ResumptionPskStore<CURRENT_VERSION>,
747 >(
748 &self,
749 group_id: &GroupId,
750 resumption_psk_store: &ResumptionPskStore,
751 ) -> Result<(), Self::Error> {
752 self.write::<CURRENT_VERSION>(
753 RESUMPTION_PSK_STORE_LABEL,
754 &bincode::serialize(group_id)?,
755 &bincode::serialize(resumption_psk_store)?,
756 )
757 }
758
759 fn delete_all_resumption_psk_secrets<GroupId: traits::GroupId<CURRENT_VERSION>>(
760 &self,
761 group_id: &GroupId,
762 ) -> Result<(), Self::Error> {
763 self.delete::<CURRENT_VERSION>(RESUMPTION_PSK_STORE_LABEL, &bincode::serialize(group_id)?)
764 }
765
766 fn own_leaf_index<
767 GroupId: traits::GroupId<CURRENT_VERSION>,
768 LeafNodeIndex: traits::LeafNodeIndex<CURRENT_VERSION>,
769 >(
770 &self,
771 group_id: &GroupId,
772 ) -> Result<Option<LeafNodeIndex>, Self::Error> {
773 let key = build_key::<CURRENT_VERSION, &GroupId>(OWN_LEAF_NODE_INDEX_LABEL, group_id)?;
774 self.read(OWN_LEAF_NODE_INDEX_LABEL, &key)
775 }
776
777 fn write_own_leaf_index<
778 GroupId: traits::GroupId<CURRENT_VERSION>,
779 LeafNodeIndex: traits::LeafNodeIndex<CURRENT_VERSION>,
780 >(
781 &self,
782 group_id: &GroupId,
783 own_leaf_index: &LeafNodeIndex,
784 ) -> Result<(), Self::Error> {
785 let key = build_key::<CURRENT_VERSION, &GroupId>(OWN_LEAF_NODE_INDEX_LABEL, group_id)?;
786 self.write::<CURRENT_VERSION>(
787 OWN_LEAF_NODE_INDEX_LABEL,
788 &key,
789 &bincode::serialize(own_leaf_index)?,
790 )
791 }
792
793 fn delete_own_leaf_index<GroupId: traits::GroupId<CURRENT_VERSION>>(
794 &self,
795 group_id: &GroupId,
796 ) -> Result<(), Self::Error> {
797 let key = build_key::<CURRENT_VERSION, &GroupId>(OWN_LEAF_NODE_INDEX_LABEL, group_id)?;
798 self.delete::<CURRENT_VERSION>(OWN_LEAF_NODE_INDEX_LABEL, &key)
799 }
800
801 fn group_epoch_secrets<
802 GroupId: traits::GroupId<CURRENT_VERSION>,
803 GroupEpochSecrets: traits::GroupEpochSecrets<CURRENT_VERSION>,
804 >(
805 &self,
806 group_id: &GroupId,
807 ) -> Result<Option<GroupEpochSecrets>, Self::Error> {
808 let key = build_key::<CURRENT_VERSION, &GroupId>(EPOCH_SECRETS_LABEL, group_id)?;
809 self.read(EPOCH_SECRETS_LABEL, &key)
810 }
811
812 fn write_group_epoch_secrets<
813 GroupId: traits::GroupId<CURRENT_VERSION>,
814 GroupEpochSecrets: traits::GroupEpochSecrets<CURRENT_VERSION>,
815 >(
816 &self,
817 group_id: &GroupId,
818 group_epoch_secrets: &GroupEpochSecrets,
819 ) -> Result<(), Self::Error> {
820 let key = build_key::<CURRENT_VERSION, &GroupId>(EPOCH_SECRETS_LABEL, group_id)?;
821 self.write::<CURRENT_VERSION>(
822 EPOCH_SECRETS_LABEL,
823 &key,
824 &bincode::serialize(group_epoch_secrets)?,
825 )
826 }
827
828 fn delete_group_epoch_secrets<GroupId: traits::GroupId<CURRENT_VERSION>>(
829 &self,
830 group_id: &GroupId,
831 ) -> Result<(), Self::Error> {
832 let key = build_key::<CURRENT_VERSION, &GroupId>(EPOCH_SECRETS_LABEL, group_id)?;
833 self.delete::<CURRENT_VERSION>(EPOCH_SECRETS_LABEL, &key)
834 }
835
836 fn write_encryption_epoch_key_pairs<
837 GroupId: traits::GroupId<CURRENT_VERSION>,
838 EpochKey: traits::EpochKey<CURRENT_VERSION>,
839 HpkeKeyPair: traits::HpkeKeyPair<CURRENT_VERSION>,
840 >(
841 &self,
842 group_id: &GroupId,
843 epoch: &EpochKey,
844 leaf_index: u32,
845 key_pairs: &[HpkeKeyPair],
846 ) -> Result<(), Self::Error> {
847 let key = epoch_key_pairs_id(group_id, epoch, leaf_index)?;
848 let value = bincode::serialize(key_pairs)?;
849 tracing::trace!("Writing encryption epoch key pairs");
850
851 self.write::<CURRENT_VERSION>(EPOCH_KEY_PAIRS_LABEL, &key, &value)
852 }
853
854 fn encryption_epoch_key_pairs<
855 GroupId: traits::GroupId<CURRENT_VERSION>,
856 EpochKey: traits::EpochKey<CURRENT_VERSION>,
857 HpkeKeyPair: traits::HpkeKeyPair<CURRENT_VERSION>,
858 >(
859 &self,
860 group_id: &GroupId,
861 epoch: &EpochKey,
862 leaf_index: u32,
863 ) -> Result<Vec<HpkeKeyPair>, Self::Error> {
864 tracing::trace!("Reading encryption epoch key pairs");
865
866 let key = epoch_key_pairs_id(group_id, epoch, leaf_index)?;
867 let storage_key = build_key_from_vec::<CURRENT_VERSION>(EPOCH_KEY_PAIRS_LABEL, key);
868 tracing::trace!(" key: {}", hex::encode(&storage_key));
869
870 let query = "SELECT value_bytes FROM openmls_key_value WHERE key_bytes = ? AND version = ?";
871
872 let data: Vec<StorageData> = self.conn.raw_query_read(|conn| {
873 sql_query(query)
874 .bind::<diesel::sql_types::Binary, _>(&storage_key)
875 .bind::<diesel::sql_types::Integer, _>(CURRENT_VERSION as i32)
876 .load(conn)
877 })?;
878
879 if let Some(entry) = data.into_iter().next() {
880 match bincode::deserialize::<Vec<HpkeKeyPair>>(&entry.value_bytes) {
881 Ok(deserialized) => Ok(deserialized),
882 Err(e) => {
883 eprintln!("Error occurred: {}", e);
884 Err(SqlKeyStoreError::SerializationError)
885 }
886 }
887 } else {
888 Ok(vec![])
889 }
890 }
891
892 fn delete_encryption_epoch_key_pairs<
893 GroupId: traits::GroupId<CURRENT_VERSION>,
894 EpochKey: traits::EpochKey<CURRENT_VERSION>,
895 >(
896 &self,
897 group_id: &GroupId,
898 epoch: &EpochKey,
899 leaf_index: u32,
900 ) -> Result<(), Self::Error> {
901 let key = epoch_key_pairs_id(group_id, epoch, leaf_index)?;
902
903 self.delete::<CURRENT_VERSION>(EPOCH_KEY_PAIRS_LABEL, &key)
904 }
905
906 fn clear_proposal_queue<
907 GroupId: traits::GroupId<CURRENT_VERSION>,
908 ProposalRef: traits::ProposalRef<CURRENT_VERSION>,
909 >(
910 &self,
911 group_id: &GroupId,
912 ) -> Result<(), Self::Error> {
913 let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id)?;
914 let proposal_refs: Vec<ProposalRef> = self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &key)?;
915
916 for proposal_ref in proposal_refs {
917 let key = bincode::serialize(&(group_id, proposal_ref))?;
918 self.delete::<CURRENT_VERSION>(QUEUED_PROPOSAL_LABEL, &key)?;
919 }
920
921 let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id)?;
922
923 self.delete::<CURRENT_VERSION>(PROPOSAL_QUEUE_REFS_LABEL, &key)
924 }
925
926 fn mls_group_join_config<
927 GroupId: traits::GroupId<CURRENT_VERSION>,
928 MlsGroupJoinConfig: traits::MlsGroupJoinConfig<CURRENT_VERSION>,
929 >(
930 &self,
931 group_id: &GroupId,
932 ) -> Result<Option<MlsGroupJoinConfig>, Self::Error> {
933 let key = build_key::<CURRENT_VERSION, &GroupId>(JOIN_CONFIG_LABEL, group_id)?;
934
935 self.read(JOIN_CONFIG_LABEL, &key)
936 }
937
938 fn write_mls_join_config<
939 GroupId: traits::GroupId<CURRENT_VERSION>,
940 MlsGroupJoinConfig: traits::MlsGroupJoinConfig<CURRENT_VERSION>,
941 >(
942 &self,
943 group_id: &GroupId,
944 config: &MlsGroupJoinConfig,
945 ) -> Result<(), Self::Error> {
946 let key = build_key::<CURRENT_VERSION, &GroupId>(JOIN_CONFIG_LABEL, group_id)?;
947 let value = bincode::serialize(config)?;
948
949 self.write::<CURRENT_VERSION>(JOIN_CONFIG_LABEL, &key, &value)
950 }
951
952 fn own_leaf_nodes<
953 GroupId: traits::GroupId<CURRENT_VERSION>,
954 LeafNode: traits::LeafNode<CURRENT_VERSION>,
955 >(
956 &self,
957 group_id: &GroupId,
958 ) -> Result<Vec<LeafNode>, Self::Error> {
959 tracing::trace!("own_leaf_nodes");
960 let key = build_key::<CURRENT_VERSION, &GroupId>(OWN_LEAF_NODES_LABEL, group_id)?;
961
962 self.read_list(OWN_LEAF_NODES_LABEL, &key)
963 }
964
965 fn append_own_leaf_node<
966 GroupId: traits::GroupId<CURRENT_VERSION>,
967 LeafNode: traits::LeafNode<CURRENT_VERSION>,
968 >(
969 &self,
970 group_id: &GroupId,
971 leaf_node: &LeafNode,
972 ) -> Result<(), Self::Error> {
973 let key = build_key::<CURRENT_VERSION, &GroupId>(OWN_LEAF_NODES_LABEL, group_id)?;
974 let value = bincode::serialize(leaf_node)?;
975
976 self.append::<CURRENT_VERSION>(OWN_LEAF_NODES_LABEL, &key, &value)
977 }
978
979 fn delete_own_leaf_nodes<GroupId: traits::GroupId<CURRENT_VERSION>>(
980 &self,
981 group_id: &GroupId,
982 ) -> Result<(), Self::Error> {
983 let key = build_key::<CURRENT_VERSION, &GroupId>(OWN_LEAF_NODES_LABEL, group_id)?;
984 self.delete::<CURRENT_VERSION>(OWN_LEAF_NODES_LABEL, &key)
985 }
986
987 fn delete_group_config<GroupId: traits::GroupId<CURRENT_VERSION>>(
988 &self,
989 group_id: &GroupId,
990 ) -> Result<(), Self::Error> {
991 let key = build_key::<CURRENT_VERSION, &GroupId>(JOIN_CONFIG_LABEL, group_id)?;
992 self.delete::<CURRENT_VERSION>(JOIN_CONFIG_LABEL, &key)
993 }
994
995 fn delete_tree<GroupId: traits::GroupId<CURRENT_VERSION>>(
996 &self,
997 group_id: &GroupId,
998 ) -> Result<(), Self::Error> {
999 let key = build_key::<CURRENT_VERSION, &GroupId>(TREE_LABEL, group_id)?;
1000
1001 self.delete::<CURRENT_VERSION>(TREE_LABEL, &key)
1002 }
1003
1004 fn delete_confirmation_tag<GroupId: traits::GroupId<CURRENT_VERSION>>(
1005 &self,
1006 group_id: &GroupId,
1007 ) -> Result<(), Self::Error> {
1008 let key = build_key::<CURRENT_VERSION, &GroupId>(CONFIRMATION_TAG_LABEL, group_id)?;
1009
1010 self.delete::<CURRENT_VERSION>(CONFIRMATION_TAG_LABEL, &key)
1011 }
1012
1013 fn delete_context<GroupId: traits::GroupId<CURRENT_VERSION>>(
1014 &self,
1015 group_id: &GroupId,
1016 ) -> Result<(), Self::Error> {
1017 let key = build_key::<CURRENT_VERSION, &GroupId>(GROUP_CONTEXT_LABEL, group_id)?;
1018
1019 self.delete::<CURRENT_VERSION>(GROUP_CONTEXT_LABEL, &key)
1020 }
1021
1022 fn delete_interim_transcript_hash<GroupId: traits::GroupId<CURRENT_VERSION>>(
1023 &self,
1024 group_id: &GroupId,
1025 ) -> Result<(), Self::Error> {
1026 let key = build_key::<CURRENT_VERSION, &GroupId>(INTERIM_TRANSCRIPT_HASH_LABEL, group_id)?;
1027
1028 self.delete::<CURRENT_VERSION>(INTERIM_TRANSCRIPT_HASH_LABEL, &key)
1029 }
1030
1031 fn remove_proposal<
1032 GroupId: traits::GroupId<CURRENT_VERSION>,
1033 ProposalRef: traits::ProposalRef<CURRENT_VERSION>,
1034 >(
1035 &self,
1036 group_id: &GroupId,
1037 proposal_ref: &ProposalRef,
1038 ) -> Result<(), Self::Error> {
1039 let key = build_key::<CURRENT_VERSION, &GroupId>(PROPOSAL_QUEUE_REFS_LABEL, group_id)?;
1041 let value = bincode::serialize(proposal_ref)?;
1042 self.remove_item::<CURRENT_VERSION>(PROPOSAL_QUEUE_REFS_LABEL, &key, &value)?;
1043
1044 let key = bincode::serialize(&(group_id, proposal_ref))?;
1046 self.delete::<CURRENT_VERSION>(QUEUED_PROPOSAL_LABEL, &key)
1047 }
1048}
1049
1050fn build_key_from_vec<const V: u16>(label: &[u8], key: Vec<u8>) -> Vec<u8> {
1052 let mut key_out = label.to_vec();
1053 key_out.extend_from_slice(&key);
1054 key_out.extend_from_slice(&u16::to_be_bytes(V));
1055 key_out
1056}
1057
1058fn build_key<const V: u16, K: Serialize>(
1060 label: &[u8],
1061 key: K,
1062) -> Result<Vec<u8>, SqlKeyStoreError> {
1063 let key_vec = bincode::serialize(&key)?;
1064 Ok(build_key_from_vec::<V>(label, key_vec))
1065}
1066
1067fn epoch_key_pairs_id(
1068 group_id: &impl traits::GroupId<CURRENT_VERSION>,
1069 epoch: &impl traits::EpochKey<CURRENT_VERSION>,
1070 leaf_index: u32,
1071) -> Result<Vec<u8>, SqlKeyStoreError> {
1072 let mut key = bincode::serialize(group_id)?;
1073 key.extend_from_slice(&bincode::serialize(epoch)?);
1074 key.extend_from_slice(&bincode::serialize(&leaf_index)?);
1075 Ok(key)
1076}
1077
1078impl From<bincode::Error> for SqlKeyStoreError {
1079 fn from(_: bincode::Error) -> Self {
1080 Self::SerializationError
1081 }
1082}
1083
1084#[cfg(any(test, feature = "test-utils"))]
1085impl SqlKeyStore<crate::test_utils::MemoryStorage> {
1086 pub fn kv_pairs(&self) -> String {
1087 self.conn.key_value_pairs()
1088 }
1089
1090 pub fn kv_pairs_utf8(&self) -> String {
1091 self.conn.key_value_pairs_utf8()
1092 }
1093}
1094
1095#[cfg(test)]
1096pub(crate) mod tests {
1097 #[cfg(target_arch = "wasm32")]
1098 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
1099
1100 use openmls::group::GroupId;
1101 use openmls_basic_credential::{SignatureKeyPair, StorageId};
1102 use openmls_traits::{
1103 OpenMlsProvider,
1104 storage::{
1105 CURRENT_VERSION, Entity, Key, StorageProvider,
1106 traits::{self},
1107 },
1108 };
1109 use serde::{Deserialize, Serialize};
1110
1111 use super::SqlKeyStore;
1112 use crate::encrypted_store::MlsProviderExt;
1113 use crate::{
1114 XmtpTestDb, sql_key_store::SqlKeyStoreError, xmtp_openmls_provider::XmtpOpenMlsProvider,
1115 };
1116 use xmtp_cryptography::configuration::CIPHERSUITE;
1117
1118 #[xmtp_common::test]
1119 async fn store_read_delete() {
1120 let store = crate::TestDb::create_persistent_store(None).await;
1121 let conn = store.conn();
1122 let key_store = SqlKeyStore::new(conn);
1123
1124 let signature_keys = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm()).unwrap();
1125 let public_key = StorageId::from(signature_keys.to_public_vec());
1126 assert!(
1127 key_store
1128 .signature_key_pair::<StorageId, SignatureKeyPair>(&public_key)
1129 .unwrap()
1130 .is_none()
1131 );
1132
1133 key_store
1134 .write_signature_key_pair::<StorageId, SignatureKeyPair>(&public_key, &signature_keys)
1135 .unwrap();
1136
1137 assert!(
1138 key_store
1139 .signature_key_pair::<StorageId, SignatureKeyPair>(&public_key)
1140 .unwrap()
1141 .is_some()
1142 );
1143
1144 key_store
1145 .delete_signature_key_pair::<StorageId>(&public_key)
1146 .unwrap();
1147
1148 assert!(
1149 key_store
1150 .signature_key_pair::<StorageId, SignatureKeyPair>(&public_key)
1151 .unwrap()
1152 .is_none()
1153 );
1154 }
1155
1156 #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
1157 struct Proposal(Vec<u8>);
1158 impl traits::QueuedProposal<CURRENT_VERSION> for Proposal {}
1159 impl Entity<CURRENT_VERSION> for Proposal {}
1160
1161 #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)]
1162 struct ProposalRef(usize);
1163 impl traits::ProposalRef<CURRENT_VERSION> for ProposalRef {}
1164 impl Key<CURRENT_VERSION> for ProposalRef {}
1165 impl Entity<CURRENT_VERSION> for ProposalRef {}
1166
1167 #[xmtp_common::test(unwrap_try = true)]
1168 async fn test_read_write() {
1169 let store = crate::TestDb::create_persistent_store(None).await;
1170 let conn = store.conn();
1171 let mls_store = SqlKeyStore::new(conn);
1172 let provider = XmtpOpenMlsProvider::new(mls_store);
1173 let key_store = provider.key_store();
1174
1175 let raw_value = vec![3u8; 32];
1176 let group_1 = bincode::serialize(&[1u8; 32])?;
1177 let group_2 = bincode::serialize(&[2u8; 32])?;
1178 let value_1 = bincode::serialize(&raw_value)?;
1179
1180 key_store.write::<CURRENT_VERSION>(
1181 crate::sql_key_store::COMMIT_LOG_SIGNER_PRIVATE_KEY,
1182 &group_1,
1183 &value_1,
1184 )?;
1185
1186 let result = key_store.read::<CURRENT_VERSION, Vec<u8>>(
1188 crate::sql_key_store::COMMIT_LOG_SIGNER_PRIVATE_KEY,
1189 &group_2,
1190 );
1191 assert!(result.is_ok(), "{}", result.err().unwrap());
1192 assert!(result.unwrap().is_none());
1193
1194 let result = key_store.read::<CURRENT_VERSION, Vec<u8>>(
1195 crate::sql_key_store::COMMIT_LOG_SIGNER_PRIVATE_KEY,
1196 &group_1,
1197 );
1198 assert!(result.is_ok(), "{}", result.err().unwrap());
1199 assert_eq!(result.unwrap(), Some(raw_value));
1200 }
1201
1202 #[xmtp_common::test]
1203 async fn list_append_remove() {
1204 let store = crate::TestDb::create_persistent_store(None).await;
1205 let conn = store.conn();
1206 let mls_store = SqlKeyStore::new(conn);
1207 let provider = XmtpOpenMlsProvider::new(mls_store);
1208 let group_id = GroupId::random(provider.rand());
1209 let proposals = (0..10)
1210 .map(|i| Proposal(format!("TestProposal{i}").as_bytes().to_vec()))
1211 .collect::<Vec<_>>();
1212
1213 for (i, proposal) in proposals.iter().enumerate() {
1215 provider
1216 .storage()
1217 .queue_proposal::<GroupId, ProposalRef, Proposal>(
1218 &group_id,
1219 &ProposalRef(i),
1220 proposal,
1221 )
1222 .expect("Failed to queue proposal");
1223 }
1224
1225 tracing::trace!("Finished with queued proposals");
1226 let proposal_refs_read: Vec<ProposalRef> = provider
1228 .storage()
1229 .queued_proposal_refs(&group_id)
1230 .expect("Failed to read proposal refs");
1231 assert_eq!(
1232 (0..10).map(ProposalRef).collect::<Vec<_>>(),
1233 proposal_refs_read
1234 );
1235
1236 let proposals_read: Vec<(ProposalRef, Proposal)> =
1238 provider.storage().queued_proposals(&group_id).unwrap();
1239 let proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10)
1240 .map(ProposalRef)
1241 .zip(proposals.clone().into_iter())
1242 .collect();
1243 assert_eq!(proposals_expected, proposals_read);
1244
1245 provider
1247 .storage()
1248 .remove_proposal(&group_id, &ProposalRef(5))
1249 .unwrap();
1250
1251 let proposal_refs_read: Vec<ProposalRef> =
1252 provider.storage().queued_proposal_refs(&group_id).unwrap();
1253 let mut expected = (0..10).map(ProposalRef).collect::<Vec<_>>();
1254 expected.remove(5);
1255 assert_eq!(expected, proposal_refs_read);
1256
1257 let proposals_read: Vec<(ProposalRef, Proposal)> =
1258 provider.storage().queued_proposals(&group_id).unwrap();
1259 let mut proposals_expected: Vec<(ProposalRef, Proposal)> = (0..10)
1260 .map(ProposalRef)
1261 .zip(proposals.clone().into_iter())
1262 .collect();
1263 proposals_expected.remove(5);
1264 assert_eq!(proposals_expected, proposals_read);
1265
1266 provider
1268 .storage()
1269 .clear_proposal_queue::<GroupId, ProposalRef>(&group_id)
1270 .unwrap();
1271 let proposal_refs_read: Result<Vec<ProposalRef>, SqlKeyStoreError> =
1272 provider.storage().queued_proposal_refs(&group_id);
1273 assert!(proposal_refs_read.unwrap().is_empty());
1274
1275 let proposals_read: Result<Vec<(ProposalRef, Proposal)>, SqlKeyStoreError> =
1276 provider.storage().queued_proposals(&group_id);
1277 assert!(proposals_read.unwrap().is_empty());
1278 }
1279
1280 #[xmtp_common::test]
1281 async fn group_state() {
1282 let store = crate::TestDb::create_persistent_store(None).await;
1283 let conn = store.conn();
1284 let store = SqlKeyStore::new(conn);
1285 let provider = XmtpOpenMlsProvider::new(store);
1286
1287 #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Copy)]
1288 struct GroupState(usize);
1289 impl traits::GroupState<CURRENT_VERSION> for GroupState {}
1290 impl Entity<CURRENT_VERSION> for GroupState {}
1291
1292 let group_id = GroupId::random(provider.rand());
1293
1294 provider
1296 .storage()
1297 .write_group_state(&group_id, &GroupState(77))
1298 .unwrap();
1299
1300 let group_state: Option<GroupState> = provider.storage().group_state(&group_id).unwrap();
1302 assert_eq!(GroupState(77), group_state.unwrap());
1303 }
1304}