xmtp_db/encrypted_store/
refresh_state.rs

1use std::collections::HashMap;
2
3use diesel::{
4    backend::Backend,
5    connection::DefaultLoadingMode,
6    deserialize::{self, FromSql, FromSqlRow},
7    expression::AsExpression,
8    prelude::*,
9    serialize::{self, IsNull, Output, ToSql},
10    sql_types::{BigInt, Binary, Integer},
11};
12use itertools::Itertools;
13use xmtp_configuration::Originators;
14use xmtp_proto::types::{Cursor, GlobalCursor, OriginatorId, Topic, TopicKind};
15
16use super::{ConnectionExt, Sqlite, db_connection::DbConnection, schema::refresh_state};
17use crate::{StorageError, StoreOrIgnore, impl_store_or_ignore};
18
19allow_columns_to_appear_in_same_group_by_clause!(
20    super::schema::identity_updates::originator_id,
21    super::schema::identity_updates::sequence_id,
22    super::schema::refresh_state::originator_id,
23    super::schema::refresh_state::sequence_id
24);
25
26#[repr(i32)]
27#[derive(Debug, Clone, Copy, PartialEq, Eq, AsExpression, Hash, FromSqlRow)]
28#[diesel(sql_type = Integer)]
29pub enum EntityKind {
30    Welcome = 1,
31    ApplicationMessage = 2,       // Application messages (originator 10)
32    CommitLogUpload = 3, // Rowid of the last local entry we uploaded to the remote commit log
33    CommitLogDownload = 4, // Server log sequence id of last remote entry we downloaded from the remote commit log
34    CommitLogForkCheckLocal = 5, // Last rowid verified in local commit log
35    CommitLogForkCheckRemote = 6, // Last rowid verified in remote commit log
36    CommitMessage = 7,     // MLS commit messages (originator 0)
37}
38
39pub trait HasEntityKind {
40    fn entity_kind(&self) -> EntityKind;
41}
42
43impl HasEntityKind for xmtp_proto::types::GroupMessage {
44    fn entity_kind(&self) -> EntityKind {
45        if self.is_commit() {
46            EntityKind::CommitMessage
47        } else {
48            EntityKind::ApplicationMessage
49        }
50    }
51}
52
53impl HasEntityKind for xmtp_proto::types::WelcomeMessage {
54    fn entity_kind(&self) -> EntityKind {
55        EntityKind::Welcome
56    }
57}
58
59impl std::fmt::Display for EntityKind {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        use EntityKind::*;
62        match self {
63            Welcome => write!(f, "welcome"),
64            ApplicationMessage => write!(f, "group"),
65            CommitLogUpload => write!(f, "commit_log_upload"),
66            CommitLogDownload => write!(f, "commit_log_download"),
67            CommitLogForkCheckLocal => write!(f, "commit_log_fork_check_local"),
68            CommitLogForkCheckRemote => write!(f, "commit_log_fork_check_remote"),
69            CommitMessage => write!(f, "commit_message"),
70        }
71    }
72}
73
74impl ToSql<Integer, Sqlite> for EntityKind
75where
76    i32: ToSql<Integer, Sqlite>,
77{
78    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result {
79        out.set_value(*self as i32);
80        Ok(IsNull::No)
81    }
82}
83
84impl FromSql<Integer, Sqlite> for EntityKind
85where
86    i32: FromSql<Integer, Sqlite>,
87{
88    fn from_sql(bytes: <Sqlite as Backend>::RawValue<'_>) -> deserialize::Result<Self> {
89        match i32::from_sql(bytes)? {
90            1 => Ok(EntityKind::Welcome),
91            2 => Ok(EntityKind::ApplicationMessage),
92            3 => Ok(EntityKind::CommitLogUpload),
93            4 => Ok(EntityKind::CommitLogDownload),
94            5 => Ok(EntityKind::CommitLogForkCheckLocal),
95            6 => Ok(EntityKind::CommitLogForkCheckRemote),
96            7 => Ok(EntityKind::CommitMessage),
97            x => Err(format!("Unrecognized variant {}", x).into()),
98        }
99    }
100}
101
102#[derive(Insertable, Identifiable, Queryable, Debug, Clone)]
103#[diesel(table_name = refresh_state)]
104#[diesel(primary_key(entity_id, entity_kind, originator_id))]
105pub struct RefreshState {
106    pub entity_id: Vec<u8>,
107    pub entity_kind: EntityKind,
108    pub sequence_id: i64,
109    pub originator_id: i32,
110}
111
112impl_store_or_ignore!(RefreshState, refresh_state);
113
114#[derive(QueryableByName, Selectable)]
115#[diesel(check_for_backend(Sqlite), table_name = super::schema::refresh_state)]
116struct SingleCursor {
117    #[diesel(sql_type = Integer)]
118    originator_id: i32,
119    #[diesel(sql_type = BigInt)]
120    sequence_id: i64,
121}
122
123/// Helper function to convert rows of (entity_id, originator_id, sequence_id) into a HashMap
124/// where each entity_id maps to a GlobalCursor containing all its originator->sequence_id pairs.
125/// Null sequence_id values are coalesced to 0.
126fn rows_to_global_cursor_map(
127    rows: Vec<(Vec<u8>, i32, Option<i64>)>,
128) -> HashMap<Vec<u8>, GlobalCursor> {
129    let mut map: HashMap<Vec<u8>, GlobalCursor> = HashMap::new();
130
131    for (entity_id, originator_id, sequence_id) in rows {
132        let cursors = map.entry(entity_id).or_default();
133        let originator_id_u32 = originator_id as u32;
134        let sequence_id_u64 = sequence_id.unwrap_or(0) as u64;
135
136        cursors.insert(originator_id_u32, sequence_id_u64);
137    }
138
139    map
140}
141
142pub trait QueryRefreshState {
143    fn get_refresh_state<EntityId: AsRef<[u8]>>(
144        &self,
145        entity_id: EntityId,
146        entity_kind: EntityKind,
147        originator_id: u32,
148    ) -> Result<Option<RefreshState>, StorageError>;
149
150    fn get_last_cursor_for_originators<Id: AsRef<[u8]>>(
151        &self,
152        id: Id,
153        entity_kind: EntityKind,
154        originator_ids: &[u32],
155    ) -> Result<Vec<Cursor>, StorageError>;
156
157    fn get_last_cursor_for_originator<Id: AsRef<[u8]>>(
158        &self,
159        id: Id,
160        entity_kind: EntityKind,
161        originator_id: u32,
162    ) -> Result<Cursor, StorageError> {
163        // get_last_cursor guaranteed to return entry for id
164        self.get_last_cursor_for_originators(id, entity_kind, &[originator_id])
165            .map(|c| c[0])
166    }
167
168    fn get_last_cursor_for_ids<Id: AsRef<[u8]>>(
169        &self,
170        ids: &[Id],
171        entities: &[EntityKind],
172    ) -> Result<HashMap<Vec<u8>, GlobalCursor>, StorageError>;
173
174    fn update_cursor<Id: AsRef<[u8]>>(
175        &self,
176        entity_id: Id,
177        entity_kind: EntityKind,
178        cursor: Cursor,
179    ) -> Result<bool, StorageError>;
180
181    fn lowest_common_cursor(&self, topics: &[&Topic]) -> Result<GlobalCursor, StorageError>;
182
183    fn lowest_common_cursor_combined(
184        &self,
185        topics: &[&Topic],
186    ) -> Result<GlobalCursor, StorageError>;
187
188    fn latest_cursor_for_id<Id: AsRef<[u8]>>(
189        &self,
190        entity_id: Id,
191        entities: &[EntityKind],
192        originators: Option<&[&OriginatorId]>,
193    ) -> Result<GlobalCursor, StorageError>;
194
195    fn latest_cursor_combined<Id: AsRef<[u8]>>(
196        &self,
197        entity_id: Id,
198        entities: &[EntityKind],
199        originators: Option<&[&OriginatorId]>,
200    ) -> Result<GlobalCursor, StorageError>;
201
202    fn get_remote_log_cursors(
203        &self,
204        conversation_ids: &[&Vec<u8>],
205    ) -> Result<HashMap<Vec<u8>, Cursor>, crate::ConnectionError>;
206}
207
208impl<T: QueryRefreshState> QueryRefreshState for &'_ T {
209    fn get_refresh_state<EntityId: AsRef<[u8]>>(
210        &self,
211        entity_id: EntityId,
212        entity_kind: EntityKind,
213        originator: u32,
214    ) -> Result<Option<RefreshState>, StorageError> {
215        (**self).get_refresh_state(entity_id, entity_kind, originator)
216    }
217
218    fn get_last_cursor_for_ids<Id: AsRef<[u8]>>(
219        &self,
220        ids: &[Id],
221        entities: &[EntityKind],
222    ) -> Result<HashMap<Vec<u8>, GlobalCursor>, StorageError> {
223        (**self).get_last_cursor_for_ids(ids, entities)
224    }
225
226    fn update_cursor<Id: AsRef<[u8]>>(
227        &self,
228        entity_id: Id,
229        entity_kind: EntityKind,
230        cursor: Cursor,
231    ) -> Result<bool, StorageError> {
232        (**self).update_cursor(entity_id, entity_kind, cursor)
233    }
234
235    fn get_remote_log_cursors(
236        &self,
237        conversation_ids: &[&Vec<u8>],
238    ) -> Result<HashMap<Vec<u8>, Cursor>, crate::ConnectionError> {
239        (**self).get_remote_log_cursors(conversation_ids)
240    }
241
242    fn get_last_cursor_for_originators<Id: AsRef<[u8]>>(
243        &self,
244        id: Id,
245        entity_kind: EntityKind,
246        originator_ids: &[u32],
247    ) -> Result<Vec<Cursor>, StorageError> {
248        (**self).get_last_cursor_for_originators(id, entity_kind, originator_ids)
249    }
250
251    fn lowest_common_cursor(&self, topics: &[&Topic]) -> Result<GlobalCursor, StorageError> {
252        (**self).lowest_common_cursor(topics)
253    }
254
255    fn lowest_common_cursor_combined(
256        &self,
257        topics: &[&Topic],
258    ) -> Result<GlobalCursor, StorageError> {
259        (**self).lowest_common_cursor_combined(topics)
260    }
261
262    fn latest_cursor_for_id<Id: AsRef<[u8]>>(
263        &self,
264        entity_id: Id,
265        entities: &[EntityKind],
266        originators: Option<&[&OriginatorId]>,
267    ) -> Result<GlobalCursor, StorageError> {
268        (**self).latest_cursor_for_id(entity_id, entities, originators)
269    }
270
271    fn latest_cursor_combined<Id: AsRef<[u8]>>(
272        &self,
273        entity_id: Id,
274        entities: &[EntityKind],
275        originators: Option<&[&OriginatorId]>,
276    ) -> Result<GlobalCursor, StorageError> {
277        (**self).latest_cursor_combined(entity_id, entities, originators)
278    }
279}
280
281impl<C: ConnectionExt> QueryRefreshState for DbConnection<C> {
282    fn get_refresh_state<EntityId: AsRef<[u8]>>(
283        &self,
284        entity_id: EntityId,
285        entity_kind: EntityKind,
286        originator_id: u32,
287    ) -> Result<Option<RefreshState>, StorageError> {
288        use super::schema::refresh_state::dsl;
289
290        let res = self.raw_query_read(|conn| {
291            dsl::refresh_state
292                .find((entity_id.as_ref(), entity_kind, originator_id as i32))
293                .first(conn)
294                .optional()
295        })?;
296        Ok(res)
297    }
298
299    fn get_last_cursor_for_originators<Id: AsRef<[u8]>>(
300        &self,
301        id: Id,
302        entity_kind: EntityKind,
303        originator_ids: &[u32],
304    ) -> Result<Vec<Cursor>, StorageError> {
305        use super::schema::refresh_state::dsl;
306
307        let id_ref = id.as_ref();
308
309        let originator_ids_i32: Vec<i32> = originator_ids.iter().map(|o| *o as i32).collect();
310        let found_states: Vec<RefreshState> = self.raw_query_read(|conn| {
311            dsl::refresh_state
312                .filter(dsl::entity_id.eq(id_ref))
313                .filter(dsl::entity_kind.eq(entity_kind))
314                .filter(dsl::originator_id.eq_any(originator_ids_i32))
315                .load(conn)
316        })?;
317        let state_map: HashMap<u32, &RefreshState> = found_states
318            .iter()
319            .map(|s| (s.originator_id as u32, s))
320            .collect();
321        // Identify missing originators and create default states
322        let mut missing_states = Vec::new();
323        for originator in originator_ids {
324            if !state_map.contains_key(originator) {
325                missing_states.push(RefreshState {
326                    entity_id: id_ref.to_vec(),
327                    entity_kind,
328                    sequence_id: 0,
329                    originator_id: *originator as i32,
330                });
331            }
332        }
333
334        // Insert missing states
335        for missing_state in &missing_states {
336            missing_state.store_or_ignore(self)?;
337        }
338
339        // Build result vector maintaining input order
340        let result: Vec<Cursor> = originator_ids
341            .iter()
342            .map(|originator| match state_map.get(originator) {
343                Some(state) => Cursor::new(state.sequence_id as u64, state.originator_id as u32),
344                None => Cursor::new(0, *originator),
345            })
346            .collect();
347
348        Ok(result)
349    }
350
351    fn get_last_cursor_for_ids<Id: AsRef<[u8]>>(
352        &self,
353        ids: &[Id],
354        entities: &[EntityKind],
355    ) -> Result<HashMap<Vec<u8>, GlobalCursor>, StorageError> {
356        use super::schema::refresh_state::dsl;
357        use std::collections::HashMap;
358
359        if ids.is_empty() {
360            return Ok(HashMap::new());
361        }
362
363        // Run multiple small IN-queries and merge results.
364        // Keep chunks comfortably under SQLite's default 999-bind limit.
365        const CHUNK: usize = 900;
366
367        let map = self.raw_query_read(|conn| {
368            ids.chunks(CHUNK)
369                .map(|chunk| {
370                    let id_refs: Vec<&[u8]> = chunk.iter().map(|id| id.as_ref()).collect();
371                    let rows = dsl::refresh_state
372                        .filter(dsl::entity_kind.eq_any(entities))
373                        .filter(dsl::entity_id.eq_any(&id_refs))
374                        .group_by((dsl::entity_id, dsl::originator_id))
375                        .select((
376                            dsl::entity_id,
377                            dsl::originator_id,
378                            diesel::dsl::max(dsl::sequence_id),
379                        ))
380                        .load::<(Vec<u8>, i32, Option<i64>)>(conn)?;
381
382                    // Convert this chunk's rows to a partial map immediately
383                    Ok(rows_to_global_cursor_map(rows))
384                })
385                .collect::<Result<Vec<_>, _>>()
386                .map(|partial_maps| {
387                    // Flatten all partial maps into a single map
388                    // No merging needed since entity_ids don't repeat across chunks
389                    partial_maps
390                        .into_iter()
391                        .flat_map(|partial_map| partial_map.into_iter())
392                        .collect()
393                })
394        })?;
395
396        Ok(map)
397    }
398
399    #[tracing::instrument(level = "info", skip(self), fields(entity_id = %hex::encode(&entity_id)))]
400    fn update_cursor<Id: AsRef<[u8]>>(
401        &self,
402        entity_id: Id,
403        entity_kind: EntityKind,
404        cursor: Cursor,
405    ) -> Result<bool, StorageError> {
406        use super::schema::refresh_state::dsl;
407        use crate::diesel::upsert::excluded;
408        use diesel::query_dsl::methods::FilterDsl;
409
410        let state = RefreshState {
411            entity_id: entity_id.as_ref().to_vec(),
412            entity_kind,
413            sequence_id: cursor.sequence_id as i64,
414            originator_id: cursor.originator_id as i32,
415        };
416        let num_updated = self.raw_query_write(|conn| {
417            diesel::insert_into(dsl::refresh_state)
418                .values(&state)
419                .on_conflict((dsl::entity_id, dsl::entity_kind, dsl::originator_id))
420                .do_update()
421                .set(dsl::sequence_id.eq(excluded(dsl::sequence_id)))
422                .filter(dsl::sequence_id.lt(excluded(dsl::sequence_id)))
423                .execute(conn)
424        })?;
425        Ok(num_updated >= 1)
426    }
427
428    fn get_remote_log_cursors(
429        &self,
430        conversation_ids: &[&Vec<u8>],
431    ) -> Result<HashMap<Vec<u8>, Cursor>, crate::ConnectionError> {
432        let mut cursor_map: HashMap<Vec<u8>, Cursor> = HashMap::new();
433        for conversation_id in conversation_ids {
434            let cursor = self
435                .get_last_cursor_for_originator(
436                    conversation_id,
437                    EntityKind::CommitLogDownload,
438                    Originators::REMOTE_COMMIT_LOG,
439                )
440                .unwrap_or_default();
441            cursor_map.insert(conversation_id.to_vec(), cursor);
442        }
443        Ok(cursor_map)
444    }
445
446    fn lowest_common_cursor(&self, topics: &[&Topic]) -> Result<GlobalCursor, StorageError> {
447        use super::schema::identity_updates::dsl as idsl;
448        use super::schema::refresh_state::dsl as rdsl;
449
450        // diesel does not support eq_any (IN) on tuple types.
451        // so, something like `.filter((dsl::entity_id, dsl::entity_kind).eq_any(entities))` will not compile. its possible to implement
452        // with a custom QueryFragment, but maybe that's a future
453        // exercise. ref: https://github.com/diesel-rs/diesel/issues/3222#issuecomment-2079474318
454        // it also does not support group_by on boxed queries
455        let entities = topics
456            .iter()
457            .flat_map(|t| match t.kind() {
458                TopicKind::GroupMessagesV1 => {
459                    vec![
460                        (t.identifier().to_vec(), EntityKind::ApplicationMessage),
461                        (t.identifier().to_vec(), EntityKind::CommitMessage),
462                    ]
463                }
464                TopicKind::WelcomeMessagesV1 => {
465                    vec![(t.identifier().to_vec(), EntityKind::Welcome)]
466                }
467                TopicKind::IdentityUpdatesV1 | TopicKind::KeyPackagesV1 | _ => vec![],
468            })
469            .collect::<Vec<_>>();
470
471        let identity_inbox_ids: Vec<String> = topics
472            .iter()
473            .filter_map(|t| Topic::identity_updates(t))
474            .map(|t| hex::encode(t.identifier()))
475            .collect();
476
477        let mut refresh = rdsl::refresh_state
478            .select((rdsl::originator_id, rdsl::sequence_id))
479            .filter(rdsl::entity_kind.eq(-1)) // Start with a query that will never return any results
480            .into_boxed();
481        for (entity_id, entity_kind) in &entities {
482            refresh = refresh.or_filter(
483                rdsl::entity_id
484                    .eq(entity_id)
485                    .and(rdsl::entity_kind.eq(entity_kind)),
486            );
487        }
488
489        let identity = idsl::identity_updates
490            .select((idsl::originator_id, idsl::sequence_id))
491            .filter(idsl::inbox_id.eq_any(identity_inbox_ids))
492            .into_boxed();
493        let cursor = self.raw_query_read(|conn| {
494            refresh
495                .select((rdsl::originator_id, rdsl::sequence_id))
496                .union_all(identity)
497                .load_iter::<(i32, i64), DefaultLoadingMode>(conn)?
498                .map_ok(|(o, s)| (o as u32, s as u64))
499                .process_results(|iter| iter.into_grouping_map().min())
500        })?;
501
502        Ok(GlobalCursor::with_hashmap(cursor))
503    }
504
505    // _NOTE:_ TEMP until reliable streams
506    // and cursor can be updated from streams
507    fn lowest_common_cursor_combined(
508        &self,
509        topics: &[&Topic],
510    ) -> Result<GlobalCursor, StorageError> {
511        // Build entities list from topics, including both refresh_state entries and group_messages
512        let entities = topics
513            .iter()
514            .flat_map(|t| match t.kind() {
515                TopicKind::GroupMessagesV1 => {
516                    vec![
517                        (t.identifier().to_vec(), EntityKind::ApplicationMessage),
518                        (t.identifier().to_vec(), EntityKind::CommitMessage),
519                    ]
520                }
521                TopicKind::WelcomeMessagesV1 => {
522                    vec![(t.identifier().to_vec(), EntityKind::Welcome)]
523                }
524                TopicKind::IdentityUpdatesV1 | TopicKind::KeyPackagesV1 | _ => vec![],
525            })
526            .collect::<Vec<_>>();
527
528        // Collect identity update inbox IDs
529        let identity_inbox_ids: Vec<String> = topics
530            .iter()
531            .filter_map(|t| match t.kind() {
532                TopicKind::IdentityUpdatesV1 => Some(hex::encode(t.identifier())),
533                _ => None,
534            })
535            .collect();
536
537        let has_identity_updates = !identity_inbox_ids.is_empty();
538        let has_entities = !entities.is_empty();
539
540        if !has_entities && !has_identity_updates {
541            return Ok(GlobalCursor::default());
542        }
543
544        let mut query_parts = Vec::new();
545
546        // Add refresh_state and group_messages parts if we have entities
547        if has_entities {
548            let placeholders = entities
549                .iter()
550                .map(|_| "(?, ?)")
551                .collect::<Vec<_>>()
552                .join(", ");
553
554            query_parts.push(format!(
555                "SELECT originator_id, sequence_id
556                FROM refresh_state
557                WHERE (entity_id, entity_kind) IN ({})",
558                placeholders
559            ));
560
561            query_parts.push(format!(
562                "SELECT originator_id, sequence_id
563                FROM conversation_list
564                WHERE (id, CASE message_kind
565                    WHEN 1 THEN 2  -- GroupMessageKind::Application -> EntityKind::ApplicationMessage
566                    WHEN 2 THEN 7  -- GroupMessageKind::MembershipChange -> EntityKind::CommitMessage
567                END) IN ({})",
568                placeholders
569            ));
570        }
571
572        // Add identity_updates part if we have inbox IDs
573        if has_identity_updates {
574            let inbox_placeholders = identity_inbox_ids
575                .iter()
576                .map(|_| "?")
577                .collect::<Vec<_>>()
578                .join(", ");
579            query_parts.push(format!(
580                "SELECT originator_id, sequence_id
581                FROM identity_updates
582                WHERE inbox_id IN ({})",
583                inbox_placeholders
584            ));
585        }
586
587        // Build a query that unions all sources, then finds MIN per originator
588        let query = format!(
589            "SELECT originator_id, MIN(sequence_id) AS sequence_id
590            FROM ({})
591            GROUP BY originator_id",
592            query_parts.join(" UNION ALL ")
593        );
594
595        let cursor = self.raw_query_read(|conn| {
596            let mut q = diesel::sql_query(query).into_boxed();
597
598            if has_entities {
599                // Bind entity_id and entity_kind pairs for refresh_state
600                for (id, kind) in &entities {
601                    q = q.bind::<Binary, _>(id);
602                    q = q.bind::<Integer, _>(*kind);
603                }
604
605                // Bind entity_id and entity_kind pairs for group_messages
606                for (id, kind) in &entities {
607                    q = q.bind::<Binary, _>(id);
608                    q = q.bind::<Integer, _>(*kind);
609                }
610            }
611
612            // Bind identity_updates parameters
613            if has_identity_updates {
614                for inbox_id in identity_inbox_ids {
615                    q = q.bind::<diesel::sql_types::Text, _>(inbox_id);
616                }
617            }
618
619            q.load_iter::<SingleCursor, DefaultLoadingMode>(conn)?
620                .map_ok(|c| (c.originator_id as u32, c.sequence_id as u64))
621                .collect::<QueryResult<GlobalCursor>>()
622        })?;
623
624        Ok(cursor)
625    }
626
627    fn latest_cursor_for_id<Id: AsRef<[u8]>>(
628        &self,
629        entity_id: Id,
630        entities: &[EntityKind],
631        originators: Option<&[&OriginatorId]>,
632    ) -> Result<GlobalCursor, StorageError> {
633        use super::schema::refresh_state::dsl;
634        use diesel::dsl::max;
635
636        let entity_ref = entity_id.as_ref();
637
638        let cursor_map = self.raw_query_read(|conn| {
639            // Build base query with entity_id and entity_kind filters
640            let base_query = dsl::refresh_state
641                .filter(dsl::entity_id.eq(entity_ref))
642                .filter(dsl::entity_kind.eq_any(entities));
643
644            // Add originator filter if provided, then group and select
645            let results = if let Some(oids) = originators {
646                let originator_ids_i32: Vec<i32> = oids.iter().map(|o| **o as i32).collect();
647                base_query
648                    .filter(dsl::originator_id.eq_any(originator_ids_i32))
649                    .group_by(dsl::originator_id)
650                    .select((dsl::originator_id, max(dsl::sequence_id)))
651                    .load::<(i32, Option<i64>)>(conn)?
652            } else {
653                base_query
654                    .group_by(dsl::originator_id)
655                    .select((dsl::originator_id, max(dsl::sequence_id)))
656                    .load::<(i32, Option<i64>)>(conn)?
657            };
658
659            Ok(results
660                .into_iter()
661                .filter_map(|(orig_id, seq_id)| seq_id.map(|seq| (orig_id as u32, seq as u64)))
662                .collect::<GlobalCursor>())
663        })?;
664
665        Ok(cursor_map)
666    }
667
668    // _NOTE:_ TEMP until reliable streams
669    // and cursor can be updated from streams
670    fn latest_cursor_combined<Id: AsRef<[u8]>>(
671        &self,
672        entity_id: Id,
673        entities: &[EntityKind],
674        originators: Option<&[&OriginatorId]>,
675    ) -> Result<GlobalCursor, StorageError> {
676        let entity_ref = entity_id.as_ref();
677
678        // Build entity_kind placeholders for refresh_state
679        let entity_kind_placeholders = entities.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
680
681        // Build a query that unions refresh_state and group_messages
682        let mut query = format!(
683            "SELECT originator_id, MAX(sequence_id) AS sequence_id
684            FROM (
685                SELECT originator_id, sequence_id
686                FROM refresh_state
687                WHERE entity_id = ? AND entity_kind IN ({})
688                UNION ALL
689                SELECT originator_id, sequence_id
690                FROM group_messages
691                WHERE group_id = ? AND kind IN (",
692            entity_kind_placeholders
693        );
694
695        // Map EntityKind to GroupMessageKind
696        let group_message_kinds: Vec<i32> = entities
697            .iter()
698            .filter_map(|e| match e {
699                EntityKind::ApplicationMessage => Some(1), // GroupMessageKind::Application
700                EntityKind::CommitMessage => Some(2),      // GroupMessageKind::MembershipChange
701                _ => None,
702            })
703            .collect();
704
705        // Add placeholders for group_message kinds
706        let kind_placeholders = group_message_kinds
707            .iter()
708            .map(|_| "?")
709            .collect::<Vec<_>>()
710            .join(", ");
711        query.push_str(&kind_placeholders);
712        query.push(')');
713
714        // Add originator filter if provided
715        if let Some(oids) = originators {
716            let originator_placeholders = oids.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
717            query.push_str(&format!(
718                "
719            ) WHERE originator_id IN ({})
720            GROUP BY originator_id",
721                originator_placeholders
722            ));
723        } else {
724            query.push_str(
725                "
726            ) GROUP BY originator_id",
727            );
728        }
729
730        let cursor_map = self.raw_query_read(|conn| {
731            let mut q = diesel::sql_query(query).into_boxed();
732
733            // Bind entity_id for refresh_state
734            q = q.bind::<Binary, _>(entity_ref);
735
736            // Bind entity_kinds for refresh_state
737            for kind in entities {
738                q = q.bind::<Integer, _>(*kind);
739            }
740
741            // Bind group_id for group_messages
742            q = q.bind::<Binary, _>(entity_ref);
743
744            // Bind group_message_kinds for group_messages
745            for kind in &group_message_kinds {
746                q = q.bind::<Integer, _>(*kind);
747            }
748
749            // Bind originators if provided
750            if let Some(oids) = originators {
751                for oid in oids {
752                    q = q.bind::<Integer, _>(**oid as i32);
753                }
754            }
755
756            q.load_iter::<SingleCursor, DefaultLoadingMode>(conn)?
757                .map_ok(|c| (c.originator_id as u32, c.sequence_id as u64))
758                .collect::<QueryResult<GlobalCursor>>()
759        })?;
760
761        Ok(cursor_map)
762    }
763}
764
765#[cfg(test)]
766pub(crate) mod tests {
767    #[cfg(target_arch = "wasm32")]
768    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker);
769
770    use super::*;
771    use crate::identity_update::StoredIdentityUpdateBuilder;
772    use crate::test_utils::with_connection;
773    use crate::{Store, StoreOrIgnore};
774    use rstest::rstest;
775
776    #[xmtp_common::test]
777    fn get_cursor_with_no_existing_state() {
778        with_connection(|conn| {
779            let id = vec![1, 2, 3];
780            let kind = EntityKind::ApplicationMessage;
781            let entry: Option<RefreshState> = conn
782                .get_refresh_state(&id, kind, Originators::MLS_COMMITS)
783                .unwrap();
784            assert!(entry.is_none());
785            assert_eq!(
786                conn.get_last_cursor_for_originator(&id, kind, Originators::MLS_COMMITS)
787                    .unwrap(),
788                Cursor::mls_commits(0)
789            );
790            let entry: Option<RefreshState> = conn
791                .get_refresh_state(&id, kind, Originators::MLS_COMMITS)
792                .unwrap();
793            assert!(entry.is_some());
794        })
795    }
796
797    #[xmtp_common::test]
798    fn get_cursor_with_no_existing_state_originator() {
799        with_connection(|conn| {
800            let id = vec![1, 2, 3];
801            let kind = EntityKind::ApplicationMessage;
802            let entry: Option<RefreshState> = conn
803                .get_refresh_state(&id, kind, Originators::MLS_COMMITS)
804                .unwrap();
805            assert!(entry.is_none());
806            assert_eq!(
807                conn.get_last_cursor_for_originators(&id, kind, &[0])
808                    .unwrap()[0],
809                Cursor::mls_commits(0)
810            );
811            let entry: Option<RefreshState> = conn
812                .get_refresh_state(&id, kind, Originators::MLS_COMMITS)
813                .unwrap();
814            assert!(entry.is_some());
815        })
816    }
817
818    #[xmtp_common::test]
819    fn get_timestamp_with_existing_state() {
820        with_connection(|conn| {
821            let id = vec![1, 2, 3];
822            let entity_kind = EntityKind::Welcome;
823            let entry = RefreshState {
824                entity_id: id.clone(),
825                entity_kind,
826                sequence_id: 123,
827                originator_id: Originators::MLS_COMMITS as i32,
828            };
829            entry.store_or_ignore(conn).unwrap();
830            assert_eq!(
831                conn.get_last_cursor_for_originator(&id, entity_kind, Originators::MLS_COMMITS)
832                    .unwrap(),
833                Cursor::mls_commits(123)
834            );
835        })
836    }
837
838    #[xmtp_common::test]
839    fn update_timestamp_when_bigger() {
840        with_connection(|conn| {
841            let id = vec![1, 2, 3];
842            let entity_kind = EntityKind::ApplicationMessage;
843            let entry = RefreshState {
844                entity_id: id.clone(),
845                entity_kind,
846                sequence_id: 123,
847                originator_id: 10,
848            };
849            entry.store_or_ignore(conn).unwrap();
850            assert!(
851                conn.update_cursor(
852                    &id,
853                    entity_kind,
854                    Cursor::new(124, Originators::APPLICATION_MESSAGES)
855                )
856                .unwrap()
857            );
858            let entry: Option<RefreshState> = conn
859                .get_refresh_state(&id, entity_kind, Originators::APPLICATION_MESSAGES)
860                .unwrap();
861            assert_eq!(entry.unwrap().sequence_id, 124);
862        })
863    }
864
865    #[xmtp_common::test]
866    fn dont_update_timestamp_when_smaller() {
867        with_connection(|conn| {
868            let entity_id = vec![1, 2, 3];
869            let entity_kind = EntityKind::Welcome;
870
871            let entry = RefreshState {
872                entity_id: entity_id.clone(),
873                entity_kind,
874                sequence_id: 123,
875                originator_id: 10,
876            };
877            entry.store_or_ignore(conn).unwrap();
878            assert!(
879                !conn
880                    .update_cursor(
881                        &entity_id,
882                        entity_kind,
883                        Cursor::new(122, Originators::APPLICATION_MESSAGES)
884                    )
885                    .unwrap()
886            );
887            let entry: Option<RefreshState> = conn
888                .get_refresh_state(&entity_id, entity_kind, Originators::APPLICATION_MESSAGES)
889                .unwrap();
890            assert_eq!(entry.unwrap().sequence_id, 123);
891        })
892    }
893
894    #[xmtp_common::test]
895    fn allow_installation_and_welcome_same_id() {
896        with_connection(|conn| {
897            let entity_id = vec![1, 2, 3];
898            let welcome_state = RefreshState {
899                entity_id: entity_id.clone(),
900                entity_kind: EntityKind::Welcome,
901                sequence_id: 123,
902                originator_id: Originators::MLS_COMMITS as i32,
903            };
904            welcome_state.store_or_ignore(conn).unwrap();
905
906            let group_state = RefreshState {
907                entity_id: entity_id.clone(),
908                entity_kind: EntityKind::ApplicationMessage,
909                sequence_id: 456,
910                originator_id: Originators::MLS_COMMITS as i32,
911            };
912            group_state.store_or_ignore(conn).unwrap();
913
914            let welcome_state_retrieved = conn
915                .get_refresh_state(&entity_id, EntityKind::Welcome, Originators::MLS_COMMITS)
916                .unwrap()
917                .unwrap();
918            assert_eq!(welcome_state_retrieved.sequence_id, 123);
919
920            let group_state_retrieved = conn
921                .get_refresh_state(
922                    &entity_id,
923                    EntityKind::ApplicationMessage,
924                    Originators::MLS_COMMITS,
925                )
926                .unwrap()
927                .unwrap();
928            assert_eq!(group_state_retrieved.sequence_id, 456);
929        })
930    }
931
932    // Helper function to create and store a RefreshState
933    fn create_state<C: ConnectionExt>(
934        conn: &DbConnection<C>,
935        entity_id: &[u8],
936        entity_kind: EntityKind,
937        originator_id: i32,
938        sequence_id: i64,
939    ) {
940        RefreshState {
941            entity_id: entity_id.to_vec(),
942            entity_kind,
943            sequence_id,
944            originator_id,
945        }
946        .store_or_ignore(conn)
947        .unwrap();
948    }
949
950    // Helper function to create and store a RefreshState
951    fn create_identity_update<C: ConnectionExt>(
952        conn: &DbConnection<C>,
953        originator_id: i32,
954        sequence_id: i64,
955    ) {
956        StoredIdentityUpdateBuilder::default()
957            .inbox_id(xmtp_common::rand_string::<32>())
958            .sequence_id(sequence_id)
959            .originator_id(originator_id)
960            .payload(xmtp_common::rand_vec::<32>())
961            .server_timestamp_ns(xmtp_common::rand_i64())
962            .build()
963            .unwrap()
964            .store(conn)
965            .unwrap();
966    }
967
968    #[rstest]
969    #[case::mixed_existing_missing(
970        vec![(0, 100), (10, 200)], // Pre-populate originators 0 and 10
971        vec![0, 10, 20],            // Request 0, 10, and missing 20
972        vec![(0, 100), (10, 200), (20, 0)] // Expected results
973    )]
974    #[case::preserves_order(
975        vec![(5, 555), (10, 1010), (15, 1515)],
976        vec![15, 5, 10], // Non-sequential order
977        vec![(15, 1515), (5, 555), (10, 1010)]
978    )]
979    #[case::all_missing(
980        vec![], // No pre-populated states
981        vec![1, 2, 3],
982        vec![(1, 0), (2, 0), (3, 0)]
983    )]
984    #[case::empty_request(
985        vec![(5, 500)],
986        vec![], // Empty request
987        vec![]  // Empty result
988    )]
989    #[xmtp_common::test]
990    async fn batch_query_scenarios(
991        #[case] pre_populate: Vec<(i32, i64)>,
992        #[case] request_originators: Vec<u32>,
993        #[case] expected: Vec<(u32, u64)>,
994    ) {
995        with_connection(|conn| {
996            let entity_id = vec![1, 1, 1];
997            let entity_kind = EntityKind::CommitMessage;
998            // Pre-populate states
999            for (orig, seq) in pre_populate {
1000                create_state(conn, &entity_id, entity_kind, orig, seq);
1001            }
1002
1003            // Execute query
1004            let cursors = conn
1005                .get_last_cursor_for_originators(&entity_id, entity_kind, &request_originators)
1006                .unwrap();
1007
1008            // Verify results
1009            assert_eq!(cursors.len(), expected.len());
1010            for (i, (expected_orig, expected_seq)) in expected.iter().enumerate() {
1011                assert_eq!(cursors[i].originator_id, *expected_orig);
1012                assert_eq!(cursors[i].sequence_id, *expected_seq);
1013            }
1014
1015            // Verify missing originators were persisted
1016            for orig in &request_originators {
1017                let state = conn
1018                    .get_refresh_state(&entity_id, entity_kind, *orig)
1019                    .unwrap();
1020                assert!(state.is_some(), "Originator {} should be persisted", orig);
1021            }
1022        })
1023    }
1024
1025    #[rstest]
1026    #[case::finds_maximum_per_originator(
1027        vec![
1028            (EntityKind::ApplicationMessage, 5, 100),  // Originator 5, ApplicationMessage
1029            (EntityKind::CommitMessage, 5, 150),       // Originator 5, CommitMessage (higher)
1030            (EntityKind::ApplicationMessage, 10, 500), // Originator 10
1031            (EntityKind::CommitMessage, 0, 250),       // Originator 0
1032        ],
1033        vec![EntityKind::ApplicationMessage, EntityKind::CommitMessage],
1034        vec![0, 5, 10],
1035        vec![(0, 250), (5, 150), (10, 500)] // Expected: max per originator across entity kinds
1036    )]
1037    #[case::single_entry(
1038        vec![(EntityKind::Welcome, 11, 999)],
1039        vec![EntityKind::Welcome],
1040        vec![11],
1041        vec![(11, 999)]
1042    )]
1043    #[case::filters_by_entity_kind(
1044        vec![
1045            (EntityKind::ApplicationMessage, 5, 1000),
1046            (EntityKind::CommitMessage, 5, 2000),  // Higher but filtered out
1047            (EntityKind::Welcome, 5, 3000),        // Highest but filtered out
1048        ],
1049        vec![EntityKind::ApplicationMessage],  // Only query ApplicationMessage
1050        vec![5],
1051        vec![(5, 1000)]  // Should get ApplicationMessage's value, not others
1052    )]
1053    #[case::filters_by_originator(
1054        vec![
1055            (EntityKind::ApplicationMessage, 5, 500),
1056            (EntityKind::ApplicationMessage, 10, 1000),
1057            (EntityKind::ApplicationMessage, 15, 1500), // Filtered out
1058        ],
1059        vec![EntityKind::ApplicationMessage],
1060        vec![5, 10],  // Don't include 15
1061        vec![(5, 500), (10, 1000)]  // Should get originator 5 and 10, not 15
1062    )]
1063    #[xmtp_common::test]
1064    async fn latest_cursor_for_id(
1065        #[case] pre_populate: Vec<(EntityKind, i32, i64)>,
1066        #[case] query_entities: Vec<EntityKind>,
1067        #[case] query_originators: Vec<u32>,
1068        #[case] expected: Vec<(u32, u64)>,
1069    ) {
1070        with_connection(|conn| {
1071            let entity_id = vec![99, 88, 77];
1072
1073            // Pre-populate states
1074            for (kind, orig, seq) in pre_populate {
1075                create_state(conn, &entity_id, kind, orig, seq);
1076            }
1077
1078            // Convert to OriginatorId references
1079            let originator_refs: Vec<&OriginatorId> = query_originators
1080                .iter()
1081                .map(|o| o as &OriginatorId)
1082                .collect();
1083
1084            // Execute query
1085            let cursor = conn
1086                .latest_cursor_for_id(&entity_id, &query_entities, Some(&originator_refs))
1087                .unwrap();
1088
1089            // Verify results
1090            assert_eq!(cursor.len(), expected.len());
1091            for (expected_orig, expected_seq) in expected {
1092                assert_eq!(
1093                    cursor.get(&expected_orig),
1094                    expected_seq,
1095                    "Mismatch for originator {}: expected {}, got {}",
1096                    expected_orig,
1097                    expected_seq,
1098                    cursor.get(&expected_orig)
1099                );
1100            }
1101        })
1102    }
1103
1104    #[rstest]
1105    #[case::single_topic_minimium(
1106        vec![
1107            (vec![1, 1, 1], EntityKind::ApplicationMessage, 200, 127),
1108            (vec![1, 1, 1], EntityKind::CommitMessage, 0, 115),
1109            (vec![1, 1, 1], EntityKind::CommitLogDownload, 100, 0),
1110            (vec![1, 1, 1], EntityKind::CommitLogUpload, 100, 2),
1111            (vec![1, 1, 1], EntityKind::CommitLogForkCheckLocal, 100, 0),
1112            (vec![1, 1, 1], EntityKind::CommitLogForkCheckRemote, 100, 0)
1113        ],
1114        vec![
1115            TopicKind::GroupMessagesV1.create(vec![1, 1, 1]),
1116        ],
1117        vec![(200, 127), (0, 115)]  // MIN across both topics: min(min(100, 150), min(50, 75)) = 50
1118    )]
1119    #[case::multiple_topics_finds_minimum(
1120        vec![
1121            (vec![1, 1, 1], EntityKind::ApplicationMessage, 0, 100),
1122            (vec![1, 1, 1], EntityKind::CommitMessage, 0, 150),
1123            (vec![2, 2, 2], EntityKind::ApplicationMessage, 0, 50),  // Lower value in topic 2
1124            (vec![2, 2, 2], EntityKind::CommitMessage, 0, 75),
1125        ],
1126        vec![
1127            TopicKind::GroupMessagesV1.create(vec![1, 1, 1]),
1128            TopicKind::GroupMessagesV1.create(vec![2, 2, 2]),
1129        ],
1130        vec![(0, 50)]  // MIN across both topics: min(min(100, 150), min(50, 75)) = 50
1131    )]
1132    #[case::multiple_topics_different_originators(
1133        vec![
1134            (vec![3, 3, 3], EntityKind::ApplicationMessage, 5, 500),
1135            (vec![3, 3, 3], EntityKind::CommitMessage, 5, 600),
1136            (vec![4, 4, 4], EntityKind::ApplicationMessage, 10, 1000),
1137            (vec![4, 4, 4], EntityKind::CommitMessage, 10, 1100),
1138            (vec![4, 4, 4], EntityKind::ApplicationMessage, 5, 300),  // Lower value for originator 5
1139        ],
1140        vec![
1141            TopicKind::GroupMessagesV1.create(vec![3, 3, 3]),
1142            TopicKind::GroupMessagesV1.create(vec![4, 4, 4]),
1143        ],
1144        vec![(5, 300), (10, 1000)]  // MIN for each originator across topics
1145    )]
1146    #[case::mixed_group_and_welcome_topics(
1147        vec![
1148            (vec![6, 6, 6], EntityKind::ApplicationMessage, 0, 100),
1149            (vec![6, 6, 6], EntityKind::CommitMessage, 0, 150),
1150            (vec![7, 7, 7], EntityKind::Welcome, 0, 50),  // Lower value for originator 0
1151            (vec![7, 7, 7], EntityKind::Welcome, 10, 200),
1152        ],
1153        vec![
1154            TopicKind::GroupMessagesV1.create(vec![6, 6, 6]),
1155            TopicKind::WelcomeMessagesV1.create(vec![7, 7, 7]),
1156        ],
1157        vec![(0, 50), (10, 200)]  // MIN across different entity kinds
1158    )]
1159    #[case::originator_in_some_topics_only(
1160        vec![
1161            (vec![8, 8, 8], EntityKind::ApplicationMessage, 5, 100),
1162            (vec![8, 8, 8], EntityKind::CommitMessage, 5, 200),
1163            (vec![9, 9, 9], EntityKind::ApplicationMessage, 10, 300),
1164            (vec![9, 9, 9], EntityKind::CommitMessage, 10, 400),
1165        ],
1166        vec![
1167            TopicKind::GroupMessagesV1.create(vec![8, 8, 8]),
1168            TopicKind::GroupMessagesV1.create(vec![9, 9, 9]),
1169        ],
1170        vec![(5, 100), (10, 300)]  // Each originator appears in only one topic
1171    )]
1172    #[xmtp_common::test]
1173    async fn lowest_common_cursor_scenarios(
1174        #[case] pre_populate: Vec<(Vec<u8>, EntityKind, i32, i64)>,
1175        #[case] query_topics: Vec<xmtp_proto::types::Topic>,
1176        #[case] expected: Vec<(u32, u64)>,
1177    ) {
1178        with_connection(|conn| {
1179            // Pre-populate states
1180            for (entity_id, kind, orig, seq) in pre_populate {
1181                create_state(conn, &entity_id, kind, orig, seq);
1182            }
1183
1184            // Execute query
1185            let topic_refs: Vec<&xmtp_proto::types::Topic> = query_topics.iter().collect();
1186            let cursor = conn.lowest_common_cursor(&topic_refs).unwrap();
1187
1188            // Verify results
1189            assert_eq!(
1190                cursor.len(),
1191                expected.len(),
1192                "Expected {} originators, got {}",
1193                expected.len(),
1194                cursor.len()
1195            );
1196            for (expected_orig, expected_seq) in expected {
1197                assert_eq!(
1198                    cursor.get(&expected_orig),
1199                    expected_seq,
1200                    "Mismatch for originator {}: expected {}, got {}",
1201                    expected_orig,
1202                    expected_seq,
1203                    cursor.get(&expected_orig)
1204                );
1205            }
1206        })
1207    }
1208
1209    #[xmtp_common::test]
1210    fn lowest_common_cursor_empty_topics() {
1211        with_connection(|conn| {
1212            create_state(conn, &[1, 2, 3], EntityKind::ApplicationMessage, 0, 100);
1213            create_identity_update(conn, 1, 100);
1214            let result = conn.lowest_common_cursor(&[]);
1215            match result {
1216                Ok(cursor) => {
1217                    tracing::info!("{:?}", cursor);
1218                    assert_eq!(cursor.len(), 0, "Empty topics should return empty cursor");
1219                }
1220                Err(_e) => {
1221                    // Also acceptable to return an error for empty topics
1222                }
1223            }
1224        })
1225    }
1226
1227    #[xmtp_common::test]
1228    fn lowest_common_cursor_no_matching_states() {
1229        with_connection(|conn| {
1230            let topics = [
1231                TopicKind::GroupMessagesV1.create(vec![99, 99, 99]),
1232                TopicKind::WelcomeMessagesV1.create(vec![88, 88, 88]),
1233                TopicKind::IdentityUpdatesV1.create(b"test inbox"),
1234                TopicKind::IdentityUpdatesV1.create(b"inbox test 2"),
1235            ];
1236            let topic_refs: Vec<&xmtp_proto::types::Topic> = topics.iter().collect();
1237            create_identity_update(conn, 1, 100);
1238            let cursor = conn.lowest_common_cursor(&topic_refs).unwrap();
1239            assert_eq!(cursor.len(), 0);
1240        })
1241    }
1242
1243    #[xmtp_common::test]
1244    fn get_last_cursor_for_ids_empty() {
1245        with_connection(|conn| {
1246            let ids: Vec<Vec<u8>> = vec![];
1247            let entities = vec![EntityKind::ApplicationMessage];
1248            let result = conn.get_last_cursor_for_ids(&ids, &entities).unwrap();
1249            assert!(result.is_empty());
1250        })
1251    }
1252
1253    #[xmtp_common::test]
1254    async fn get_last_cursor_for_ids_single() {
1255        with_connection(|conn| {
1256            let id = vec![1, 2, 3];
1257            let entity_kind = EntityKind::ApplicationMessage;
1258
1259            // Store a state with originator 10 and sequence_id 456
1260            create_state(conn, &id, entity_kind, 10, 456);
1261
1262            // Query for it
1263            let ids = vec![id.clone()];
1264            let entities = vec![entity_kind];
1265            let result = conn.get_last_cursor_for_ids(&ids, &entities).unwrap();
1266
1267            assert_eq!(result.len(), 1);
1268            let cursor = result.get(&id).expect("Should have cursor for id");
1269            assert_eq!(cursor.get(&10), 456);
1270        })
1271    }
1272
1273    #[xmtp_common::test]
1274    fn get_last_cursor_for_ids_multiple_mixed() {
1275        with_connection(|conn| {
1276            let entity_kind = EntityKind::ApplicationMessage;
1277
1278            // Create some ids with existing state
1279            let id1 = vec![1, 0, 0];
1280            let id2 = vec![2, 0, 0];
1281            let id3 = vec![3, 0, 0];
1282            let id4 = vec![4, 0, 0]; // This one won't have state
1283
1284            create_state(conn, &id1, entity_kind, 10, 100);
1285            create_state(conn, &id2, entity_kind, 10, 200);
1286            create_state(conn, &id3, entity_kind, 10, 300);
1287
1288            // Query for all ids including one without state
1289            let ids = vec![id1.clone(), id2.clone(), id3.clone(), id4.clone()];
1290            let entities = vec![entity_kind];
1291            let result = conn.get_last_cursor_for_ids(&ids, &entities).unwrap();
1292
1293            // Should only return the ones with existing state
1294            assert_eq!(result.len(), 3);
1295            assert_eq!(result.get(&id1).unwrap().get(&10), 100);
1296            assert_eq!(result.get(&id2).unwrap().get(&10), 200);
1297            assert_eq!(result.get(&id3).unwrap().get(&10), 300);
1298            assert!(!result.contains_key(&id4));
1299        })
1300    }
1301
1302    #[xmtp_common::test]
1303    fn get_last_cursor_for_ids_exactly_900() {
1304        with_connection(|conn| {
1305            let entity_kind = EntityKind::ApplicationMessage;
1306
1307            // Create exactly 900 ids
1308            let mut ids = Vec::new();
1309            for i in 0..900 {
1310                let id = vec![(i / 256) as u8, (i % 256) as u8];
1311                create_state(conn, &id, entity_kind, 10, i as i64);
1312                ids.push(id);
1313            }
1314
1315            // Query for all 900 ids
1316            let entities = vec![entity_kind];
1317            let result = conn.get_last_cursor_for_ids(&ids, &entities).unwrap();
1318
1319            assert_eq!(result.len(), 900);
1320            for (idx, id) in ids.iter().enumerate() {
1321                assert_eq!(result.get(id).unwrap().get(&10), idx as u64);
1322            }
1323        })
1324    }
1325
1326    #[xmtp_common::test]
1327    fn get_last_cursor_for_ids_over_900() {
1328        with_connection(|conn| {
1329            let entity_kind = EntityKind::ApplicationMessage;
1330
1331            // Create 1000 ids to test chunking
1332            let mut ids = Vec::new();
1333            for i in 0..1000 {
1334                let id = vec![(i / 256) as u8, (i % 256) as u8, 0];
1335                create_state(conn, &id, entity_kind, 10, i as i64);
1336                ids.push(id);
1337            }
1338
1339            // Query for all 1000 ids (should use 2 chunks)
1340            let entities = vec![entity_kind];
1341            let result = conn.get_last_cursor_for_ids(&ids, &entities).unwrap();
1342
1343            assert_eq!(result.len(), 1000);
1344            for (idx, id) in ids.iter().enumerate() {
1345                assert_eq!(
1346                    result.get(id).unwrap().get(&10),
1347                    idx as u64,
1348                    "Mismatch for id at index {}",
1349                    idx
1350                );
1351            }
1352        })
1353    }
1354
1355    #[xmtp_common::test]
1356    fn get_last_cursor_for_ids_over_1800() {
1357        with_connection(|conn| {
1358            let entity_kind = EntityKind::ApplicationMessage;
1359
1360            // Create 2000 ids to test multiple chunks
1361            let mut ids = Vec::new();
1362            for i in 0..2000 {
1363                let id = vec![(i / 256) as u8, (i % 256) as u8, 1];
1364                create_state(conn, &id, entity_kind, 10, i as i64);
1365                ids.push(id);
1366            }
1367
1368            // Query for all 2000 ids (should use 3 chunks: 900, 900, 200)
1369            let entities = vec![entity_kind];
1370            let result = conn.get_last_cursor_for_ids(&ids, &entities).unwrap();
1371
1372            assert_eq!(result.len(), 2000);
1373            for (idx, id) in ids.iter().enumerate() {
1374                assert_eq!(
1375                    result.get(id).unwrap().get(&10),
1376                    idx as u64,
1377                    "Mismatch for id at index {}",
1378                    idx
1379                );
1380            }
1381        })
1382    }
1383
1384    #[xmtp_common::test]
1385    fn get_last_cursor_for_ids_different_entity_kinds() {
1386        with_connection(|conn| {
1387            let id1 = vec![1, 2, 3];
1388            let id2 = vec![4, 5, 6];
1389
1390            // Store same ids with different entity kinds
1391            create_state(conn, &id1, EntityKind::ApplicationMessage, 10, 100);
1392            create_state(conn, &id1, EntityKind::Welcome, 10, 200);
1393            create_state(conn, &id2, EntityKind::ApplicationMessage, 10, 300);
1394
1395            // Query for ApplicationMessage entity kind only
1396            let ids = vec![id1.clone(), id2.clone()];
1397            let result = conn
1398                .get_last_cursor_for_ids(&ids, &[EntityKind::ApplicationMessage])
1399                .unwrap();
1400
1401            assert_eq!(result.len(), 2);
1402            assert_eq!(result.get(&id1).unwrap().get(&10), 100);
1403            assert_eq!(result.get(&id2).unwrap().get(&10), 300);
1404
1405            // Query for Welcome entity kind only
1406            let result = conn
1407                .get_last_cursor_for_ids(&ids, &[EntityKind::Welcome])
1408                .unwrap();
1409
1410            assert_eq!(result.len(), 1);
1411            assert_eq!(result.get(&id1).unwrap().get(&10), 200);
1412            assert!(!result.contains_key(&id2));
1413        })
1414    }
1415}