From ecc17a714fd55cde273bb27d9fb52cd5223efc19 Mon Sep 17 00:00:00 2001 From: eternal-flame-AD Date: Wed, 13 Nov 2024 15:39:35 -0600 Subject: [PATCH] some address family customization on local env Signed-off-by: eternal-flame-AD --- src/config.rs | 27 ++++++++++++++++ src/fetch/mod.rs | 68 ++++++++++++++++++++++++++--------------- src/lib.rs | 2 ++ src/main.rs | 3 ++ src/post_process/mod.rs | 24 +++++---------- 5 files changed, 83 insertions(+), 41 deletions(-) diff --git a/src/config.rs b/src/config.rs index fe12917..ff499a4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -77,6 +77,8 @@ impl Default for Config { listen: Some("127.0.0.1:3000".to_string()), enable_cache: false, fetch: FetchConfig { + #[cfg(not(feature = "cf-worker"))] + addr_family: AddrFamilyConfig::Both, allow_http: false, via: concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")).to_string(), user_agent: concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")) @@ -104,6 +106,11 @@ impl Default for Config { /// Fetch configuration #[derive(Debug, Clone, serde::Deserialize)] pub struct FetchConfig { + /// The address family to use + #[cfg(not(feature = "cf-worker"))] + #[serde(default)] + pub addr_family: AddrFamilyConfig, + /// Whether to allow HTTP requests pub allow_http: bool, /// The via string to use when fetching media @@ -112,6 +119,26 @@ pub struct FetchConfig { pub user_agent: String, } +/// Address family configuration +#[cfg(not(feature = "cf-worker"))] +#[derive(Debug, Clone, Copy, serde::Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum AddrFamilyConfig { + /// Prefer IPv4 + V4Only, + /// Prefer IPv6 + V6Only, + /// Use both IPv4 and IPv6 + Both, +} + +#[cfg(not(feature = "cf-worker"))] +impl Default for AddrFamilyConfig { + fn default() -> Self { + Self::Both + } +} + /// Post-processing configuration #[derive(Debug, Clone, serde::Deserialize)] pub struct PostProcessConfig { diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 1bf1994..d2784ae 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -136,6 +136,8 @@ pub trait UpstreamClient { /// Reqwest client #[cfg(feature = "reqwest")] pub mod reqwest { + use crate::AddrFamilyConfig; + use super::{ http_version_to_via, Cow, ErrorResponse, HTTPResponse, HeaderMap, Pin, RequestCtx, UpstreamClient, MAX_SIZE, @@ -144,43 +146,61 @@ pub mod reqwest { use axum::body::Bytes; use futures::TryStreamExt; use reqwest::dns::Resolve; - use std::{sync::Arc, time::Duration}; + use std::{net::SocketAddrV4, sync::Arc, time::Duration}; /// A Safe DNS resolver that only resolves to global addresses unless the requester itself is local. - pub struct SafeResolver(); + pub struct SafeResolver(AddrFamilyConfig); // pulled from https://doc.rust-lang.org/src/core/net/ip_addr.rs.html#1650 const fn is_unicast_local_v6(ip: &std::net::Ipv6Addr) -> bool { - ip.segments()[0] & 0xfe00 == 0xfc00 + (ip.segments()[0] & 0xfe00) == 0xfc00 } const fn is_unicast_link_local_v6(ip: &std::net::Ipv6Addr) -> bool { - ip.segments()[0] & 0xffc0 == 0xfe80 + (ip.segments()[0] & 0xffc0) == 0xfe80 } impl Resolve for SafeResolver { fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving { + let af = self.0; Box::pin(async move { - match tokio::net::lookup_host(format!("{}:80", name.as_str())).await { - Ok(lookup) => Ok(Box::new(lookup.filter(|addr| match addr { - std::net::SocketAddr::V4(a) => { - !a.ip().is_loopback() - && !a.ip().is_private() - && !a.ip().is_link_local() - && !a.ip().is_multicast() - && !a.ip().is_documentation() - && !a.ip().is_unspecified() - } + log::trace!("Resolving {}", name.as_str()); + match tokio::net::lookup_host(format!("{}:443", name.as_str())).await { + Ok(lookup) => Ok(Box::new( + lookup + .map(|addr| match addr { + std::net::SocketAddr::V6(a) => { + if let Some(v4) = a.ip().to_ipv4() { + std::net::SocketAddr::V4(SocketAddrV4::new(v4, a.port())) + } else { + std::net::SocketAddr::V6(a) + } + } + o => o, + }) + .filter(move |addr| match addr { + std::net::SocketAddr::V4(a) if af != AddrFamilyConfig::V6Only => { + log::trace!("Resolved v4 addr {}", a); + !a.ip().is_loopback() + && !a.ip().is_private() + && !a.ip().is_link_local() + && !a.ip().is_multicast() + && !a.ip().is_documentation() + && !a.ip().is_unspecified() + } - std::net::SocketAddr::V6(a) => { - !a.ip().is_loopback() - && !a.ip().is_multicast() - && !a.ip().is_unspecified() - && is_unicast_local_v6(a.ip()) - && !is_unicast_link_local_v6(a.ip()) - && a.ip().to_ipv4_mapped().is_none() - } - })) + std::net::SocketAddr::V6(a) if af != AddrFamilyConfig::V4Only => { + log::trace!("Resolved v6 addr {}", a); + !a.ip().is_loopback() + && !a.ip().is_multicast() + && !a.ip().is_unspecified() + && !is_unicast_local_v6(a.ip()) + && !is_unicast_link_local_v6(a.ip()) + } + + _ => false, + }), + ) as Box + Send>), Err(e) => { log::error!("Failed to resolve {}: {}", name.as_str(), e); @@ -259,7 +279,7 @@ pub mod reqwest { via_ident: config.via.clone(), client: ClientBuilder::new() .https_only(!config.allow_http) - .dns_resolver(Arc::new(SafeResolver())) + .dns_resolver(Arc::new(SafeResolver(config.addr_family))) .brotli(true) .zstd(true) .gzip(true) diff --git a/src/lib.rs b/src/lib.rs index 05f238b..9bc3c32 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -606,6 +606,8 @@ impl App { } } + log::info!("Proxying {}, options: {:?}", query.url, options); + let resp = state .client .request_upstream(&info, &query.url, false, true, DEFAULT_MAX_REDIRECTS) diff --git a/src/main.rs b/src/main.rs index 23d8406..b3c5dd0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,9 @@ struct Cli { #[tokio::main] async fn main() { + if std::env::var("RUST_LOG").is_err() { + std::env::set_var("RUST_LOG", "info"); + } env_logger::init(); let cli = Cli::parse(); diff --git a/src/post_process/mod.rs b/src/post_process/mod.rs index 81d07a5..5c9488b 100644 --- a/src/post_process/mod.rs +++ b/src/post_process/mod.rs @@ -238,6 +238,8 @@ where buf.extend_from_slice(bytes.as_ref()); } + let slurp_dur = slurp_begin.elapsed(); + let output_static_format = if options .format .as_deref() @@ -271,10 +273,7 @@ where is_https, }) .with_timing_info(TIME_TO_FIRST_BYTE_KEY, ttfb) - .with_opt_timing_info( - SLURP_TIMING_KEY, - Some(slurp_begin.elapsed()), - ) + .with_opt_timing_info(SLURP_TIMING_KEY, Some(slurp_dur)) .with_timing_info(TIMING_KEY, begin.elapsed())) } None => Ok(MediaResponse::Buffer { @@ -282,10 +281,7 @@ where content_type: Some("image/png".into()), } .with_timing_info(TIME_TO_FIRST_BYTE_KEY, ttfb) - .with_opt_timing_info( - SLURP_TIMING_KEY, - Some(slurp_begin.elapsed()), - ) + .with_opt_timing_info(SLURP_TIMING_KEY, Some(slurp_dur)) .with_timing_info(TIMING_KEY, begin.elapsed())), }; } @@ -305,10 +301,7 @@ where is_https, }) .with_timing_info(TIME_TO_FIRST_BYTE_KEY, ttfb) - .with_opt_timing_info( - SLURP_TIMING_KEY, - Some(slurp_begin.elapsed()), - ) + .with_opt_timing_info(SLURP_TIMING_KEY, Some(slurp_dur)) .with_timing_info(TIMING_KEY, begin.elapsed())) } None => Ok(MediaResponse::Buffer { @@ -316,10 +309,7 @@ where content_type: Some("image/webp".into()), } .with_timing_info(TIME_TO_FIRST_BYTE_KEY, ttfb) - .with_opt_timing_info( - SLURP_TIMING_KEY, - Some(slurp_begin.elapsed()), - ) + .with_opt_timing_info(SLURP_TIMING_KEY, Some(slurp_dur)) .with_timing_info(TIMING_KEY, begin.elapsed())), }; } @@ -336,7 +326,7 @@ where is_https, }) .with_timing_info(TIME_TO_FIRST_BYTE_KEY, ttfb) - .with_opt_timing_info(SLURP_TIMING_KEY, Some(slurp_begin.elapsed())) + .with_opt_timing_info(SLURP_TIMING_KEY, Some(slurp_dur)) .with_timing_info(TIMING_KEY, begin.elapsed())) } else { Ok(MediaResponse::PassThru(PassThru {