properly detect x_forwarded_for header
Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
parent
c91adb0995
commit
58bfea6643
3 changed files with 33 additions and 9 deletions
|
@ -20,5 +20,6 @@ normalization = "lazy"
|
|||
allow_svg_passthrough = false
|
||||
|
||||
[rate_limit]
|
||||
max_x_forwarded_for = 0
|
||||
replenish_every = 2000
|
||||
burst = 32
|
|
@ -75,6 +75,8 @@ pub struct AppArmorConfig {
|
|||
#[cfg(feature = "governor")]
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct RateLimitConfig {
|
||||
/// The maximum number of X-Forwarded-For headers to allow
|
||||
pub max_x_forwarded_for: u8,
|
||||
/// The rate limit replenish interval in milliseconds
|
||||
pub replenish_every: u64,
|
||||
/// The rate limit burst size
|
||||
|
@ -130,6 +132,7 @@ impl Default for Config {
|
|||
},
|
||||
#[cfg(feature = "governor")]
|
||||
rate_limit: RateLimitConfig {
|
||||
max_x_forwarded_for: 0,
|
||||
replenish_every: 2000,
|
||||
burst: NonZero::new(32).unwrap(),
|
||||
},
|
||||
|
|
38
src/lib.rs
38
src/lib.rs
|
@ -6,7 +6,7 @@
|
|||
|
||||
#[cfg(feature = "governor")]
|
||||
use std::net::SocketAddr;
|
||||
use std::{borrow::Cow, fmt::Display, marker::PhantomData, sync::Arc};
|
||||
use std::{borrow::Cow, fmt::Display, marker::PhantomData, net::IpAddr, sync::Arc};
|
||||
|
||||
#[cfg(feature = "governor")]
|
||||
use axum::extract::ConnectInfo;
|
||||
|
@ -226,9 +226,33 @@ pub async fn rate_limit_middleware<S: Sandboxing + Send + Sync + 'static>(
|
|||
request: axum::extract::Request,
|
||||
next: axum::middleware::Next,
|
||||
) -> Response {
|
||||
use std::time::SystemTime;
|
||||
use std::{
|
||||
net::{IpAddr, Ipv6Addr},
|
||||
time::SystemTime,
|
||||
};
|
||||
|
||||
match state.limiter.check_key(&addr) {
|
||||
let forwarded_ip = if state.config.rate_limit.max_x_forwarded_for > 0 {
|
||||
std::iter::repeat(addr.ip())
|
||||
.chain(
|
||||
request
|
||||
.headers()
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|x| x.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.split(',')
|
||||
.map(|x| x.trim().parse().ok())
|
||||
.flatten(),
|
||||
)
|
||||
.nth_back(state.config.rate_limit.max_x_forwarded_for as usize - 1)
|
||||
.map(|addr| match addr {
|
||||
IpAddr::V6(addr) => IpAddr::V6(addr & Ipv6Addr::from_bits(!0u128 << 64)),
|
||||
addr => addr,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
match state.limiter.check_key(&forwarded_ip.unwrap_or(addr.ip())) {
|
||||
Ok(ok) => {
|
||||
let mut resp = next.run(request).await;
|
||||
|
||||
|
@ -619,12 +643,8 @@ impl IntoResponse for ErrorResponse {
|
|||
#[allow(unused)]
|
||||
pub struct AppState<C: UpstreamClient, S: Sandboxing> {
|
||||
#[cfg(feature = "governor")]
|
||||
limiter: RateLimiter<
|
||||
SocketAddr,
|
||||
DashMapStateStore<SocketAddr>,
|
||||
SystemClock,
|
||||
StateInformationMiddleware,
|
||||
>,
|
||||
limiter:
|
||||
RateLimiter<IpAddr, DashMapStateStore<IpAddr>, SystemClock, StateInformationMiddleware>,
|
||||
config: Config,
|
||||
client: C,
|
||||
sandbox: S,
|
||||
|
|
Loading…
Reference in a new issue