xmtp_db/encrypted_store/database/
native.rs

1mod pool;
2mod sqlcipher_connection;
3
4use crate::StorageError;
5use crate::database::instrumentation::TestInstrumentation;
6/// Native SQLite connection using SqlCipher
7use crate::{ConnectionError, ConnectionExt, DbConnection, NotFound};
8use arc_swap::ArcSwapOption;
9use diesel::sqlite::SqliteConnection;
10use diesel::{
11    Connection,
12    connection::SimpleConnection,
13    r2d2::{self, CustomizeConnection, PooledConnection},
14};
15use parking_lot::Mutex;
16use std::sync::Arc;
17use thiserror::Error;
18use xmtp_common::{BoxDynError, RetryableError, retryable};
19use xmtp_configuration::BUSY_TIMEOUT;
20
21use pool::*;
22
23pub type RawDbConnection = PooledConnection<ConnectionManager>;
24
25pub use self::sqlcipher_connection::EncryptedConnection;
26use crate::{EncryptionKey, StorageOption, XmtpDb};
27
28use super::PersistentOrMem;
29
30trait XmtpConnection:
31    ValidatedConnection
32    + ConnectionOptions
33    + CustomizeConnection<SqliteConnection, r2d2::Error>
34    + dyn_clone::DynClone
35{
36}
37
38trait ConnectionOptions {
39    fn options(&self) -> &StorageOption;
40    fn is_persistent(&self) -> bool {
41        matches!(self.options(), StorageOption::Persistent(_))
42    }
43}
44
45impl<T> XmtpConnection for T where
46    T: ValidatedConnection
47        + CustomizeConnection<SqliteConnection, r2d2::Error>
48        + ConnectionOptions
49        + dyn_clone::DynClone
50{
51}
52
53dyn_clone::clone_trait_object!(XmtpConnection);
54
55pub(crate) trait ValidatedConnection {
56    fn validate(&self, _conn: &mut SqliteConnection) -> Result<(), PlatformStorageError> {
57        Ok(())
58    }
59}
60
61/// Pragmas to execute on acquiring a new SQLite connection
62/// According to [pragmas](https://docs.rs/diesel/latest/diesel/prelude/struct.SqliteConnection.html#concurrency)
63/// for concurrency
64/// these pragmas only required to be ran once per session.
65fn connection_pragmas(c: &mut impl SimpleConnection) -> diesel::result::QueryResult<()> {
66    // pragmas must be in a separate call to ensure they apply correctly
67    // _NOTE:_ order is important to ensure later pragmas do not timeout
68    c.batch_execute(&format!("PRAGMA busy_timeout = {};", BUSY_TIMEOUT))?; // sleep for 5s if the database is busy
69    c.batch_execute("PRAGMA synchronous = NORMAL;")?; // fsync only in critical moments
70    c.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?; // write WAL changes back every 1000 pages, for an in average 1MB WAL file. May affect readers if number is increased
71    c.batch_execute("PRAGMA wal_checkpoint(TRUNCATE);")?; // free some space by truncating possibly massive WAL files from the last run.
72    c.batch_execute("PRAGMA query_only = OFF;")?; // Enable writing with the connection
73    c.batch_execute("PRAGMA journal_size_limit = 67108864")?; // maximum size of the WAL file, corresponds to 64MB
74    c.batch_execute("PRAGMA mmap_size = 134217728")?; // maximum size of the internal mmap pool. Corresponds to 128MB
75    c.batch_execute("PRAGMA cache_size = 2000")?; // maximum number of database disk pages that will be hold in memory. Corresponds to ~8MB
76    c.batch_execute("PRAGMA foreign_keys = ON;")?; // enforce foreign keys
77
78    Ok(())
79}
80
81/// An Unencrypted Connection
82/// Creates a Sqlite3 Database/Connection in WAL mode.
83/// _*NOTE:*_Unencrypted Connections are not validated and mostly meant for testing.
84/// It is not recommended to use an unencrypted connection in production.
85#[derive(Clone, Debug)]
86pub struct UnencryptedConnection {
87    options: StorageOption,
88}
89
90impl UnencryptedConnection {
91    pub fn new(options: StorageOption) -> Self {
92        Self { options }
93    }
94}
95
96impl ValidatedConnection for UnencryptedConnection {}
97
98impl ConnectionOptions for UnencryptedConnection {
99    fn options(&self) -> &StorageOption {
100        &self.options
101    }
102}
103
104impl CustomizeConnection<SqliteConnection, r2d2::Error> for UnencryptedConnection {
105    fn on_acquire(&self, c: &mut SqliteConnection) -> Result<(), r2d2::Error> {
106        if cfg!(any(test, feature = "test-utils")) {
107            c.set_instrumentation(TestInstrumentation);
108        }
109        connection_pragmas(c)?;
110        Ok(())
111    }
112}
113
114impl ConnectionOptions for NopConnection {
115    fn options(&self) -> &StorageOption {
116        &self.options
117    }
118}
119
120#[derive(Clone, Debug)]
121pub struct NopConnection {
122    options: StorageOption,
123}
124
125impl Default for NopConnection {
126    fn default() -> Self {
127        NopConnection {
128            options: StorageOption::Ephemeral,
129        }
130    }
131}
132
133impl ValidatedConnection for NopConnection {}
134impl CustomizeConnection<SqliteConnection, r2d2::Error> for NopConnection {
135    fn on_acquire(&self, c: &mut SqliteConnection) -> Result<(), r2d2::Error> {
136        if cfg!(any(test, feature = "test-utils")) {
137            c.set_instrumentation(TestInstrumentation);
138        }
139        Ok(())
140    }
141}
142
143impl StorageOption {
144    pub(super) fn path(&self) -> Option<&String> {
145        use StorageOption::*;
146        match self {
147            Persistent(path) => Some(path),
148            _ => None,
149        }
150    }
151}
152
153#[derive(Debug, Error)]
154pub enum PlatformStorageError {
155    #[error("Pool error: {0}")]
156    Pool(#[from] diesel::r2d2::PoolError),
157    #[error("Error with connection to Sqlite {0}")]
158    DbConnection(#[from] diesel::r2d2::Error),
159    #[error("Pool needs to  reconnect before use")]
160    PoolNeedsConnection,
161    #[error("Using a DB Pool requires a persistent path")]
162    PoolRequiresPath,
163    #[error("The SQLCipher Sqlite extension is not present, but an encryption key is given")]
164    SqlCipherNotLoaded,
165    #[error("PRAGMA key or salt has incorrect value")]
166    SqlCipherKeyIncorrect,
167    #[error("Database is locked")]
168    DatabaseLocked,
169    #[error(transparent)]
170    DieselResult(#[from] diesel::result::Error),
171    #[error(transparent)]
172    NotFound(#[from] NotFound),
173    #[error(transparent)]
174    Io(#[from] std::io::Error),
175    #[error(transparent)]
176    FromHex(#[from] hex::FromHexError),
177    #[error(transparent)]
178    DieselConnect(#[from] diesel::ConnectionError),
179    #[error(transparent)]
180    Boxed(#[from] BoxDynError),
181}
182
183impl RetryableError for PlatformStorageError {
184    fn is_retryable(&self) -> bool {
185        match self {
186            Self::Pool(_) => true,
187            Self::SqlCipherNotLoaded => true,
188            Self::PoolNeedsConnection => true,
189            Self::SqlCipherKeyIncorrect => false,
190            Self::DatabaseLocked => true,
191            Self::DieselResult(result) => retryable!(result),
192            Self::Io(_) => true,
193            Self::DieselConnect(_) => true,
194
195            _ => false,
196        }
197    }
198}
199
200#[derive(Clone, Debug)]
201/// Database used in `native` (everywhere but web)
202pub struct NativeDb {
203    customizer: Box<dyn XmtpConnection>,
204    conn: Arc<PersistentOrMem<NativeDbConnection, EphemeralDbConnection>>,
205    opts: StorageOption,
206}
207
208impl NativeDb {
209    pub fn new(opts: &StorageOption, enc_key: EncryptionKey) -> Result<Self, StorageError> {
210        Self::new_inner(opts, Some(enc_key)).map_err(Into::into)
211    }
212
213    pub fn new_unencrypted(opts: &StorageOption) -> Result<Self, StorageError> {
214        Self::new_inner(opts, None).map_err(Into::into)
215    }
216
217    /// This function is private so that an unencrypted database cannot be created by accident
218    fn new_inner(
219        opts: &StorageOption,
220        enc_key: Option<EncryptionKey>,
221    ) -> Result<Self, PlatformStorageError> {
222        let customizer = if let Some(key) = enc_key {
223            let enc_connection = EncryptedConnection::new(key, opts)?;
224            if let Some(path) = enc_connection.options().path() {
225                let mut conn = SqliteConnection::establish(path)?;
226                enc_connection.validate(&mut conn)?;
227            }
228            Box::new(enc_connection) as Box<dyn XmtpConnection>
229        } else if matches!(opts, StorageOption::Persistent(_)) {
230            Box::new(UnencryptedConnection::new(opts.clone())) as Box<dyn XmtpConnection>
231        } else {
232            Box::new(NopConnection::default()) as Box<dyn XmtpConnection>
233        };
234        let conn = if customizer.is_persistent() {
235            PersistentOrMem::Persistent(NativeDbConnection::new(customizer.clone())?)
236        } else {
237            PersistentOrMem::Mem(EphemeralDbConnection::new()?)
238        };
239
240        Ok(Self {
241            opts: opts.clone(),
242            conn: conn.into(),
243            customizer,
244        })
245    }
246}
247
248impl XmtpDb for NativeDb {
249    type Connection = Arc<PersistentOrMem<NativeDbConnection, EphemeralDbConnection>>;
250    type DbQuery = DbConnection<Self::Connection>;
251
252    fn conn(&self) -> Self::Connection {
253        self.conn.clone()
254    }
255
256    fn db(&self) -> Self::DbQuery {
257        DbConnection::new(self.conn.clone())
258    }
259
260    fn opts(&self) -> &StorageOption {
261        &self.opts
262    }
263
264    fn validate(&self, conn: &mut SqliteConnection) -> Result<(), ConnectionError> {
265        self.customizer.validate(conn)?;
266        Ok(())
267    }
268
269    fn disconnect(&self) -> Result<(), ConnectionError> {
270        self.conn.disconnect()
271    }
272
273    fn reconnect(&self) -> Result<(), ConnectionError> {
274        self.conn.reconnect()
275    }
276}
277
278pub struct EphemeralDbConnection {
279    conn: Arc<Mutex<SqliteConnection>>,
280}
281
282impl std::fmt::Debug for EphemeralDbConnection {
283    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284        write!(
285            f,
286            "EphemeralConnection {{ is_locked={} }}",
287            self.conn.is_locked()
288        )
289    }
290}
291
292impl EphemeralDbConnection {
293    pub fn new() -> Result<Self, PlatformStorageError> {
294        let mut c = SqliteConnection::establish(":memory:")?;
295        UnencryptedConnection::on_acquire(
296            &UnencryptedConnection::new(StorageOption::Ephemeral),
297            &mut c,
298        )?;
299        Ok(Self {
300            conn: Arc::new(Mutex::new(c)),
301        })
302    }
303
304    fn db_disconnect(&self) -> Result<(), PlatformStorageError> {
305        Ok(())
306    }
307
308    fn db_reconnect(&self) -> Result<(), PlatformStorageError> {
309        let mut w = self.conn.lock();
310        let conn = SqliteConnection::establish(":memory:")?;
311        *w = conn;
312        Ok(())
313    }
314}
315
316impl ConnectionExt for EphemeralDbConnection {
317    fn raw_query_read<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
318    where
319        F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
320        Self: Sized,
321    {
322        let mut conn = self.conn.lock();
323        fun(&mut conn).map_err(ConnectionError::from)
324    }
325
326    fn raw_query_write<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
327    where
328        F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
329        Self: Sized,
330    {
331        let mut conn = self.conn.lock();
332        fun(&mut conn).map_err(ConnectionError::from)
333    }
334
335    fn disconnect(&self) -> Result<(), crate::ConnectionError> {
336        Ok(self.db_disconnect()?)
337    }
338
339    fn reconnect(&self) -> Result<(), crate::ConnectionError> {
340        Ok(self.db_reconnect()?)
341    }
342}
343
344pub struct NativeDbConnection {
345    pub(super) pool: ArcSwapOption<DbPool>,
346    customizer: Box<dyn XmtpConnection>,
347}
348
349impl std::fmt::Debug for NativeDbConnection {
350    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351        write!(
352            f,
353            "NativeDbConnection {{ path: {}, state={:?} }}",
354            &self.customizer.options(),
355            self.pool.load().as_ref().map(|s| s.state()),
356        )
357    }
358}
359
360impl NativeDbConnection {
361    fn new(customizer: Box<dyn XmtpConnection>) -> Result<Self, PlatformStorageError> {
362        Ok(Self {
363            pool: ArcSwapOption::new(Some(Arc::new(DbPool::new(customizer.clone())?))),
364            customizer,
365        })
366    }
367
368    fn db_disconnect(&self) -> Result<(), PlatformStorageError> {
369        tracing::warn!("released sqlite database connection");
370        self.pool.store(None);
371        Ok(())
372    }
373
374    fn db_reconnect(&self) -> Result<(), PlatformStorageError> {
375        tracing::info!("reconnecting sqlite database connection");
376        self.pool
377            .store(Some(Arc::new(DbPool::new(self.customizer.clone())?)));
378        Ok(())
379    }
380}
381
382impl ConnectionExt for NativeDbConnection {
383    #[tracing::instrument(level = "trace", skip_all)]
384    fn raw_query_read<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
385    where
386        F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
387        Self: Sized,
388    {
389        if let Some(pool) = &*self.pool.load() {
390            tracing::trace!(
391                "pulling connection from pool, idle={}, total={}",
392                pool.state().idle_connections,
393                pool.state().connections
394            );
395            let mut conn = pool.get()?;
396            fun(&mut conn).map_err(ConnectionError::from)
397        } else {
398            Err(ConnectionError::from(
399                PlatformStorageError::PoolNeedsConnection,
400            ))
401        }
402    }
403
404    #[tracing::instrument(level = "trace", skip_all)]
405    fn raw_query_write<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
406    where
407        F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
408        Self: Sized,
409    {
410        if let Some(pool) = &*self.pool.load() {
411            tracing::trace!(
412                "pulling connection from pool for write, idle={}, total={}",
413                pool.state().idle_connections,
414                pool.state().connections
415            );
416            let mut conn = pool.get()?;
417            fun(&mut conn).map_err(ConnectionError::from)
418        } else {
419            Err(ConnectionError::from(
420                PlatformStorageError::PoolNeedsConnection,
421            ))
422        }
423    }
424
425    fn disconnect(&self) -> Result<(), ConnectionError> {
426        Ok(self.db_disconnect()?)
427    }
428
429    fn reconnect(&self) -> Result<(), ConnectionError> {
430        Ok(self.db_reconnect()?)
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use crate::{EncryptedMessageStore, XmtpTestDb};
437
438    use super::*;
439    use crate::{Fetch, Store, identity::StoredIdentity};
440    use xmtp_common::{rand_vec, tmp_path};
441
442    #[tokio::test]
443    async fn releases_db_lock() {
444        let db_path = tmp_path();
445        {
446            let store = crate::TestDb::create_persistent_store(Some(db_path.clone())).await;
447            let conn = &store.conn();
448
449            let inbox_id = "inbox_id";
450            StoredIdentity::new(inbox_id.to_string(), rand_vec::<24>(), rand_vec::<24>())
451                .store(conn)
452                .unwrap();
453
454            let fetched_identity: StoredIdentity = conn.fetch(&()).unwrap().unwrap();
455
456            assert_eq!(fetched_identity.inbox_id, inbox_id);
457
458            store.release_connection().unwrap();
459            if let PersistentOrMem::Persistent(p) = &*store.db.conn() {
460                assert!(p.pool.load().is_none())
461            } else {
462                panic!("conn expected")
463            }
464            store.reconnect().unwrap();
465            let fetched_identity2: StoredIdentity = conn.fetch(&()).unwrap().unwrap();
466
467            assert_eq!(fetched_identity2.inbox_id, inbox_id);
468        }
469
470        EncryptedMessageStore::<()>::remove_db_files(db_path)
471    }
472
473    #[tokio::test]
474    async fn mismatched_encryption_key() {
475        use crate::database::PlatformStorageError;
476        let mut enc_key = [1u8; 32];
477
478        let db_path = tmp_path();
479        let opts = StorageOption::Persistent(db_path.clone());
480        {
481            let db = NativeDb::new(&opts, enc_key).unwrap();
482            db.init().unwrap();
483
484            StoredIdentity::new(
485                "dummy_address".to_string(),
486                rand_vec::<24>(),
487                rand_vec::<24>(),
488            )
489            .store(&db.conn())
490            .unwrap();
491        } // Drop it
492        enc_key[3] = 145; // Alter the enc_key
493        let err = NativeDb::new(&opts, enc_key).unwrap_err();
494        // Ensure it fails
495        assert!(
496            matches!(
497                err,
498                crate::StorageError::Platform(PlatformStorageError::SqlCipherKeyIncorrect)
499            ),
500            "Expected SqlCipherKeyIncorrect error, got {}",
501            err
502        );
503        EncryptedMessageStore::<()>::remove_db_files(db_path)
504    }
505}