diff --git a/Cargo.lock b/Cargo.lock index 9a029fe..d506ff1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -651,6 +651,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" + [[package]] name = "fontconfig-parser" version = "0.5.7" @@ -871,6 +877,11 @@ name = "hashbrown" version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] [[package]] name = "heck" @@ -1334,6 +1345,15 @@ dependencies = [ "imgref", ] +[[package]] +name = "lru" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" +dependencies = [ + "hashbrown 0.15.1", +] + [[package]] name = "matchit" version = "0.7.3" @@ -3200,6 +3220,7 @@ dependencies = [ "chumsky", "clap", "console_error_panic_hook", + "dashmap", "env_logger", "fontdb", "futures", @@ -3208,6 +3229,7 @@ dependencies = [ "image", "libc", "log", + "lru", "quote", "reqwest", "resvg", diff --git a/Cargo.toml b/Cargo.toml index 9f6ed34..39a6e3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,6 +81,8 @@ fontdb = { version = "0.23", optional = true } webp = { version = "0.3.0", optional = true } url = { version = "2", optional = true } tower-http = { version = "0.6.2", features = ["catch-panic", "timeout"], optional = true } +dashmap = "6.1.0" +lru = "0.12.5" [patch.crates-io] # licensing and webp dependencies diff --git a/README.md b/README.md index 126b4a4..95c2f00 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Work in progress! Currently to do: - [X] Rate-limiting on local deployment (untested) - [X] Read config from Cloudflare - [X] Timing and Rate-limiting headers (some not available on Cloudflare Workers) + - [X] Tiered rate-limiting - [ ] Lossy WebP on CF Workers - [ ] Cache Results on Cloudflare KV. - [ ] Handle all possible panics reported by Clippy diff --git a/local.toml b/local.toml index 26ed25c..9b21503 100644 --- a/local.toml +++ b/local.toml @@ -2,6 +2,7 @@ listen = "127.0.0.1:3000" enable_cache = false index_redirect = { permanent = false, url = "https://mi.yumechi.jp/" } allow_unknown = false +max_x_forwarded_for = 0 # you need AppArmor and the policy loaded to use this [sandbox.apparmor] @@ -19,7 +20,18 @@ enable_redirects = false normalization = "lazy" allow_svg_passthrough = false -[rate_limit] -max_x_forwarded_for = 0 +[[rate_limit]] +replenish_every = 50 +burst = 256 + +[[rate_limit]] +key = "500ms" +min_request_duration = 500 replenish_every = 200 -burst = 64 \ No newline at end of file +burst = 64 + +[[rate_limit]] +key = "1500ms" +min_request_duration = 1500 +replenish_every = 1500 +burst = 16 diff --git a/src/config.rs b/src/config.rs index a195578..6d9fa77 100644 --- a/src/config.rs +++ b/src/config.rs @@ -37,9 +37,13 @@ pub struct Config { /// Post-processing configuration pub post_process: PostProcessConfig, + #[cfg(not(feature = "cf-worker"))] + /// The maximum number of X-Forwarded-For headers to allow + pub max_x_forwarded_for: u8, + #[cfg(feature = "governor")] /// Governor configuration - pub rate_limit: RateLimitConfig, + pub rate_limit: Vec, } /// Sandbox configuration @@ -75,12 +79,15 @@ pub struct AppArmorConfig { #[cfg(feature = "governor")] #[derive(Debug, Clone, serde::Deserialize)] pub struct RateLimitConfig { - /// The maximum number of X-Forwarded-For headers to allow - pub max_x_forwarded_for: u8, + /// The key to use for rate limiting headers + pub key: Option, /// The rate limit replenish interval in milliseconds pub replenish_every: u64, /// The rate limit burst size pub burst: NonZero, + + /// The minimum request duration in milliseconds for this rate limit to apply + pub min_request_duration: Option, } #[cfg(feature = "cf-worker")] @@ -130,12 +137,15 @@ impl Default for Config { normalization: NormalizationPolicy::Opportunistic, allow_svg_passthrough: false, }, + #[cfg(not(feature = "cf-worker"))] + max_x_forwarded_for: 0, #[cfg(feature = "governor")] - rate_limit: RateLimitConfig { - max_x_forwarded_for: 0, + rate_limit: vec![RateLimitConfig { + key: None, replenish_every: 2000, burst: NonZero::new(32).unwrap(), - }, + min_request_duration: None, + }], } } } diff --git a/src/lib.rs b/src/lib.rs index c8740f4..5ffd964 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,12 @@ #[cfg(feature = "governor")] use std::net::SocketAddr; -use std::{borrow::Cow, fmt::Display, marker::PhantomData, sync::Arc}; +use std::{ + borrow::Cow, + fmt::Display, + marker::PhantomData, + sync::{atomic::AtomicU64, Arc, RwLock}, +}; #[cfg(feature = "governor")] use axum::extract::ConnectInfo; @@ -24,6 +29,7 @@ use governor::{ clock::SystemClock, middleware::StateInformationMiddleware, state::keyed::DashMapStateStore, RateLimiter, }; +use lru::LruCache; use post_process::{CompressionLevel, MediaResponse}; use sandbox::Sandboxing; @@ -110,18 +116,29 @@ where }; #[cfg(feature = "governor")] use std::time::Duration; + use std::{num::NonZero, sync::RwLock}; #[cfg(not(feature = "cf-worker"))] use tower_http::{catch_panic::CatchPanicLayer, timeout::TimeoutLayer}; let state = AppState { #[cfg(feature = "governor")] - limiter: RateLimiter::dashmap_with_clock( - Quota::with_period(Duration::from_millis(config.rate_limit.replenish_every)) - .unwrap() - .allow_burst(config.rate_limit.burst), - SystemClock, - ) - .with_middleware::(), + limiters: config + .rate_limit + .iter() + .map(|x| { + ( + RateLimiter::dashmap_with_clock( + Quota::with_period(Duration::from_millis(x.replenish_every)) + .unwrap() + .allow_burst(x.burst), + SystemClock, + ) + .with_middleware::(), + RwLock::new(LruCache::new(NonZero::new(1024).unwrap())), + ) + }) + .collect::>() + .into_boxed_slice(), client: Upstream::new(&config.fetch), sandbox: S::new(&config.sandbox), config, @@ -178,6 +195,17 @@ where .layer(TimeoutLayer::new(Duration::from_secs(10))); } + #[cfg(feature = "governor")] + { + let state_gc = Arc::clone(&state); + std::thread::spawn(move || loop { + std::thread::sleep(Duration::from_secs(300)); + for (limiter, _) in state_gc.limiters.iter() { + limiter.retain_recent(); + } + }); + } + #[cfg(feature = "governor")] return router.route_layer(middleware::from_fn_with_state(state, rate_limit_middleware)); #[cfg(not(feature = "governor"))] @@ -232,6 +260,24 @@ pub async fn common_security_headers( resp } +fn atomic_u64_saturating_dec(credits: &AtomicU64) -> bool { + loop { + let current = credits.load(std::sync::atomic::Ordering::Relaxed); + if current == 0 { + return false; + } + if credits.compare_exchange_weak( + current, + current - 1, + std::sync::atomic::Ordering::Relaxed, + std::sync::atomic::Ordering::Relaxed, + ) == Ok(current) + { + return true; + } + } +} + /// Middleware for rate limiting #[cfg(feature = "governor")] #[cfg_attr(feature = "cf-worker", worker::send)] @@ -246,7 +292,9 @@ pub async fn rate_limit_middleware( time::SystemTime, }; - let forwarded_ip = if state.config.rate_limit.max_x_forwarded_for > 0 { + use axum::http::HeaderName; + + let forwarded_ip = if state.config.max_x_forwarded_for > 0 { std::iter::repeat(addr.ip()) .chain( request @@ -257,7 +305,7 @@ pub async fn rate_limit_middleware( .split(',') .filter_map(|x| x.trim().parse().ok()), ) - .nth_back(state.config.rate_limit.max_x_forwarded_for as usize - 1) + .nth_back(state.config.max_x_forwarded_for as usize - 1) .map(|addr| match addr { IpAddr::V6(addr) => IpAddr::V6(addr & Ipv6Addr::from_bits(!0u128 >> 64)), addr => addr, @@ -266,30 +314,90 @@ pub async fn rate_limit_middleware( None }; - match state.limiter.check_key(&forwarded_ip.unwrap_or(addr.ip())) { - Ok(ok) => { + let real_ip = forwarded_ip.unwrap_or_else(|| match addr.ip() { + IpAddr::V6(addr) => IpAddr::V6(addr & Ipv6Addr::from_bits(!0u128 >> 64)), + addr => addr, + }); + + let mut res = Vec::with_capacity(state.limiters.len()); + match state.limiters.iter().fold(Ok(()), |acc, limiter| { + acc.and_then(|_| { + // if we have credits, we don't need to check the key + let credits = limiter.1.read().unwrap(); + if let Some(credits_val) = credits.peek(&forwarded_ip.unwrap_or(real_ip)) { + if atomic_u64_saturating_dec(credits_val) { + res.push(None); + return Ok(()); + } + + drop(credits); + let mut credits = limiter.1.write().unwrap(); + credits.pop(&forwarded_ip.unwrap_or(real_ip)); + } + + limiter + .0 + .check_key(&forwarded_ip.unwrap_or(real_ip)) + .map(|x| res.push(Some(x))) + }) + }) { + Ok(_) => { + let begin = SystemTime::now(); let mut resp = next.run(request).await; + let elapsed = begin.elapsed().unwrap().as_millis() as u64; + + // credit back limits that didn't get used + for (config, (_, credits)) in state.config.rate_limit.iter().zip(state.limiters.iter()) + { + if elapsed < config.min_request_duration.unwrap_or(0) { + let mut credits = credits.write().unwrap(); + credits + .get_or_insert(forwarded_ip.unwrap_or(real_ip), || 0.into()) + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + } let headers = resp.headers_mut(); - headers.insert( - "X-RateLimit-Limit", - #[allow(clippy::unwrap_used)] - ok.quota().burst_size().to_string().parse().unwrap(), - ); - headers.insert( - "X-RateLimit-Replenish-Interval", + + for (config, snapshot) in state .config .rate_limit - .replenish_every - .to_string() - .parse() - .unwrap(), - ); - headers.insert( - "X-RateLimit-Remaining", - ok.remaining_burst_capacity().to_string().parse().unwrap(), - ); + .iter() + .zip(res.iter()) + .filter_map(|(config, snapshot)| match snapshot { + Some(snapshot) => Some((config, snapshot)), + None => None, + }) + { + let header_prefix = match config.key { + Some(ref key) => format!("X-{}-RateLimit-", key), + None => "X-RateLimit-".to_string(), + }; + let (header_limit, header_interval, header_remaining) = ( + format!("{}Limit", header_prefix), + format!("{}Replenish-Interval", header_prefix), + format!("{}Remaining", header_prefix), + ); + headers.insert( + HeaderName::from_bytes(header_limit.as_bytes()).unwrap(), + #[allow(clippy::unwrap_used)] + snapshot.quota().burst_size().to_string().parse().unwrap(), + ); + headers.insert( + HeaderName::from_bytes(header_interval.as_bytes()).unwrap(), + config.replenish_every.to_string().parse().unwrap(), + ); + headers.insert("X-RateLimit-Remaining", "0".parse().unwrap()); + headers.insert( + HeaderName::from_bytes(header_remaining.as_bytes()).unwrap(), + snapshot + .remaining_burst_capacity() + .to_string() + .parse() + .unwrap(), + ); + } resp } @@ -298,29 +406,55 @@ pub async fn rate_limit_middleware( let mut resp = ErrorResponse::rate_limit_exceeded().into_response(); let headers = resp.headers_mut(); - headers.insert( - "X-RateLimit-Limit", - #[allow(clippy::unwrap_used)] - err.quota().burst_size().to_string().parse().unwrap(), - ); - headers.insert( - "X-RateLimit-Replenish-Interval", + + for (config, snapshot) in state .config .rate_limit - .replenish_every - .to_string() - .parse() - .unwrap(), - ); - headers.insert("X-RateLimit-Remaining", "0".parse().unwrap()); + .iter() + .zip(res.iter()) + .filter_map(|(config, snapshot)| match snapshot { + Some(snapshot) => Some((config, snapshot)), + None => None, + }) + { + let header_prefix = match config.key { + Some(ref key) => format!("X-{}-RateLimit-", key), + None => "X-RateLimit-".to_string(), + }; + let (header_limit, header_interval, header_remaining) = ( + format!("{}Limit", header_prefix), + format!("{}Replenish-Interval", header_prefix), + format!("{}Remaining", header_prefix), + ); + headers.insert( + HeaderName::from_bytes(header_limit.as_bytes()).unwrap(), + #[allow(clippy::unwrap_used)] + snapshot.quota().burst_size().to_string().parse().unwrap(), + ); + headers.insert( + HeaderName::from_bytes(header_interval.as_bytes()).unwrap(), + config.replenish_every.to_string().parse().unwrap(), + ); + + headers.insert("X-RateLimit-Remaining", "0".parse().unwrap()); + headers.insert( + HeaderName::from_bytes(header_remaining.as_bytes()).unwrap(), + snapshot + .remaining_burst_capacity() + .to_string() + .parse() + .unwrap(), + ); + } + headers.insert( "Retry-After", - err.wait_time_from(SystemTime::now()) + err.earliest_possible() + .duration_since(SystemTime::now()) + .unwrap() .as_secs() - .to_string() - .parse() - .unwrap(), + .into(), ); resp @@ -464,7 +598,8 @@ impl std::error::Error for ErrorResponse {} impl ErrorResponse { #[cfg(not(feature = "cf-worker"))] /// URL must be a DNS name - #[must_use] pub const fn non_dns_name() -> Self { + #[must_use] + pub const fn non_dns_name() -> Self { Self { status: StatusCode::BAD_REQUEST, message: Cow::Borrowed("URL must be a DNS name"), @@ -657,11 +792,16 @@ impl IntoResponse for ErrorResponse { #[allow(unused)] pub struct AppState { #[cfg(feature = "governor")] - limiter: RateLimiter< - std::net::IpAddr, - DashMapStateStore, - SystemClock, - StateInformationMiddleware, + limiters: Box< + [( + RateLimiter< + std::net::IpAddr, + DashMapStateStore, + SystemClock, + StateInformationMiddleware, + >, + RwLock>, + )], >, config: Config, client: C,