xmtp_id/associations/
state.rs

1//! [`AssociationState`] describes a single point in time for an Inbox where it contains a set of
2//! associated [`MemberIdentifier`]'s, which may be one of [`MemberKind::Address`]
3//! or[`MemberKind::Installation`]. A diff between two states can be calculated to determine
4//! a change of membership between two periods of time. [XIP-46](https://github.com/xmtp/XIPs/pull/53)
5
6use std::{
7    collections::{HashMap, HashSet},
8    fmt::{Debug, Write},
9};
10
11use prost::Message;
12use xmtp_db::association_state::StoredAssociationState;
13use xmtp_proto::{
14    ConversionError, xmtp::identity::associations::AssociationState as AssociationStateProto,
15};
16
17use super::{
18    AssociationError, MemberIdentifier, MemberKind, ident,
19    member::{Identifier, Member},
20};
21use crate::InboxIdRef;
22
23#[derive(Debug, Clone)]
24pub struct AssociationStateDiff {
25    pub new_members: Vec<MemberIdentifier>,
26    pub removed_members: Vec<MemberIdentifier>,
27}
28
29#[derive(Debug)]
30pub struct Installation {
31    pub id: Vec<u8>,
32    pub client_timestamp_ns: Option<u64>,
33}
34
35impl AssociationStateDiff {
36    pub fn new_installations(&self) -> Vec<Vec<u8>> {
37        self.new_members
38            .iter()
39            .filter_map(|member| match member {
40                MemberIdentifier::Installation(ident::Installation(key)) => Some(key.clone()),
41                _ => None,
42            })
43            .collect()
44    }
45
46    pub fn removed_installations(&self) -> Vec<Vec<u8>> {
47        self.removed_members
48            .iter()
49            .filter_map(|member| match member {
50                MemberIdentifier::Installation(ident::Installation(key)) => Some(key.clone()),
51                _ => None,
52            })
53            .collect()
54    }
55}
56
57#[derive(Clone)]
58pub struct AssociationState {
59    pub(crate) inbox_id: String,
60    pub(crate) members: HashMap<MemberIdentifier, Member>,
61    pub(crate) recovery_identifier: Identifier,
62    pub(crate) seen_signatures: HashSet<Vec<u8>>,
63}
64
65impl std::fmt::Debug for AssociationState {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        let mut members = String::new();
68        for member in self.members.keys() {
69            write!(members, "{:?}", member)?;
70            write!(members, ",")?;
71        }
72
73        let mut signatures = String::new();
74        for signature in self.seen_signatures.iter() {
75            write!(
76                signatures,
77                "{}",
78                xmtp_common::fmt::truncate_hex(hex::encode(signature))
79            )?;
80            write!(signatures, ",")?;
81        }
82
83        write!(
84            f,
85            "AssociationState {{ inbox_id: {}, members: {}, recovery: {}, seen_signatures: {} }}",
86            self.inbox_id, members, self.recovery_identifier, signatures
87        )
88    }
89}
90
91impl TryFrom<MemberIdentifier> for Identifier {
92    type Error = AssociationError;
93    fn try_from(ident: MemberIdentifier) -> Result<Self, Self::Error> {
94        let ident = match ident {
95            MemberIdentifier::Ethereum(eth) => Self::Ethereum(eth),
96            MemberIdentifier::Passkey(passkey) => Self::Passkey(passkey),
97            MemberIdentifier::Installation(_) => {
98                return Err(AssociationError::NotIdentifier(
99                    "Installation Keys".to_string(),
100                ));
101            }
102        };
103        Ok(ident)
104    }
105}
106
107impl TryFrom<StoredAssociationState> for AssociationState {
108    type Error = ConversionError;
109
110    fn try_from(stored_state: StoredAssociationState) -> Result<Self, Self::Error> {
111        AssociationStateProto::decode(stored_state.state.as_slice())?.try_into()
112    }
113}
114
115impl AssociationState {
116    pub fn add(&self, member: Member) -> Self {
117        let mut new_state = self.clone();
118        let _ = new_state.members.insert(member.identifier.clone(), member);
119
120        new_state
121    }
122
123    pub fn remove(&self, identifier: &MemberIdentifier) -> Self {
124        let mut new_state = self.clone();
125        let _ = new_state.members.remove(identifier);
126
127        new_state
128    }
129
130    pub fn set_recovery_identifier(&self, recovery_identifier: Identifier) -> Self {
131        let mut new_state = self.clone();
132        new_state.recovery_identifier = recovery_identifier;
133
134        new_state
135    }
136
137    pub fn get(&self, identifier: &MemberIdentifier) -> Option<&Member> {
138        self.members.get(identifier)
139    }
140
141    pub fn add_seen_signatures(&self, signatures: Vec<Vec<u8>>) -> Self {
142        let mut new_state = self.clone();
143        new_state.seen_signatures.extend(signatures);
144
145        new_state
146    }
147
148    pub fn has_seen(&self, signature: &Vec<u8>) -> bool {
149        self.seen_signatures.contains(signature)
150    }
151
152    pub fn members(&self) -> Vec<Member> {
153        let mut sorted_members: Vec<_> = self.members.values().cloned().collect();
154        sorted_members.sort_by_key(|m| m.client_timestamp_ns.unwrap_or(u64::MAX));
155        sorted_members
156    }
157
158    pub fn inbox_id(&self) -> InboxIdRef<'_> {
159        &self.inbox_id
160    }
161
162    pub fn recovery_identifier(&self) -> &Identifier {
163        &self.recovery_identifier
164    }
165
166    pub fn members_by_parent(&self, parent_id: &MemberIdentifier) -> Vec<Member> {
167        self.members
168            .values()
169            .filter(|e| e.added_by_entity.eq(&Some(parent_id.clone())))
170            .cloned()
171            .collect()
172    }
173
174    pub fn members_by_kind(&self, kind: MemberKind) -> Vec<Member> {
175        self.members
176            .values()
177            .filter(|e| e.kind() == kind)
178            .cloned()
179            .collect()
180    }
181
182    pub fn identifiers(&self) -> Vec<Identifier> {
183        let mut address_members: Vec<_> = self.members.values().cloned().collect();
184
185        address_members.sort_by_key(|m| m.client_timestamp_ns.unwrap_or(u64::MAX));
186
187        address_members
188            .into_iter()
189            .filter_map(|member| match member.identifier {
190                MemberIdentifier::Ethereum(eth) => Some(Identifier::Ethereum(eth)),
191                MemberIdentifier::Passkey(pk) => Some(Identifier::Passkey(pk)),
192                _ => None,
193            })
194            .collect()
195    }
196
197    pub fn installation_ids(&self) -> Vec<Vec<u8>> {
198        self.members_by_kind(MemberKind::Installation)
199            .into_iter()
200            .filter_map(|member| match member.identifier {
201                MemberIdentifier::Installation(ident::Installation(key)) => Some(key),
202                _ => None,
203            })
204            .collect()
205    }
206
207    pub fn installations(&self) -> Vec<Installation> {
208        self.members()
209            .into_iter()
210            .filter_map(|member| match member.identifier {
211                MemberIdentifier::Installation(ident::Installation(id)) => Some(Installation {
212                    id,
213                    client_timestamp_ns: member.client_timestamp_ns,
214                }),
215                _ => None,
216            })
217            .collect()
218    }
219
220    pub fn diff(&self, new_state: &Self) -> AssociationStateDiff {
221        let new_members: Vec<MemberIdentifier> = new_state
222            .members
223            .keys()
224            .filter(|new_member_identifier| !self.members.contains_key(new_member_identifier))
225            .cloned()
226            .collect();
227
228        let removed_members: Vec<MemberIdentifier> = self
229            .members
230            .keys()
231            .filter(|existing_member_identifier| {
232                !new_state.members.contains_key(existing_member_identifier)
233            })
234            .cloned()
235            .collect();
236
237        AssociationStateDiff {
238            new_members,
239            removed_members,
240        }
241    }
242
243    /// Converts the [`AssociationState`] to a diff that represents all members
244    /// of the inbox at the current state.
245    pub fn as_diff(&self) -> AssociationStateDiff {
246        AssociationStateDiff {
247            new_members: self.members.keys().cloned().collect(),
248            removed_members: vec![],
249        }
250    }
251
252    pub fn new(
253        account_identifier: Identifier,
254        nonce: u64,
255        chain_id: Option<u64>,
256    ) -> Result<Self, AssociationError> {
257        let member_identifier: MemberIdentifier = account_identifier.clone().into();
258
259        let inbox_id = account_identifier.inbox_id(nonce)?;
260        let new_member = Member::new(member_identifier.clone(), None, None, chain_id);
261        Ok(Self {
262            members: HashMap::from_iter([(member_identifier, new_member)]),
263            seen_signatures: HashSet::new(),
264            recovery_identifier: account_identifier,
265            inbox_id,
266        })
267    }
268}
269
270#[cfg(test)]
271pub(crate) mod tests {
272    #[cfg(target_arch = "wasm32")]
273    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
274
275    use super::*;
276
277    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
278    #[cfg_attr(not(target_arch = "wasm32"), test)]
279    fn can_add_remove() {
280        let starting_state = AssociationState::new(Identifier::rand_ethereum(), 0, None).unwrap();
281        let new_entity = Member::default();
282        let with_add = starting_state.add(new_entity.clone());
283        assert!(with_add.get(&new_entity.identifier).is_some());
284        assert!(starting_state.get(&new_entity.identifier).is_none());
285    }
286
287    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
288    #[cfg_attr(not(target_arch = "wasm32"), test)]
289    fn can_diff() {
290        let starting_state = AssociationState::new(Identifier::rand_ethereum(), 0, None).unwrap();
291        let entity_1 = Member::default();
292        let entity_2 = Member::default();
293        let entity_3 = Member::default();
294
295        let state_1 = starting_state.add(entity_1.clone()).add(entity_2.clone());
296        let state_2 = state_1.remove(&entity_1.identifier).add(entity_3.clone());
297
298        let diff = state_1.diff(&state_2);
299
300        assert_eq!(diff.new_members, vec![entity_3.identifier]);
301        assert_eq!(diff.removed_members, vec![entity_1.identifier]);
302    }
303}