xmtp_db/encrypted_store/
migrations.rs1use diesel::migration::{Migration, MigrationSource, MigrationVersion};
2use diesel_migrations::MigrationHarness;
3
4use super::{ConnectionExt, MIGRATIONS, Sqlite, db_connection::DbConnection};
5use crate::ConnectionError;
6
7pub trait QueryMigrations {
12 fn applied_migrations(&self) -> Result<Vec<String>, ConnectionError>;
14
15 fn available_migrations(&self) -> Result<Vec<String>, ConnectionError>;
17
18 fn rollback_to_version(&self, version: &str) -> Result<Vec<String>, ConnectionError>;
22
23 fn run_migration(&self, name: &str) -> Result<(), ConnectionError>;
28
29 fn revert_migration(&self, name: &str) -> Result<(), ConnectionError>;
34
35 fn run_pending_migrations(&self) -> Result<Vec<String>, ConnectionError>;
37}
38
39fn get_migrations() -> Result<Vec<Box<dyn Migration<Sqlite>>>, ConnectionError> {
40 MigrationSource::<Sqlite>::migrations(&MIGRATIONS)
41 .map_err(|e| ConnectionError::Database(diesel::result::Error::QueryBuilderError(e)))
42}
43
44impl<C: ConnectionExt> QueryMigrations for DbConnection<C> {
45 fn applied_migrations(&self) -> Result<Vec<String>, ConnectionError> {
46 let applied: Vec<MigrationVersion<'static>> = self.raw_query_read(|conn| {
47 conn.applied_migrations()
48 .map_err(diesel::result::Error::QueryBuilderError)
49 })?;
50 Ok(applied.into_iter().map(|v| v.to_string()).collect())
51 }
52
53 fn available_migrations(&self) -> Result<Vec<String>, ConnectionError> {
54 let migrations = get_migrations()?;
55 let names: Vec<String> = migrations.iter().map(|m| m.name().to_string()).collect();
56 Ok(names)
57 }
58
59 fn rollback_to_version(&self, version: &str) -> Result<Vec<String>, ConnectionError> {
60 let target: String = version.chars().filter(|c| c.is_numeric()).collect();
61 let target: u64 = target.parse().map_err(|_| {
62 ConnectionError::InvalidQuery(format!("Invalid migration version: {version}"))
63 })?;
64
65 let mut reverted = Vec::new();
66
67 loop {
68 let applied = self.applied_migrations()?;
69 let Some(current_version) = applied.first() else {
70 break;
71 };
72
73 let version_number: String =
74 current_version.chars().filter(|c| c.is_numeric()).collect();
75 let current_num: u64 = version_number.parse().map_err(|_| {
76 ConnectionError::InvalidQuery(format!("Invalid applied version: {current_version}"))
77 })?;
78
79 if current_num < target {
80 break;
81 }
82
83 let result = self.raw_query_write(|conn| {
84 conn.revert_last_migration(MIGRATIONS)
85 .map(|v| v.to_string())
86 .map_err(diesel::result::Error::QueryBuilderError)
87 });
88
89 match result {
90 Ok(version) => {
91 reverted.push(version);
92 }
93 Err(e) => {
94 tracing::warn!("Migration rollback stopped: {e:?}");
95 break;
96 }
97 }
98 }
99
100 Ok(reverted)
101 }
102
103 fn run_migration(&self, name: &str) -> Result<(), ConnectionError> {
104 let migrations = get_migrations()?;
105
106 for migration in &migrations {
107 if migration.name().to_string() == name {
108 self.raw_query_write(|c| {
109 migration
110 .run(c)
111 .map_err(diesel::result::Error::QueryBuilderError)
112 })?;
113 return Ok(());
114 }
115 }
116
117 Err(ConnectionError::InvalidQuery(format!(
118 "Migration not found: {name}"
119 )))
120 }
121
122 fn revert_migration(&self, name: &str) -> Result<(), ConnectionError> {
123 let migrations = get_migrations()?;
124
125 for migration in &migrations {
126 if migration.name().to_string() == name {
127 self.raw_query_write(|c| {
128 migration
129 .revert(c)
130 .map_err(diesel::result::Error::QueryBuilderError)
131 })?;
132 return Ok(());
133 }
134 }
135
136 Err(ConnectionError::InvalidQuery(format!(
137 "Migration not found: {name}"
138 )))
139 }
140
141 fn run_pending_migrations(&self) -> Result<Vec<String>, ConnectionError> {
142 let ran: Vec<String> = self.raw_query_write(|conn| {
143 conn.run_pending_migrations(MIGRATIONS)
144 .map(|versions| versions.into_iter().map(|v| v.to_string()).collect())
145 .map_err(diesel::result::Error::QueryBuilderError)
146 })?;
147 Ok(ran)
148 }
149}