1use diesel::RunQueryDsl;
2
3use crate::{
4 ConnectionExt, DbConnection, impl_store, schema::remote_commit_log,
5 schema::remote_commit_log::dsl,
6};
7use diesel::{
8 Insertable, Queryable,
9 backend::Backend,
10 deserialize::{self, FromSql, FromSqlRow},
11 expression::AsExpression,
12 prelude::*,
13 serialize::{self, IsNull, Output, ToSql},
14 sql_types::Integer,
15 sqlite::Sqlite,
16};
17
18use serde::{Deserialize, Serialize};
19use xmtp_common::snippet::Snippet;
20use xmtp_proto::xmtp::mls::message_contents::CommitResult as ProtoCommitResult;
21
22#[derive(Insertable, Debug, Clone)]
23#[diesel(table_name = remote_commit_log)]
24pub struct NewRemoteCommitLog {
25 pub log_sequence_id: i64,
26 pub group_id: Vec<u8>,
27 pub commit_sequence_id: i64,
28 pub commit_result: CommitResult,
29 pub applied_epoch_number: i64,
30 pub applied_epoch_authenticator: Vec<u8>,
31}
32
33impl_store!(NewRemoteCommitLog, remote_commit_log);
34
35#[derive(Insertable, Queryable, Clone)]
36#[diesel(table_name = remote_commit_log)]
37#[diesel(primary_key(rowid))]
38pub struct RemoteCommitLog {
39 pub rowid: i32,
40 pub log_sequence_id: i64,
42 pub group_id: Vec<u8>,
44 pub commit_sequence_id: i64,
46 pub commit_result: CommitResult,
49 pub applied_epoch_number: i64,
51 pub applied_epoch_authenticator: Vec<u8>,
53}
54
55impl_store!(RemoteCommitLog, remote_commit_log);
56
57#[repr(i32)]
58#[derive(Copy, Clone, Serialize, Deserialize, Eq, PartialEq, AsExpression, FromSqlRow)]
59#[diesel(sql_type = Integer)]
60pub enum CommitResult {
61 Unknown = 0,
62 Success = 1,
63 WrongEpoch = 2,
64 Undecryptable = 3,
65 Invalid = 4,
66}
67
68impl std::fmt::Debug for CommitResult {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 let s = match self {
71 CommitResult::Unknown => "Unknown",
72 CommitResult::Success => "Success",
73 CommitResult::WrongEpoch => "WrongEpoch",
74 CommitResult::Undecryptable => "Undecryptable",
75 CommitResult::Invalid => "Invalid",
76 };
77 write!(f, "{}", s)
78 }
79}
80
81impl std::fmt::Debug for RemoteCommitLog {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 write!(
84 f,
85 "RemoteCommitLog {{ rowid: {:?}, log_sequence_id: {:?}, group_id {:?}, commit_sequence_id: {:?}, commit_result: {:?}, applied_epoch_number: {:?}, applied_epoch_authenticator: {:?} }}",
86 self.rowid,
87 self.log_sequence_id,
88 &self.group_id.snippet(),
89 self.commit_sequence_id,
90 self.commit_result,
91 self.applied_epoch_number,
92 &self.applied_epoch_authenticator.snippet()
93 )
94 }
95}
96
97impl ToSql<Integer, Sqlite> for CommitResult
98where
99 i32: ToSql<Integer, Sqlite>,
100{
101 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result {
102 out.set_value(*self as i32);
103 Ok(IsNull::No)
104 }
105}
106
107impl FromSql<Integer, Sqlite> for CommitResult
108where
109 i32: FromSql<Integer, Sqlite>,
110{
111 fn from_sql(bytes: <Sqlite as Backend>::RawValue<'_>) -> deserialize::Result<Self> {
112 match i32::from_sql(bytes)? {
113 0 => Ok(Self::Unknown),
114 1 => Ok(Self::Success),
115 2 => Ok(Self::WrongEpoch),
116 3 => Ok(Self::Undecryptable),
117 4 => Ok(Self::Invalid),
118 x => Err(format!("Unrecognized variant {}", x).into()),
119 }
120 }
121}
122
123impl From<ProtoCommitResult> for CommitResult {
124 fn from(value: ProtoCommitResult) -> Self {
125 match value {
126 ProtoCommitResult::Applied => Self::Success,
127 ProtoCommitResult::WrongEpoch => Self::WrongEpoch,
128 ProtoCommitResult::Undecryptable => Self::Undecryptable,
129 ProtoCommitResult::Invalid => Self::Invalid,
130 ProtoCommitResult::Unspecified => Self::Unknown,
131 }
132 }
133}
134
135pub enum RemoteCommitLogOrder {
136 AscendingByRowid,
137 DescendingByRowid,
138}
139
140pub trait QueryRemoteCommitLog {
141 fn get_latest_remote_log_for_group(
142 &self,
143 group_id: &[u8],
144 ) -> Result<Option<RemoteCommitLog>, crate::ConnectionError>;
145
146 fn get_remote_commit_log_after_cursor(
147 &self,
148 group_id: &[u8],
149 after_cursor: i64,
150 order_by: RemoteCommitLogOrder,
151 ) -> Result<Vec<RemoteCommitLog>, crate::ConnectionError>;
152}
153
154impl<T> QueryRemoteCommitLog for &T
155where
156 T: QueryRemoteCommitLog,
157{
158 fn get_latest_remote_log_for_group(
159 &self,
160 group_id: &[u8],
161 ) -> Result<Option<RemoteCommitLog>, crate::ConnectionError> {
162 (**self).get_latest_remote_log_for_group(group_id)
163 }
164
165 fn get_remote_commit_log_after_cursor(
166 &self,
167 group_id: &[u8],
168 after_cursor: i64,
169 order_by: RemoteCommitLogOrder,
170 ) -> Result<Vec<RemoteCommitLog>, crate::ConnectionError> {
171 (**self).get_remote_commit_log_after_cursor(group_id, after_cursor, order_by)
172 }
173}
174
175impl<C: ConnectionExt> QueryRemoteCommitLog for DbConnection<C> {
176 fn get_latest_remote_log_for_group(
177 &self,
178 group_id: &[u8],
179 ) -> Result<Option<RemoteCommitLog>, crate::ConnectionError> {
180 self.raw_query_read(|db| {
181 dsl::remote_commit_log
182 .filter(remote_commit_log::group_id.eq(group_id))
183 .order(remote_commit_log::log_sequence_id.desc())
184 .limit(1)
185 .first(db)
186 .optional()
187 })
188 }
189
190 fn get_remote_commit_log_after_cursor(
191 &self,
192 group_id: &[u8],
193 after_cursor: i64,
194 order: RemoteCommitLogOrder,
195 ) -> Result<Vec<RemoteCommitLog>, crate::ConnectionError> {
196 if after_cursor > i32::MAX as i64 {
199 return Err(crate::ConnectionError::Database(
200 diesel::result::Error::QueryBuilderError("Cursor value exceeds i32::MAX".into()),
201 ));
202 }
203 let after_cursor: i32 = after_cursor as i32;
204
205 let query = dsl::remote_commit_log
206 .filter(dsl::group_id.eq(group_id))
207 .filter(dsl::rowid.gt(after_cursor))
208 .filter(dsl::commit_sequence_id.ne(0));
209
210 self.raw_query_read(|db| match order {
211 RemoteCommitLogOrder::AscendingByRowid => query.order_by(dsl::rowid.asc()).load(db),
212 RemoteCommitLogOrder::DescendingByRowid => query.order_by(dsl::rowid.desc()).load(db),
213 })
214 }
215}