644 lines
22 KiB
Rust
644 lines
22 KiB
Rust
use crate::{ErrorResponse, FetchConfig, MAX_SIZE};
|
|
use axum::{
|
|
body::Bytes,
|
|
extract::FromRequestParts,
|
|
http::{request::Parts, HeaderMap},
|
|
};
|
|
use std::{borrow::Cow, collections::HashSet, convert::Infallible, pin::Pin};
|
|
|
|
/// Default maximum number of redirects to follow
|
|
pub const DEFAULT_MAX_REDIRECTS: usize = 6;
|
|
|
|
/// Some context about the request for writing response and logging
|
|
#[allow(missing_docs)]
|
|
pub struct RequestCtx<'a> {
|
|
pub time_to_body: std::time::Duration,
|
|
pub url: &'a str,
|
|
pub secure: bool,
|
|
}
|
|
|
|
const fn http_version_to_via(v: axum::http::Version) -> &'static str {
|
|
#[allow(clippy::match_same_arms)]
|
|
match v {
|
|
axum::http::Version::HTTP_09 => "0.9",
|
|
axum::http::Version::HTTP_10 => "1.0",
|
|
axum::http::Version::HTTP_11 => "1.1",
|
|
axum::http::Version::HTTP_2 => "2.0",
|
|
axum::http::Version::HTTP_3 => "3.0",
|
|
_ => "1.1",
|
|
}
|
|
}
|
|
|
|
/// Trait for HTTP responses
|
|
pub trait HTTPResponse {
|
|
/// Type of the byte buffer
|
|
type Bytes: Into<Vec<u8>> + AsRef<[u8]> + Into<Bytes> + Send + 'static;
|
|
/// Type of body stream
|
|
type BodyStream: futures::Stream<Item = Result<Self::Bytes, ErrorResponse>> + Send + 'static;
|
|
|
|
/// Get some context about the request
|
|
fn request(&self) -> RequestCtx<'_>;
|
|
/// Get the status code
|
|
fn status(&self) -> u16;
|
|
/// Get a header value
|
|
fn header_one<'a>(&'a self, name: &str) -> Result<Option<Cow<'a, str>>, ErrorResponse>;
|
|
/// Walk through all headers with a callback
|
|
fn header_walk<F: FnMut(&str, &str) -> bool>(&self, f: F);
|
|
/// Collect all headers
|
|
fn header_collect(&self, out: &mut HeaderMap) -> Result<(), ErrorResponse>;
|
|
/// Get the body stream
|
|
fn body(self) -> Self::BodyStream;
|
|
}
|
|
|
|
/// Information about the incoming request
|
|
pub struct IncomingInfo {
|
|
version: axum::http::Version,
|
|
user_agent: String,
|
|
via: String,
|
|
}
|
|
|
|
impl IncomingInfo {
|
|
/// Check if the request is potentially looping
|
|
#[must_use]
|
|
pub fn looping(&self, self_via: &str) -> bool {
|
|
if self.user_agent.is_empty() {
|
|
return true;
|
|
}
|
|
|
|
if self.via.contains(self_via) {
|
|
return true;
|
|
}
|
|
|
|
// defense against upstream
|
|
if self.user_agent.contains("Misskey/") ||
|
|
// Purposefully typoed
|
|
// https://raw.githubusercontent.com/backrunner/misskey-media-proxy-worker/refs/heads/main/wrangler.toml
|
|
self.user_agent.contains("Edg/119.0.2109.1")
|
|
{
|
|
return true;
|
|
}
|
|
|
|
let split = self.via.split(", ");
|
|
|
|
let mut seen = HashSet::new();
|
|
for part in split {
|
|
if !seen.insert(part) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
false
|
|
}
|
|
}
|
|
|
|
#[axum::async_trait]
|
|
impl<S> FromRequestParts<S> for IncomingInfo {
|
|
type Rejection = Infallible;
|
|
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
|
Ok(Self {
|
|
version: parts.version,
|
|
user_agent: parts
|
|
.headers
|
|
.get("user-agent")
|
|
.and_then(|v| v.to_str().ok())
|
|
.unwrap_or_default()
|
|
.to_string(),
|
|
via: parts
|
|
.headers
|
|
.get_all("via")
|
|
.into_iter()
|
|
.fold(String::new(), |mut acc, v| {
|
|
if !acc.is_empty() {
|
|
acc.push_str(", ");
|
|
}
|
|
acc.push_str(v.to_str().unwrap_or_default());
|
|
acc
|
|
}),
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Trait for upstream clients
|
|
pub trait UpstreamClient {
|
|
/// Type of the response
|
|
type Response: HTTPResponse;
|
|
/// Create a new client
|
|
fn new(config: &FetchConfig) -> Self;
|
|
/// Request the upstream
|
|
fn request_upstream(
|
|
&self,
|
|
info: &IncomingInfo,
|
|
url: &str,
|
|
polish: bool,
|
|
secure: bool,
|
|
remaining: usize,
|
|
) -> impl std::future::Future<Output = Result<Self::Response, ErrorResponse>>;
|
|
}
|
|
|
|
/// Reqwest client
|
|
#[cfg(feature = "reqwest")]
|
|
pub mod reqwest {
|
|
use crate::config::AddrFamilyConfig;
|
|
|
|
use super::{
|
|
http_version_to_via, Cow, ErrorResponse, HTTPResponse, HeaderMap, Pin, RequestCtx,
|
|
UpstreamClient, MAX_SIZE,
|
|
};
|
|
use ::reqwest::{redirect::Policy, ClientBuilder, Url};
|
|
use axum::body::Bytes;
|
|
use futures::TryStreamExt;
|
|
use reqwest::dns::Resolve;
|
|
use std::{net::SocketAddrV4, sync::Arc, time::Duration};
|
|
use url::Host;
|
|
|
|
/// A Safe DNS resolver that only resolves to global addresses unless the requester itself is local.
|
|
pub struct SafeResolver(AddrFamilyConfig);
|
|
|
|
// pulled from https://doc.rust-lang.org/src/core/net/ip_addr.rs.html#1650
|
|
const fn is_unicast_local_v6(ip: &std::net::Ipv6Addr) -> bool {
|
|
(ip.segments()[0] & 0xfe00) == 0xfc00
|
|
}
|
|
|
|
const fn is_unicast_link_local_v6(ip: &std::net::Ipv6Addr) -> bool {
|
|
(ip.segments()[0] & 0xffc0) == 0xfe80
|
|
}
|
|
|
|
impl Resolve for SafeResolver {
|
|
fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
|
|
let af = self.0;
|
|
Box::pin(async move {
|
|
log::trace!("Resolving {}", name.as_str());
|
|
match tokio::net::lookup_host(format!("{}:443", name.as_str())).await {
|
|
Ok(lookup) => Ok(Box::new(
|
|
lookup
|
|
.map(|addr| match addr {
|
|
std::net::SocketAddr::V6(a) => {
|
|
if let Some(v4) = a.ip().to_ipv4() {
|
|
std::net::SocketAddr::V4(SocketAddrV4::new(v4, a.port()))
|
|
} else {
|
|
std::net::SocketAddr::V6(a)
|
|
}
|
|
}
|
|
std::net::SocketAddr::V4(a) => std::net::SocketAddr::V4(a),
|
|
})
|
|
.filter(move |addr| match addr {
|
|
std::net::SocketAddr::V4(a) if af != AddrFamilyConfig::V6Only => {
|
|
log::trace!("Resolved v4 addr {}", a);
|
|
!a.ip().is_loopback()
|
|
&& !a.ip().is_private()
|
|
&& !a.ip().is_link_local()
|
|
&& !a.ip().is_multicast()
|
|
&& !a.ip().is_documentation()
|
|
&& !a.ip().is_unspecified()
|
|
}
|
|
|
|
std::net::SocketAddr::V6(a) if af != AddrFamilyConfig::V4Only => {
|
|
log::trace!("Resolved v6 addr {}", a);
|
|
!a.ip().is_loopback()
|
|
&& !a.ip().is_multicast()
|
|
&& !a.ip().is_unspecified()
|
|
&& !is_unicast_local_v6(a.ip())
|
|
&& !is_unicast_link_local_v6(a.ip())
|
|
}
|
|
|
|
_ => false,
|
|
}),
|
|
)
|
|
as Box<dyn Iterator<Item = std::net::SocketAddr> + Send>),
|
|
Err(e) => {
|
|
log::error!("Failed to resolve {}: {}", name.as_str(), e);
|
|
Err(e.into())
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Reqwest client
|
|
pub struct ReqwestClient {
|
|
https_only: bool,
|
|
via_ident: String,
|
|
client: ::reqwest::Client,
|
|
}
|
|
|
|
/// Response from Reqwest
|
|
pub struct ReqwestResponse {
|
|
time_to_body: std::time::Duration,
|
|
resp: ::reqwest::Response,
|
|
}
|
|
|
|
impl HTTPResponse for ReqwestResponse {
|
|
type Bytes = Bytes;
|
|
type BodyStream = Pin<
|
|
Box<dyn futures::Stream<Item = Result<Self::Bytes, ErrorResponse>> + Send + 'static>,
|
|
>;
|
|
|
|
fn request(&self) -> RequestCtx<'_> {
|
|
RequestCtx {
|
|
time_to_body: self.time_to_body,
|
|
url: self.resp.url().as_str(),
|
|
secure: self.resp.url().scheme().eq_ignore_ascii_case("https"),
|
|
}
|
|
}
|
|
|
|
fn status(&self) -> u16 {
|
|
self.resp.status().as_u16()
|
|
}
|
|
|
|
fn header_one<'a>(&'a self, name: &str) -> Result<Option<Cow<'a, str>>, ErrorResponse> {
|
|
self.resp
|
|
.headers()
|
|
.get(name)
|
|
.map(|v| v.to_str().map(Cow::Borrowed))
|
|
.transpose()
|
|
.map_err(|_| ErrorResponse::upstream_protocol_error())
|
|
}
|
|
|
|
fn header_walk<F: FnMut(&str, &str) -> bool>(&self, mut f: F) {
|
|
for (name, value) in self.resp.headers() {
|
|
if !f(name.as_str(), value.to_str().unwrap_or_default()) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
fn header_collect(&self, out: &mut HeaderMap) -> Result<(), ErrorResponse> {
|
|
for (name, value) in self.resp.headers() {
|
|
out.insert(name, value.clone());
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn body(self) -> Self::BodyStream {
|
|
Box::pin(self.resp.bytes_stream().map_err(Into::into))
|
|
}
|
|
}
|
|
|
|
impl UpstreamClient for ReqwestClient {
|
|
type Response = ReqwestResponse;
|
|
fn new(config: &crate::FetchConfig) -> Self {
|
|
Self {
|
|
https_only: !config.allow_http,
|
|
via_ident: config.via.clone(),
|
|
client: ClientBuilder::new()
|
|
.https_only(!config.allow_http)
|
|
.dns_resolver(Arc::new(SafeResolver(config.addr_family)))
|
|
.brotli(true)
|
|
.zstd(true)
|
|
.gzip(true)
|
|
.redirect(Policy::none())
|
|
.connect_timeout(Duration::from_secs(5))
|
|
.timeout(Duration::from_secs(15))
|
|
.user_agent(config.user_agent.clone())
|
|
.build()
|
|
.expect("Failed to create reqwest client"),
|
|
}
|
|
}
|
|
async fn request_upstream(
|
|
&self,
|
|
info: &super::IncomingInfo,
|
|
url: &str,
|
|
polish: bool,
|
|
mut secure: bool,
|
|
remaining: usize,
|
|
) -> Result<ReqwestResponse, ErrorResponse> {
|
|
if remaining == 0 {
|
|
return Err(ErrorResponse::too_many_redirects());
|
|
}
|
|
|
|
if info.looping(&self.via_ident) {
|
|
return Err(ErrorResponse::loop_detected());
|
|
}
|
|
|
|
let url_parsed = Url::parse(url).map_err(|_| ErrorResponse::bad_url())?;
|
|
if !url_parsed.host().map_or(false, |h| match h {
|
|
Host::Domain(_) => true,
|
|
_ => false,
|
|
}) {
|
|
return Err(ErrorResponse::non_dns_name());
|
|
}
|
|
|
|
secure &= url_parsed.scheme().eq_ignore_ascii_case("https");
|
|
if self.https_only && !secure {
|
|
return Err(ErrorResponse::insecure_request());
|
|
}
|
|
|
|
let begin = crate::timing::Instant::now();
|
|
|
|
let resp = self
|
|
.client
|
|
.get(url_parsed)
|
|
.header(
|
|
"via",
|
|
format!(
|
|
"{}, {} {}",
|
|
info.via,
|
|
http_version_to_via(info.version),
|
|
self.via_ident
|
|
),
|
|
)
|
|
.send()
|
|
.await?;
|
|
|
|
if resp.status().is_redirection() {
|
|
if let Some(location) = resp.headers().get("location").and_then(|l| l.to_str().ok())
|
|
{
|
|
return Box::pin(self.request_upstream(
|
|
info,
|
|
location,
|
|
polish,
|
|
secure,
|
|
remaining - 1,
|
|
))
|
|
.await;
|
|
}
|
|
return Err(ErrorResponse::missing_location());
|
|
}
|
|
|
|
if !resp.status().is_success() {
|
|
return Err(ErrorResponse::unexpected_status(
|
|
url,
|
|
resp.status().as_u16(),
|
|
));
|
|
}
|
|
|
|
let content_length = resp.headers().get("content-length");
|
|
if let Some(content_length) = content_length.and_then(|c| c.to_str().ok()) {
|
|
if content_length.parse::<usize>().unwrap_or(0) > MAX_SIZE {
|
|
return Err(ErrorResponse::payload_too_large());
|
|
}
|
|
}
|
|
|
|
let content_type = resp.headers().get("content-type");
|
|
if let Some(content_type) = content_type.and_then(|c| c.to_str().ok()) {
|
|
if !["image/", "video/", "audio/", "application/octet-stream"]
|
|
.iter()
|
|
.any(|prefix| {
|
|
content_type[..prefix.len().min(content_type.len())]
|
|
.eq_ignore_ascii_case(prefix)
|
|
})
|
|
{
|
|
return Err(ErrorResponse::not_media());
|
|
}
|
|
}
|
|
|
|
Ok(ReqwestResponse {
|
|
time_to_body: begin.elapsed(),
|
|
resp,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Cloudflare Workers client
|
|
#[cfg(feature = "cf-worker")]
|
|
#[cfg_attr(
|
|
not(target_arch = "wasm32"),
|
|
deprecated = "You should use reqwest instead when not on Cloudflare Workers"
|
|
)]
|
|
pub mod cf_worker {
|
|
use std::time::Duration;
|
|
|
|
use super::{
|
|
http_version_to_via, Cow, ErrorResponse, HTTPResponse, HeaderMap, Pin, RequestCtx,
|
|
UpstreamClient, MAX_SIZE,
|
|
};
|
|
use axum::http::{HeaderName, HeaderValue};
|
|
use futures::{Stream, TryFutureExt};
|
|
use worker::{
|
|
AbortController, ByteStream, CfProperties, Fetch, Headers, Method, PolishConfig, Request,
|
|
RequestInit, RequestRedirect, Url,
|
|
};
|
|
|
|
/// Cloudflare Workers client
|
|
pub struct CfWorkerClient {
|
|
https_only: bool,
|
|
user_agent: String,
|
|
via_ident: String,
|
|
}
|
|
|
|
/// Wrapper for the body stream
|
|
pub struct CfBodyStreamWrapper {
|
|
stream: Option<Result<ByteStream, ErrorResponse>>,
|
|
}
|
|
|
|
impl Stream for CfBodyStreamWrapper {
|
|
type Item = Result<Vec<u8>, ErrorResponse>;
|
|
|
|
fn poll_next(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<Option<Self::Item>> {
|
|
let this = self.get_mut();
|
|
match this.stream.as_mut() {
|
|
Some(Ok(stream)) => match futures::ready!(std::pin::pin!(stream).poll_next(cx)) {
|
|
Some(Ok(chunk)) => std::task::Poll::Ready(Some(Ok(chunk))),
|
|
Some(Err(e)) => std::task::Poll::Ready(Some(Err(ErrorResponse::from(e)))),
|
|
None => std::task::Poll::Ready(None),
|
|
},
|
|
Some(Err(e)) => std::task::Poll::Ready(Some(Err(e.clone()))),
|
|
None => std::task::Poll::Ready(None),
|
|
}
|
|
}
|
|
|
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
|
match self.stream {
|
|
Some(Ok(ref stream)) => stream.size_hint(),
|
|
_ => (0, None),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[allow(unsafe_code, reason = "this is never used concurrently")]
|
|
unsafe impl Send for CfBodyStreamWrapper {}
|
|
#[allow(unsafe_code, reason = "this is never used concurrently")]
|
|
unsafe impl Sync for CfBodyStreamWrapper {}
|
|
|
|
/// Response from Cloudflare Workers
|
|
pub struct CfWorkerResponse {
|
|
time_to_body: std::time::Duration,
|
|
resp: worker::Response,
|
|
url: Url,
|
|
}
|
|
|
|
impl HTTPResponse for CfWorkerResponse {
|
|
type Bytes = Vec<u8>;
|
|
type BodyStream = CfBodyStreamWrapper;
|
|
|
|
fn request(&self) -> RequestCtx<'_> {
|
|
RequestCtx {
|
|
time_to_body: self.time_to_body,
|
|
url: self.url.as_str(),
|
|
secure: self.url.scheme().eq_ignore_ascii_case("https"),
|
|
}
|
|
}
|
|
|
|
fn status(&self) -> u16 {
|
|
self.resp.status_code()
|
|
}
|
|
|
|
fn header_one<'a>(&'a self, name: &str) -> Result<Option<Cow<'a, str>>, ErrorResponse> {
|
|
self.resp
|
|
.headers()
|
|
.get(name)
|
|
.map(|v| v.map(|v| Cow::Owned(v.to_string())))
|
|
.map_err(|_| ErrorResponse::upstream_protocol_error())
|
|
}
|
|
|
|
fn header_walk<F: FnMut(&str, &str) -> bool>(&self, mut f: F) {
|
|
for (name, value) in self.resp.headers().entries() {
|
|
if !f(&name, &value) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
fn header_collect(&self, out: &mut HeaderMap) -> Result<(), ErrorResponse> {
|
|
for name in self.resp.headers().keys() {
|
|
out.insert(
|
|
HeaderName::from_bytes(name.as_bytes())
|
|
.map_err(|_| ErrorResponse::upstream_protocol_error())?,
|
|
self.resp
|
|
.headers()
|
|
.get(&name)
|
|
.map_err(|_| ErrorResponse::upstream_protocol_error())?
|
|
.map(HeaderValue::try_from)
|
|
.transpose()
|
|
.map_err(|_| ErrorResponse::upstream_protocol_error())?
|
|
.ok_or(ErrorResponse::upstream_protocol_error())?,
|
|
);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn body(mut self) -> Self::BodyStream {
|
|
let stream = self.resp.stream().map_err(ErrorResponse::from);
|
|
CfBodyStreamWrapper {
|
|
stream: Some(stream),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl UpstreamClient for CfWorkerClient {
|
|
type Response = CfWorkerResponse;
|
|
|
|
fn new(config: &crate::FetchConfig) -> Self {
|
|
Self {
|
|
https_only: !config.allow_http,
|
|
user_agent: config.user_agent.clone(),
|
|
via_ident: config.via.clone(),
|
|
}
|
|
}
|
|
|
|
async fn request_upstream(
|
|
&self,
|
|
info: &super::IncomingInfo,
|
|
url: &str,
|
|
polish: bool,
|
|
mut secure: bool,
|
|
remaining: usize,
|
|
) -> Result<Self::Response, ErrorResponse> {
|
|
if remaining == 0 {
|
|
return Err(ErrorResponse::too_many_redirects());
|
|
}
|
|
|
|
if info.looping(&self.via_ident) {
|
|
return Err(ErrorResponse::loop_detected());
|
|
}
|
|
|
|
let mut headers = Headers::new();
|
|
|
|
headers.set("user-agent", &self.user_agent)?;
|
|
headers.set(
|
|
"via",
|
|
&format!(
|
|
"{}, {} {}",
|
|
info.via,
|
|
http_version_to_via(info.version),
|
|
self.via_ident
|
|
),
|
|
)?;
|
|
|
|
let mut prop = CfProperties::new();
|
|
if polish {
|
|
prop.polish = Some(PolishConfig::Lossless);
|
|
}
|
|
let mut init = RequestInit::new();
|
|
let init = init
|
|
.with_method(Method::Get)
|
|
.with_headers(headers)
|
|
.with_cf_properties(prop)
|
|
.with_redirect(RequestRedirect::Manual);
|
|
|
|
let url_parsed = Url::parse(url).map_err(|_| ErrorResponse::bad_url())?;
|
|
|
|
if self.https_only && !url_parsed.scheme().eq_ignore_ascii_case("https") {
|
|
return Err(ErrorResponse::insecure_request());
|
|
}
|
|
|
|
secure &= url_parsed.scheme().eq_ignore_ascii_case("http");
|
|
|
|
let begin = crate::timing::Instant::now();
|
|
|
|
let req = Request::new_with_init(url, init)?;
|
|
|
|
let abc = AbortController::default();
|
|
let abs = abc.signal();
|
|
let req = Fetch::Request(req);
|
|
|
|
worker::wasm_bindgen_futures::spawn_local(async move {
|
|
worker::Delay::from(Duration::from_secs(5)).await;
|
|
abc.abort();
|
|
});
|
|
|
|
let resp = std::pin::pin!(req
|
|
.send_with_signal(&abs)
|
|
.map_err(ErrorResponse::worker_fetch_error))
|
|
.await?;
|
|
|
|
if resp.status_code() >= 300 && resp.status_code() < 400 {
|
|
if let Ok(Some(location)) = resp.headers().get("location") {
|
|
return Box::pin(self.request_upstream(
|
|
info,
|
|
&location,
|
|
polish,
|
|
secure,
|
|
remaining - 1,
|
|
))
|
|
.await;
|
|
}
|
|
return Err(ErrorResponse::missing_location());
|
|
}
|
|
|
|
if resp.status_code() < 200 || resp.status_code() >= 300 {
|
|
return Err(ErrorResponse::unexpected_status(url, resp.status_code()));
|
|
}
|
|
|
|
let content_length = resp.headers().get("content-length").unwrap_or_default();
|
|
if let Some(content_length) = content_length {
|
|
if content_length.parse::<usize>().unwrap_or(0) > MAX_SIZE {
|
|
return Err(ErrorResponse::payload_too_large());
|
|
}
|
|
}
|
|
|
|
let content_type = resp.headers().get("content-type").unwrap_or_default();
|
|
if let Some(content_type) = content_type {
|
|
if !["image/", "video/", "audio/", "application/octet-stream"]
|
|
.iter()
|
|
.any(|prefix| {
|
|
content_type.as_str()[..prefix.len().min(content_type.len())]
|
|
.eq_ignore_ascii_case(prefix)
|
|
})
|
|
{
|
|
return Err(ErrorResponse::not_media());
|
|
}
|
|
}
|
|
|
|
Ok(CfWorkerResponse {
|
|
time_to_body: begin.elapsed(),
|
|
resp,
|
|
url: url_parsed,
|
|
})
|
|
}
|
|
}
|
|
}
|