tiered rate limiting

Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-11-19 03:23:28 -06:00
parent 5f2cd3ade7
commit 4c98ae337b
No known key found for this signature in database
6 changed files with 247 additions and 60 deletions

22
Cargo.lock generated
View file

@ -651,6 +651,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
[[package]] [[package]]
name = "fontconfig-parser" name = "fontconfig-parser"
version = "0.5.7" version = "0.5.7"
@ -871,6 +877,11 @@ name = "hashbrown"
version = "0.15.1" version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash",
]
[[package]] [[package]]
name = "heck" name = "heck"
@ -1334,6 +1345,15 @@ dependencies = [
"imgref", "imgref",
] ]
[[package]]
name = "lru"
version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38"
dependencies = [
"hashbrown 0.15.1",
]
[[package]] [[package]]
name = "matchit" name = "matchit"
version = "0.7.3" version = "0.7.3"
@ -3200,6 +3220,7 @@ dependencies = [
"chumsky", "chumsky",
"clap", "clap",
"console_error_panic_hook", "console_error_panic_hook",
"dashmap",
"env_logger", "env_logger",
"fontdb", "fontdb",
"futures", "futures",
@ -3208,6 +3229,7 @@ dependencies = [
"image", "image",
"libc", "libc",
"log", "log",
"lru",
"quote", "quote",
"reqwest", "reqwest",
"resvg", "resvg",

View file

@ -81,6 +81,8 @@ fontdb = { version = "0.23", optional = true }
webp = { version = "0.3.0", optional = true } webp = { version = "0.3.0", optional = true }
url = { version = "2", optional = true } url = { version = "2", optional = true }
tower-http = { version = "0.6.2", features = ["catch-panic", "timeout"], optional = true } tower-http = { version = "0.6.2", features = ["catch-panic", "timeout"], optional = true }
dashmap = "6.1.0"
lru = "0.12.5"
[patch.crates-io] [patch.crates-io]
# licensing and webp dependencies # licensing and webp dependencies

View file

@ -16,6 +16,7 @@ Work in progress! Currently to do:
- [X] Rate-limiting on local deployment (untested) - [X] Rate-limiting on local deployment (untested)
- [X] Read config from Cloudflare - [X] Read config from Cloudflare
- [X] Timing and Rate-limiting headers (some not available on Cloudflare Workers) - [X] Timing and Rate-limiting headers (some not available on Cloudflare Workers)
- [X] Tiered rate-limiting
- [ ] Lossy WebP on CF Workers - [ ] Lossy WebP on CF Workers
- [ ] Cache Results on Cloudflare KV. - [ ] Cache Results on Cloudflare KV.
- [ ] Handle all possible panics reported by Clippy - [ ] Handle all possible panics reported by Clippy

View file

@ -2,6 +2,7 @@ listen = "127.0.0.1:3000"
enable_cache = false enable_cache = false
index_redirect = { permanent = false, url = "https://mi.yumechi.jp/" } index_redirect = { permanent = false, url = "https://mi.yumechi.jp/" }
allow_unknown = false allow_unknown = false
max_x_forwarded_for = 0
# you need AppArmor and the policy loaded to use this # you need AppArmor and the policy loaded to use this
[sandbox.apparmor] [sandbox.apparmor]
@ -19,7 +20,18 @@ enable_redirects = false
normalization = "lazy" normalization = "lazy"
allow_svg_passthrough = false allow_svg_passthrough = false
[rate_limit] [[rate_limit]]
max_x_forwarded_for = 0 replenish_every = 50
burst = 256
[[rate_limit]]
key = "500ms"
min_request_duration = 500
replenish_every = 200 replenish_every = 200
burst = 64 burst = 64
[[rate_limit]]
key = "1500ms"
min_request_duration = 1500
replenish_every = 1500
burst = 16

View file

@ -37,9 +37,13 @@ pub struct Config {
/// Post-processing configuration /// Post-processing configuration
pub post_process: PostProcessConfig, pub post_process: PostProcessConfig,
#[cfg(not(feature = "cf-worker"))]
/// The maximum number of X-Forwarded-For headers to allow
pub max_x_forwarded_for: u8,
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
/// Governor configuration /// Governor configuration
pub rate_limit: RateLimitConfig, pub rate_limit: Vec<RateLimitConfig>,
} }
/// Sandbox configuration /// Sandbox configuration
@ -75,12 +79,15 @@ pub struct AppArmorConfig {
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
#[derive(Debug, Clone, serde::Deserialize)] #[derive(Debug, Clone, serde::Deserialize)]
pub struct RateLimitConfig { pub struct RateLimitConfig {
/// The maximum number of X-Forwarded-For headers to allow /// The key to use for rate limiting headers
pub max_x_forwarded_for: u8, pub key: Option<String>,
/// The rate limit replenish interval in milliseconds /// The rate limit replenish interval in milliseconds
pub replenish_every: u64, pub replenish_every: u64,
/// The rate limit burst size /// The rate limit burst size
pub burst: NonZero<u32>, pub burst: NonZero<u32>,
/// The minimum request duration in milliseconds for this rate limit to apply
pub min_request_duration: Option<u64>,
} }
#[cfg(feature = "cf-worker")] #[cfg(feature = "cf-worker")]
@ -130,12 +137,15 @@ impl Default for Config {
normalization: NormalizationPolicy::Opportunistic, normalization: NormalizationPolicy::Opportunistic,
allow_svg_passthrough: false, allow_svg_passthrough: false,
}, },
#[cfg(feature = "governor")] #[cfg(not(feature = "cf-worker"))]
rate_limit: RateLimitConfig {
max_x_forwarded_for: 0, max_x_forwarded_for: 0,
#[cfg(feature = "governor")]
rate_limit: vec![RateLimitConfig {
key: None,
replenish_every: 2000, replenish_every: 2000,
burst: NonZero::new(32).unwrap(), burst: NonZero::new(32).unwrap(),
}, min_request_duration: None,
}],
} }
} }
} }

View file

@ -6,7 +6,12 @@
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
use std::net::SocketAddr; use std::net::SocketAddr;
use std::{borrow::Cow, fmt::Display, marker::PhantomData, sync::Arc}; use std::{
borrow::Cow,
fmt::Display,
marker::PhantomData,
sync::{atomic::AtomicU64, Arc, RwLock},
};
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
use axum::extract::ConnectInfo; use axum::extract::ConnectInfo;
@ -24,6 +29,7 @@ use governor::{
clock::SystemClock, middleware::StateInformationMiddleware, state::keyed::DashMapStateStore, clock::SystemClock, middleware::StateInformationMiddleware, state::keyed::DashMapStateStore,
RateLimiter, RateLimiter,
}; };
use lru::LruCache;
use post_process::{CompressionLevel, MediaResponse}; use post_process::{CompressionLevel, MediaResponse};
use sandbox::Sandboxing; use sandbox::Sandboxing;
@ -110,18 +116,29 @@ where
}; };
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
use std::time::Duration; use std::time::Duration;
use std::{num::NonZero, sync::RwLock};
#[cfg(not(feature = "cf-worker"))] #[cfg(not(feature = "cf-worker"))]
use tower_http::{catch_panic::CatchPanicLayer, timeout::TimeoutLayer}; use tower_http::{catch_panic::CatchPanicLayer, timeout::TimeoutLayer};
let state = AppState { let state = AppState {
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
limiter: RateLimiter::dashmap_with_clock( limiters: config
Quota::with_period(Duration::from_millis(config.rate_limit.replenish_every)) .rate_limit
.iter()
.map(|x| {
(
RateLimiter::dashmap_with_clock(
Quota::with_period(Duration::from_millis(x.replenish_every))
.unwrap() .unwrap()
.allow_burst(config.rate_limit.burst), .allow_burst(x.burst),
SystemClock, SystemClock,
) )
.with_middleware::<StateInformationMiddleware>(), .with_middleware::<StateInformationMiddleware>(),
RwLock::new(LruCache::new(NonZero::new(1024).unwrap())),
)
})
.collect::<Vec<_>>()
.into_boxed_slice(),
client: Upstream::new(&config.fetch), client: Upstream::new(&config.fetch),
sandbox: S::new(&config.sandbox), sandbox: S::new(&config.sandbox),
config, config,
@ -178,6 +195,17 @@ where
.layer(TimeoutLayer::new(Duration::from_secs(10))); .layer(TimeoutLayer::new(Duration::from_secs(10)));
} }
#[cfg(feature = "governor")]
{
let state_gc = Arc::clone(&state);
std::thread::spawn(move || loop {
std::thread::sleep(Duration::from_secs(300));
for (limiter, _) in state_gc.limiters.iter() {
limiter.retain_recent();
}
});
}
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
return router.route_layer(middleware::from_fn_with_state(state, rate_limit_middleware)); return router.route_layer(middleware::from_fn_with_state(state, rate_limit_middleware));
#[cfg(not(feature = "governor"))] #[cfg(not(feature = "governor"))]
@ -232,6 +260,24 @@ pub async fn common_security_headers(
resp resp
} }
fn atomic_u64_saturating_dec(credits: &AtomicU64) -> bool {
loop {
let current = credits.load(std::sync::atomic::Ordering::Relaxed);
if current == 0 {
return false;
}
if credits.compare_exchange_weak(
current,
current - 1,
std::sync::atomic::Ordering::Relaxed,
std::sync::atomic::Ordering::Relaxed,
) == Ok(current)
{
return true;
}
}
}
/// Middleware for rate limiting /// Middleware for rate limiting
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
#[cfg_attr(feature = "cf-worker", worker::send)] #[cfg_attr(feature = "cf-worker", worker::send)]
@ -246,7 +292,9 @@ pub async fn rate_limit_middleware<S: Sandboxing + Send + Sync + 'static>(
time::SystemTime, time::SystemTime,
}; };
let forwarded_ip = if state.config.rate_limit.max_x_forwarded_for > 0 { use axum::http::HeaderName;
let forwarded_ip = if state.config.max_x_forwarded_for > 0 {
std::iter::repeat(addr.ip()) std::iter::repeat(addr.ip())
.chain( .chain(
request request
@ -257,7 +305,7 @@ pub async fn rate_limit_middleware<S: Sandboxing + Send + Sync + 'static>(
.split(',') .split(',')
.filter_map(|x| x.trim().parse().ok()), .filter_map(|x| x.trim().parse().ok()),
) )
.nth_back(state.config.rate_limit.max_x_forwarded_for as usize - 1) .nth_back(state.config.max_x_forwarded_for as usize - 1)
.map(|addr| match addr { .map(|addr| match addr {
IpAddr::V6(addr) => IpAddr::V6(addr & Ipv6Addr::from_bits(!0u128 >> 64)), IpAddr::V6(addr) => IpAddr::V6(addr & Ipv6Addr::from_bits(!0u128 >> 64)),
addr => addr, addr => addr,
@ -266,30 +314,90 @@ pub async fn rate_limit_middleware<S: Sandboxing + Send + Sync + 'static>(
None None
}; };
match state.limiter.check_key(&forwarded_ip.unwrap_or(addr.ip())) { let real_ip = forwarded_ip.unwrap_or_else(|| match addr.ip() {
Ok(ok) => { IpAddr::V6(addr) => IpAddr::V6(addr & Ipv6Addr::from_bits(!0u128 >> 64)),
addr => addr,
});
let mut res = Vec::with_capacity(state.limiters.len());
match state.limiters.iter().fold(Ok(()), |acc, limiter| {
acc.and_then(|_| {
// if we have credits, we don't need to check the key
let credits = limiter.1.read().unwrap();
if let Some(credits_val) = credits.peek(&forwarded_ip.unwrap_or(real_ip)) {
if atomic_u64_saturating_dec(credits_val) {
res.push(None);
return Ok(());
}
drop(credits);
let mut credits = limiter.1.write().unwrap();
credits.pop(&forwarded_ip.unwrap_or(real_ip));
}
limiter
.0
.check_key(&forwarded_ip.unwrap_or(real_ip))
.map(|x| res.push(Some(x)))
})
}) {
Ok(_) => {
let begin = SystemTime::now();
let mut resp = next.run(request).await; let mut resp = next.run(request).await;
let elapsed = begin.elapsed().unwrap().as_millis() as u64;
// credit back limits that didn't get used
for (config, (_, credits)) in state.config.rate_limit.iter().zip(state.limiters.iter())
{
if elapsed < config.min_request_duration.unwrap_or(0) {
let mut credits = credits.write().unwrap();
credits
.get_or_insert(forwarded_ip.unwrap_or(real_ip), || 0.into())
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
let headers = resp.headers_mut(); let headers = resp.headers_mut();
headers.insert(
"X-RateLimit-Limit", for (config, snapshot) in
#[allow(clippy::unwrap_used)]
ok.quota().burst_size().to_string().parse().unwrap(),
);
headers.insert(
"X-RateLimit-Replenish-Interval",
state state
.config .config
.rate_limit .rate_limit
.replenish_every .iter()
.zip(res.iter())
.filter_map(|(config, snapshot)| match snapshot {
Some(snapshot) => Some((config, snapshot)),
None => None,
})
{
let header_prefix = match config.key {
Some(ref key) => format!("X-{}-RateLimit-", key),
None => "X-RateLimit-".to_string(),
};
let (header_limit, header_interval, header_remaining) = (
format!("{}Limit", header_prefix),
format!("{}Replenish-Interval", header_prefix),
format!("{}Remaining", header_prefix),
);
headers.insert(
HeaderName::from_bytes(header_limit.as_bytes()).unwrap(),
#[allow(clippy::unwrap_used)]
snapshot.quota().burst_size().to_string().parse().unwrap(),
);
headers.insert(
HeaderName::from_bytes(header_interval.as_bytes()).unwrap(),
config.replenish_every.to_string().parse().unwrap(),
);
headers.insert("X-RateLimit-Remaining", "0".parse().unwrap());
headers.insert(
HeaderName::from_bytes(header_remaining.as_bytes()).unwrap(),
snapshot
.remaining_burst_capacity()
.to_string() .to_string()
.parse() .parse()
.unwrap(), .unwrap(),
); );
headers.insert( }
"X-RateLimit-Remaining",
ok.remaining_burst_capacity().to_string().parse().unwrap(),
);
resp resp
} }
@ -298,29 +406,55 @@ pub async fn rate_limit_middleware<S: Sandboxing + Send + Sync + 'static>(
let mut resp = ErrorResponse::rate_limit_exceeded().into_response(); let mut resp = ErrorResponse::rate_limit_exceeded().into_response();
let headers = resp.headers_mut(); let headers = resp.headers_mut();
headers.insert(
"X-RateLimit-Limit", for (config, snapshot) in
#[allow(clippy::unwrap_used)]
err.quota().burst_size().to_string().parse().unwrap(),
);
headers.insert(
"X-RateLimit-Replenish-Interval",
state state
.config .config
.rate_limit .rate_limit
.replenish_every .iter()
.zip(res.iter())
.filter_map(|(config, snapshot)| match snapshot {
Some(snapshot) => Some((config, snapshot)),
None => None,
})
{
let header_prefix = match config.key {
Some(ref key) => format!("X-{}-RateLimit-", key),
None => "X-RateLimit-".to_string(),
};
let (header_limit, header_interval, header_remaining) = (
format!("{}Limit", header_prefix),
format!("{}Replenish-Interval", header_prefix),
format!("{}Remaining", header_prefix),
);
headers.insert(
HeaderName::from_bytes(header_limit.as_bytes()).unwrap(),
#[allow(clippy::unwrap_used)]
snapshot.quota().burst_size().to_string().parse().unwrap(),
);
headers.insert(
HeaderName::from_bytes(header_interval.as_bytes()).unwrap(),
config.replenish_every.to_string().parse().unwrap(),
);
headers.insert("X-RateLimit-Remaining", "0".parse().unwrap());
headers.insert(
HeaderName::from_bytes(header_remaining.as_bytes()).unwrap(),
snapshot
.remaining_burst_capacity()
.to_string() .to_string()
.parse() .parse()
.unwrap(), .unwrap(),
); );
headers.insert("X-RateLimit-Remaining", "0".parse().unwrap()); }
headers.insert( headers.insert(
"Retry-After", "Retry-After",
err.wait_time_from(SystemTime::now()) err.earliest_possible()
.duration_since(SystemTime::now())
.unwrap()
.as_secs() .as_secs()
.to_string() .into(),
.parse()
.unwrap(),
); );
resp resp
@ -464,7 +598,8 @@ impl std::error::Error for ErrorResponse {}
impl ErrorResponse { impl ErrorResponse {
#[cfg(not(feature = "cf-worker"))] #[cfg(not(feature = "cf-worker"))]
/// URL must be a DNS name /// URL must be a DNS name
#[must_use] pub const fn non_dns_name() -> Self { #[must_use]
pub const fn non_dns_name() -> Self {
Self { Self {
status: StatusCode::BAD_REQUEST, status: StatusCode::BAD_REQUEST,
message: Cow::Borrowed("URL must be a DNS name"), message: Cow::Borrowed("URL must be a DNS name"),
@ -657,12 +792,17 @@ impl IntoResponse for ErrorResponse {
#[allow(unused)] #[allow(unused)]
pub struct AppState<C: UpstreamClient, S: Sandboxing> { pub struct AppState<C: UpstreamClient, S: Sandboxing> {
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
limiter: RateLimiter< limiters: Box<
[(
RateLimiter<
std::net::IpAddr, std::net::IpAddr,
DashMapStateStore<std::net::IpAddr>, DashMapStateStore<std::net::IpAddr>,
SystemClock, SystemClock,
StateInformationMiddleware, StateInformationMiddleware,
>, >,
RwLock<LruCache<std::net::IpAddr, AtomicU64>>,
)],
>,
config: Config, config: Config,
client: C, client: C,
sandbox: S, sandbox: S,