use std::{ io::{Cursor, Write}, pin::Pin, task::Poll, }; use futures::{Stream, StreamExt}; use crate::{fetch::HTTPResponse, ErrorResponse}; // MIME sniffing data include!(concat!(env!("OUT_DIR"), "/magic.rs")); /// An association between a MIME type and a file extension #[derive(Clone)] pub struct MIMEAssociation { /// The MIME type pub mime: &'static str, /// The file extension pub ext: &'static str, /// Whether the file is safe to display pub safe: bool, /// The file signatures signatures: &'static [FlattenedFileSignature], } #[derive(Debug, Clone, PartialEq, serde::Serialize)] struct FlattenedFileSignature(&'static [(u8, u8)]); impl FlattenedFileSignature { #[inline] fn matches(&self, test: &[u8]) -> bool { if self.0.len() > test.len() { return false; } self.0 .iter() .zip(test.iter()) .all(|((sig, mask), byte)| sig & mask == *byte & mask) } } /// A stream that sniffs the MIME type of the data it receives pub struct SniffingStream { body: ::BodyStream, sniff_buffer: Cursor<[u8; SNIFF_SIZE]>, sniffed: Option, } /// The result of MIME sniffing #[derive(Debug, Clone, PartialEq, serde::Serialize)] pub struct SniffResult { /// The MIME type that was sniffed pub sniffed_mime: Option<&'static str>, /// Whether the file may be unsafe pub maybe_unsafe: bool, } impl SniffingStream { /// Create a new `SniffingStream` from a response pub fn new(response: R) -> Self { Self { body: response.body(), sniff_buffer: Cursor::new([0; SNIFF_SIZE]), sniffed: None, } } /// Create a new `SniffingStream` from a body stream pub fn new_from_body_stream(response: ::BodyStream) -> Self { Self { body: response, sniff_buffer: Cursor::new([0; SNIFF_SIZE]), sniffed: None, } } /// Get the result of MIME sniffing pub fn result_ref(&self) -> Option<&SniffResult> { self.sniffed.as_ref() } /// Get the result of MIME sniffing pub fn result(self) -> Option { self.sniffed } /// Get the remaining body stream pub fn into_remaining_body(self) -> ::BodyStream { self.body } /// Poll the stream for MIME sniffing, writing any data consumed to a buffer pub fn poll_sniff<'a>( mut self: Pin<&'a mut Self>, cx: &mut std::task::Context<'_>, mut notify: impl FnMut(&[u8]), ) -> Poll>> where ::BodyStream: Unpin, { #[allow(clippy::cast_possible_truncation)] let remaining_sniff_buffer = SNIFF_SIZE - self.sniff_buffer.position() as usize; if remaining_sniff_buffer > 0 { match self.body.poll_next_unpin(cx) { Poll::Ready(None) => { return Poll::Ready(None); } Poll::Ready(Some(Err(e))) => { return Poll::Ready(Some(Err(e))); } Poll::Pending => { return Poll::Pending; } Poll::Ready(Some(Ok(bytes))) => { notify(bytes.as_ref()); self.sniff_buffer.write_all(bytes.as_ref()).ok(); if self.sniff_buffer.position() == SNIFF_SIZE as u64 { let cands = MAGICS.iter().filter(|assoc| { assoc .signatures .iter() .any(|sig| sig.matches(self.sniff_buffer.get_ref())) }); let mut all_safe = true; let mut best_match = None; for cand in cands { match best_match { None => best_match = Some(cand), Some(assoc) if assoc.signatures.len() < cand.signatures.len() => { best_match = Some(cand); } _ => {} } if !cand.safe { all_safe = false; break; } } self.sniffed = Some(SniffResult { sniffed_mime: best_match.map(|assoc| assoc.mime), maybe_unsafe: !all_safe, }); self.sniff_buffer.set_position(0); return Poll::Ready(None); } return Poll::Ready(Some(Ok(bytes.as_ref().len()))); } } } Poll::Ready(None) } } impl Stream for SniffingStream where ::BodyStream: Unpin, { type Item = Result<::Bytes, ErrorResponse>; fn poll_next( mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { #[allow(clippy::cast_possible_truncation)] let remaining_sniff_buffer = SNIFF_SIZE - self.sniff_buffer.position() as usize; if self.sniffed.is_none() && remaining_sniff_buffer > 0 { match self.body.poll_next_unpin(cx) { Poll::Ready(None) => { return Poll::Ready(None); } Poll::Ready(Some(Err(e))) => { return Poll::Ready(Some(Err(e))); } Poll::Pending => { return Poll::Pending; } Poll::Ready(Some(Ok(bytes))) => { #[allow(clippy::unused_io_amount)] self.sniff_buffer.write(bytes.as_ref()).ok(); if self.sniff_buffer.position() == SNIFF_SIZE as u64 { let cands = MAGICS.iter().filter(|assoc| { assoc .signatures .iter() .any(|sig| sig.matches(self.sniff_buffer.get_ref())) }); let mut all_safe = true; let mut best_match = None; for cand in cands { if best_match .map_or(0, |assoc: &MIMEAssociation| assoc.signatures.len()) < cand.signatures.len() { best_match = Some(cand); } if !cand.safe { all_safe = false; break; } } self.sniffed = Some(SniffResult { sniffed_mime: best_match.map(|assoc| assoc.mime), maybe_unsafe: !all_safe, }); self.sniff_buffer.set_position(0); } return Poll::Ready(Some(Ok(bytes))); } } } self.body.poll_next_unpin(cx) } }