Saves some allocations on E2E decryption

Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-09-05 00:52:54 -05:00
parent 2863fe52c1
commit 0502c6098a
No known key found for this signature in database
2 changed files with 20 additions and 18 deletions

View file

@ -110,7 +110,7 @@ impl MatrixClient {
MediaSource::Plain(_) => { MediaSource::Plain(_) => {
Box::pin(body.map_err(DumpError::from)) Box::pin(body.map_err(DumpError::from))
} }
MediaSource::Encrypted(e) => Box::pin(decrypt_file(e.as_ref(), body).await?.map_ok(|v| Bytes::from(v)).map_err( MediaSource::Encrypted(e) => Box::pin(decrypt_file(e.as_ref(), body)?.map_ok(|v| Bytes::from(v)).map_err(
|e| match e { |e| match e {
ErrOrWrongHash::Err(e) => e.into(), ErrOrWrongHash::Err(e) => e.into(),
ErrOrWrongHash::WrongHash => DumpError::HashMismatch, ErrOrWrongHash::WrongHash => DumpError::HashMismatch,

View file

@ -9,7 +9,7 @@ type Aes256CTR = ctr::Ctr128BE<aes::Aes256>;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum ErrOrWrongHash<E: std::error::Error> { pub enum ErrOrWrongHash<E: std::error::Error> {
#[error("Error: {0}")] #[error("{0}")]
Err(#[from] E), Err(#[from] E),
#[error("Wrong hash")] #[error("Wrong hash")]
WrongHash, WrongHash,
@ -36,11 +36,11 @@ pub enum DecryptError {
VersionMismatch, VersionMismatch,
} }
pub fn try_decrypt<'s, E>( pub fn try_decrypt<'s, S: AsMut<[u8]>, E>(
jwk: &JsonWebKey, jwk: &JsonWebKey,
data: impl TryStream<Ok = Vec<u8>, Error = E> + Send + 's, data: impl TryStream<Ok = S, Error = E> + Send + 's,
iv: &[u8], iv: &[u8],
) -> Result<impl TryStream<Ok = Vec<u8>, Error = E> + 's, DecryptError> { ) -> Result<impl TryStream<Ok = S, Error = E> + 's, DecryptError> {
if jwk.alg != "A256CTR" { if jwk.alg != "A256CTR" {
return Err(DecryptError::UnsupportedEncryptionAlgorithm( return Err(DecryptError::UnsupportedEncryptionAlgorithm(
jwk.alg.clone(), jwk.alg.clone(),
@ -52,33 +52,36 @@ pub fn try_decrypt<'s, E>(
let mut cipher = Aes256CTR::new_from_slices(key, iv).map_err(|_| DecryptError::WrongKeySpec)?; let mut cipher = Aes256CTR::new_from_slices(key, iv).map_err(|_| DecryptError::WrongKeySpec)?;
Ok(data.map_ok(move |mut chunk| { Ok(data.map_ok(move |mut chunk| {
cipher.apply_keystream(&mut chunk); cipher.apply_keystream(chunk.as_mut());
chunk chunk
})) }))
} }
pub struct VerifyingStream<'a, S> { pub struct VerifyingStream<'a, R, S> {
inner: Pin<Box<S>>, inner: Pin<Box<S>>,
hasher: Option<sha2::Sha256>, hasher: Option<sha2::Sha256>,
expected: &'a [u8], expected: &'a [u8],
_marker: std::marker::PhantomData<R>,
} }
impl<'a, S> VerifyingStream<'a, S> { impl<'a, R, S> VerifyingStream<'a, R, S> {
#[must_use] #[must_use]
pub fn new(inner: Pin<Box<S>>, expected: &'a [u8]) -> Self { pub fn new(inner: Pin<Box<S>>, expected: &'a [u8]) -> Self {
Self { Self {
inner, inner,
hasher: Some(sha2::Sha256::new()), hasher: Some(sha2::Sha256::new()),
expected, expected,
_marker: std::marker::PhantomData,
} }
} }
} }
impl<'a, S, E: std::error::Error> futures::Stream for VerifyingStream<'a, S> impl<'a, S, R: AsRef<[u8]>, E: std::error::Error> futures::Stream for VerifyingStream<'a, R, S>
where where
S: futures::Stream<Item = Result<Vec<u8>, E>>, R: Unpin,
S: futures::Stream<Item = Result<R, E>>,
{ {
type Item = Result<Vec<u8>, ErrOrWrongHash<E>>; type Item = Result<R, ErrOrWrongHash<E>>;
fn poll_next( fn poll_next(
mut self: std::pin::Pin<&mut Self>, mut self: std::pin::Pin<&mut Self>,
@ -103,7 +106,7 @@ where
} }
} }
pub async fn decrypt_file<'s, E: std::error::Error + 's>( pub fn decrypt_file<'s, E: std::error::Error + 's>(
file: &'s EncryptedFile, file: &'s EncryptedFile,
data: impl TryStream<Ok = Bytes, Error = E> + Send + 's, data: impl TryStream<Ok = Bytes, Error = E> + Send + 's,
) -> Result<impl TryStream<Ok = Vec<u8>, Error = ErrOrWrongHash<E>> + 's, DecryptError> { ) -> Result<impl TryStream<Ok = Vec<u8>, Error = ErrOrWrongHash<E>> + 's, DecryptError> {
@ -116,10 +119,9 @@ pub async fn decrypt_file<'s, E: std::error::Error + 's>(
let sha256 = file.hashes.get("sha256").ok_or(DecryptError::MissingHash)?; let sha256 = file.hashes.get("sha256").ok_or(DecryptError::MissingHash)?;
let sha256_expect = sha256.as_bytes(); let sha256_expect = sha256.as_bytes();
let data = Box::pin(VerifyingStream::new( try_decrypt(
Box::pin(data.map_ok(|b| b.to_vec())), &file.key,
sha256_expect, VerifyingStream::new(Box::pin(data.map_ok(|b| b.to_vec())), sha256_expect),
)); iv,
)
try_decrypt(&file.key, data, iv)
} }