xmtp_db/encrypted_store/
migrations.rs

1use diesel::migration::{Migration, MigrationSource, MigrationVersion};
2use diesel_migrations::MigrationHarness;
3
4use super::{ConnectionExt, MIGRATIONS, Sqlite, db_connection::DbConnection};
5use crate::ConnectionError;
6
7/// Trait for database migration operations.
8///
9/// WARNING: These operations are dangerous and can cause data loss.
10/// They are intended for debugging and admin tools only.
11pub trait QueryMigrations {
12    /// Returns a list of all applied migration versions, most recent first.
13    fn applied_migrations(&self) -> Result<Vec<String>, ConnectionError>;
14
15    /// Returns a list of all available (embedded) migration names.
16    fn available_migrations(&self) -> Result<Vec<String>, ConnectionError>;
17
18    /// Rollback all migrations after and including the specified version.
19    ///
20    /// WARNING: This is destructive and may cause data loss.
21    fn rollback_to_version(&self, version: &str) -> Result<Vec<String>, ConnectionError>;
22
23    /// Run a specific migration by name.
24    ///
25    /// NOTE: This runs the migration SQL directly without updating the
26    /// schema_migrations tracking table.
27    fn run_migration(&self, name: &str) -> Result<(), ConnectionError>;
28
29    /// Revert a specific migration by name.
30    ///
31    /// NOTE: This runs the revert SQL directly without updating the
32    /// schema_migrations tracking table.
33    fn revert_migration(&self, name: &str) -> Result<(), ConnectionError>;
34
35    /// Run all pending migrations.
36    fn run_pending_migrations(&self) -> Result<Vec<String>, ConnectionError>;
37}
38
39fn get_migrations() -> Result<Vec<Box<dyn Migration<Sqlite>>>, ConnectionError> {
40    MigrationSource::<Sqlite>::migrations(&MIGRATIONS)
41        .map_err(|e| ConnectionError::Database(diesel::result::Error::QueryBuilderError(e)))
42}
43
44impl<C: ConnectionExt> QueryMigrations for DbConnection<C> {
45    fn applied_migrations(&self) -> Result<Vec<String>, ConnectionError> {
46        let applied: Vec<MigrationVersion<'static>> = self.raw_query_read(|conn| {
47            conn.applied_migrations()
48                .map_err(diesel::result::Error::QueryBuilderError)
49        })?;
50        Ok(applied.into_iter().map(|v| v.to_string()).collect())
51    }
52
53    fn available_migrations(&self) -> Result<Vec<String>, ConnectionError> {
54        let migrations = get_migrations()?;
55        let names: Vec<String> = migrations.iter().map(|m| m.name().to_string()).collect();
56        Ok(names)
57    }
58
59    fn rollback_to_version(&self, version: &str) -> Result<Vec<String>, ConnectionError> {
60        let target: String = version.chars().filter(|c| c.is_numeric()).collect();
61        let target: u64 = target.parse().map_err(|_| {
62            ConnectionError::InvalidQuery(format!("Invalid migration version: {version}"))
63        })?;
64
65        let mut reverted = Vec::new();
66
67        loop {
68            let applied = self.applied_migrations()?;
69            let Some(current_version) = applied.first() else {
70                break;
71            };
72
73            let version_number: String =
74                current_version.chars().filter(|c| c.is_numeric()).collect();
75            let current_num: u64 = version_number.parse().map_err(|_| {
76                ConnectionError::InvalidQuery(format!("Invalid applied version: {current_version}"))
77            })?;
78
79            if current_num < target {
80                break;
81            }
82
83            let result = self.raw_query_write(|conn| {
84                conn.revert_last_migration(MIGRATIONS)
85                    .map(|v| v.to_string())
86                    .map_err(diesel::result::Error::QueryBuilderError)
87            });
88
89            match result {
90                Ok(version) => {
91                    reverted.push(version);
92                }
93                Err(e) => {
94                    tracing::warn!("Migration rollback stopped: {e:?}");
95                    break;
96                }
97            }
98        }
99
100        Ok(reverted)
101    }
102
103    fn run_migration(&self, name: &str) -> Result<(), ConnectionError> {
104        let migrations = get_migrations()?;
105
106        for migration in &migrations {
107            if migration.name().to_string() == name {
108                self.raw_query_write(|c| {
109                    migration
110                        .run(c)
111                        .map_err(diesel::result::Error::QueryBuilderError)
112                })?;
113                return Ok(());
114            }
115        }
116
117        Err(ConnectionError::InvalidQuery(format!(
118            "Migration not found: {name}"
119        )))
120    }
121
122    fn revert_migration(&self, name: &str) -> Result<(), ConnectionError> {
123        let migrations = get_migrations()?;
124
125        for migration in &migrations {
126            if migration.name().to_string() == name {
127                self.raw_query_write(|c| {
128                    migration
129                        .revert(c)
130                        .map_err(diesel::result::Error::QueryBuilderError)
131                })?;
132                return Ok(());
133            }
134        }
135
136        Err(ConnectionError::InvalidQuery(format!(
137            "Migration not found: {name}"
138        )))
139    }
140
141    fn run_pending_migrations(&self) -> Result<Vec<String>, ConnectionError> {
142        let ran: Vec<String> = self.raw_query_write(|conn| {
143            conn.run_pending_migrations(MIGRATIONS)
144                .map(|versions| versions.into_iter().map(|v| v.to_string()).collect())
145                .map_err(diesel::result::Error::QueryBuilderError)
146        })?;
147        Ok(ran)
148    }
149}