1
use std::str::FromStr;
2

            
3
use serde::{Deserialize, Serialize};
4
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
5

            
6
#[derive(Default, PartialEq, Eq, Copy, Clone, Debug, zvariant::Type)]
7
#[zvariant(signature = "s")]
8
pub enum ContentType {
9
    Text,
10
    #[default]
11
    Blob,
12
}
13

            
14
impl Serialize for ContentType {
15
26
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
16
    where
17
        S: serde::Serializer,
18
    {
19
53
        self.as_str().serialize(serializer)
20
    }
21
}
22

            
23
impl<'de> Deserialize<'de> for ContentType {
24
22
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
25
    where
26
        D: serde::Deserializer<'de>,
27
    {
28
23
        let s = String::deserialize(deserializer)?;
29
44
        Self::from_str(&s).map_err(serde::de::Error::custom)
30
    }
31
}
32

            
33
impl FromStr for ContentType {
34
    type Err = String;
35

            
36
21
    fn from_str(s: &str) -> Result<Self, Self::Err> {
37
        match s {
38
44
            "text/plain" => Ok(Self::Text),
39
26
            "application/octet-stream" => Ok(Self::Blob),
40
2
            e => Err(format!("Invalid content type: {e}")),
41
        }
42
    }
43
}
44

            
45
impl ContentType {
46
21
    pub const fn as_str(&self) -> &'static str {
47
21
        match self {
48
21
            Self::Text => "text/plain",
49
13
            Self::Blob => "application/octet-stream",
50
        }
51
    }
52
}
53

            
54
/// A wrapper around a combination of (secret, content-type).
55
#[derive(Clone, PartialEq, Eq, Zeroize, ZeroizeOnDrop)]
56
pub enum Secret {
57
    /// Corresponds to [`ContentType::Text`]
58
    Text(String),
59
    /// Corresponds to [`ContentType::Blob`]
60
    Blob(Vec<u8>),
61
}
62

            
63
impl std::fmt::Debug for Secret {
64
2
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65
2
        match self {
66
2
            Self::Text(_) => write!(f, "Secret::Text([REDACTED])"),
67
2
            Self::Blob(_) => write!(f, "Secret::Blob([REDACTED])"),
68
        }
69
    }
70
}
71

            
72
impl Secret {
73
    /// Generate a random secret, used when creating a session collection.
74
35
    pub fn random() -> Result<Self, getrandom::Error> {
75
39
        let mut secret = [0; 64];
76
        // Equivalent of `ring::rand::SecureRandom`
77
35
        getrandom::fill(&mut secret)?;
78

            
79
39
        Ok(Self::blob(secret))
80
    }
81

            
82
    /// Get the sandboxed app secret if the app is sandboxed using
83
    /// org.freedesktop.portal.Secret portal.
84
    pub async fn sandboxed() -> Result<Self, crate::file::Error> {
85
        Ok(Self::blob(
86
            ashpd::desktop::secret::retrieve()
87
                .await
88
                .map_err(crate::file::Error::from)?,
89
        ))
90
    }
91

            
92
    /// Create a text secret, stored with `text/plain` content type.
93
62
    pub fn text(value: impl AsRef<str>) -> Self {
94
116
        Self::Text(value.as_ref().to_owned())
95
    }
96

            
97
    /// Create a blob secret, stored with `application/octet-stream` content
98
    /// type.
99
117
    pub fn blob(value: impl AsRef<[u8]>) -> Self {
100
243
        Self::Blob(value.as_ref().to_owned())
101
    }
102

            
103
21
    pub const fn content_type(&self) -> ContentType {
104
22
        match self {
105
21
            Self::Text(_) => ContentType::Text,
106
13
            Self::Blob(_) => ContentType::Blob,
107
        }
108
    }
109

            
110
40
    pub fn as_bytes(&self) -> &[u8] {
111
71
        match self {
112
37
            Self::Text(text) => text.as_bytes(),
113
26
            Self::Blob(bytes) => bytes.as_ref(),
114
        }
115
    }
116

            
117
19
    pub fn with_content_type(content_type: ContentType, secret: impl AsRef<[u8]>) -> Self {
118
19
        match content_type {
119
38
            ContentType::Text => match String::from_utf8(secret.as_ref().to_owned()) {
120
19
                Ok(text) => Secret::text(text),
121
2
                Err(_e) => {
122
4
                    #[cfg(feature = "tracing")]
123
                    tracing::warn!(
124
                        "Failed to decode secret as UTF-8: {}, falling back to blob",
125
                        _e
126
                    );
127

            
128
2
                    Secret::blob(secret)
129
                }
130
            },
131
26
            _ => Secret::blob(secret),
132
        }
133
    }
134
}
135

            
136
impl From<&[u8]> for Secret {
137
8
    fn from(value: &[u8]) -> Self {
138
8
        Self::blob(value)
139
    }
140
}
141

            
142
impl From<Zeroizing<Vec<u8>>> for Secret {
143
17
    fn from(value: Zeroizing<Vec<u8>>) -> Self {
144
22
        Self::blob(value)
145
    }
146
}
147

            
148
impl From<Vec<u8>> for Secret {
149
18
    fn from(value: Vec<u8>) -> Self {
150
18
        Self::blob(value)
151
    }
152
}
153

            
154
impl From<&Vec<u8>> for Secret {
155
6
    fn from(value: &Vec<u8>) -> Self {
156
6
        Self::blob(value)
157
    }
158
}
159

            
160
impl<const N: usize> From<&[u8; N]> for Secret {
161
    fn from(value: &[u8; N]) -> Self {
162
        Self::blob(value)
163
    }
164
}
165

            
166
impl From<String> for Secret {
167
6
    fn from(value: String) -> Self {
168
6
        Self::text(value)
169
    }
170
}
171

            
172
impl From<&str> for Secret {
173
38
    fn from(value: &str) -> Self {
174
36
        Self::text(value)
175
    }
176
}
177

            
178
impl std::ops::Deref for Secret {
179
    type Target = [u8];
180

            
181
37
    fn deref(&self) -> &Self::Target {
182
31
        self.as_bytes()
183
    }
184
}
185

            
186
impl AsRef<[u8]> for Secret {
187
12
    fn as_ref(&self) -> &[u8] {
188
12
        self.as_bytes()
189
    }
190
}
191

            
192
#[cfg(test)]
193
mod tests {
194
    use zvariant::{Endian, serialized::Context, to_bytes};
195

            
196
    use super::*;
197

            
198
    #[test]
199
    fn secret_debug_is_redacted() {
200
        let text_secret = Secret::text("password");
201
        let blob_secret = Secret::blob([1, 2, 3]);
202

            
203
        assert_eq!(format!("{:?}", text_secret), "Secret::Text([REDACTED])");
204
        assert_eq!(format!("{:?}", blob_secret), "Secret::Blob([REDACTED])");
205
    }
206

            
207
    #[test]
208
    fn content_type_serialization() {
209
        let ctxt = Context::new_dbus(Endian::Little, 0);
210

            
211
        // Test Text serialization
212
        let encoded = to_bytes(ctxt, &ContentType::Text).unwrap();
213
        let value: String = encoded.deserialize().unwrap().0;
214
        assert_eq!(value, "text/plain");
215

            
216
        // Test Blob serialization
217
        let encoded = to_bytes(ctxt, &ContentType::Blob).unwrap();
218
        let value: String = encoded.deserialize().unwrap().0;
219
        assert_eq!(value, "application/octet-stream");
220

            
221
        // Test Text deserialization
222
        let encoded = to_bytes(ctxt, &"text/plain").unwrap();
223
        let content_type: ContentType = encoded.deserialize().unwrap().0;
224
        assert_eq!(content_type, ContentType::Text);
225

            
226
        // Test Blob deserialization
227
        let encoded = to_bytes(ctxt, &"application/octet-stream").unwrap();
228
        let content_type: ContentType = encoded.deserialize().unwrap().0;
229
        assert_eq!(content_type, ContentType::Blob);
230

            
231
        // Test invalid content type deserialization
232
        let encoded = to_bytes(ctxt, &"invalid/type").unwrap();
233
        let result: Result<(ContentType, _), _> = encoded.deserialize();
234
        assert!(result.is_err());
235
        assert!(
236
            result
237
                .unwrap_err()
238
                .to_string()
239
                .contains("Invalid content type")
240
        );
241
    }
242

            
243
    #[test]
244
    fn content_type_from_str() {
245
        assert_eq!(
246
            ContentType::from_str("text/plain").unwrap(),
247
            ContentType::Text
248
        );
249
        assert_eq!(
250
            ContentType::from_str("application/octet-stream").unwrap(),
251
            ContentType::Blob
252
        );
253

            
254
        // Test error case
255
        let result = ContentType::from_str("invalid");
256
        assert!(result.is_err());
257
        assert!(result.unwrap_err().contains("Invalid content type"));
258
    }
259

            
260
    #[test]
261
    fn invalid_utf8() {
262
        // Test with invalid UTF-8 bytes
263
        let invalid_utf8 = vec![0xFF, 0xFE, 0xFD];
264

            
265
        // Should fall back to blob when UTF-8 decoding fails
266
        let secret = Secret::with_content_type(ContentType::Text, &invalid_utf8);
267
        assert_eq!(secret.content_type(), ContentType::Blob);
268
        assert_eq!(&*secret, &[0xFF, 0xFE, 0xFD]);
269

            
270
        // Test with valid UTF-8
271
        let valid_utf8 = "Hello, World!";
272
        let secret = Secret::with_content_type(ContentType::Text, valid_utf8.as_bytes());
273
        assert_eq!(secret.content_type(), ContentType::Text);
274
        assert_eq!(&*secret, valid_utf8.as_bytes());
275

            
276
        // Test with blob content type
277
        let data = vec![1, 2, 3, 4];
278
        let secret = Secret::with_content_type(ContentType::Blob, &data);
279
        assert_eq!(secret.content_type(), ContentType::Blob);
280
        assert_eq!(&*secret, &[1, 2, 3, 4]);
281
    }
282

            
283
    #[test]
284
    fn random() {
285
        let secret1 = Secret::random().unwrap();
286
        let secret2 = Secret::random().unwrap();
287

            
288
        // Random secrets should be blobs
289
        assert_eq!(secret1.content_type(), ContentType::Blob);
290
        assert_eq!(secret2.content_type(), ContentType::Blob);
291

            
292
        // Should be 64 bytes
293
        assert_eq!(secret1.as_bytes().len(), 64);
294
        assert_eq!(secret2.as_bytes().len(), 64);
295

            
296
        // Should be different
297
        assert_ne!(secret1.as_bytes(), secret2.as_bytes());
298
    }
299
}