tiered rate limiting
Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
parent
5f2cd3ade7
commit
4c98ae337b
6 changed files with 247 additions and 60 deletions
22
Cargo.lock
generated
22
Cargo.lock
generated
|
@ -651,6 +651,12 @@ version = "1.0.7"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
|
||||
[[package]]
|
||||
name = "foldhash"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
|
||||
|
||||
[[package]]
|
||||
name = "fontconfig-parser"
|
||||
version = "0.5.7"
|
||||
|
@ -871,6 +877,11 @@ name = "hashbrown"
|
|||
version = "0.15.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3"
|
||||
dependencies = [
|
||||
"allocator-api2",
|
||||
"equivalent",
|
||||
"foldhash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
|
@ -1334,6 +1345,15 @@ dependencies = [
|
|||
"imgref",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lru"
|
||||
version = "0.12.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38"
|
||||
dependencies = [
|
||||
"hashbrown 0.15.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.7.3"
|
||||
|
@ -3200,6 +3220,7 @@ dependencies = [
|
|||
"chumsky",
|
||||
"clap",
|
||||
"console_error_panic_hook",
|
||||
"dashmap",
|
||||
"env_logger",
|
||||
"fontdb",
|
||||
"futures",
|
||||
|
@ -3208,6 +3229,7 @@ dependencies = [
|
|||
"image",
|
||||
"libc",
|
||||
"log",
|
||||
"lru",
|
||||
"quote",
|
||||
"reqwest",
|
||||
"resvg",
|
||||
|
|
|
@ -81,6 +81,8 @@ fontdb = { version = "0.23", optional = true }
|
|||
webp = { version = "0.3.0", optional = true }
|
||||
url = { version = "2", optional = true }
|
||||
tower-http = { version = "0.6.2", features = ["catch-panic", "timeout"], optional = true }
|
||||
dashmap = "6.1.0"
|
||||
lru = "0.12.5"
|
||||
|
||||
[patch.crates-io]
|
||||
# licensing and webp dependencies
|
||||
|
|
|
@ -16,6 +16,7 @@ Work in progress! Currently to do:
|
|||
- [X] Rate-limiting on local deployment (untested)
|
||||
- [X] Read config from Cloudflare
|
||||
- [X] Timing and Rate-limiting headers (some not available on Cloudflare Workers)
|
||||
- [X] Tiered rate-limiting
|
||||
- [ ] Lossy WebP on CF Workers
|
||||
- [ ] Cache Results on Cloudflare KV.
|
||||
- [ ] Handle all possible panics reported by Clippy
|
||||
|
|
16
local.toml
16
local.toml
|
@ -2,6 +2,7 @@ listen = "127.0.0.1:3000"
|
|||
enable_cache = false
|
||||
index_redirect = { permanent = false, url = "https://mi.yumechi.jp/" }
|
||||
allow_unknown = false
|
||||
max_x_forwarded_for = 0
|
||||
|
||||
# you need AppArmor and the policy loaded to use this
|
||||
[sandbox.apparmor]
|
||||
|
@ -19,7 +20,18 @@ enable_redirects = false
|
|||
normalization = "lazy"
|
||||
allow_svg_passthrough = false
|
||||
|
||||
[rate_limit]
|
||||
max_x_forwarded_for = 0
|
||||
[[rate_limit]]
|
||||
replenish_every = 50
|
||||
burst = 256
|
||||
|
||||
[[rate_limit]]
|
||||
key = "500ms"
|
||||
min_request_duration = 500
|
||||
replenish_every = 200
|
||||
burst = 64
|
||||
|
||||
[[rate_limit]]
|
||||
key = "1500ms"
|
||||
min_request_duration = 1500
|
||||
replenish_every = 1500
|
||||
burst = 16
|
||||
|
|
|
@ -37,9 +37,13 @@ pub struct Config {
|
|||
/// Post-processing configuration
|
||||
pub post_process: PostProcessConfig,
|
||||
|
||||
#[cfg(not(feature = "cf-worker"))]
|
||||
/// The maximum number of X-Forwarded-For headers to allow
|
||||
pub max_x_forwarded_for: u8,
|
||||
|
||||
#[cfg(feature = "governor")]
|
||||
/// Governor configuration
|
||||
pub rate_limit: RateLimitConfig,
|
||||
pub rate_limit: Vec<RateLimitConfig>,
|
||||
}
|
||||
|
||||
/// Sandbox configuration
|
||||
|
@ -75,12 +79,15 @@ pub struct AppArmorConfig {
|
|||
#[cfg(feature = "governor")]
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct RateLimitConfig {
|
||||
/// The maximum number of X-Forwarded-For headers to allow
|
||||
pub max_x_forwarded_for: u8,
|
||||
/// The key to use for rate limiting headers
|
||||
pub key: Option<String>,
|
||||
/// The rate limit replenish interval in milliseconds
|
||||
pub replenish_every: u64,
|
||||
/// The rate limit burst size
|
||||
pub burst: NonZero<u32>,
|
||||
|
||||
/// The minimum request duration in milliseconds for this rate limit to apply
|
||||
pub min_request_duration: Option<u64>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "cf-worker")]
|
||||
|
@ -130,12 +137,15 @@ impl Default for Config {
|
|||
normalization: NormalizationPolicy::Opportunistic,
|
||||
allow_svg_passthrough: false,
|
||||
},
|
||||
#[cfg(feature = "governor")]
|
||||
rate_limit: RateLimitConfig {
|
||||
#[cfg(not(feature = "cf-worker"))]
|
||||
max_x_forwarded_for: 0,
|
||||
#[cfg(feature = "governor")]
|
||||
rate_limit: vec![RateLimitConfig {
|
||||
key: None,
|
||||
replenish_every: 2000,
|
||||
burst: NonZero::new(32).unwrap(),
|
||||
},
|
||||
min_request_duration: None,
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
210
src/lib.rs
210
src/lib.rs
|
@ -6,7 +6,12 @@
|
|||
|
||||
#[cfg(feature = "governor")]
|
||||
use std::net::SocketAddr;
|
||||
use std::{borrow::Cow, fmt::Display, marker::PhantomData, sync::Arc};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
fmt::Display,
|
||||
marker::PhantomData,
|
||||
sync::{atomic::AtomicU64, Arc, RwLock},
|
||||
};
|
||||
|
||||
#[cfg(feature = "governor")]
|
||||
use axum::extract::ConnectInfo;
|
||||
|
@ -24,6 +29,7 @@ use governor::{
|
|||
clock::SystemClock, middleware::StateInformationMiddleware, state::keyed::DashMapStateStore,
|
||||
RateLimiter,
|
||||
};
|
||||
use lru::LruCache;
|
||||
use post_process::{CompressionLevel, MediaResponse};
|
||||
use sandbox::Sandboxing;
|
||||
|
||||
|
@ -110,18 +116,29 @@ where
|
|||
};
|
||||
#[cfg(feature = "governor")]
|
||||
use std::time::Duration;
|
||||
use std::{num::NonZero, sync::RwLock};
|
||||
#[cfg(not(feature = "cf-worker"))]
|
||||
use tower_http::{catch_panic::CatchPanicLayer, timeout::TimeoutLayer};
|
||||
|
||||
let state = AppState {
|
||||
#[cfg(feature = "governor")]
|
||||
limiter: RateLimiter::dashmap_with_clock(
|
||||
Quota::with_period(Duration::from_millis(config.rate_limit.replenish_every))
|
||||
limiters: config
|
||||
.rate_limit
|
||||
.iter()
|
||||
.map(|x| {
|
||||
(
|
||||
RateLimiter::dashmap_with_clock(
|
||||
Quota::with_period(Duration::from_millis(x.replenish_every))
|
||||
.unwrap()
|
||||
.allow_burst(config.rate_limit.burst),
|
||||
.allow_burst(x.burst),
|
||||
SystemClock,
|
||||
)
|
||||
.with_middleware::<StateInformationMiddleware>(),
|
||||
RwLock::new(LruCache::new(NonZero::new(1024).unwrap())),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.into_boxed_slice(),
|
||||
client: Upstream::new(&config.fetch),
|
||||
sandbox: S::new(&config.sandbox),
|
||||
config,
|
||||
|
@ -178,6 +195,17 @@ where
|
|||
.layer(TimeoutLayer::new(Duration::from_secs(10)));
|
||||
}
|
||||
|
||||
#[cfg(feature = "governor")]
|
||||
{
|
||||
let state_gc = Arc::clone(&state);
|
||||
std::thread::spawn(move || loop {
|
||||
std::thread::sleep(Duration::from_secs(300));
|
||||
for (limiter, _) in state_gc.limiters.iter() {
|
||||
limiter.retain_recent();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(feature = "governor")]
|
||||
return router.route_layer(middleware::from_fn_with_state(state, rate_limit_middleware));
|
||||
#[cfg(not(feature = "governor"))]
|
||||
|
@ -232,6 +260,24 @@ pub async fn common_security_headers(
|
|||
resp
|
||||
}
|
||||
|
||||
fn atomic_u64_saturating_dec(credits: &AtomicU64) -> bool {
|
||||
loop {
|
||||
let current = credits.load(std::sync::atomic::Ordering::Relaxed);
|
||||
if current == 0 {
|
||||
return false;
|
||||
}
|
||||
if credits.compare_exchange_weak(
|
||||
current,
|
||||
current - 1,
|
||||
std::sync::atomic::Ordering::Relaxed,
|
||||
std::sync::atomic::Ordering::Relaxed,
|
||||
) == Ok(current)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Middleware for rate limiting
|
||||
#[cfg(feature = "governor")]
|
||||
#[cfg_attr(feature = "cf-worker", worker::send)]
|
||||
|
@ -246,7 +292,9 @@ pub async fn rate_limit_middleware<S: Sandboxing + Send + Sync + 'static>(
|
|||
time::SystemTime,
|
||||
};
|
||||
|
||||
let forwarded_ip = if state.config.rate_limit.max_x_forwarded_for > 0 {
|
||||
use axum::http::HeaderName;
|
||||
|
||||
let forwarded_ip = if state.config.max_x_forwarded_for > 0 {
|
||||
std::iter::repeat(addr.ip())
|
||||
.chain(
|
||||
request
|
||||
|
@ -257,7 +305,7 @@ pub async fn rate_limit_middleware<S: Sandboxing + Send + Sync + 'static>(
|
|||
.split(',')
|
||||
.filter_map(|x| x.trim().parse().ok()),
|
||||
)
|
||||
.nth_back(state.config.rate_limit.max_x_forwarded_for as usize - 1)
|
||||
.nth_back(state.config.max_x_forwarded_for as usize - 1)
|
||||
.map(|addr| match addr {
|
||||
IpAddr::V6(addr) => IpAddr::V6(addr & Ipv6Addr::from_bits(!0u128 >> 64)),
|
||||
addr => addr,
|
||||
|
@ -266,30 +314,90 @@ pub async fn rate_limit_middleware<S: Sandboxing + Send + Sync + 'static>(
|
|||
None
|
||||
};
|
||||
|
||||
match state.limiter.check_key(&forwarded_ip.unwrap_or(addr.ip())) {
|
||||
Ok(ok) => {
|
||||
let real_ip = forwarded_ip.unwrap_or_else(|| match addr.ip() {
|
||||
IpAddr::V6(addr) => IpAddr::V6(addr & Ipv6Addr::from_bits(!0u128 >> 64)),
|
||||
addr => addr,
|
||||
});
|
||||
|
||||
let mut res = Vec::with_capacity(state.limiters.len());
|
||||
match state.limiters.iter().fold(Ok(()), |acc, limiter| {
|
||||
acc.and_then(|_| {
|
||||
// if we have credits, we don't need to check the key
|
||||
let credits = limiter.1.read().unwrap();
|
||||
if let Some(credits_val) = credits.peek(&forwarded_ip.unwrap_or(real_ip)) {
|
||||
if atomic_u64_saturating_dec(credits_val) {
|
||||
res.push(None);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
drop(credits);
|
||||
let mut credits = limiter.1.write().unwrap();
|
||||
credits.pop(&forwarded_ip.unwrap_or(real_ip));
|
||||
}
|
||||
|
||||
limiter
|
||||
.0
|
||||
.check_key(&forwarded_ip.unwrap_or(real_ip))
|
||||
.map(|x| res.push(Some(x)))
|
||||
})
|
||||
}) {
|
||||
Ok(_) => {
|
||||
let begin = SystemTime::now();
|
||||
let mut resp = next.run(request).await;
|
||||
let elapsed = begin.elapsed().unwrap().as_millis() as u64;
|
||||
|
||||
// credit back limits that didn't get used
|
||||
for (config, (_, credits)) in state.config.rate_limit.iter().zip(state.limiters.iter())
|
||||
{
|
||||
if elapsed < config.min_request_duration.unwrap_or(0) {
|
||||
let mut credits = credits.write().unwrap();
|
||||
credits
|
||||
.get_or_insert(forwarded_ip.unwrap_or(real_ip), || 0.into())
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
let headers = resp.headers_mut();
|
||||
headers.insert(
|
||||
"X-RateLimit-Limit",
|
||||
#[allow(clippy::unwrap_used)]
|
||||
ok.quota().burst_size().to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"X-RateLimit-Replenish-Interval",
|
||||
|
||||
for (config, snapshot) in
|
||||
state
|
||||
.config
|
||||
.rate_limit
|
||||
.replenish_every
|
||||
.iter()
|
||||
.zip(res.iter())
|
||||
.filter_map(|(config, snapshot)| match snapshot {
|
||||
Some(snapshot) => Some((config, snapshot)),
|
||||
None => None,
|
||||
})
|
||||
{
|
||||
let header_prefix = match config.key {
|
||||
Some(ref key) => format!("X-{}-RateLimit-", key),
|
||||
None => "X-RateLimit-".to_string(),
|
||||
};
|
||||
let (header_limit, header_interval, header_remaining) = (
|
||||
format!("{}Limit", header_prefix),
|
||||
format!("{}Replenish-Interval", header_prefix),
|
||||
format!("{}Remaining", header_prefix),
|
||||
);
|
||||
headers.insert(
|
||||
HeaderName::from_bytes(header_limit.as_bytes()).unwrap(),
|
||||
#[allow(clippy::unwrap_used)]
|
||||
snapshot.quota().burst_size().to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
HeaderName::from_bytes(header_interval.as_bytes()).unwrap(),
|
||||
config.replenish_every.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert("X-RateLimit-Remaining", "0".parse().unwrap());
|
||||
headers.insert(
|
||||
HeaderName::from_bytes(header_remaining.as_bytes()).unwrap(),
|
||||
snapshot
|
||||
.remaining_burst_capacity()
|
||||
.to_string()
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"X-RateLimit-Remaining",
|
||||
ok.remaining_burst_capacity().to_string().parse().unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
resp
|
||||
}
|
||||
|
@ -298,29 +406,55 @@ pub async fn rate_limit_middleware<S: Sandboxing + Send + Sync + 'static>(
|
|||
let mut resp = ErrorResponse::rate_limit_exceeded().into_response();
|
||||
|
||||
let headers = resp.headers_mut();
|
||||
headers.insert(
|
||||
"X-RateLimit-Limit",
|
||||
#[allow(clippy::unwrap_used)]
|
||||
err.quota().burst_size().to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"X-RateLimit-Replenish-Interval",
|
||||
|
||||
for (config, snapshot) in
|
||||
state
|
||||
.config
|
||||
.rate_limit
|
||||
.replenish_every
|
||||
.iter()
|
||||
.zip(res.iter())
|
||||
.filter_map(|(config, snapshot)| match snapshot {
|
||||
Some(snapshot) => Some((config, snapshot)),
|
||||
None => None,
|
||||
})
|
||||
{
|
||||
let header_prefix = match config.key {
|
||||
Some(ref key) => format!("X-{}-RateLimit-", key),
|
||||
None => "X-RateLimit-".to_string(),
|
||||
};
|
||||
let (header_limit, header_interval, header_remaining) = (
|
||||
format!("{}Limit", header_prefix),
|
||||
format!("{}Replenish-Interval", header_prefix),
|
||||
format!("{}Remaining", header_prefix),
|
||||
);
|
||||
headers.insert(
|
||||
HeaderName::from_bytes(header_limit.as_bytes()).unwrap(),
|
||||
#[allow(clippy::unwrap_used)]
|
||||
snapshot.quota().burst_size().to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
HeaderName::from_bytes(header_interval.as_bytes()).unwrap(),
|
||||
config.replenish_every.to_string().parse().unwrap(),
|
||||
);
|
||||
|
||||
headers.insert("X-RateLimit-Remaining", "0".parse().unwrap());
|
||||
headers.insert(
|
||||
HeaderName::from_bytes(header_remaining.as_bytes()).unwrap(),
|
||||
snapshot
|
||||
.remaining_burst_capacity()
|
||||
.to_string()
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
headers.insert("X-RateLimit-Remaining", "0".parse().unwrap());
|
||||
}
|
||||
|
||||
headers.insert(
|
||||
"Retry-After",
|
||||
err.wait_time_from(SystemTime::now())
|
||||
err.earliest_possible()
|
||||
.duration_since(SystemTime::now())
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
.to_string()
|
||||
.parse()
|
||||
.unwrap(),
|
||||
.into(),
|
||||
);
|
||||
|
||||
resp
|
||||
|
@ -464,7 +598,8 @@ impl std::error::Error for ErrorResponse {}
|
|||
impl ErrorResponse {
|
||||
#[cfg(not(feature = "cf-worker"))]
|
||||
/// URL must be a DNS name
|
||||
#[must_use] pub const fn non_dns_name() -> Self {
|
||||
#[must_use]
|
||||
pub const fn non_dns_name() -> Self {
|
||||
Self {
|
||||
status: StatusCode::BAD_REQUEST,
|
||||
message: Cow::Borrowed("URL must be a DNS name"),
|
||||
|
@ -657,12 +792,17 @@ impl IntoResponse for ErrorResponse {
|
|||
#[allow(unused)]
|
||||
pub struct AppState<C: UpstreamClient, S: Sandboxing> {
|
||||
#[cfg(feature = "governor")]
|
||||
limiter: RateLimiter<
|
||||
limiters: Box<
|
||||
[(
|
||||
RateLimiter<
|
||||
std::net::IpAddr,
|
||||
DashMapStateStore<std::net::IpAddr>,
|
||||
SystemClock,
|
||||
StateInformationMiddleware,
|
||||
>,
|
||||
RwLock<LruCache<std::net::IpAddr, AtomicU64>>,
|
||||
)],
|
||||
>,
|
||||
config: Config,
|
||||
client: C,
|
||||
sandbox: S,
|
||||
|
|
Loading…
Reference in a new issue