diff --git a/local.toml b/local.toml index c3c8b67..6440da4 100644 --- a/local.toml +++ b/local.toml @@ -20,5 +20,6 @@ normalization = "lazy" allow_svg_passthrough = false [rate_limit] +max_x_forwarded_for = 0 replenish_every = 2000 burst = 32 \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index fbf162a..a195578 100644 --- a/src/config.rs +++ b/src/config.rs @@ -75,6 +75,8 @@ pub struct AppArmorConfig { #[cfg(feature = "governor")] #[derive(Debug, Clone, serde::Deserialize)] pub struct RateLimitConfig { + /// The maximum number of X-Forwarded-For headers to allow + pub max_x_forwarded_for: u8, /// The rate limit replenish interval in milliseconds pub replenish_every: u64, /// The rate limit burst size @@ -130,6 +132,7 @@ impl Default for Config { }, #[cfg(feature = "governor")] rate_limit: RateLimitConfig { + max_x_forwarded_for: 0, replenish_every: 2000, burst: NonZero::new(32).unwrap(), }, diff --git a/src/lib.rs b/src/lib.rs index 96246a3..b6830e5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,7 @@ #[cfg(feature = "governor")] use std::net::SocketAddr; -use std::{borrow::Cow, fmt::Display, marker::PhantomData, sync::Arc}; +use std::{borrow::Cow, fmt::Display, marker::PhantomData, net::IpAddr, sync::Arc}; #[cfg(feature = "governor")] use axum::extract::ConnectInfo; @@ -226,9 +226,33 @@ pub async fn rate_limit_middleware( request: axum::extract::Request, next: axum::middleware::Next, ) -> Response { - use std::time::SystemTime; + use std::{ + net::{IpAddr, Ipv6Addr}, + time::SystemTime, + }; - match state.limiter.check_key(&addr) { + let forwarded_ip = if state.config.rate_limit.max_x_forwarded_for > 0 { + std::iter::repeat(addr.ip()) + .chain( + request + .headers() + .get("x-forwarded-for") + .and_then(|x| x.to_str().ok()) + .unwrap_or("") + .split(',') + .map(|x| x.trim().parse().ok()) + .flatten(), + ) + .nth_back(state.config.rate_limit.max_x_forwarded_for as usize - 1) + .map(|addr| match addr { + IpAddr::V6(addr) => IpAddr::V6(addr & Ipv6Addr::from_bits(!0u128 << 64)), + addr => addr, + }) + } else { + None + }; + + match state.limiter.check_key(&forwarded_ip.unwrap_or(addr.ip())) { Ok(ok) => { let mut resp = next.run(request).await; @@ -619,12 +643,8 @@ impl IntoResponse for ErrorResponse { #[allow(unused)] pub struct AppState { #[cfg(feature = "governor")] - limiter: RateLimiter< - SocketAddr, - DashMapStateStore, - SystemClock, - StateInformationMiddleware, - >, + limiter: + RateLimiter, SystemClock, StateInformationMiddleware>, config: Config, client: C, sandbox: S,