properly detect x_forwarded_for header

Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-11-15 14:28:16 -06:00
parent c91adb0995
commit 58bfea6643
No known key found for this signature in database
3 changed files with 33 additions and 9 deletions

View file

@ -20,5 +20,6 @@ normalization = "lazy"
allow_svg_passthrough = false
[rate_limit]
max_x_forwarded_for = 0
replenish_every = 2000
burst = 32

View file

@ -75,6 +75,8 @@ pub struct AppArmorConfig {
#[cfg(feature = "governor")]
#[derive(Debug, Clone, serde::Deserialize)]
pub struct RateLimitConfig {
/// The maximum number of X-Forwarded-For headers to allow
pub max_x_forwarded_for: u8,
/// The rate limit replenish interval in milliseconds
pub replenish_every: u64,
/// The rate limit burst size
@ -130,6 +132,7 @@ impl Default for Config {
},
#[cfg(feature = "governor")]
rate_limit: RateLimitConfig {
max_x_forwarded_for: 0,
replenish_every: 2000,
burst: NonZero::new(32).unwrap(),
},

View file

@ -6,7 +6,7 @@
#[cfg(feature = "governor")]
use std::net::SocketAddr;
use std::{borrow::Cow, fmt::Display, marker::PhantomData, sync::Arc};
use std::{borrow::Cow, fmt::Display, marker::PhantomData, net::IpAddr, sync::Arc};
#[cfg(feature = "governor")]
use axum::extract::ConnectInfo;
@ -226,9 +226,33 @@ pub async fn rate_limit_middleware<S: Sandboxing + Send + Sync + 'static>(
request: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
use std::time::SystemTime;
use std::{
net::{IpAddr, Ipv6Addr},
time::SystemTime,
};
match state.limiter.check_key(&addr) {
let forwarded_ip = if state.config.rate_limit.max_x_forwarded_for > 0 {
std::iter::repeat(addr.ip())
.chain(
request
.headers()
.get("x-forwarded-for")
.and_then(|x| x.to_str().ok())
.unwrap_or("")
.split(',')
.map(|x| x.trim().parse().ok())
.flatten(),
)
.nth_back(state.config.rate_limit.max_x_forwarded_for as usize - 1)
.map(|addr| match addr {
IpAddr::V6(addr) => IpAddr::V6(addr & Ipv6Addr::from_bits(!0u128 << 64)),
addr => addr,
})
} else {
None
};
match state.limiter.check_key(&forwarded_ip.unwrap_or(addr.ip())) {
Ok(ok) => {
let mut resp = next.run(request).await;
@ -619,12 +643,8 @@ impl IntoResponse for ErrorResponse {
#[allow(unused)]
pub struct AppState<C: UpstreamClient, S: Sandboxing> {
#[cfg(feature = "governor")]
limiter: RateLimiter<
SocketAddr,
DashMapStateStore<SocketAddr>,
SystemClock,
StateInformationMiddleware,
>,
limiter:
RateLimiter<IpAddr, DashMapStateStore<IpAddr>, SystemClock, StateInformationMiddleware>,
config: Config,
client: C,
sandbox: S,