1use diesel::prelude::*;
2
3use super::schema::association_state::{self, dsl};
4use crate::ConnectionExt;
5use crate::DbConnection;
6use crate::{Fetch, StorageError, StoreOrIgnore, impl_fetch, impl_store_or_ignore};
7use prost::Message;
8use xmtp_proto::xmtp::identity::associations::AssociationState as AssociationStateProto;
9
10#[derive(Insertable, Identifiable, Queryable, Debug, Clone, PartialEq, Eq)]
12#[diesel(table_name = association_state)]
13#[diesel(primary_key(inbox_id, sequence_id))]
14pub struct StoredAssociationState {
15 pub inbox_id: String,
16 pub sequence_id: i64,
17 pub state: Vec<u8>,
18}
19impl_fetch!(StoredAssociationState, association_state, (String, i64));
20impl_store_or_ignore!(StoredAssociationState, association_state);
21
22pub trait QueryAssociationStateCache {
23 fn write_to_cache(
24 &self,
25 inbox_id: String,
26 sequence_id: i64,
27 state: AssociationStateProto,
28 ) -> Result<(), StorageError>;
29
30 fn read_from_cache<A: AsRef<str>>(
31 &self,
32 inbox_id: A,
33 sequence_id: i64,
34 ) -> Result<Option<AssociationStateProto>, StorageError>;
35
36 fn batch_read_from_cache(
37 &self,
38 identifiers: Vec<(String, i64)>,
39 ) -> Result<Vec<AssociationStateProto>, StorageError>;
40}
41
42impl<R> QueryAssociationStateCache for &R
43where
44 R: QueryAssociationStateCache,
45{
46 fn write_to_cache(
47 &self,
48 inbox_id: String,
49 sequence_id: i64,
50 state: AssociationStateProto,
51 ) -> Result<(), StorageError> {
52 (**self).write_to_cache(inbox_id, sequence_id, state)
53 }
54
55 fn read_from_cache<A: AsRef<str>>(
56 &self,
57 inbox_id: A,
58 sequence_id: i64,
59 ) -> Result<Option<AssociationStateProto>, StorageError> {
60 (**self).read_from_cache(inbox_id, sequence_id)
61 }
62
63 fn batch_read_from_cache(
64 &self,
65 identifiers: Vec<(String, i64)>,
66 ) -> Result<Vec<AssociationStateProto>, StorageError> {
67 (**self).batch_read_from_cache(identifiers)
68 }
69}
70
71impl<C: ConnectionExt> QueryAssociationStateCache for DbConnection<C> {
72 fn write_to_cache(
73 &self,
74 inbox_id: String,
75 sequence_id: i64,
76 state: AssociationStateProto,
77 ) -> Result<(), StorageError> {
78 let result = StoredAssociationState {
79 inbox_id: inbox_id.clone(),
80 sequence_id,
81 state: state.encode_to_vec(),
82 }
83 .store_or_ignore(self);
84
85 if result.is_ok() {
86 tracing::debug!(
87 "Wrote association state to cache: {} {}",
88 inbox_id,
89 sequence_id
90 );
91 }
92
93 result
94 }
95
96 fn read_from_cache<A: AsRef<str>>(
97 &self,
98 inbox_id: A,
99 sequence_id: i64,
100 ) -> Result<Option<AssociationStateProto>, StorageError> {
101 let inbox_id = inbox_id.as_ref();
102 let stored_state: Option<StoredAssociationState> =
103 self.fetch(&(inbox_id.to_string(), sequence_id))?;
104
105 let result = stored_state
106 .map(|stored_state| stored_state.state)
107 .inspect(|_| {
108 tracing::debug!(
109 "Loaded association state from cache: {} {}",
110 inbox_id,
111 sequence_id
112 )
113 });
114 Ok(result
115 .map(|r| AssociationStateProto::decode(r.as_slice()))
116 .transpose()?)
117 }
118
119 fn batch_read_from_cache(
120 &self,
121 identifiers: Vec<(String, i64)>,
122 ) -> Result<Vec<AssociationStateProto>, StorageError> {
123 if identifiers.is_empty() {
124 return Ok(vec![]);
125 }
126
127 let (inbox_ids, sequence_ids): (Vec<String>, Vec<i64>) = identifiers.into_iter().unzip();
128
129 let query = dsl::association_state
130 .select((dsl::inbox_id, dsl::sequence_id, dsl::state))
131 .filter(
132 dsl::inbox_id
133 .eq_any(inbox_ids)
134 .and(dsl::sequence_id.eq_any(sequence_ids)),
135 );
136
137 let association_states =
138 self.raw_query_read(|query_conn| query.load::<StoredAssociationState>(query_conn))?;
139
140 association_states
141 .into_iter()
142 .map(|stored_association_state| {
143 Ok(AssociationStateProto::decode(
144 stored_association_state.state.as_slice(),
145 )?)
146 })
147 .collect::<Result<Vec<_>, _>>()
148 }
149}
150
151#[cfg(test)]
152pub(crate) mod tests {
153 #[cfg(target_arch = "wasm32")]
154 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
155
156 use super::*;
157 use crate::test_utils::with_connection;
158 use serde::{Deserialize, Serialize};
159 use xmtp_proto::xmtp::identity::associations::AssociationState as AssociationStateProto;
160
161 #[derive(Serialize, Deserialize)]
162 pub struct MockState {
163 inbox_id: String,
164 }
165 impl From<StoredAssociationState> for MockState {
166 fn from(v: StoredAssociationState) -> MockState {
167 crate::db_deserialize(&v.state).unwrap()
168 }
169 }
170 impl From<AssociationStateProto> for MockState {
171 fn from(v: AssociationStateProto) -> Self {
172 MockState {
173 inbox_id: v.inbox_id,
174 }
175 }
176 }
177
178 #[xmtp_common::test]
179 fn test_batch_read() {
180 with_connection(|conn| {
181 let mock = AssociationStateProto {
182 inbox_id: "test_id1".into(),
183 members: vec![],
184 ..Default::default()
185 };
186 conn.write_to_cache(mock.inbox_id.clone(), 1, mock.clone())
187 .unwrap();
188 let mock_2 = AssociationStateProto {
189 inbox_id: "test_id2".into(),
190 members: vec![],
191 ..Default::default()
192 };
193
194 conn.write_to_cache(mock_2.inbox_id.clone(), 2, mock_2.clone())
195 .unwrap();
196
197 let first_association_state: Vec<MockState> = conn
198 .batch_read_from_cache(vec![(mock.inbox_id.to_string(), 1)])
199 .unwrap()
200 .into_iter()
201 .map(Into::into)
202 .collect();
203 assert_eq!(first_association_state.len(), 1);
204 assert_eq!(&first_association_state[0].inbox_id, &mock.inbox_id);
205
206 let both_association_states: Vec<MockState> = conn
207 .batch_read_from_cache(vec![
208 (mock.inbox_id.clone(), 1),
209 (mock_2.inbox_id.clone(), 2),
210 ])
211 .unwrap()
212 .into_iter()
213 .map(Into::into)
214 .collect();
215
216 assert_eq!(both_association_states.len(), 2);
217
218 let no_results = conn
219 .batch_read_from_cache(vec![(mock.inbox_id.clone(), 2)])
220 .unwrap()
221 .into_iter()
222 .map(Into::into)
223 .collect::<Vec<MockState>>();
224 assert_eq!(no_results.len(), 0);
225 })
226 }
227}