1use std::{
7 collections::{HashMap, HashSet},
8 fmt::{Debug, Write},
9};
10
11use prost::Message;
12use xmtp_db::association_state::StoredAssociationState;
13use xmtp_proto::{
14 ConversionError, xmtp::identity::associations::AssociationState as AssociationStateProto,
15};
16
17use super::{
18 AssociationError, MemberIdentifier, MemberKind, ident,
19 member::{Identifier, Member},
20};
21use crate::InboxIdRef;
22
23#[derive(Debug, Clone)]
24pub struct AssociationStateDiff {
25 pub new_members: Vec<MemberIdentifier>,
26 pub removed_members: Vec<MemberIdentifier>,
27}
28
29#[derive(Debug)]
30pub struct Installation {
31 pub id: Vec<u8>,
32 pub client_timestamp_ns: Option<u64>,
33}
34
35impl AssociationStateDiff {
36 pub fn new_installations(&self) -> Vec<Vec<u8>> {
37 self.new_members
38 .iter()
39 .filter_map(|member| match member {
40 MemberIdentifier::Installation(ident::Installation(key)) => Some(key.clone()),
41 _ => None,
42 })
43 .collect()
44 }
45
46 pub fn removed_installations(&self) -> Vec<Vec<u8>> {
47 self.removed_members
48 .iter()
49 .filter_map(|member| match member {
50 MemberIdentifier::Installation(ident::Installation(key)) => Some(key.clone()),
51 _ => None,
52 })
53 .collect()
54 }
55}
56
57#[derive(Clone)]
58pub struct AssociationState {
59 pub(crate) inbox_id: String,
60 pub(crate) members: HashMap<MemberIdentifier, Member>,
61 pub(crate) recovery_identifier: Identifier,
62 pub(crate) seen_signatures: HashSet<Vec<u8>>,
63}
64
65impl std::fmt::Debug for AssociationState {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 let mut members = String::new();
68 for member in self.members.keys() {
69 write!(members, "{:?}", member)?;
70 write!(members, ",")?;
71 }
72
73 let mut signatures = String::new();
74 for signature in self.seen_signatures.iter() {
75 write!(
76 signatures,
77 "{}",
78 xmtp_common::fmt::truncate_hex(hex::encode(signature))
79 )?;
80 write!(signatures, ",")?;
81 }
82
83 write!(
84 f,
85 "AssociationState {{ inbox_id: {}, members: {}, recovery: {}, seen_signatures: {} }}",
86 self.inbox_id, members, self.recovery_identifier, signatures
87 )
88 }
89}
90
91impl TryFrom<MemberIdentifier> for Identifier {
92 type Error = AssociationError;
93 fn try_from(ident: MemberIdentifier) -> Result<Self, Self::Error> {
94 let ident = match ident {
95 MemberIdentifier::Ethereum(eth) => Self::Ethereum(eth),
96 MemberIdentifier::Passkey(passkey) => Self::Passkey(passkey),
97 MemberIdentifier::Installation(_) => {
98 return Err(AssociationError::NotIdentifier(
99 "Installation Keys".to_string(),
100 ));
101 }
102 };
103 Ok(ident)
104 }
105}
106
107impl TryFrom<StoredAssociationState> for AssociationState {
108 type Error = ConversionError;
109
110 fn try_from(stored_state: StoredAssociationState) -> Result<Self, Self::Error> {
111 AssociationStateProto::decode(stored_state.state.as_slice())?.try_into()
112 }
113}
114
115impl AssociationState {
116 pub fn add(&self, member: Member) -> Self {
117 let mut new_state = self.clone();
118 let _ = new_state.members.insert(member.identifier.clone(), member);
119
120 new_state
121 }
122
123 pub fn remove(&self, identifier: &MemberIdentifier) -> Self {
124 let mut new_state = self.clone();
125 let _ = new_state.members.remove(identifier);
126
127 new_state
128 }
129
130 pub fn set_recovery_identifier(&self, recovery_identifier: Identifier) -> Self {
131 let mut new_state = self.clone();
132 new_state.recovery_identifier = recovery_identifier;
133
134 new_state
135 }
136
137 pub fn get(&self, identifier: &MemberIdentifier) -> Option<&Member> {
138 self.members.get(identifier)
139 }
140
141 pub fn add_seen_signatures(&self, signatures: Vec<Vec<u8>>) -> Self {
142 let mut new_state = self.clone();
143 new_state.seen_signatures.extend(signatures);
144
145 new_state
146 }
147
148 pub fn has_seen(&self, signature: &Vec<u8>) -> bool {
149 self.seen_signatures.contains(signature)
150 }
151
152 pub fn members(&self) -> Vec<Member> {
153 let mut sorted_members: Vec<_> = self.members.values().cloned().collect();
154 sorted_members.sort_by_key(|m| m.client_timestamp_ns.unwrap_or(u64::MAX));
155 sorted_members
156 }
157
158 pub fn inbox_id(&self) -> InboxIdRef<'_> {
159 &self.inbox_id
160 }
161
162 pub fn recovery_identifier(&self) -> &Identifier {
163 &self.recovery_identifier
164 }
165
166 pub fn members_by_parent(&self, parent_id: &MemberIdentifier) -> Vec<Member> {
167 self.members
168 .values()
169 .filter(|e| e.added_by_entity.eq(&Some(parent_id.clone())))
170 .cloned()
171 .collect()
172 }
173
174 pub fn members_by_kind(&self, kind: MemberKind) -> Vec<Member> {
175 self.members
176 .values()
177 .filter(|e| e.kind() == kind)
178 .cloned()
179 .collect()
180 }
181
182 pub fn identifiers(&self) -> Vec<Identifier> {
183 let mut address_members: Vec<_> = self.members.values().cloned().collect();
184
185 address_members.sort_by_key(|m| m.client_timestamp_ns.unwrap_or(u64::MAX));
186
187 address_members
188 .into_iter()
189 .filter_map(|member| match member.identifier {
190 MemberIdentifier::Ethereum(eth) => Some(Identifier::Ethereum(eth)),
191 MemberIdentifier::Passkey(pk) => Some(Identifier::Passkey(pk)),
192 _ => None,
193 })
194 .collect()
195 }
196
197 pub fn installation_ids(&self) -> Vec<Vec<u8>> {
198 self.members_by_kind(MemberKind::Installation)
199 .into_iter()
200 .filter_map(|member| match member.identifier {
201 MemberIdentifier::Installation(ident::Installation(key)) => Some(key),
202 _ => None,
203 })
204 .collect()
205 }
206
207 pub fn installations(&self) -> Vec<Installation> {
208 self.members()
209 .into_iter()
210 .filter_map(|member| match member.identifier {
211 MemberIdentifier::Installation(ident::Installation(id)) => Some(Installation {
212 id,
213 client_timestamp_ns: member.client_timestamp_ns,
214 }),
215 _ => None,
216 })
217 .collect()
218 }
219
220 pub fn diff(&self, new_state: &Self) -> AssociationStateDiff {
221 let new_members: Vec<MemberIdentifier> = new_state
222 .members
223 .keys()
224 .filter(|new_member_identifier| !self.members.contains_key(new_member_identifier))
225 .cloned()
226 .collect();
227
228 let removed_members: Vec<MemberIdentifier> = self
229 .members
230 .keys()
231 .filter(|existing_member_identifier| {
232 !new_state.members.contains_key(existing_member_identifier)
233 })
234 .cloned()
235 .collect();
236
237 AssociationStateDiff {
238 new_members,
239 removed_members,
240 }
241 }
242
243 pub fn as_diff(&self) -> AssociationStateDiff {
246 AssociationStateDiff {
247 new_members: self.members.keys().cloned().collect(),
248 removed_members: vec![],
249 }
250 }
251
252 pub fn new(
253 account_identifier: Identifier,
254 nonce: u64,
255 chain_id: Option<u64>,
256 ) -> Result<Self, AssociationError> {
257 let member_identifier: MemberIdentifier = account_identifier.clone().into();
258
259 let inbox_id = account_identifier.inbox_id(nonce)?;
260 let new_member = Member::new(member_identifier.clone(), None, None, chain_id);
261 Ok(Self {
262 members: HashMap::from_iter([(member_identifier, new_member)]),
263 seen_signatures: HashSet::new(),
264 recovery_identifier: account_identifier,
265 inbox_id,
266 })
267 }
268}
269
270#[cfg(test)]
271pub(crate) mod tests {
272 #[cfg(target_arch = "wasm32")]
273 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
274
275 use super::*;
276
277 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
278 #[cfg_attr(not(target_arch = "wasm32"), test)]
279 fn can_add_remove() {
280 let starting_state = AssociationState::new(Identifier::rand_ethereum(), 0, None).unwrap();
281 let new_entity = Member::default();
282 let with_add = starting_state.add(new_entity.clone());
283 assert!(with_add.get(&new_entity.identifier).is_some());
284 assert!(starting_state.get(&new_entity.identifier).is_none());
285 }
286
287 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
288 #[cfg_attr(not(target_arch = "wasm32"), test)]
289 fn can_diff() {
290 let starting_state = AssociationState::new(Identifier::rand_ethereum(), 0, None).unwrap();
291 let entity_1 = Member::default();
292 let entity_2 = Member::default();
293 let entity_3 = Member::default();
294
295 let state_1 = starting_state.add(entity_1.clone()).add(entity_2.clone());
296 let state_2 = state_1.remove(&entity_1.identifier).add(entity_3.clone());
297
298 let diff = state_1.diff(&state_2);
299
300 assert_eq!(diff.new_members, vec![entity_3.identifier]);
301 assert_eq!(diff.removed_members, vec![entity_1.identifier]);
302 }
303}