221 lines
7.4 KiB
Rust
221 lines
7.4 KiB
Rust
use std::{
|
|
io::{Cursor, Write},
|
|
pin::Pin,
|
|
task::Poll,
|
|
};
|
|
|
|
use futures::{Stream, StreamExt};
|
|
|
|
use crate::{fetch::HTTPResponse, ErrorResponse};
|
|
|
|
// MIME sniffing data
|
|
include!(concat!(env!("OUT_DIR"), "/magic.rs"));
|
|
|
|
/// An association between a MIME type and a file extension
|
|
#[derive(Clone)]
|
|
pub struct MIMEAssociation {
|
|
/// The MIME type
|
|
pub mime: &'static str,
|
|
/// The file extension
|
|
pub ext: &'static str,
|
|
/// Whether the file is safe to display
|
|
pub safe: bool,
|
|
/// The file signatures
|
|
signatures: &'static [FlattenedFileSignature],
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
|
|
struct FlattenedFileSignature(&'static [(u8, u8)]);
|
|
|
|
impl FlattenedFileSignature {
|
|
#[inline]
|
|
fn matches(&self, test: &[u8]) -> bool {
|
|
if self.0.len() > test.len() {
|
|
return false;
|
|
}
|
|
self.0
|
|
.iter()
|
|
.zip(test.iter())
|
|
.all(|((sig, mask), byte)| sig & mask == *byte & mask)
|
|
}
|
|
}
|
|
|
|
/// A stream that sniffs the MIME type of the data it receives
|
|
pub struct SniffingStream<R: HTTPResponse> {
|
|
body: <R as HTTPResponse>::BodyStream,
|
|
sniff_buffer: Cursor<[u8; SNIFF_SIZE]>,
|
|
sniffed: Option<SniffResult>,
|
|
}
|
|
|
|
/// The result of MIME sniffing
|
|
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
|
|
pub struct SniffResult {
|
|
/// The MIME type that was sniffed
|
|
pub sniffed_mime: Option<&'static str>,
|
|
/// Whether the file may be unsafe
|
|
pub maybe_unsafe: bool,
|
|
}
|
|
|
|
impl<R: HTTPResponse> SniffingStream<R> {
|
|
/// Create a new `SniffingStream` from a response
|
|
pub fn new(response: R) -> Self {
|
|
Self {
|
|
body: response.body(),
|
|
sniff_buffer: Cursor::new([0; SNIFF_SIZE]),
|
|
sniffed: None,
|
|
}
|
|
}
|
|
/// Create a new `SniffingStream` from a body stream
|
|
pub fn new_from_body_stream(response: <R as HTTPResponse>::BodyStream) -> Self {
|
|
Self {
|
|
body: response,
|
|
sniff_buffer: Cursor::new([0; SNIFF_SIZE]),
|
|
sniffed: None,
|
|
}
|
|
}
|
|
/// Get the result of MIME sniffing
|
|
pub fn result_ref(&self) -> Option<&SniffResult> {
|
|
self.sniffed.as_ref()
|
|
}
|
|
/// Get the result of MIME sniffing
|
|
pub fn result(self) -> Option<SniffResult> {
|
|
self.sniffed
|
|
}
|
|
/// Get the remaining body stream
|
|
pub fn into_remaining_body(self) -> <R as HTTPResponse>::BodyStream {
|
|
self.body
|
|
}
|
|
/// Poll the stream for MIME sniffing, writing any data consumed to a buffer
|
|
pub fn poll_sniff<'a>(
|
|
mut self: Pin<&'a mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
mut notify: impl FnMut(&[u8]),
|
|
) -> Poll<Option<Result<usize, ErrorResponse>>>
|
|
where
|
|
<R as HTTPResponse>::BodyStream: Unpin,
|
|
{
|
|
#[allow(clippy::cast_possible_truncation)]
|
|
let remaining_sniff_buffer = SNIFF_SIZE - self.sniff_buffer.position() as usize;
|
|
|
|
if remaining_sniff_buffer > 0 {
|
|
match self.body.poll_next_unpin(cx) {
|
|
Poll::Ready(None) => {
|
|
return Poll::Ready(None);
|
|
}
|
|
Poll::Ready(Some(Err(e))) => {
|
|
return Poll::Ready(Some(Err(e)));
|
|
}
|
|
Poll::Pending => {
|
|
return Poll::Pending;
|
|
}
|
|
Poll::Ready(Some(Ok(bytes))) => {
|
|
notify(bytes.as_ref());
|
|
self.sniff_buffer.write_all(bytes.as_ref()).ok();
|
|
|
|
if self.sniff_buffer.position() == SNIFF_SIZE as u64 {
|
|
let cands = MAGICS.iter().filter(|assoc| {
|
|
assoc
|
|
.signatures
|
|
.iter()
|
|
.any(|sig| sig.matches(self.sniff_buffer.get_ref()))
|
|
});
|
|
let mut all_safe = true;
|
|
let mut best_match = None;
|
|
|
|
for cand in cands {
|
|
match best_match {
|
|
None => best_match = Some(cand),
|
|
Some(assoc) if assoc.signatures.len() < cand.signatures.len() => {
|
|
best_match = Some(cand);
|
|
}
|
|
_ => {}
|
|
}
|
|
if !cand.safe {
|
|
all_safe = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
self.sniffed = Some(SniffResult {
|
|
sniffed_mime: best_match.map(|assoc| assoc.mime),
|
|
maybe_unsafe: !all_safe,
|
|
});
|
|
self.sniff_buffer.set_position(0);
|
|
|
|
return Poll::Ready(None);
|
|
}
|
|
|
|
return Poll::Ready(Some(Ok(bytes.as_ref().len())));
|
|
}
|
|
}
|
|
}
|
|
|
|
Poll::Ready(None)
|
|
}
|
|
}
|
|
|
|
impl<R: HTTPResponse> Stream for SniffingStream<R>
|
|
where
|
|
<R as HTTPResponse>::BodyStream: Unpin,
|
|
{
|
|
type Item = Result<<R as HTTPResponse>::Bytes, ErrorResponse>;
|
|
|
|
fn poll_next(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> Poll<Option<Self::Item>> {
|
|
#[allow(clippy::cast_possible_truncation)]
|
|
let remaining_sniff_buffer = SNIFF_SIZE - self.sniff_buffer.position() as usize;
|
|
if self.sniffed.is_none() && remaining_sniff_buffer > 0 {
|
|
match self.body.poll_next_unpin(cx) {
|
|
Poll::Ready(None) => {
|
|
return Poll::Ready(None);
|
|
}
|
|
Poll::Ready(Some(Err(e))) => {
|
|
return Poll::Ready(Some(Err(e)));
|
|
}
|
|
Poll::Pending => {
|
|
return Poll::Pending;
|
|
}
|
|
Poll::Ready(Some(Ok(bytes))) => {
|
|
#[allow(clippy::unused_io_amount)]
|
|
self.sniff_buffer.write(bytes.as_ref()).ok();
|
|
|
|
if self.sniff_buffer.position() == SNIFF_SIZE as u64 {
|
|
let cands = MAGICS.iter().filter(|assoc| {
|
|
assoc
|
|
.signatures
|
|
.iter()
|
|
.any(|sig| sig.matches(self.sniff_buffer.get_ref()))
|
|
});
|
|
let mut all_safe = true;
|
|
let mut best_match = None;
|
|
|
|
for cand in cands {
|
|
if best_match
|
|
.map_or(0, |assoc: &MIMEAssociation| assoc.signatures.len())
|
|
< cand.signatures.len()
|
|
{
|
|
best_match = Some(cand);
|
|
}
|
|
if !cand.safe {
|
|
all_safe = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
self.sniffed = Some(SniffResult {
|
|
sniffed_mime: best_match.map(|assoc| assoc.mime),
|
|
maybe_unsafe: !all_safe,
|
|
});
|
|
self.sniff_buffer.set_position(0);
|
|
}
|
|
|
|
return Poll::Ready(Some(Ok(bytes)));
|
|
}
|
|
}
|
|
}
|
|
|
|
self.body.poll_next_unpin(cx)
|
|
}
|
|
}
|