yumechi-no-kuni-proxy-worker/src/post_process/sniff.rs
eternal-flame-AD 146e317849
init
Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
2024-11-13 05:23:22 -06:00

221 lines
7.4 KiB
Rust

use std::{
io::{Cursor, Write},
pin::Pin,
task::Poll,
};
use futures::{Stream, StreamExt};
use crate::{fetch::HTTPResponse, ErrorResponse};
// MIME sniffing data
include!(concat!(env!("OUT_DIR"), "/magic.rs"));
/// An association between a MIME type and a file extension
#[derive(Clone)]
pub struct MIMEAssociation {
/// The MIME type
pub mime: &'static str,
/// The file extension
pub ext: &'static str,
/// Whether the file is safe to display
pub safe: bool,
/// The file signatures
signatures: &'static [FlattenedFileSignature],
}
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
struct FlattenedFileSignature(&'static [(u8, u8)]);
impl FlattenedFileSignature {
#[inline]
fn matches(&self, test: &[u8]) -> bool {
if self.0.len() > test.len() {
return false;
}
self.0
.iter()
.zip(test.iter())
.all(|((sig, mask), byte)| sig & mask == *byte & mask)
}
}
/// A stream that sniffs the MIME type of the data it receives
pub struct SniffingStream<R: HTTPResponse> {
body: <R as HTTPResponse>::BodyStream,
sniff_buffer: Cursor<[u8; SNIFF_SIZE]>,
sniffed: Option<SniffResult>,
}
/// The result of MIME sniffing
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
pub struct SniffResult {
/// The MIME type that was sniffed
pub sniffed_mime: Option<&'static str>,
/// Whether the file may be unsafe
pub maybe_unsafe: bool,
}
impl<R: HTTPResponse> SniffingStream<R> {
/// Create a new `SniffingStream` from a response
pub fn new(response: R) -> Self {
Self {
body: response.body(),
sniff_buffer: Cursor::new([0; SNIFF_SIZE]),
sniffed: None,
}
}
/// Create a new `SniffingStream` from a body stream
pub fn new_from_body_stream(response: <R as HTTPResponse>::BodyStream) -> Self {
Self {
body: response,
sniff_buffer: Cursor::new([0; SNIFF_SIZE]),
sniffed: None,
}
}
/// Get the result of MIME sniffing
pub fn result_ref(&self) -> Option<&SniffResult> {
self.sniffed.as_ref()
}
/// Get the result of MIME sniffing
pub fn result(self) -> Option<SniffResult> {
self.sniffed
}
/// Get the remaining body stream
pub fn into_remaining_body(self) -> <R as HTTPResponse>::BodyStream {
self.body
}
/// Poll the stream for MIME sniffing, writing any data consumed to a buffer
pub fn poll_sniff<'a>(
mut self: Pin<&'a mut Self>,
cx: &mut std::task::Context<'_>,
mut notify: impl FnMut(&[u8]),
) -> Poll<Option<Result<usize, ErrorResponse>>>
where
<R as HTTPResponse>::BodyStream: Unpin,
{
#[allow(clippy::cast_possible_truncation)]
let remaining_sniff_buffer = SNIFF_SIZE - self.sniff_buffer.position() as usize;
if remaining_sniff_buffer > 0 {
match self.body.poll_next_unpin(cx) {
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(Some(Ok(bytes))) => {
notify(bytes.as_ref());
self.sniff_buffer.write_all(bytes.as_ref()).ok();
if self.sniff_buffer.position() == SNIFF_SIZE as u64 {
let cands = MAGICS.iter().filter(|assoc| {
assoc
.signatures
.iter()
.any(|sig| sig.matches(self.sniff_buffer.get_ref()))
});
let mut all_safe = true;
let mut best_match = None;
for cand in cands {
match best_match {
None => best_match = Some(cand),
Some(assoc) if assoc.signatures.len() < cand.signatures.len() => {
best_match = Some(cand);
}
_ => {}
}
if !cand.safe {
all_safe = false;
break;
}
}
self.sniffed = Some(SniffResult {
sniffed_mime: best_match.map(|assoc| assoc.mime),
maybe_unsafe: !all_safe,
});
self.sniff_buffer.set_position(0);
return Poll::Ready(None);
}
return Poll::Ready(Some(Ok(bytes.as_ref().len())));
}
}
}
Poll::Ready(None)
}
}
impl<R: HTTPResponse> Stream for SniffingStream<R>
where
<R as HTTPResponse>::BodyStream: Unpin,
{
type Item = Result<<R as HTTPResponse>::Bytes, ErrorResponse>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
#[allow(clippy::cast_possible_truncation)]
let remaining_sniff_buffer = SNIFF_SIZE - self.sniff_buffer.position() as usize;
if self.sniffed.is_none() && remaining_sniff_buffer > 0 {
match self.body.poll_next_unpin(cx) {
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(Some(Ok(bytes))) => {
#[allow(clippy::unused_io_amount)]
self.sniff_buffer.write(bytes.as_ref()).ok();
if self.sniff_buffer.position() == SNIFF_SIZE as u64 {
let cands = MAGICS.iter().filter(|assoc| {
assoc
.signatures
.iter()
.any(|sig| sig.matches(self.sniff_buffer.get_ref()))
});
let mut all_safe = true;
let mut best_match = None;
for cand in cands {
if best_match
.map_or(0, |assoc: &MIMEAssociation| assoc.signatures.len())
< cand.signatures.len()
{
best_match = Some(cand);
}
if !cand.safe {
all_safe = false;
break;
}
}
self.sniffed = Some(SniffResult {
sniffed_mime: best_match.map(|assoc| assoc.mime),
maybe_unsafe: !all_safe,
});
self.sniff_buffer.set_position(0);
}
return Poll::Ready(Some(Ok(bytes)));
}
}
}
self.body.poll_next_unpin(cx)
}
}