some address family customization on local env

Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-11-13 15:39:35 -06:00
parent bb3af85275
commit ecc17a714f
No known key found for this signature in database
5 changed files with 83 additions and 41 deletions

View file

@ -77,6 +77,8 @@ impl Default for Config {
listen: Some("127.0.0.1:3000".to_string()),
enable_cache: false,
fetch: FetchConfig {
#[cfg(not(feature = "cf-worker"))]
addr_family: AddrFamilyConfig::Both,
allow_http: false,
via: concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")).to_string(),
user_agent: concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"))
@ -104,6 +106,11 @@ impl Default for Config {
/// Fetch configuration
#[derive(Debug, Clone, serde::Deserialize)]
pub struct FetchConfig {
/// The address family to use
#[cfg(not(feature = "cf-worker"))]
#[serde(default)]
pub addr_family: AddrFamilyConfig,
/// Whether to allow HTTP requests
pub allow_http: bool,
/// The via string to use when fetching media
@ -112,6 +119,26 @@ pub struct FetchConfig {
pub user_agent: String,
}
/// Address family configuration
#[cfg(not(feature = "cf-worker"))]
#[derive(Debug, Clone, Copy, serde::Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum AddrFamilyConfig {
/// Prefer IPv4
V4Only,
/// Prefer IPv6
V6Only,
/// Use both IPv4 and IPv6
Both,
}
#[cfg(not(feature = "cf-worker"))]
impl Default for AddrFamilyConfig {
fn default() -> Self {
Self::Both
}
}
/// Post-processing configuration
#[derive(Debug, Clone, serde::Deserialize)]
pub struct PostProcessConfig {

View file

@ -136,6 +136,8 @@ pub trait UpstreamClient {
/// Reqwest client
#[cfg(feature = "reqwest")]
pub mod reqwest {
use crate::AddrFamilyConfig;
use super::{
http_version_to_via, Cow, ErrorResponse, HTTPResponse, HeaderMap, Pin, RequestCtx,
UpstreamClient, MAX_SIZE,
@ -144,26 +146,41 @@ pub mod reqwest {
use axum::body::Bytes;
use futures::TryStreamExt;
use reqwest::dns::Resolve;
use std::{sync::Arc, time::Duration};
use std::{net::SocketAddrV4, sync::Arc, time::Duration};
/// A Safe DNS resolver that only resolves to global addresses unless the requester itself is local.
pub struct SafeResolver();
pub struct SafeResolver(AddrFamilyConfig);
// pulled from https://doc.rust-lang.org/src/core/net/ip_addr.rs.html#1650
const fn is_unicast_local_v6(ip: &std::net::Ipv6Addr) -> bool {
ip.segments()[0] & 0xfe00 == 0xfc00
(ip.segments()[0] & 0xfe00) == 0xfc00
}
const fn is_unicast_link_local_v6(ip: &std::net::Ipv6Addr) -> bool {
ip.segments()[0] & 0xffc0 == 0xfe80
(ip.segments()[0] & 0xffc0) == 0xfe80
}
impl Resolve for SafeResolver {
fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
let af = self.0;
Box::pin(async move {
match tokio::net::lookup_host(format!("{}:80", name.as_str())).await {
Ok(lookup) => Ok(Box::new(lookup.filter(|addr| match addr {
std::net::SocketAddr::V4(a) => {
log::trace!("Resolving {}", name.as_str());
match tokio::net::lookup_host(format!("{}:443", name.as_str())).await {
Ok(lookup) => Ok(Box::new(
lookup
.map(|addr| match addr {
std::net::SocketAddr::V6(a) => {
if let Some(v4) = a.ip().to_ipv4() {
std::net::SocketAddr::V4(SocketAddrV4::new(v4, a.port()))
} else {
std::net::SocketAddr::V6(a)
}
}
o => o,
})
.filter(move |addr| match addr {
std::net::SocketAddr::V4(a) if af != AddrFamilyConfig::V6Only => {
log::trace!("Resolved v4 addr {}", a);
!a.ip().is_loopback()
&& !a.ip().is_private()
&& !a.ip().is_link_local()
@ -172,15 +189,18 @@ pub mod reqwest {
&& !a.ip().is_unspecified()
}
std::net::SocketAddr::V6(a) => {
std::net::SocketAddr::V6(a) if af != AddrFamilyConfig::V4Only => {
log::trace!("Resolved v6 addr {}", a);
!a.ip().is_loopback()
&& !a.ip().is_multicast()
&& !a.ip().is_unspecified()
&& is_unicast_local_v6(a.ip())
&& !is_unicast_local_v6(a.ip())
&& !is_unicast_link_local_v6(a.ip())
&& a.ip().to_ipv4_mapped().is_none()
}
}))
_ => false,
}),
)
as Box<dyn Iterator<Item = std::net::SocketAddr> + Send>),
Err(e) => {
log::error!("Failed to resolve {}: {}", name.as_str(), e);
@ -259,7 +279,7 @@ pub mod reqwest {
via_ident: config.via.clone(),
client: ClientBuilder::new()
.https_only(!config.allow_http)
.dns_resolver(Arc::new(SafeResolver()))
.dns_resolver(Arc::new(SafeResolver(config.addr_family)))
.brotli(true)
.zstd(true)
.gzip(true)

View file

@ -606,6 +606,8 @@ impl<C: UpstreamClient + 'static, S: Sandboxing + 'static> App<C, S> {
}
}
log::info!("Proxying {}, options: {:?}", query.url, options);
let resp = state
.client
.request_upstream(&info, &query.url, false, true, DEFAULT_MAX_REDIRECTS)

View file

@ -11,6 +11,9 @@ struct Cli {
#[tokio::main]
async fn main() {
if std::env::var("RUST_LOG").is_err() {
std::env::set_var("RUST_LOG", "info");
}
env_logger::init();
let cli = Cli::parse();

View file

@ -238,6 +238,8 @@ where
buf.extend_from_slice(bytes.as_ref());
}
let slurp_dur = slurp_begin.elapsed();
let output_static_format = if options
.format
.as_deref()
@ -271,10 +273,7 @@ where
is_https,
})
.with_timing_info(TIME_TO_FIRST_BYTE_KEY, ttfb)
.with_opt_timing_info(
SLURP_TIMING_KEY,
Some(slurp_begin.elapsed()),
)
.with_opt_timing_info(SLURP_TIMING_KEY, Some(slurp_dur))
.with_timing_info(TIMING_KEY, begin.elapsed()))
}
None => Ok(MediaResponse::Buffer {
@ -282,10 +281,7 @@ where
content_type: Some("image/png".into()),
}
.with_timing_info(TIME_TO_FIRST_BYTE_KEY, ttfb)
.with_opt_timing_info(
SLURP_TIMING_KEY,
Some(slurp_begin.elapsed()),
)
.with_opt_timing_info(SLURP_TIMING_KEY, Some(slurp_dur))
.with_timing_info(TIMING_KEY, begin.elapsed())),
};
}
@ -305,10 +301,7 @@ where
is_https,
})
.with_timing_info(TIME_TO_FIRST_BYTE_KEY, ttfb)
.with_opt_timing_info(
SLURP_TIMING_KEY,
Some(slurp_begin.elapsed()),
)
.with_opt_timing_info(SLURP_TIMING_KEY, Some(slurp_dur))
.with_timing_info(TIMING_KEY, begin.elapsed()))
}
None => Ok(MediaResponse::Buffer {
@ -316,10 +309,7 @@ where
content_type: Some("image/webp".into()),
}
.with_timing_info(TIME_TO_FIRST_BYTE_KEY, ttfb)
.with_opt_timing_info(
SLURP_TIMING_KEY,
Some(slurp_begin.elapsed()),
)
.with_opt_timing_info(SLURP_TIMING_KEY, Some(slurp_dur))
.with_timing_info(TIMING_KEY, begin.elapsed())),
};
}
@ -336,7 +326,7 @@ where
is_https,
})
.with_timing_info(TIME_TO_FIRST_BYTE_KEY, ttfb)
.with_opt_timing_info(SLURP_TIMING_KEY, Some(slurp_begin.elapsed()))
.with_opt_timing_info(SLURP_TIMING_KEY, Some(slurp_dur))
.with_timing_info(TIMING_KEY, begin.elapsed()))
} else {
Ok(MediaResponse::PassThru(PassThru {