Saves some allocations on E2E decryption

Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-09-05 00:52:54 -05:00
parent 2863fe52c1
commit 0502c6098a
No known key found for this signature in database
2 changed files with 20 additions and 18 deletions

View file

@ -110,7 +110,7 @@ impl MatrixClient {
MediaSource::Plain(_) => {
Box::pin(body.map_err(DumpError::from))
}
MediaSource::Encrypted(e) => Box::pin(decrypt_file(e.as_ref(), body).await?.map_ok(|v| Bytes::from(v)).map_err(
MediaSource::Encrypted(e) => Box::pin(decrypt_file(e.as_ref(), body)?.map_ok(|v| Bytes::from(v)).map_err(
|e| match e {
ErrOrWrongHash::Err(e) => e.into(),
ErrOrWrongHash::WrongHash => DumpError::HashMismatch,

View file

@ -9,7 +9,7 @@ type Aes256CTR = ctr::Ctr128BE<aes::Aes256>;
#[derive(Debug, thiserror::Error)]
pub enum ErrOrWrongHash<E: std::error::Error> {
#[error("Error: {0}")]
#[error("{0}")]
Err(#[from] E),
#[error("Wrong hash")]
WrongHash,
@ -36,11 +36,11 @@ pub enum DecryptError {
VersionMismatch,
}
pub fn try_decrypt<'s, E>(
pub fn try_decrypt<'s, S: AsMut<[u8]>, E>(
jwk: &JsonWebKey,
data: impl TryStream<Ok = Vec<u8>, Error = E> + Send + 's,
data: impl TryStream<Ok = S, Error = E> + Send + 's,
iv: &[u8],
) -> Result<impl TryStream<Ok = Vec<u8>, Error = E> + 's, DecryptError> {
) -> Result<impl TryStream<Ok = S, Error = E> + 's, DecryptError> {
if jwk.alg != "A256CTR" {
return Err(DecryptError::UnsupportedEncryptionAlgorithm(
jwk.alg.clone(),
@ -52,33 +52,36 @@ pub fn try_decrypt<'s, E>(
let mut cipher = Aes256CTR::new_from_slices(key, iv).map_err(|_| DecryptError::WrongKeySpec)?;
Ok(data.map_ok(move |mut chunk| {
cipher.apply_keystream(&mut chunk);
cipher.apply_keystream(chunk.as_mut());
chunk
}))
}
pub struct VerifyingStream<'a, S> {
pub struct VerifyingStream<'a, R, S> {
inner: Pin<Box<S>>,
hasher: Option<sha2::Sha256>,
expected: &'a [u8],
_marker: std::marker::PhantomData<R>,
}
impl<'a, S> VerifyingStream<'a, S> {
impl<'a, R, S> VerifyingStream<'a, R, S> {
#[must_use]
pub fn new(inner: Pin<Box<S>>, expected: &'a [u8]) -> Self {
Self {
inner,
hasher: Some(sha2::Sha256::new()),
expected,
_marker: std::marker::PhantomData,
}
}
}
impl<'a, S, E: std::error::Error> futures::Stream for VerifyingStream<'a, S>
impl<'a, S, R: AsRef<[u8]>, E: std::error::Error> futures::Stream for VerifyingStream<'a, R, S>
where
S: futures::Stream<Item = Result<Vec<u8>, E>>,
R: Unpin,
S: futures::Stream<Item = Result<R, E>>,
{
type Item = Result<Vec<u8>, ErrOrWrongHash<E>>;
type Item = Result<R, ErrOrWrongHash<E>>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
@ -103,7 +106,7 @@ where
}
}
pub async fn decrypt_file<'s, E: std::error::Error + 's>(
pub fn decrypt_file<'s, E: std::error::Error + 's>(
file: &'s EncryptedFile,
data: impl TryStream<Ok = Bytes, Error = E> + Send + 's,
) -> Result<impl TryStream<Ok = Vec<u8>, Error = ErrOrWrongHash<E>> + 's, DecryptError> {
@ -116,10 +119,9 @@ pub async fn decrypt_file<'s, E: std::error::Error + 's>(
let sha256 = file.hashes.get("sha256").ok_or(DecryptError::MissingHash)?;
let sha256_expect = sha256.as_bytes();
let data = Box::pin(VerifyingStream::new(
Box::pin(data.map_ok(|b| b.to_vec())),
sha256_expect,
));
try_decrypt(&file.key, data, iv)
try_decrypt(
&file.key,
VerifyingStream::new(Box::pin(data.map_ok(|b| b.to_vec())), sha256_expect),
iv,
)
}