1
use std::{
2
    ops::{Mul, Rem, Shr},
3
    sync::LazyLock,
4
};
5

            
6
use cbc::cipher::{
7
    BlockModeDecrypt, BlockModeEncrypt, BlockSizeUser, IvSizeUser, KeyInit, KeyIvInit, KeySizeUser,
8
    block_padding::{NoPadding, Pkcs7},
9
};
10
use hkdf::Hkdf;
11
use md5::Md5;
12
use num::{FromPrimitive, Integer, One, Zero};
13
use num_bigint_dig::BigUint;
14
use pbkdf2::hmac::{
15
    Mac,
16
    digest::{Digest, FixedOutput, Output, OutputSizeUser},
17
};
18
use sha2::Sha256;
19
use subtle::ConstantTimeEq;
20
use zeroize::{Zeroize, Zeroizing};
21

            
22
use crate::{Key, file};
23

            
24
type EncAlg = cbc::Encryptor<aes::Aes128>;
25
type DecAlg = cbc::Decryptor<aes::Aes128>;
26
type MacAlg = pbkdf2::hmac::Hmac<sha2::Sha256>;
27

            
28
35
pub fn encrypt(
29
    data: impl AsRef<[u8]>,
30
    key: &Key,
31
    iv: impl AsRef<[u8]>,
32
) -> Result<Vec<u8>, super::Error> {
33
67
    let mut blob = vec![0; data.as_ref().len() + EncAlg::block_size()];
34

            
35
    // Unwrapping since adding `CIPHER_BLOCK_SIZE` to array is enough space for
36
    // PKCS7
37
132
    let encrypted = EncAlg::new_from_slices(key.as_ref(), iv.as_ref())
38
        .expect("Invalid key length")
39
63
        .encrypt_padded_b2b::<Pkcs7>(data.as_ref(), &mut blob)?;
40

            
41
30
    Ok(encrypted.to_vec())
42
}
43

            
44
20
pub fn decrypt(
45
    blob: impl AsRef<[u8]>,
46
    key: &Key,
47
    iv: impl AsRef<[u8]>,
48
) -> Result<Zeroizing<Vec<u8>>, super::Error> {
49
40
    let mut data = blob.as_ref().to_vec();
50

            
51
80
    let decrypted = DecAlg::new_from_slices(key.as_ref(), iv.as_ref())
52
        .expect("Invalid key length")
53
40
        .decrypt_padded::<Pkcs7>(&mut data)?;
54

            
55
20
    Ok(decrypted.to_vec().into())
56
}
57

            
58
4
pub(crate) fn decrypt_no_padding(
59
    blob: impl AsRef<[u8]>,
60
    key: &Key,
61
    iv: impl AsRef<[u8]>,
62
) -> Result<Zeroizing<Vec<u8>>, super::Error> {
63
8
    let mut data = blob.as_ref().to_vec();
64

            
65
16
    let decrypted = DecAlg::new_from_slices(key.as_ref(), iv.as_ref())
66
        .expect("Invalid key length")
67
8
        .decrypt_padded::<NoPadding>(&mut data)?;
68

            
69
4
    Ok(decrypted.to_vec().into())
70
}
71

            
72
11
pub(crate) fn iv_len() -> usize {
73
10
    DecAlg::iv_size()
74
}
75

            
76
13
pub(crate) fn generate_private_key() -> Result<Zeroizing<Vec<u8>>, super::Error> {
77
13
    let mut key = vec![0u8; EncAlg::key_size()];
78
25
    getrandom::fill(&mut key)?;
79
10
    Ok(Zeroizing::new(key))
80
}
81

            
82
11
pub(crate) fn generate_public_key(private_key: impl AsRef<[u8]>) -> Result<Vec<u8>, super::Error> {
83
22
    let private_key_uint = BigUint::from_bytes_be(private_key.as_ref());
84
16
    static DH_GENERATOR: LazyLock<BigUint> = LazyLock::new(|| BigUint::from_u64(0x2).unwrap());
85
22
    let public_key_uint = powm(&DH_GENERATOR, private_key_uint);
86

            
87
17
    Ok(public_key_uint.to_bytes_be())
88
}
89

            
90
11
pub(crate) fn generate_aes_key(
91
    private_key: impl AsRef<[u8]>,
92
    server_public_key: impl AsRef<[u8]>,
93
) -> Result<Zeroizing<Vec<u8>>, super::Error> {
94
22
    let server_public_key_uint = BigUint::from_bytes_be(server_public_key.as_ref());
95
22
    let private_key_uint = BigUint::from_bytes_be(private_key.as_ref());
96
11
    let common_secret = powm(&server_public_key_uint, private_key_uint);
97

            
98
10
    let mut common_secret_bytes = common_secret.to_bytes_be();
99
18
    let mut common_secret_padded = vec![0; 128 - common_secret_bytes.len()];
100
    // inefficient, but ok for now
101
11
    common_secret_padded.append(&mut common_secret_bytes);
102

            
103
    // hkdf
104
    // input_keying_material
105
11
    let ikm = common_secret_padded;
106
11
    let salt = None;
107
    let info = [];
108

            
109
    // output keying material
110
22
    let mut okm = Zeroizing::new(vec![0; 16]);
111

            
112
22
    let (_, hk) = Hkdf::<Sha256>::extract(salt, &ikm);
113
21
    hk.expand(&info, okm.as_mut())
114
        .expect("hkdf expand should never fail");
115

            
116
11
    Ok(okm)
117
}
118

            
119
12
pub fn generate_iv() -> Result<Vec<u8>, super::Error> {
120
13
    let mut iv = vec![0u8; EncAlg::iv_size()];
121
26
    getrandom::fill(&mut iv)?;
122
13
    Ok(iv)
123
}
124

            
125
10
pub(crate) fn mac_len() -> usize {
126
10
    MacAlg::output_size()
127
}
128

            
129
22
pub(crate) fn compute_mac(data: impl AsRef<[u8]>, key: &Key) -> Result<crate::Mac, super::Error> {
130
47
    let mut mac = MacAlg::new_from_slice(key.as_ref()).unwrap();
131
47
    mac.update(data.as_ref());
132
25
    Ok(crate::Mac::new(mac.finalize().into_bytes().to_vec()))
133
}
134

            
135
22
pub(crate) fn verify_mac(
136
    data: impl AsRef<[u8]>,
137
    key: &Key,
138
    expected_mac: impl AsRef<[u8]>,
139
) -> Result<bool, super::Error> {
140
43
    let mut mac = MacAlg::new_from_slice(key.as_ref()).unwrap();
141
43
    mac.update(data.as_ref());
142
21
    Ok(mac.verify_slice(expected_mac.as_ref()).is_ok())
143
}
144

            
145
4
pub(crate) fn verify_checksum_md5(digest: impl AsRef<[u8]>, content: impl AsRef<[u8]>) -> bool {
146
4
    let mut hasher = Md5::new();
147
8
    hasher.update(content.as_ref());
148
4
    hasher.finalize_fixed().ct_eq(digest.as_ref()).into()
149
}
150

            
151
22
pub(crate) fn derive_key(
152
    secret: impl AsRef<[u8]>,
153
    key_strength: Result<(), file::WeakKeyError>,
154
    salt: impl AsRef<[u8]>,
155
    iteration_count: usize,
156
) -> Result<Key, super::Error> {
157
38
    let mut key = Key::new_with_strength(vec![0; EncAlg::block_size()], key_strength);
158

            
159
    pbkdf2::pbkdf2::<pbkdf2::hmac::Hmac<sha2::Sha256>>(
160
22
        secret.as_ref(),
161
16
        salt.as_ref(),
162
22
        iteration_count.try_into().unwrap(),
163
16
        key.as_mut(),
164
    )
165
    .expect("HMAC can be initialized with any key length");
166

            
167
18
    Ok(key)
168
}
169

            
170
5
pub(crate) fn legacy_derive_key_and_iv(
171
    secret: impl AsRef<[u8]>,
172
    key_strength: Result<(), file::WeakKeyError>,
173
    salt: impl AsRef<[u8]>,
174
    iteration_count: usize,
175
) -> Result<(Key, Vec<u8>), super::Error> {
176
10
    let mut buffer = vec![0; EncAlg::key_size() + EncAlg::iv_size()];
177
5
    let mut hasher = Sha256::new();
178
10
    let mut digest_buffer = vec![0; <Sha256 as Digest>::output_size()];
179
    #[allow(deprecated)]
180
10
    let digest = Output::<Sha256>::from_mut_slice(digest_buffer.as_mut_slice());
181

            
182
5
    let mut pos = 0usize;
183

            
184
    loop {
185
5
        hasher.update(secret.as_ref());
186
5
        hasher.update(salt.as_ref());
187
5
        hasher.finalize_into_reset(digest);
188

            
189
5
        for _ in 1..iteration_count {
190
            // We can't pass an instance, the borrow checker
191
            // would complain about digest being dropped at the end of
192
            // for block
193
            #[allow(clippy::needless_borrows_for_generic_args)]
194
5
            hasher.update(&digest);
195
5
            hasher.finalize_into_reset(digest);
196
        }
197

            
198
5
        let to_read = usize::min(digest.len(), buffer.len() - pos);
199
5
        buffer[pos..].copy_from_slice(&digest[..to_read]);
200
5
        pos += to_read;
201

            
202
10
        if pos == buffer.len() {
203
            break;
204
        }
205

            
206
        // We can't pass an instance, the borrow checker
207
        // would complain about digest being dropped at the end of
208
        // loop block
209
        #[allow(clippy::needless_borrows_for_generic_args)]
210
        hasher.update(&digest);
211
    }
212

            
213
10
    let iv = buffer.split_off(EncAlg::key_size());
214
10
    Ok((Key::new_with_strength(buffer, key_strength), iv))
215
}
216

            
217
/// from https://github.com/plietar/librespot/blob/master/core/src/util/mod.rs#L53
218
11
fn powm(base: &BigUint, mut exp: BigUint) -> BigUint {
219
    // for key exchange
220
8
    static DH_PRIME: LazyLock<BigUint> = LazyLock::new(|| {
221
8
        BigUint::from_bytes_be(&[
222
            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xC9, 0x0F, 0xDA, 0xA2, 0x21, 0x68,
223
            0xC2, 0x34, 0xC4, 0xC6, 0x62, 0x8B, 0x80, 0xDC, 0x1C, 0xD1, 0x29, 0x02, 0x4E, 0x08,
224
            0x8A, 0x67, 0xCC, 0x74, 0x02, 0x0B, 0xBE, 0xA6, 0x3B, 0x13, 0x9B, 0x22, 0x51, 0x4A,
225
            0x08, 0x79, 0x8E, 0x34, 0x04, 0xDD, 0xEF, 0x95, 0x19, 0xB3, 0xCD, 0x3A, 0x43, 0x1B,
226
            0x30, 0x2B, 0x0A, 0x6D, 0xF2, 0x5F, 0x14, 0x37, 0x4F, 0xE1, 0x35, 0x6D, 0x6D, 0x51,
227
            0xC2, 0x45, 0xE4, 0x85, 0xB5, 0x76, 0x62, 0x5E, 0x7E, 0xC6, 0xF4, 0x4C, 0x42, 0xE9,
228
            0xA6, 0x37, 0xED, 0x6B, 0x0B, 0xFF, 0x5C, 0xB6, 0xF4, 0x06, 0xB7, 0xED, 0xEE, 0x38,
229
            0x6B, 0xFB, 0x5A, 0x89, 0x9F, 0xA5, 0xAE, 0x9F, 0x24, 0x11, 0x7C, 0x4B, 0x1F, 0xE6,
230
            0x49, 0x28, 0x66, 0x51, 0xEC, 0xE6, 0x53, 0x81, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
231
            0xFF, 0xFF,
232
        ])
233
    });
234

            
235
11
    let mut base = base.clone();
236
11
    let mut result: BigUint = One::one();
237

            
238
33
    while !exp.is_zero() {
239
31
        if exp.is_odd() {
240
8
            result = result.mul(&base).rem(&*DH_PRIME);
241
        }
242
16
        exp = exp.shr(1);
243
8
        base = (&base).mul(&base).rem(&*DH_PRIME);
244
    }
245
9
    exp.zeroize();
246

            
247
9
    result
248
}