xmtp_db/encrypted_store/database/
native.rs1mod pool;
2mod sqlcipher_connection;
3
4use crate::StorageError;
5use crate::database::instrumentation::TestInstrumentation;
6use crate::{ConnectionError, ConnectionExt, DbConnection, NotFound};
8use arc_swap::ArcSwapOption;
9use diesel::sqlite::SqliteConnection;
10use diesel::{
11 Connection,
12 connection::SimpleConnection,
13 r2d2::{self, CustomizeConnection, PooledConnection},
14};
15use parking_lot::Mutex;
16use std::sync::Arc;
17use thiserror::Error;
18use xmtp_common::{BoxDynError, RetryableError, retryable};
19use xmtp_configuration::BUSY_TIMEOUT;
20
21use pool::*;
22
23pub type RawDbConnection = PooledConnection<ConnectionManager>;
24
25pub use self::sqlcipher_connection::EncryptedConnection;
26use crate::{EncryptionKey, StorageOption, XmtpDb};
27
28use super::PersistentOrMem;
29
30trait XmtpConnection:
31 ValidatedConnection
32 + ConnectionOptions
33 + CustomizeConnection<SqliteConnection, r2d2::Error>
34 + dyn_clone::DynClone
35{
36}
37
38trait ConnectionOptions {
39 fn options(&self) -> &StorageOption;
40 fn is_persistent(&self) -> bool {
41 matches!(self.options(), StorageOption::Persistent(_))
42 }
43}
44
45impl<T> XmtpConnection for T where
46 T: ValidatedConnection
47 + CustomizeConnection<SqliteConnection, r2d2::Error>
48 + ConnectionOptions
49 + dyn_clone::DynClone
50{
51}
52
53dyn_clone::clone_trait_object!(XmtpConnection);
54
55pub(crate) trait ValidatedConnection {
56 fn validate(&self, _conn: &mut SqliteConnection) -> Result<(), PlatformStorageError> {
57 Ok(())
58 }
59}
60
61fn connection_pragmas(c: &mut impl SimpleConnection) -> diesel::result::QueryResult<()> {
66 c.batch_execute(&format!("PRAGMA busy_timeout = {};", BUSY_TIMEOUT))?; c.batch_execute("PRAGMA synchronous = NORMAL;")?; c.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?; c.batch_execute("PRAGMA wal_checkpoint(TRUNCATE);")?; c.batch_execute("PRAGMA query_only = OFF;")?; c.batch_execute("PRAGMA journal_size_limit = 67108864")?; c.batch_execute("PRAGMA mmap_size = 134217728")?; c.batch_execute("PRAGMA cache_size = 2000")?; c.batch_execute("PRAGMA foreign_keys = ON;")?; Ok(())
79}
80
81#[derive(Clone, Debug)]
86pub struct UnencryptedConnection {
87 options: StorageOption,
88}
89
90impl UnencryptedConnection {
91 pub fn new(options: StorageOption) -> Self {
92 Self { options }
93 }
94}
95
96impl ValidatedConnection for UnencryptedConnection {}
97
98impl ConnectionOptions for UnencryptedConnection {
99 fn options(&self) -> &StorageOption {
100 &self.options
101 }
102}
103
104impl CustomizeConnection<SqliteConnection, r2d2::Error> for UnencryptedConnection {
105 fn on_acquire(&self, c: &mut SqliteConnection) -> Result<(), r2d2::Error> {
106 if cfg!(any(test, feature = "test-utils")) {
107 c.set_instrumentation(TestInstrumentation);
108 }
109 connection_pragmas(c)?;
110 Ok(())
111 }
112}
113
114impl ConnectionOptions for NopConnection {
115 fn options(&self) -> &StorageOption {
116 &self.options
117 }
118}
119
120#[derive(Clone, Debug)]
121pub struct NopConnection {
122 options: StorageOption,
123}
124
125impl Default for NopConnection {
126 fn default() -> Self {
127 NopConnection {
128 options: StorageOption::Ephemeral,
129 }
130 }
131}
132
133impl ValidatedConnection for NopConnection {}
134impl CustomizeConnection<SqliteConnection, r2d2::Error> for NopConnection {
135 fn on_acquire(&self, c: &mut SqliteConnection) -> Result<(), r2d2::Error> {
136 if cfg!(any(test, feature = "test-utils")) {
137 c.set_instrumentation(TestInstrumentation);
138 }
139 Ok(())
140 }
141}
142
143impl StorageOption {
144 pub(super) fn path(&self) -> Option<&String> {
145 use StorageOption::*;
146 match self {
147 Persistent(path) => Some(path),
148 _ => None,
149 }
150 }
151}
152
153#[derive(Debug, Error)]
154pub enum PlatformStorageError {
155 #[error("Pool error: {0}")]
156 Pool(#[from] diesel::r2d2::PoolError),
157 #[error("Error with connection to Sqlite {0}")]
158 DbConnection(#[from] diesel::r2d2::Error),
159 #[error("Pool needs to reconnect before use")]
160 PoolNeedsConnection,
161 #[error("Using a DB Pool requires a persistent path")]
162 PoolRequiresPath,
163 #[error("The SQLCipher Sqlite extension is not present, but an encryption key is given")]
164 SqlCipherNotLoaded,
165 #[error("PRAGMA key or salt has incorrect value")]
166 SqlCipherKeyIncorrect,
167 #[error("Database is locked")]
168 DatabaseLocked,
169 #[error(transparent)]
170 DieselResult(#[from] diesel::result::Error),
171 #[error(transparent)]
172 NotFound(#[from] NotFound),
173 #[error(transparent)]
174 Io(#[from] std::io::Error),
175 #[error(transparent)]
176 FromHex(#[from] hex::FromHexError),
177 #[error(transparent)]
178 DieselConnect(#[from] diesel::ConnectionError),
179 #[error(transparent)]
180 Boxed(#[from] BoxDynError),
181}
182
183impl RetryableError for PlatformStorageError {
184 fn is_retryable(&self) -> bool {
185 match self {
186 Self::Pool(_) => true,
187 Self::SqlCipherNotLoaded => true,
188 Self::PoolNeedsConnection => true,
189 Self::SqlCipherKeyIncorrect => false,
190 Self::DatabaseLocked => true,
191 Self::DieselResult(result) => retryable!(result),
192 Self::Io(_) => true,
193 Self::DieselConnect(_) => true,
194
195 _ => false,
196 }
197 }
198}
199
200#[derive(Clone, Debug)]
201pub struct NativeDb {
203 customizer: Box<dyn XmtpConnection>,
204 conn: Arc<PersistentOrMem<NativeDbConnection, EphemeralDbConnection>>,
205 opts: StorageOption,
206}
207
208impl NativeDb {
209 pub fn new(opts: &StorageOption, enc_key: EncryptionKey) -> Result<Self, StorageError> {
210 Self::new_inner(opts, Some(enc_key)).map_err(Into::into)
211 }
212
213 pub fn new_unencrypted(opts: &StorageOption) -> Result<Self, StorageError> {
214 Self::new_inner(opts, None).map_err(Into::into)
215 }
216
217 fn new_inner(
219 opts: &StorageOption,
220 enc_key: Option<EncryptionKey>,
221 ) -> Result<Self, PlatformStorageError> {
222 let customizer = if let Some(key) = enc_key {
223 let enc_connection = EncryptedConnection::new(key, opts)?;
224 if let Some(path) = enc_connection.options().path() {
225 let mut conn = SqliteConnection::establish(path)?;
226 enc_connection.validate(&mut conn)?;
227 }
228 Box::new(enc_connection) as Box<dyn XmtpConnection>
229 } else if matches!(opts, StorageOption::Persistent(_)) {
230 Box::new(UnencryptedConnection::new(opts.clone())) as Box<dyn XmtpConnection>
231 } else {
232 Box::new(NopConnection::default()) as Box<dyn XmtpConnection>
233 };
234 let conn = if customizer.is_persistent() {
235 PersistentOrMem::Persistent(NativeDbConnection::new(customizer.clone())?)
236 } else {
237 PersistentOrMem::Mem(EphemeralDbConnection::new()?)
238 };
239
240 Ok(Self {
241 opts: opts.clone(),
242 conn: conn.into(),
243 customizer,
244 })
245 }
246}
247
248impl XmtpDb for NativeDb {
249 type Connection = Arc<PersistentOrMem<NativeDbConnection, EphemeralDbConnection>>;
250 type DbQuery = DbConnection<Self::Connection>;
251
252 fn conn(&self) -> Self::Connection {
253 self.conn.clone()
254 }
255
256 fn db(&self) -> Self::DbQuery {
257 DbConnection::new(self.conn.clone())
258 }
259
260 fn opts(&self) -> &StorageOption {
261 &self.opts
262 }
263
264 fn validate(&self, conn: &mut SqliteConnection) -> Result<(), ConnectionError> {
265 self.customizer.validate(conn)?;
266 Ok(())
267 }
268
269 fn disconnect(&self) -> Result<(), ConnectionError> {
270 self.conn.disconnect()
271 }
272
273 fn reconnect(&self) -> Result<(), ConnectionError> {
274 self.conn.reconnect()
275 }
276}
277
278pub struct EphemeralDbConnection {
279 conn: Arc<Mutex<SqliteConnection>>,
280}
281
282impl std::fmt::Debug for EphemeralDbConnection {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 write!(
285 f,
286 "EphemeralConnection {{ is_locked={} }}",
287 self.conn.is_locked()
288 )
289 }
290}
291
292impl EphemeralDbConnection {
293 pub fn new() -> Result<Self, PlatformStorageError> {
294 let mut c = SqliteConnection::establish(":memory:")?;
295 UnencryptedConnection::on_acquire(
296 &UnencryptedConnection::new(StorageOption::Ephemeral),
297 &mut c,
298 )?;
299 Ok(Self {
300 conn: Arc::new(Mutex::new(c)),
301 })
302 }
303
304 fn db_disconnect(&self) -> Result<(), PlatformStorageError> {
305 Ok(())
306 }
307
308 fn db_reconnect(&self) -> Result<(), PlatformStorageError> {
309 let mut w = self.conn.lock();
310 let conn = SqliteConnection::establish(":memory:")?;
311 *w = conn;
312 Ok(())
313 }
314}
315
316impl ConnectionExt for EphemeralDbConnection {
317 fn raw_query_read<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
318 where
319 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
320 Self: Sized,
321 {
322 let mut conn = self.conn.lock();
323 fun(&mut conn).map_err(ConnectionError::from)
324 }
325
326 fn raw_query_write<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
327 where
328 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
329 Self: Sized,
330 {
331 let mut conn = self.conn.lock();
332 fun(&mut conn).map_err(ConnectionError::from)
333 }
334
335 fn disconnect(&self) -> Result<(), crate::ConnectionError> {
336 Ok(self.db_disconnect()?)
337 }
338
339 fn reconnect(&self) -> Result<(), crate::ConnectionError> {
340 Ok(self.db_reconnect()?)
341 }
342}
343
344pub struct NativeDbConnection {
345 pub(super) pool: ArcSwapOption<DbPool>,
346 customizer: Box<dyn XmtpConnection>,
347}
348
349impl std::fmt::Debug for NativeDbConnection {
350 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351 write!(
352 f,
353 "NativeDbConnection {{ path: {}, state={:?} }}",
354 &self.customizer.options(),
355 self.pool.load().as_ref().map(|s| s.state()),
356 )
357 }
358}
359
360impl NativeDbConnection {
361 fn new(customizer: Box<dyn XmtpConnection>) -> Result<Self, PlatformStorageError> {
362 Ok(Self {
363 pool: ArcSwapOption::new(Some(Arc::new(DbPool::new(customizer.clone())?))),
364 customizer,
365 })
366 }
367
368 fn db_disconnect(&self) -> Result<(), PlatformStorageError> {
369 tracing::warn!("released sqlite database connection");
370 self.pool.store(None);
371 Ok(())
372 }
373
374 fn db_reconnect(&self) -> Result<(), PlatformStorageError> {
375 tracing::info!("reconnecting sqlite database connection");
376 self.pool
377 .store(Some(Arc::new(DbPool::new(self.customizer.clone())?)));
378 Ok(())
379 }
380}
381
382impl ConnectionExt for NativeDbConnection {
383 #[tracing::instrument(level = "trace", skip_all)]
384 fn raw_query_read<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
385 where
386 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
387 Self: Sized,
388 {
389 if let Some(pool) = &*self.pool.load() {
390 tracing::trace!(
391 "pulling connection from pool, idle={}, total={}",
392 pool.state().idle_connections,
393 pool.state().connections
394 );
395 let mut conn = pool.get()?;
396 fun(&mut conn).map_err(ConnectionError::from)
397 } else {
398 Err(ConnectionError::from(
399 PlatformStorageError::PoolNeedsConnection,
400 ))
401 }
402 }
403
404 #[tracing::instrument(level = "trace", skip_all)]
405 fn raw_query_write<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
406 where
407 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
408 Self: Sized,
409 {
410 if let Some(pool) = &*self.pool.load() {
411 tracing::trace!(
412 "pulling connection from pool for write, idle={}, total={}",
413 pool.state().idle_connections,
414 pool.state().connections
415 );
416 let mut conn = pool.get()?;
417 fun(&mut conn).map_err(ConnectionError::from)
418 } else {
419 Err(ConnectionError::from(
420 PlatformStorageError::PoolNeedsConnection,
421 ))
422 }
423 }
424
425 fn disconnect(&self) -> Result<(), ConnectionError> {
426 Ok(self.db_disconnect()?)
427 }
428
429 fn reconnect(&self) -> Result<(), ConnectionError> {
430 Ok(self.db_reconnect()?)
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use crate::{EncryptedMessageStore, XmtpTestDb};
437
438 use super::*;
439 use crate::{Fetch, Store, identity::StoredIdentity};
440 use xmtp_common::{rand_vec, tmp_path};
441
442 #[tokio::test]
443 async fn releases_db_lock() {
444 let db_path = tmp_path();
445 {
446 let store = crate::TestDb::create_persistent_store(Some(db_path.clone())).await;
447 let conn = &store.conn();
448
449 let inbox_id = "inbox_id";
450 StoredIdentity::new(inbox_id.to_string(), rand_vec::<24>(), rand_vec::<24>())
451 .store(conn)
452 .unwrap();
453
454 let fetched_identity: StoredIdentity = conn.fetch(&()).unwrap().unwrap();
455
456 assert_eq!(fetched_identity.inbox_id, inbox_id);
457
458 store.release_connection().unwrap();
459 if let PersistentOrMem::Persistent(p) = &*store.db.conn() {
460 assert!(p.pool.load().is_none())
461 } else {
462 panic!("conn expected")
463 }
464 store.reconnect().unwrap();
465 let fetched_identity2: StoredIdentity = conn.fetch(&()).unwrap().unwrap();
466
467 assert_eq!(fetched_identity2.inbox_id, inbox_id);
468 }
469
470 EncryptedMessageStore::<()>::remove_db_files(db_path)
471 }
472
473 #[tokio::test]
474 async fn mismatched_encryption_key() {
475 use crate::database::PlatformStorageError;
476 let mut enc_key = [1u8; 32];
477
478 let db_path = tmp_path();
479 let opts = StorageOption::Persistent(db_path.clone());
480 {
481 let db = NativeDb::new(&opts, enc_key).unwrap();
482 db.init().unwrap();
483
484 StoredIdentity::new(
485 "dummy_address".to_string(),
486 rand_vec::<24>(),
487 rand_vec::<24>(),
488 )
489 .store(&db.conn())
490 .unwrap();
491 } enc_key[3] = 145; let err = NativeDb::new(&opts, enc_key).unwrap_err();
494 assert!(
496 matches!(
497 err,
498 crate::StorageError::Platform(PlatformStorageError::SqlCipherKeyIncorrect)
499 ),
500 "Expected SqlCipherKeyIncorrect error, got {}",
501 err
502 );
503 EncryptedMessageStore::<()>::remove_db_files(db_path)
504 }
505}