xmtp_db/encrypted_store/database/native/
sqlcipher_connection.rs1use diesel::{
3 connection::{LoadConnection, SimpleConnection},
4 deserialize::FromSqlRow,
5 prelude::*,
6 sql_query,
7};
8use std::{
9 fmt::Display,
10 fs::File,
11 io::{BufReader, Read, Write},
12 path::{Path, PathBuf},
13};
14
15use super::PlatformStorageError;
16use crate::{
17 NotFound,
18 database::instrumentation::TestInstrumentation,
19 native::{ConnectionOptions, connection_pragmas},
20};
21
22use crate::{EncryptionKey, StorageOption};
23
24pub type Salt = [u8; 16];
25const PLAINTEXT_HEADER_SIZE: usize = 32;
26const SALT_FILE_NAME: &str = "sqlcipher_salt";
27
28#[derive(QueryableByName, Debug)]
30struct CipherVersion {
31 #[diesel(sql_type = diesel::sql_types::Text)]
32 cipher_version: String,
33}
34
35#[derive(QueryableByName, Debug)]
37struct CipherProviderVersion {
38 #[diesel(sql_type = diesel::sql_types::Text)]
39 cipher_provider_version: String,
40}
41
42#[derive(Clone, Debug, zeroize::ZeroizeOnDrop)]
44pub struct EncryptedConnection {
45 key: EncryptionKey,
46 salt: Option<Salt>,
48 options: StorageOption,
49}
50
51impl EncryptedConnection {
52 pub fn new(key: EncryptionKey, opts: &StorageOption) -> Result<Self, PlatformStorageError> {
54 use crate::StorageOption::*;
55
56 let salt = match opts {
57 Ephemeral => None,
58 Persistent(db_path) => {
59 {
60 let mut conn = SqliteConnection::establish(db_path)?;
61 Self::check_for_sqlcipher(opts, &mut conn)?;
62 }
63 let mut salt = [0u8; 16];
64 let db_pathbuf = PathBuf::from(db_path);
65 let salt_path = Self::salt_file(db_path)?;
66
67 match (salt_path.try_exists()?, db_pathbuf.try_exists()?) {
68 (true, true) => {
70 tracing::debug!(
71 salt = %salt_path.display(),
72 db = %db_pathbuf.display(),
73 "salt and database exist, db=[{}], salt=[{}]",
74 db_pathbuf.display(),
75 salt_path.display(),
76 );
77 let file = BufReader::new(File::open(salt_path)?);
78 salt = <Salt as hex::FromHex>::from_hex(
79 file.bytes().take(32).collect::<Result<Vec<u8>, _>>()?,
80 )?;
81 }
82 (false, true) => {
84 tracing::debug!(
85 "migrating sqlcipher db=[{}] to plaintext header with salt=[{}]",
86 db_pathbuf.display(),
87 salt_path.display()
88 );
89 Self::migrate(db_path, key, &mut salt)?;
90 }
91 (false, false) => {
93 tracing::debug!(
94 "creating new sqlcipher db=[{}] with salt=[{}]",
95 db_pathbuf.display(),
96 salt_path.display()
97 );
98 Self::create(db_path, key, &mut salt)?;
99 }
100 (true, false) => {
104 tracing::debug!(
105 "database [{}] does not exist, but the salt [{}] does, re-creating",
106 db_pathbuf.display(),
107 salt_path.display(),
108 );
109 std::fs::remove_file(salt_path)?;
110 Self::create(db_path, key, &mut salt)?;
111 }
112 }
113 tracing::info!("db_path=[{}]", db_path);
114 Some(salt)
115 }
116 };
117
118 Ok(Self {
119 key,
120 salt,
121 options: opts.clone(),
122 })
123 }
124
125 fn create(
128 path: &String,
129 key: EncryptionKey,
130 salt: &mut [u8],
131 ) -> Result<(), PlatformStorageError> {
132 let conn = &mut SqliteConnection::establish(path)?;
133 conn.batch_execute(&format!(
134 r#"
135 {}
136 {}
137 "#,
138 pragma_key(hex::encode(key)),
139 pragma_plaintext_header()
140 ))?;
141
142 Self::write_salt(path, conn, salt)?;
143 Ok(())
144 }
145
146 fn migrate(
152 path: &String,
153 key: EncryptionKey,
154 salt: &mut [u8],
155 ) -> Result<(), PlatformStorageError> {
156 let conn = &mut SqliteConnection::establish(path)?;
157
158 conn.batch_execute(&format!(
159 r#"
160 {}
161 select count(*) from sqlite_master; -- trigger header read, currently it is encrypted
162 "#,
163 pragma_key(hex::encode(key))
164 ))?;
165
166 Self::write_salt(path, conn, salt)?;
168
169 conn.batch_execute(&format!(
170 r#"
171 {}
172 PRAGMA user_version = 1; -- force header write
173 "#,
174 pragma_plaintext_header()
175 ))?;
176
177 Ok(())
178 }
179
180 fn write_salt(
183 path: &String,
184 conn: &mut SqliteConnection,
185 buf: &mut [u8],
186 ) -> Result<(), PlatformStorageError> {
187 let mut row_iter = conn.load(sql_query("PRAGMA cipher_salt"))?;
188 let row = row_iter
190 .next()
191 .ok_or(NotFound::CipherSalt(path.to_string()))??;
192 let salt = <String as FromSqlRow<diesel::sql_types::Text, _>>::build_from_row(&row)?;
193 tracing::debug!(
194 salt,
195 file = %Self::salt_file(PathBuf::from(path))?.display(),
196 "writing salt to file"
197 );
198 let mut f = File::create(Self::salt_file(PathBuf::from(path))?)?;
199
200 f.write_all(salt.as_bytes())?;
201 let mut perms = f.metadata()?.permissions();
202 perms.set_readonly(true);
203 f.set_permissions(perms)?;
204
205 let salt = hex::decode(salt)?;
206 buf.copy_from_slice(&salt);
207 Ok(())
208 }
209
210 pub(crate) fn salt_file<P: AsRef<Path>>(db_path: P) -> std::io::Result<PathBuf> {
214 let db_path: &Path = db_path.as_ref();
215 let name = db_path.file_name().ok_or(std::io::Error::new(
216 std::io::ErrorKind::NotFound,
217 "database file has no name",
218 ))?;
219 let db_path = db_path.parent().ok_or(std::io::Error::new(
220 std::io::ErrorKind::NotFound,
221 "Parent directory could not be found",
222 ))?;
223 Ok(db_path.join(format!("{}.{}", name.to_string_lossy(), SALT_FILE_NAME)))
224 }
225
226 fn pragmas(&self) -> impl Display {
228 let Self { key, salt, .. } = self;
229
230 if let Some(s) = salt {
231 format!(
232 "{}\n{}\n{}",
233 pragma_key(hex::encode(key)),
234 pragma_plaintext_header(),
235 pragma_salt(hex::encode(s))
236 )
237 } else {
238 format!(
239 "{}\n{}",
240 pragma_key(hex::encode(key)),
241 pragma_plaintext_header()
242 )
243 }
244 }
245
246 fn check_for_sqlcipher(
247 opts: &StorageOption,
248 conn: &mut SqliteConnection,
249 ) -> Result<CipherVersion, PlatformStorageError> {
250 if cfg!(any(test, feature = "test-utils")) {
251 conn.batch_execute("pragma cipher_log = stdout; pragma cipher_log_level = NONE;")?;
252 }
253
254 if let Some(path) = opts.path() {
255 let exists = std::path::Path::new(path).exists();
256 tracing::debug!("db @ [{}] exists? [{}]", path, exists);
257 }
258 let mut cipher_version = sql_query("PRAGMA cipher_version").load::<CipherVersion>(conn)?;
259 if cipher_version.is_empty() {
260 return Err(PlatformStorageError::SqlCipherNotLoaded);
261 }
262 Ok(cipher_version.pop().expect("checked for empty"))
263 }
264}
265
266impl ConnectionOptions for EncryptedConnection {
267 fn options(&self) -> &StorageOption {
268 &self.options
269 }
270}
271
272impl super::ValidatedConnection for EncryptedConnection {
273 fn validate(&self, conn: &mut SqliteConnection) -> Result<(), PlatformStorageError> {
274 let sqlcipher_version = EncryptedConnection::check_for_sqlcipher(&self.options, conn)?;
275
276 conn.batch_execute(&format!(
279 "{}
280 SELECT count(*) FROM sqlite_master;",
281 self.pragmas()
282 ))
283 .map_err(|e| {
284 tracing::error!("SQLCipher PRAGMA batch_execute failed: {:?}", e);
285 PlatformStorageError::SqlCipherKeyIncorrect
286 })?;
287
288 let CipherProviderVersion {
289 cipher_provider_version,
290 } = sql_query("PRAGMA cipher_provider_version")
291 .get_result::<CipherProviderVersion>(conn)?;
292 tracing::info!(
293 "Sqlite cipher_version={:?}, cipher_provider_version={:?}",
294 sqlcipher_version.cipher_version,
295 cipher_provider_version
296 );
297 let log = std::env::var("SQLCIPHER_LOG");
298 let is_sqlcipher_log_enabled = matches!(log, Ok(s) if s == "true" || s == "1");
299 if is_sqlcipher_log_enabled {
301 conn.batch_execute("PRAGMA cipher_log = stderr; PRAGMA cipher_log_level = INFO;")
302 .ok();
303 }
304 tracing::debug!("SQLCipher Database validated.");
305 Ok(())
306 }
307}
308
309impl diesel::r2d2::CustomizeConnection<SqliteConnection, diesel::r2d2::Error>
310 for EncryptedConnection
311{
312 fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), diesel::r2d2::Error> {
313 if cfg!(any(test, feature = "test-utils")) {
314 conn.set_instrumentation(TestInstrumentation);
315 }
316 conn.batch_execute(&format!("{}", self.pragmas(),))
317 .map_err(diesel::r2d2::Error::QueryError)?;
318 connection_pragmas(conn)?;
319 Ok(())
320 }
321}
322
323fn pragma_key(key: impl Display) -> impl Display {
324 format!(r#"PRAGMA key = "x'{key}'";"#)
325}
326
327fn pragma_salt(salt: impl Display) -> impl Display {
328 format!(r#"PRAGMA cipher_salt="x'{salt}'";"#)
329}
330
331fn pragma_plaintext_header() -> impl Display {
332 format!(r#"PRAGMA cipher_plaintext_header_size={PLAINTEXT_HEADER_SIZE};"#)
333}
334
335#[cfg(test)]
336mod tests {
337 use crate::{EncryptedMessageStore, NativeDb, XmtpTestDb};
338 use diesel_migrations::MigrationHarness;
339 use std::fs::File;
340 use xmtp_common::tmp_path;
341
342 use super::*;
343 const SQLITE3_PLAINTEXT_HEADER: &str = "SQLite format 3\0";
344 use StorageOption::*;
345
346 #[tokio::test]
347 async fn test_sqlcipher_version() {
348 let db_path = tmp_path();
349 {
350 let opts = Persistent(db_path.clone());
351 let mut conn = SqliteConnection::establish(&db_path).unwrap();
352 let v = EncryptedConnection::check_for_sqlcipher(&opts, &mut conn).unwrap();
353 println!("SQLCipher Version {}", v.cipher_version);
354 }
355 }
356
357 #[tokio::test]
358 async fn test_db_creates_with_plaintext_header() {
359 let db_path = tmp_path();
360 {
361 let _ = crate::TestDb::create_persistent_store(Some(db_path.clone())).await;
362
363 assert!(EncryptedConnection::salt_file(&db_path).unwrap().exists());
364 let bytes = std::fs::read(EncryptedConnection::salt_file(&db_path).unwrap()).unwrap();
365 let salt = hex::decode(bytes).unwrap();
366 assert_eq!(salt.len(), 16);
367
368 let mut plaintext_header = [0; 16];
369 let mut file = File::open(&db_path).unwrap();
370 file.read_exact(&mut plaintext_header).unwrap();
371
372 assert_eq!(
373 SQLITE3_PLAINTEXT_HEADER,
374 String::from_utf8(plaintext_header.into()).unwrap()
375 );
376 }
377 EncryptedMessageStore::<()>::remove_db_files(db_path)
378 }
379
380 #[tokio::test]
381 async fn test_db_migrates() {
382 let db_path = tmp_path();
383 {
384 let key = EncryptedMessageStore::<()>::generate_enc_key();
385 {
386 let conn = &mut SqliteConnection::establish(&db_path).unwrap();
387 conn.batch_execute(&format!(
388 r#"
389 {}
390 PRAGMA busy_timeout = 5000;
391 PRAGMA journal_mode = WAL;
392 "#,
393 pragma_key(hex::encode(key))
394 ))
395 .unwrap();
396 conn.run_pending_migrations(crate::MIGRATIONS).unwrap();
397 }
398
399 let mut plaintext_header = [0; 16];
401 let mut file = File::open(&db_path).unwrap();
402 file.read_exact(&mut plaintext_header).unwrap();
403 assert!(String::from_utf8_lossy(&plaintext_header) != SQLITE3_PLAINTEXT_HEADER);
404
405 tracing::info!("Creating store with file at {}", &db_path);
406 let opts = Persistent(db_path.clone());
407 let db = NativeDb::new(&opts, key).unwrap();
408 let _ = EncryptedMessageStore::new(db);
409
410 assert!(EncryptedConnection::salt_file(&db_path).unwrap().exists());
411 let bytes = std::fs::read(EncryptedConnection::salt_file(&db_path).unwrap()).unwrap();
412 let salt = hex::decode(bytes).unwrap();
413 assert_eq!(salt.len(), 16);
414
415 let mut plaintext_header = [0; 16];
416 let mut file = File::open(&db_path).unwrap();
417 file.read_exact(&mut plaintext_header).unwrap();
418
419 assert_eq!(
420 SQLITE3_PLAINTEXT_HEADER,
421 String::from_utf8(plaintext_header.into()).unwrap()
422 );
423 }
424 EncryptedMessageStore::<()>::remove_db_files(db_path)
425 }
426}