xmtp_db/encrypted_store/
association_state.rs

1use diesel::prelude::*;
2
3use super::schema::association_state::{self, dsl};
4use crate::ConnectionExt;
5use crate::DbConnection;
6use crate::{Fetch, StorageError, StoreOrIgnore, impl_fetch, impl_store_or_ignore};
7use prost::Message;
8use xmtp_proto::xmtp::identity::associations::AssociationState as AssociationStateProto;
9
10/// StoredIdentityUpdate holds a serialized IdentityUpdate record
11#[derive(Insertable, Identifiable, Queryable, Debug, Clone, PartialEq, Eq)]
12#[diesel(table_name = association_state)]
13#[diesel(primary_key(inbox_id, sequence_id))]
14pub struct StoredAssociationState {
15    pub inbox_id: String,
16    pub sequence_id: i64,
17    pub state: Vec<u8>,
18}
19impl_fetch!(StoredAssociationState, association_state, (String, i64));
20impl_store_or_ignore!(StoredAssociationState, association_state);
21
22pub trait QueryAssociationStateCache {
23    fn write_to_cache(
24        &self,
25        inbox_id: String,
26        sequence_id: i64,
27        state: AssociationStateProto,
28    ) -> Result<(), StorageError>;
29
30    fn read_from_cache<A: AsRef<str>>(
31        &self,
32        inbox_id: A,
33        sequence_id: i64,
34    ) -> Result<Option<AssociationStateProto>, StorageError>;
35
36    fn batch_read_from_cache(
37        &self,
38        identifiers: Vec<(String, i64)>,
39    ) -> Result<Vec<AssociationStateProto>, StorageError>;
40}
41
42impl<R> QueryAssociationStateCache for &R
43where
44    R: QueryAssociationStateCache,
45{
46    fn write_to_cache(
47        &self,
48        inbox_id: String,
49        sequence_id: i64,
50        state: AssociationStateProto,
51    ) -> Result<(), StorageError> {
52        (**self).write_to_cache(inbox_id, sequence_id, state)
53    }
54
55    fn read_from_cache<A: AsRef<str>>(
56        &self,
57        inbox_id: A,
58        sequence_id: i64,
59    ) -> Result<Option<AssociationStateProto>, StorageError> {
60        (**self).read_from_cache(inbox_id, sequence_id)
61    }
62
63    fn batch_read_from_cache(
64        &self,
65        identifiers: Vec<(String, i64)>,
66    ) -> Result<Vec<AssociationStateProto>, StorageError> {
67        (**self).batch_read_from_cache(identifiers)
68    }
69}
70
71impl<C: ConnectionExt> QueryAssociationStateCache for DbConnection<C> {
72    fn write_to_cache(
73        &self,
74        inbox_id: String,
75        sequence_id: i64,
76        state: AssociationStateProto,
77    ) -> Result<(), StorageError> {
78        let result = StoredAssociationState {
79            inbox_id: inbox_id.clone(),
80            sequence_id,
81            state: state.encode_to_vec(),
82        }
83        .store_or_ignore(self);
84
85        if result.is_ok() {
86            tracing::debug!(
87                "Wrote association state to cache: {} {}",
88                inbox_id,
89                sequence_id
90            );
91        }
92
93        result
94    }
95
96    fn read_from_cache<A: AsRef<str>>(
97        &self,
98        inbox_id: A,
99        sequence_id: i64,
100    ) -> Result<Option<AssociationStateProto>, StorageError> {
101        let inbox_id = inbox_id.as_ref();
102        let stored_state: Option<StoredAssociationState> =
103            self.fetch(&(inbox_id.to_string(), sequence_id))?;
104
105        let result = stored_state
106            .map(|stored_state| stored_state.state)
107            .inspect(|_| {
108                tracing::debug!(
109                    "Loaded association state from cache: {} {}",
110                    inbox_id,
111                    sequence_id
112                )
113            });
114        Ok(result
115            .map(|r| AssociationStateProto::decode(r.as_slice()))
116            .transpose()?)
117    }
118
119    fn batch_read_from_cache(
120        &self,
121        identifiers: Vec<(String, i64)>,
122    ) -> Result<Vec<AssociationStateProto>, StorageError> {
123        if identifiers.is_empty() {
124            return Ok(vec![]);
125        }
126
127        let (inbox_ids, sequence_ids): (Vec<String>, Vec<i64>) = identifiers.into_iter().unzip();
128
129        let query = dsl::association_state
130            .select((dsl::inbox_id, dsl::sequence_id, dsl::state))
131            .filter(
132                dsl::inbox_id
133                    .eq_any(inbox_ids)
134                    .and(dsl::sequence_id.eq_any(sequence_ids)),
135            );
136
137        let association_states =
138            self.raw_query_read(|query_conn| query.load::<StoredAssociationState>(query_conn))?;
139
140        association_states
141            .into_iter()
142            .map(|stored_association_state| {
143                Ok(AssociationStateProto::decode(
144                    stored_association_state.state.as_slice(),
145                )?)
146            })
147            .collect::<Result<Vec<_>, _>>()
148    }
149}
150
151#[cfg(test)]
152pub(crate) mod tests {
153    #[cfg(target_arch = "wasm32")]
154    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
155
156    use super::*;
157    use crate::test_utils::with_connection;
158    use serde::{Deserialize, Serialize};
159    use xmtp_proto::xmtp::identity::associations::AssociationState as AssociationStateProto;
160
161    #[derive(Serialize, Deserialize)]
162    pub struct MockState {
163        inbox_id: String,
164    }
165    impl From<StoredAssociationState> for MockState {
166        fn from(v: StoredAssociationState) -> MockState {
167            crate::db_deserialize(&v.state).unwrap()
168        }
169    }
170    impl From<AssociationStateProto> for MockState {
171        fn from(v: AssociationStateProto) -> Self {
172            MockState {
173                inbox_id: v.inbox_id,
174            }
175        }
176    }
177
178    #[xmtp_common::test]
179    fn test_batch_read() {
180        with_connection(|conn| {
181            let mock = AssociationStateProto {
182                inbox_id: "test_id1".into(),
183                members: vec![],
184                ..Default::default()
185            };
186            conn.write_to_cache(mock.inbox_id.clone(), 1, mock.clone())
187                .unwrap();
188            let mock_2 = AssociationStateProto {
189                inbox_id: "test_id2".into(),
190                members: vec![],
191                ..Default::default()
192            };
193
194            conn.write_to_cache(mock_2.inbox_id.clone(), 2, mock_2.clone())
195                .unwrap();
196
197            let first_association_state: Vec<MockState> = conn
198                .batch_read_from_cache(vec![(mock.inbox_id.to_string(), 1)])
199                .unwrap()
200                .into_iter()
201                .map(Into::into)
202                .collect();
203            assert_eq!(first_association_state.len(), 1);
204            assert_eq!(&first_association_state[0].inbox_id, &mock.inbox_id);
205
206            let both_association_states: Vec<MockState> = conn
207                .batch_read_from_cache(vec![
208                    (mock.inbox_id.clone(), 1),
209                    (mock_2.inbox_id.clone(), 2),
210                ])
211                .unwrap()
212                .into_iter()
213                .map(Into::into)
214                .collect();
215
216            assert_eq!(both_association_states.len(), 2);
217
218            let no_results = conn
219                .batch_read_from_cache(vec![(mock.inbox_id.clone(), 2)])
220                .unwrap()
221                .into_iter()
222                .map(Into::into)
223                .collect::<Vec<MockState>>();
224            assert_eq!(no_results.len(), 0);
225        })
226    }
227}