xmtp_db/test_utils/
mls_memory_storage.rs

1use crate::schema::openmls_key_value::dsl;
2use crate::sql_key_store::SqlKeyStore;
3use crate::{ConnectionExt, MIGRATIONS};
4use diesel::prelude::*;
5use diesel::sqlite::SqliteConnection;
6use diesel_migrations::MigrationHarness;
7use parking_lot::Mutex;
8use std::fmt::Write;
9use std::sync::Arc;
10
11pub type MlsMemoryStorage = SqlKeyStore<MemoryStorage>;
12
13#[derive(Clone)]
14pub struct MemoryStorage {
15    inner: Arc<Mutex<SqliteConnection>>,
16}
17
18impl Default for MemoryStorage {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl MemoryStorage {
25    pub fn new() -> Self {
26        let mut conn = SqliteConnection::establish(":memory:").unwrap();
27        conn.run_pending_migrations(MIGRATIONS).unwrap();
28        Self {
29            inner: Arc::new(Mutex::new(conn)),
30        }
31    }
32
33    /// Print the key-value pairs in MLS memory as hex
34    pub fn key_value_pairs(&self) -> String {
35        let mut c = self.inner.lock();
36        let key_values = dsl::openmls_key_value
37            .select((dsl::key_bytes, dsl::value_bytes))
38            .load::<(Vec<u8>, Vec<u8>)>(&mut *c)
39            .unwrap();
40        let mut s = String::new();
41        s.push('\n');
42        for (key, value) in key_values.iter() {
43            write!(s, "{}:{}", hex::encode(key), hex::encode(value)).unwrap();
44            s.push('\n');
45        }
46        s
47    }
48
49    /// Print the key-value pairs in MLS memory as hex
50    pub fn key_value_pairs_utf8(&self) -> String {
51        let mut c = self.inner.lock();
52        let key_values = dsl::openmls_key_value
53            .select((dsl::key_bytes, dsl::value_bytes))
54            .load::<(Vec<u8>, Vec<u8>)>(&mut *c)
55            .unwrap();
56        let mut s = String::new();
57        s.push('\n');
58        for (key, value) in key_values.iter() {
59            write!(
60                s,
61                "{}:{}",
62                String::from_utf8_lossy(key),
63                String::from_utf8_lossy(value)
64            )
65            .unwrap();
66            s.push('\n');
67        }
68        s
69    }
70}
71
72impl ConnectionExt for MemoryStorage {
73    fn raw_query_read<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
74    where
75        F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
76        Self: Sized,
77    {
78        let mut c = self.inner.lock();
79        Ok(fun(&mut c)?)
80    }
81
82    fn raw_query_write<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
83    where
84        F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
85        Self: Sized,
86    {
87        let mut c = self.inner.lock();
88        Ok(fun(&mut c)?)
89    }
90
91    fn disconnect(&self) -> Result<(), crate::ConnectionError> {
92        unimplemented!()
93    }
94
95    fn reconnect(&self) -> Result<(), crate::ConnectionError> {
96        unimplemented!()
97    }
98}