xmtp_api_d14n/protocol/extractors/
key_packages.rs

1use xmtp_proto::ConversionError;
2
3use crate::protocol::Extractor;
4use crate::protocol::traits::EnvelopeVisitor;
5use xmtp_proto::xmtp::mls::api::v1::UploadKeyPackageRequest;
6use xmtp_proto::xmtp::mls::api::v1::fetch_key_packages_response::KeyPackage;
7
8/// Key Packages Extractor
9/// This Extractor can be applied to multiple envelopes without losing state
10#[derive(Default, Clone)]
11pub struct KeyPackagesExtractor {
12    key_packages: Vec<KeyPackage>,
13}
14
15impl Extractor for KeyPackagesExtractor {
16    type Output = Vec<KeyPackage>;
17
18    fn get(self) -> Self::Output {
19        self.key_packages
20    }
21}
22
23impl KeyPackagesExtractor {
24    pub fn new() -> Self {
25        Default::default()
26    }
27
28    pub fn get(self) -> Vec<KeyPackage> {
29        self.key_packages
30    }
31}
32
33impl EnvelopeVisitor<'_> for KeyPackagesExtractor {
34    type Error = ConversionError;
35
36    fn visit_upload_key_package(
37        &mut self,
38        req: &UploadKeyPackageRequest,
39    ) -> Result<(), Self::Error> {
40        let key_package = req.key_package.as_ref().ok_or(ConversionError::Missing {
41            item: "key_package",
42            r#type: "OriginatorEnvelope",
43        })?;
44        self.key_packages.push(KeyPackage {
45            key_package_tls_serialized: key_package.key_package_tls_serialized.clone(),
46        });
47        Ok(())
48    }
49}
50
51#[cfg(test)]
52mod tests {
53    use super::*;
54    use crate::protocol::extractors::test_utils::*;
55    use crate::protocol::{EnvelopeError, ProtocolEnvelope};
56
57    #[xmtp_common::test]
58    fn test_extract_kp() {
59        let kp = xmtp_common::rand_vec::<32>();
60        let envelope = TestEnvelopeBuilder::new()
61            .with_key_package_custom(kp.clone())
62            .build();
63        let mut extractor = KeyPackagesExtractor::new();
64        envelope.accept(&mut extractor).unwrap();
65        let extracted_kp = extractor.get();
66        assert_eq!(kp, extracted_kp[0].key_package_tls_serialized);
67    }
68
69    #[xmtp_common::test]
70    fn extractor_errors_when_missing() {
71        let envelope = TestEnvelopeBuilder::new()
72            .with_invalid_key_package()
73            .build();
74        let mut extractor = KeyPackagesExtractor::new();
75        let err = envelope.accept(&mut extractor).unwrap_err();
76        assert!(matches!(
77            err,
78            EnvelopeError::Conversion(ConversionError::Missing { .. })
79        ));
80    }
81}