xmtp_db/sql_key_store/
transactions.rs1use super::*;
2use crate::DbConnection;
3
4pub struct MutableTransactionConnection<'a> {
8 pub(crate) conn: parking_lot::Mutex<&'a mut SqliteConnection>,
13}
14
15impl<'a> MutableTransactionConnection<'a> {
16 pub fn new(conn: &'a mut SqliteConnection) -> Self {
17 Self {
18 conn: parking_lot::Mutex::new(conn),
19 }
20 }
21}
22
23impl<'a> ConnectionExt for MutableTransactionConnection<'a> {
24 fn raw_query_read<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
25 where
26 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
27 Self: Sized,
28 {
29 let mut conn = self.conn.try_lock().expect("Lock is held somewhere else");
30 fun(&mut conn).map_err(crate::ConnectionError::from)
31 }
32
33 fn raw_query_write<T, F>(&self, fun: F) -> Result<T, crate::ConnectionError>
34 where
35 F: FnOnce(&mut SqliteConnection) -> Result<T, diesel::result::Error>,
36 Self: Sized,
37 {
38 let mut conn = self.conn.try_lock().expect("Lock is held somewhere else");
39 fun(&mut conn).map_err(crate::ConnectionError::from)
40 }
41
42 fn disconnect(&self) -> Result<(), crate::ConnectionError> {
44 Err(crate::ConnectionError::DisconnectInTransaction)
45 }
46
47 fn reconnect(&self) -> Result<(), crate::ConnectionError> {
48 Err(crate::ConnectionError::ReconnectInTransaction)
49 }
50}
51
52impl<C: ConnectionExt> XmtpMlsStorageProvider for SqlKeyStore<C> {
53 type Connection = C;
54
55 type TxQuery = SqliteConnection;
56
57 type DbQuery<'a>
58 = DbConnection<&'a C>
59 where
60 Self::Connection: 'a;
61
62 fn db<'a>(&'a self) -> Self::DbQuery<'a> {
63 DbConnection::new(&self.conn)
64 }
65
66 fn transaction<T, E, F>(&self, f: F) -> Result<T, E>
67 where
68 F: FnOnce(&mut Self::TxQuery) -> Result<T, E>,
69 E: From<diesel::result::Error> + From<crate::ConnectionError> + std::error::Error,
70 {
71 let conn = &self.conn;
72
73 conn.raw_query_write(|c| Ok(c.immediate_transaction(|sqlite_c| f(sqlite_c))))?
89 }
90
91 fn savepoint<T, E, F>(&self, f: F) -> Result<T, E>
92 where
93 F: FnOnce(&mut Self::TxQuery) -> Result<T, E>,
94 E: From<diesel::result::Error> + From<crate::ConnectionError> + std::error::Error,
95 {
96 self.conn
97 .raw_query_write(|c| Ok(c.transaction(|sqlite_c| f(sqlite_c))))?
98 }
99
100 fn read<V: Entity<CURRENT_VERSION>>(
101 &self,
102 label: &[u8],
103 key: &[u8],
104 ) -> Result<Option<V>, SqlKeyStoreError> {
105 self.read(label, key)
106 }
107
108 fn read_list<V: Entity<CURRENT_VERSION>>(
109 &self,
110 label: &[u8],
111 key: &[u8],
112 ) -> Result<Vec<V>, <Self as StorageProvider<CURRENT_VERSION>>::Error> {
113 self.read_list(label, key)
114 }
115
116 fn delete(
117 &self,
118 label: &[u8],
119 key: &[u8],
120 ) -> Result<(), <Self as StorageProvider<CURRENT_VERSION>>::Error> {
121 self.delete::<CURRENT_VERSION>(label, key)
122 }
123
124 fn write(
125 &self,
126 label: &[u8],
127 key: &[u8],
128 value: &[u8],
129 ) -> Result<(), <Self as StorageProvider<CURRENT_VERSION>>::Error> {
130 self.write::<CURRENT_VERSION>(label, key, value)
131 }
132
133 #[cfg(feature = "test-utils")]
134 fn hash_all(&self) -> Result<Vec<u8>, SqlKeyStoreError> {
135 self.conn
136 .raw_query_read(OpenMlsKeyValue::hash_all)
137 .map_err(Into::into)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143
144 #![allow(unused)]
145
146 use crate::{
147 TestDb, XmtpTestDb,
148 group_intent::{IntentKind, IntentState, NewGroupIntent},
149 prelude::QueryGroupIntent,
150 };
151
152 use super::*;
153
154 struct Foo<C> {
160 key_store: SqlKeyStore<C>,
161 }
162
163 impl<C> Foo<C>
164 where
165 C: ConnectionExt,
166 {
167 async fn long_async_call(&self) {
168 xmtp_common::time::sleep(std::time::Duration::from_millis(10)).await;
169 }
170
171 async fn db_op(&self) {
172 self.long_async_call().await;
173
174 self.key_store
175 .transaction(|conn| {
176 let storage = conn.key_store();
177 storage.db().insert_group_intent(NewGroupIntent {
178 kind: IntentKind::SendMessage,
179 group_id: vec![],
180 data: vec![],
181 should_push: false,
182 state: IntentState::ToPublish,
183 })
184 })
185 .unwrap();
186 self.long_async_call().await;
187 }
188 }
189}