1pub mod association_state;
14pub mod consent_record;
15pub mod conversation_list;
16pub mod database;
17pub mod db_connection;
18pub mod group;
19pub mod group_intent;
20pub mod group_message;
21pub mod icebox;
22pub mod identity;
23pub mod identity_cache;
24pub mod identity_update;
25pub mod key_package_history;
26pub mod key_store_entry;
27pub mod local_commit_log;
28pub mod message_deletion;
29pub mod migrations;
30pub mod pending_remove;
31pub mod pragmas;
32pub mod processed_device_sync_messages;
33pub mod readd_status;
34pub mod refresh_state;
35pub mod remote_commit_log;
36pub mod schema;
37mod schema_gen;
38pub mod store;
39pub mod tasks;
40pub mod user_preferences;
41
42#[cfg(test)]
43mod migration_test;
44
45pub use self::db_connection::DbConnection;
46use diesel::{migration::Migration, result::DatabaseErrorKind};
47pub use diesel::{
48 migration::MigrationSource,
49 sqlite::{Sqlite, SqliteConnection},
50};
51use openmls::storage::OpenMlsProvider;
52use prost::DecodeError;
53use xmtp_common::{MaybeSend, MaybeSync, RetryableError};
54
55use super::StorageError;
56use crate::sql_key_store::SqlKeyStoreError;
57use crate::{Store, XmtpMlsStorageProvider};
58
59pub use database::*;
60pub use store::*;
61
62use diesel::{prelude::*, sql_query};
63use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
64use std::sync::Arc;
65pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations/");
66
67pub type EncryptionKey = [u8; 32];
68
69#[derive(QueryableByName, Debug)]
71struct SqliteVersion {
72 #[diesel(sql_type = diesel::sql_types::Text)]
73 version: String,
74}
75
76#[derive(Default, Clone, Debug, zeroize::ZeroizeOnDrop)]
77pub enum StorageOption {
78 #[default]
79 Ephemeral,
80 Persistent(String),
81}
82
83impl std::fmt::Display for StorageOption {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 StorageOption::Ephemeral => write!(f, "Ephemeral"),
87 StorageOption::Persistent(path) => write!(f, "Persistent({})", path),
88 }
89 }
90}
91
92#[derive(thiserror::Error, Debug)]
93pub enum ConnectionError {
94 #[error(transparent)]
95 Database(#[from] diesel::result::Error),
96 #[error(transparent)]
97 Platform(#[from] PlatformStorageError),
98 #[error(transparent)]
99 DecodeError(#[from] DecodeError),
100 #[error("disconnect not possible in transaction")]
101 DisconnectInTransaction,
102 #[error("reconnect not possible in transaction")]
103 ReconnectInTransaction,
104 #[error("invalid query: {0}")]
105 InvalidQuery(String),
106 #[error(
107 "Applied migrations does not match available migrations.\n\
108 This is likely due to running a database that is newer than this version of libxmtp.\n\
109 Expected: {expected}, found: {found}"
110 )]
111 InvalidVersion { expected: String, found: String },
112}
113
114impl RetryableError for ConnectionError {
115 fn is_retryable(&self) -> bool {
116 match self {
117 Self::Database(d) => d.is_retryable(),
118 Self::Platform(n) => n.is_retryable(),
119 Self::DecodeError(_) => false,
120 Self::DisconnectInTransaction => true,
121 Self::ReconnectInTransaction => true,
122 Self::InvalidQuery(_) => false,
123 Self::InvalidVersion { .. } => false,
124 }
125 }
126}
127
128pub trait ConnectionExt: MaybeSend + MaybeSync {
129 fn raw_query_read<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
131 where
132 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
133 Self: Sized;
134
135 fn raw_query_write<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
138 where
139 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
140 Self: Sized;
141
142 fn disconnect(&self) -> Result<(), ConnectionError>;
143 fn reconnect(&self) -> Result<(), ConnectionError>;
144}
145
146impl<C> ConnectionExt for &C
147where
148 C: ConnectionExt,
149{
150 fn raw_query_read<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
151 where
152 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
153 Self: Sized,
154 {
155 <C as ConnectionExt>::raw_query_read(self, fun)
156 }
157
158 fn raw_query_write<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
159 where
160 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
161 Self: Sized,
162 {
163 <C as ConnectionExt>::raw_query_write(self, fun)
164 }
165
166 fn disconnect(&self) -> Result<(), ConnectionError> {
167 <C as ConnectionExt>::disconnect(self)
168 }
169
170 fn reconnect(&self) -> Result<(), ConnectionError> {
171 <C as ConnectionExt>::reconnect(self)
172 }
173}
174
175impl<C> ConnectionExt for &mut C
176where
177 C: ConnectionExt,
178{
179 fn raw_query_read<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
180 where
181 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
182 Self: Sized,
183 {
184 <C as ConnectionExt>::raw_query_read(self, fun)
185 }
186
187 fn raw_query_write<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
188 where
189 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
190 Self: Sized,
191 {
192 <C as ConnectionExt>::raw_query_write(self, fun)
193 }
194
195 fn disconnect(&self) -> Result<(), ConnectionError> {
196 <C as ConnectionExt>::disconnect(self)
197 }
198
199 fn reconnect(&self) -> Result<(), ConnectionError> {
200 <C as ConnectionExt>::reconnect(self)
201 }
202}
203
204impl<C> ConnectionExt for Arc<C>
205where
206 C: ConnectionExt,
207{
208 fn raw_query_read<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
209 where
210 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
211 Self: Sized,
212 {
213 <C as ConnectionExt>::raw_query_read(self, fun)
214 }
215
216 fn raw_query_write<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
217 where
218 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
219 Self: Sized,
220 {
221 <C as ConnectionExt>::raw_query_write(self, fun)
222 }
223
224 fn disconnect(&self) -> Result<(), ConnectionError> {
225 <C as ConnectionExt>::disconnect(self)
226 }
227
228 fn reconnect(&self) -> Result<(), ConnectionError> {
229 <C as ConnectionExt>::reconnect(self)
230 }
231}
232
233pub type BoxedDatabase = Box<
234 dyn XmtpDb<
235 Connection = diesel::SqliteConnection,
236 DbQuery = DbConnection<diesel::SqliteConnection>,
237 >,
238>;
239
240#[cfg_attr(any(feature = "test-utils", test), mockall::automock(type Connection = crate::mock::MockConnection; type DbQuery = crate::mock::MockDbQuery;))]
241pub trait XmtpDb: MaybeSend + MaybeSync {
242 type Connection: ConnectionExt + MaybeSend + MaybeSync;
244
245 type DbQuery: crate::DbQuery + MaybeSend + MaybeSync;
246
247 fn init(&self) -> Result<(), ConnectionError> {
248 self.conn().raw_query_write(|conn| {
249 self.validate(conn).map_err(|e| {
250 diesel::result::Error::DatabaseError(
251 DatabaseErrorKind::Unknown,
252 Box::new(e.to_string()),
253 )
254 })?;
255 conn.run_pending_migrations(MIGRATIONS)
256 .map_err(diesel::result::Error::QueryBuilderError)?;
257
258 let db_version = conn.final_migration()?;
260 let last_migration = MIGRATIONS.final_migration();
261 if db_version != last_migration {
262 return Ok(Err(ConnectionError::InvalidVersion {
263 expected: last_migration,
264 found: db_version,
265 }));
266 }
267
268 let sqlite_version =
269 sql_query("SELECT sqlite_version() AS version").load::<SqliteVersion>(conn)?;
270 tracing::info!("sqlite_version={}", sqlite_version[0].version);
271
272 tracing::info!("Migrations successful");
273 Ok(Ok(()))
274 })??;
275
276 Ok(())
277 }
278
279 fn opts(&self) -> &StorageOption;
281
282 fn validate(&self, _conn: &mut SqliteConnection) -> Result<(), ConnectionError> {
284 Ok(())
285 }
286
287 fn conn(&self) -> Self::Connection;
289
290 fn db(&self) -> Self::DbQuery;
293
294 fn reconnect(&self) -> Result<(), ConnectionError>;
296
297 fn disconnect(&self) -> Result<(), ConnectionError>;
299}
300
301#[macro_export]
302macro_rules! impl_fetch {
303 ($model:ty, $table:ident) => {
304 impl<C> $crate::Fetch<$model> for C
305 where
306 C: $crate::ConnectionExt,
307 {
308 type Key = ();
309 fn fetch(&self, _key: &Self::Key) -> Result<Option<$model>, $crate::StorageError> {
310 use $crate::encrypted_store::schema::$table::dsl::*;
311 self.raw_query_read(|conn| $table.first(conn).optional())
312 .map_err(Into::into)
313 }
314 }
315 };
316
317 ($model:ty, $table:ident, $key:ty) => {
318 impl<C> $crate::Fetch<$model> for C
319 where
320 C: $crate::ConnectionExt,
321 {
322 type Key = $key;
323 fn fetch(&self, key: &Self::Key) -> Result<Option<$model>, $crate::StorageError> {
324 use $crate::encrypted_store::schema::$table::dsl::*;
325 self.raw_query_read::<_, _>(|conn| $table.find(key.clone()).first(conn).optional())
326 .map_err(Into::into)
327 }
328 }
329 };
330}
331
332#[macro_export]
333macro_rules! impl_fetch_list {
334 ($model:ty, $table:ident) => {
335 impl<C> $crate::FetchList<$model> for C
336 where
337 C: $crate::ConnectionExt,
338 {
339 fn fetch_list(&self) -> Result<Vec<$model>, $crate::StorageError> {
340 use $crate::encrypted_store::schema::$table::dsl::*;
341 self.raw_query_read(|conn| $table.load::<$model>(conn))
342 .map_err(Into::into)
343 }
344 }
345 };
346}
347
348#[macro_export]
350macro_rules! impl_store {
351 ($model:ty, $table:ident) => {
352 impl<C> $crate::Store<C> for $model
353 where
354 C: $crate::ConnectionExt,
355 {
356 type Output = ();
357 fn store(&self, into: &C) -> Result<(), $crate::StorageError> {
358 into.raw_query_write::<_, _>(|conn| {
359 diesel::insert_into($table::table)
360 .values(self)
361 .execute(conn)
362 .map_err(Into::into)
363 .map(|_| ())
364 })
365 .map_err(Into::into)
366 }
367 }
368 };
369}
370
371#[macro_export]
372macro_rules! impl_store_or_ignore {
373 ($model:ty, $table:ident) => {
375 impl<C> $crate::StoreOrIgnore<C> for $model
376 where
377 C: $crate::ConnectionExt,
378 {
379 type Output = ();
380
381 fn store_or_ignore(&self, into: &C) -> Result<(), $crate::StorageError> {
382 into.raw_query_write(|conn| {
383 diesel::insert_or_ignore_into($table::table)
384 .values(self)
385 .execute(conn)
386 .map_err(Into::into)
387 .map(|_| ())
388 })
389 .map_err(Into::into)
390 }
391 }
392 };
393}
394
395impl<T, C> Store<DbConnection<C>> for Vec<T>
396where
397 T: Store<DbConnection<C>>,
398{
399 type Output = ();
400 fn store(&self, into: &DbConnection<C>) -> Result<Self::Output, StorageError> {
401 for item in self {
402 item.store(into)?;
403 }
404 Ok(())
405 }
406}
407
408pub trait MlsProviderExt: OpenMlsProvider<StorageError = SqlKeyStoreError> {
409 type XmtpStorage: XmtpMlsStorageProvider;
410
411 fn key_store(&self) -> &Self::XmtpStorage;
412}
413
414trait EmbeddedMigrationsExt {
415 fn final_migration(&self) -> String;
416}
417impl EmbeddedMigrationsExt for EmbeddedMigrations {
418 fn final_migration(&self) -> String {
419 let migrations: Vec<Box<dyn Migration<Sqlite>>> = self
420 .migrations()
421 .expect("Migrations are directly embedded, so this cannot error");
422 migrations
423 .first()
424 .expect("There is at least one migration")
425 .name()
426 .to_string()
427 .chars()
428 .filter(|c| c.is_numeric())
429 .collect()
430 }
431}
432
433trait MigrationHarnessExt {
434 fn final_migration(&mut self) -> Result<String, diesel::result::Error>;
435}
436
437impl MigrationHarnessExt for SqliteConnection {
438 fn final_migration(&mut self) -> Result<String, diesel::result::Error> {
439 let migration: String = self
440 .applied_migrations()
441 .map_err(diesel::result::Error::QueryBuilderError)?
442 .pop()
443 .expect("This function should be run after migrations are applied")
444 .to_string();
445
446 Ok(migration)
447 }
448}
449
450#[cfg(test)]
451pub(crate) mod tests {
452 #[cfg(target_arch = "wasm32")]
453 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
454
455 use super::*;
456 use crate::{Fetch, Store, XmtpTestDb, identity::StoredIdentity};
457 use xmtp_common::{rand_vec, tmp_path};
458
459 #[xmtp_common::test]
460 async fn ephemeral_store() {
461 let store = crate::TestDb::create_ephemeral_store().await;
462 let conn = store.conn();
463
464 let inbox_id = "inbox_id";
465 StoredIdentity::new(inbox_id.to_string(), rand_vec::<24>(), rand_vec::<24>())
466 .store(&conn)
467 .unwrap();
468
469 let fetched_identity: StoredIdentity = conn.fetch(&()).unwrap().unwrap();
470 assert_eq!(fetched_identity.inbox_id, inbox_id);
471 }
472
473 #[xmtp_common::test]
474 async fn persistent_store() {
475 let db_path = tmp_path();
476 {
477 let store = crate::TestDb::create_persistent_store(Some(db_path.clone())).await;
478 let conn = &store.conn();
479
480 let inbox_id = "inbox_id";
481 StoredIdentity::new(inbox_id.to_string(), rand_vec::<24>(), rand_vec::<24>())
482 .store(conn)
483 .unwrap();
484
485 let fetched_identity: StoredIdentity = conn.fetch(&()).unwrap().unwrap();
486 assert_eq!(fetched_identity.inbox_id, inbox_id);
487 }
488 EncryptedMessageStore::<()>::remove_db_files(db_path)
489 }
490
491 #[xmtp_common::test]
492 async fn encrypted_db_with_multiple_connections() {
493 let db_path = tmp_path();
494 {
495 let store = crate::TestDb::create_persistent_store(Some(db_path.clone())).await;
496 let conn1 = &store.conn();
497 let inbox_id = "inbox_id";
498 StoredIdentity::new(inbox_id.to_string(), rand_vec::<24>(), rand_vec::<24>())
499 .store(conn1)
500 .unwrap();
501
502 let conn2 = &store.conn();
503 tracing::info!("Getting conn 2");
504 let fetched_identity: StoredIdentity = conn2.fetch(&()).unwrap().unwrap();
505 assert_eq!(fetched_identity.inbox_id, inbox_id);
506 }
507 EncryptedMessageStore::<()>::remove_db_files(db_path)
508 }
509}