1use super::schema::identity_cache;
2use super::{ConnectionExt, Sqlite};
3use crate::{DbConnection, StorageError};
4use crate::{Store, impl_fetch, impl_store};
5use diesel::backend::Backend;
6use diesel::deserialize::{self, FromSql, FromSqlRow};
7use diesel::expression::AsExpression;
8use diesel::serialize::{IsNull, Output, ToSql};
9use diesel::sql_types::Integer;
10use diesel::{Insertable, Queryable};
11use diesel::{prelude::*, serialize};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15#[derive(Insertable, Queryable, Debug, Clone, Deserialize, Serialize)]
16#[diesel(table_name = identity_cache)]
17#[diesel()]
18pub struct IdentityCache {
19 inbox_id: String,
20 identity: String,
21 identity_kind: StoredIdentityKind,
22}
23
24#[repr(i32)]
25#[derive(Debug, Copy, Clone, Serialize, Deserialize, Eq, PartialEq, AsExpression, FromSqlRow)]
26#[diesel(sql_type = Integer)]
27pub enum StoredIdentityKind {
29 Ethereum = 1,
30 Passkey = 2,
31}
32
33impl_store!(IdentityCache, identity_cache);
34impl_fetch!(IdentityCache, identity_cache);
35
36pub trait QueryIdentityCache {
37 fn fetch_cached_inbox_ids<T>(
39 &self,
40 identifiers: &[T],
41 ) -> Result<HashMap<String, String>, StorageError>
42 where
43 T: std::fmt::Display,
44 for<'a> &'a T: Into<StoredIdentityKind>;
45
46 fn cache_inbox_id<T, S>(&self, identifier: &T, inbox_id: S) -> Result<(), StorageError>
47 where
48 T: std::fmt::Display,
49 S: ToString,
50 for<'a> &'a T: Into<StoredIdentityKind>;
51}
52
53impl<G> QueryIdentityCache for &G
54where
55 G: QueryIdentityCache,
56{
57 fn fetch_cached_inbox_ids<T>(
58 &self,
59 identifiers: &[T],
60 ) -> Result<HashMap<String, String>, StorageError>
61 where
62 T: std::fmt::Display,
63 for<'a> &'a T: Into<StoredIdentityKind>,
64 {
65 (**self).fetch_cached_inbox_ids(identifiers)
66 }
67
68 fn cache_inbox_id<T, S>(&self, identifier: &T, inbox_id: S) -> Result<(), StorageError>
69 where
70 T: std::fmt::Display,
71 S: ToString,
72 for<'a> &'a T: Into<StoredIdentityKind>,
73 {
74 (**self).cache_inbox_id(identifier, inbox_id)
75 }
76}
77
78impl<C: ConnectionExt> QueryIdentityCache for DbConnection<C> {
79 fn fetch_cached_inbox_ids<T>(
81 &self,
82 identifiers: &[T],
83 ) -> Result<HashMap<String, String>, StorageError>
84 where
85 T: std::fmt::Display,
86 for<'a> &'a T: Into<StoredIdentityKind>,
87 {
88 use crate::encrypted_store::schema::identity_cache::*;
89
90 let mut conditions = identity_cache::table.into_boxed();
91
92 for ident in identifiers {
93 let addr = (&ident).to_string();
94 let kind: StoredIdentityKind = ident.into();
95 let cond = identity.eq(addr).and(identity_kind.eq(kind));
96 conditions = conditions.or_filter(cond);
97 }
98
99 let result = self
100 .raw_query_read(|conn| conditions.load::<IdentityCache>(conn))?
101 .into_iter()
102 .map(|entry| (entry.identity, entry.inbox_id))
103 .collect();
104 Ok(result)
105 }
106
107 fn cache_inbox_id<T, S>(&self, identifier: &T, inbox_id: S) -> Result<(), StorageError>
108 where
109 T: std::fmt::Display,
110 S: ToString,
111 for<'a> &'a T: Into<StoredIdentityKind>,
112 {
113 IdentityCache {
114 inbox_id: inbox_id.to_string(),
115 identity: identifier.to_string(),
116 identity_kind: identifier.into(),
117 }
118 .store(self)
119 }
120}
121
122impl ToSql<Integer, Sqlite> for StoredIdentityKind
123where
124 i32: ToSql<Integer, Sqlite>,
125{
126 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result {
127 out.set_value(*self as i32);
128 Ok(IsNull::No)
129 }
130}
131
132impl FromSql<Integer, Sqlite> for StoredIdentityKind
133where
134 i32: FromSql<Integer, Sqlite>,
135{
136 fn from_sql(bytes: <Sqlite as Backend>::RawValue<'_>) -> deserialize::Result<Self> {
137 match i32::from_sql(bytes)? {
138 1 => Ok(Self::Ethereum),
139 2 => Ok(Self::Passkey),
140 x => Err(format!("Unrecognized variant {}", x).into()),
141 }
142 }
143}
144
145#[cfg(test)]
146pub(crate) mod tests {
147 use super::IdentityCache;
148 use crate::{
149 Store, identity_cache::StoredIdentityKind, prelude::*, test_utils::with_connection,
150 };
151
152 #[derive(Clone)]
153 struct MockIdentity {
154 identity: String,
155 kind: u8,
156 inbox_id: String,
157 }
158
159 impl MockIdentity {
160 fn create(kind: u8) -> Self {
161 Self {
162 identity: xmtp_common::rand_hexstring(),
163 inbox_id: xmtp_common::rand_string::<32>(),
164 kind,
165 }
166 }
167 }
168
169 impl<'a> From<&'a MockIdentity> for StoredIdentityKind {
170 fn from(identity: &'a MockIdentity) -> StoredIdentityKind {
171 match identity.kind {
172 0 => StoredIdentityKind::Ethereum,
173 1 => StoredIdentityKind::Ethereum,
174 2 => StoredIdentityKind::Passkey,
175 _ => panic!("unknown kind"),
176 }
177 }
178 }
179
180 impl std::fmt::Display for MockIdentity {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 f.write_str(&self.identity)
183 }
184 }
185
186 #[xmtp_common::test]
188 fn test_store_duplicated_wallets() {
189 with_connection(|conn| {
190 let entry1 = IdentityCache {
191 inbox_id: "test_dup".to_string(),
192 identity: "wallet_dup".to_string(),
193 identity_kind: StoredIdentityKind::Ethereum,
194 };
195 let entry2 = IdentityCache {
196 inbox_id: "test_dup".to_string(),
197 identity: "wallet_dup".to_string(),
198 identity_kind: StoredIdentityKind::Ethereum,
199 };
200 entry1.store(conn).expect("Failed to store wallet");
201 let result = entry2.store(conn);
202 assert!(
203 result.is_err(),
204 "Duplicated wallet stored without error, expected failure"
205 );
206 })
207 }
208
209 #[xmtp_common::test]
212 fn test_fetch_and_store_identity_cache() {
213 with_connection(|conn| {
214 let ident1 = MockIdentity::create(0);
215 let ident2 = MockIdentity::create(0);
216
217 conn.cache_inbox_id(&ident1, &ident1.inbox_id).unwrap();
218 let idents = &[ident1.clone(), ident2];
219 let stored_wallets = conn.fetch_cached_inbox_ids(idents).unwrap();
220
221 assert_eq!(stored_wallets.len(), 1);
223
224 let cached_inbox_id = stored_wallets.get(&format!("{}", idents[0])).unwrap();
226 assert_eq!(*cached_inbox_id, ident1.inbox_id);
227
228 let non_existent_wallets = conn
230 .fetch_cached_inbox_ids(&[MockIdentity::create(1)])
231 .unwrap_or_default();
232 assert!(
233 non_existent_wallets.is_empty(),
234 "Expected no wallets, found some"
235 );
236 })
237 }
238}