diff --git a/Cargo.lock b/Cargo.lock index 1aab2b1..9a029fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -213,9 +213,9 @@ dependencies = [ [[package]] name = "axum" -version = "0.7.7" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core", @@ -1267,9 +1267,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.162" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libfuzzer-sys" @@ -1958,9 +1958,9 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustix" -version = "0.38.40" +version = "0.38.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" +checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" dependencies = [ "bitflags 2.6.0", "errno", @@ -1971,9 +1971,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.16" +version = "0.23.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" +checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e" dependencies = [ "once_cell", "rustls-pki-types", @@ -2120,9 +2120,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.132" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "itoa", "memchr", @@ -2612,6 +2612,25 @@ dependencies = [ "tower-service", ] +[[package]] +name = "tower-http" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697" +dependencies = [ + "bitflags 2.6.0", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "pin-project-lite", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -3198,6 +3217,7 @@ dependencies = [ "thiserror 2.0.3", "tokio", "toml", + "tower-http", "tower-service", "url", "wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index 51f706e..9f6ed34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ env-local = ["axum/http1", "axum/http2", "clap", "toml", "image/ico", "lossy-webp", + "tower-http", "svg-text", "resvg/system-fonts", "resvg/raster-images", "fontdb/fontconfig" ] reuse-port = [] @@ -45,11 +46,12 @@ panic-console-error = ["dep:console_error_panic_hook"] apparmor = ["dep:siphasher", "dep:libc"] reqwest = ["dep:reqwest", "dep:url"] svg-text = ["resvg/text", "dep:fontdb"] -tokio = ["dep:tokio", "axum/tokio"] +tokio = ["dep:tokio", "axum/tokio", "dep:libc"] env_logger = ["dep:env_logger"] governor = ["dep:governor"] -axum-server = ["dep:axum-server"] +axum-server = ["dep:axum-server", "tower-http"] lossy-webp = ["dep:webp"] +tower-http = ["dep:tower-http"] [dependencies] worker = { version="0.4.2", features=['http', 'axum'], optional = true } @@ -78,6 +80,7 @@ axum-server = { version = "0.7.1", optional = true } fontdb = { version = "0.23", optional = true } webp = { version = "0.3.0", optional = true } url = { version = "2", optional = true } +tower-http = { version = "0.6.2", features = ["catch-panic", "timeout"], optional = true } [patch.crates-io] # licensing and webp dependencies diff --git a/README.md b/README.md index ca2c22f..126b4a4 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ Image: 6. Have a working JS environment. - 7. Install `wrangler` with you JS package manager of choice. See https://developers.cloudflare.com/workers/wrangler/install-and-update/. `npx` also works. + 7. Install `wrangler` with you JS package manager of choice. See /. `npx` also works. 8. Edit `wrangler.toml` to your liking. Everything in the `[vars]` section maps directly into the `config` section of the TOML configuration file. There is a `cf-worker-paid` feature set which enable some additional features that will never fit in the free plan, mainly SVG font rendering and some debugging features. diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 3bb624d..09cf871 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -136,7 +136,7 @@ pub trait UpstreamClient { /// Reqwest client #[cfg(feature = "reqwest")] pub mod reqwest { - use crate::AddrFamilyConfig; + use crate::config::AddrFamilyConfig; use super::{ http_version_to_via, Cow, ErrorResponse, HTTPResponse, HeaderMap, Pin, RequestCtx, @@ -177,7 +177,7 @@ pub mod reqwest { std::net::SocketAddr::V6(a) } } - o => o, + std::net::SocketAddr::V4(a) => std::net::SocketAddr::V4(a), }) .filter(move |addr| match addr { std::net::SocketAddr::V4(a) if af != AddrFamilyConfig::V6Only => { diff --git a/src/lib.rs b/src/lib.rs index ce1833c..c8740f4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,7 @@ #[cfg(feature = "governor")] use std::net::SocketAddr; -use std::{borrow::Cow, fmt::Display, marker::PhantomData, net::IpAddr, sync::Arc}; +use std::{borrow::Cow, fmt::Display, marker::PhantomData, sync::Arc}; #[cfg(feature = "governor")] use axum::extract::ConnectInfo; @@ -31,7 +31,9 @@ use serde::Deserialize; #[cfg(feature = "cf-worker")] use worker::{event, Context, Env, HttpRequest, Result as WorkerResult}; -use config::*; +use config::{ + Config, FetchConfig, IndexConfig, NormalizationPolicy, PostProcessConfig, SandboxConfig, +}; /// Module for fetching media from upstream pub mod fetch; @@ -108,6 +110,8 @@ where }; #[cfg(feature = "governor")] use std::time::Duration; + #[cfg(not(feature = "cf-worker"))] + use tower_http::{catch_panic::CatchPanicLayer, timeout::TimeoutLayer}; let state = AppState { #[cfg(feature = "governor")] @@ -125,7 +129,8 @@ where let state = Arc::new(state); - let router = Router::new() + #[allow(unused_mut)] + let mut router = Router::new() .route("/", get(App::::index)) .route( "/proxy", @@ -163,6 +168,16 @@ where .layer(middleware::from_fn(common_security_headers)) .with_state(Arc::clone(&state)); + #[cfg(not(feature = "cf-worker"))] + { + router = router + .layer(CatchPanicLayer::custom(|err| { + log::error!("Panic in request: {:?}", err); + ErrorResponse::postprocess_failed("Internal server error".into()).into_response() + })) + .layer(TimeoutLayer::new(Duration::from_secs(10))); + } + #[cfg(feature = "governor")] return router.route_layer(middleware::from_fn_with_state(state, rate_limit_middleware)); #[cfg(not(feature = "governor"))] @@ -240,8 +255,7 @@ pub async fn rate_limit_middleware( .and_then(|x| x.to_str().ok()) .unwrap_or("") .split(',') - .map(|x| x.trim().parse().ok()) - .flatten(), + .filter_map(|x| x.trim().parse().ok()), ) .nth_back(state.config.rate_limit.max_x_forwarded_for as usize - 1) .map(|addr| match addr { @@ -450,7 +464,7 @@ impl std::error::Error for ErrorResponse {} impl ErrorResponse { #[cfg(not(feature = "cf-worker"))] /// URL must be a DNS name - pub const fn non_dns_name() -> Self { + #[must_use] pub const fn non_dns_name() -> Self { Self { status: StatusCode::BAD_REQUEST, message: Cow::Borrowed("URL must be a DNS name"), @@ -475,7 +489,7 @@ impl ErrorResponse { }, CfConfigError::JsonParse(e) => Self { status: StatusCode::INTERNAL_SERVER_ERROR, - message: format!("Failed to parse config object: {}", e).into(), + message: format!("Failed to parse config object: {e}").into(), }, } } @@ -643,8 +657,12 @@ impl IntoResponse for ErrorResponse { #[allow(unused)] pub struct AppState { #[cfg(feature = "governor")] - limiter: - RateLimiter, SystemClock, StateInformationMiddleware>, + limiter: RateLimiter< + std::net::IpAddr, + DashMapStateStore, + SystemClock, + StateInformationMiddleware, + >, config: Config, client: C, sandbox: S, diff --git a/src/post_process/mod.rs b/src/post_process/mod.rs index 5314b53..7b5771d 100644 --- a/src/post_process/mod.rs +++ b/src/post_process/mod.rs @@ -59,14 +59,34 @@ where #[cfg(feature = "tokio")] macro_rules! sandboxed { ($sandbox:expr => $($tt:tt)*) => { - tokio::task::block_in_place(|| { + { + let (timeout_tx, timeout_rx) = tokio::sync::oneshot::channel(); + let (done_tx, done_rx) = tokio::sync::oneshot::channel(); + + tokio::spawn(async move { + tokio::select! { + _ = done_rx => {}, + _ = tokio::time::sleep(std::time::Duration::from_secs(15)) => { + timeout_tx.send(()).ok(); + } + } + }); + let mut key = [0u8; 8]; getrandom::getrandom(&mut key).map_err(|_| ErrorResponse::entropy_exhausted())?; - let guard = $sandbox.setup(&key); - let ret = $($tt)*; - drop(guard); + + let ret = crate::sandbox::tokio_block( + $sandbox, + &key, + timeout_rx, + || { + $($tt)* + } + ); + + done_tx.send(()).ok(); ret - }) + } }; } @@ -96,10 +116,10 @@ where options: ImageOptions, sandbox: &S, ) -> Result>>, ErrorResponse> { - let begin = crate::timing::Instant::now(); const TIME_TO_FIRST_BYTE_KEY: &str = "fetch-first-byte"; const TIMING_KEY: &str = "post-process"; const SLURP_TIMING_KEY: &str = "slurp-data"; + let begin = crate::timing::Instant::now(); let ttfb = response.request().time_to_body; diff --git a/src/sandbox.rs b/src/sandbox.rs index 2843247..25e14e2 100644 --- a/src/sandbox.rs +++ b/src/sandbox.rs @@ -1,5 +1,33 @@ use crate::SandboxConfig; +#[cfg(all(target_family = "unix", feature = "tokio"))] +#[allow(unsafe_code)] +/// A bloking tokio task with a cancellation +pub fn tokio_block R + Send, R: Send>( + sandbox: &impl Sandboxing, + key: &[u8], + signal: tokio::sync::oneshot::Receiver<()>, + f: F, +) -> R { + let tid = unsafe { libc::pthread_self() }; + + tokio::spawn(async move { + if signal.await.is_ok() { + unsafe { + libc::pthread_cancel(tid); + } + } + }); + + tokio::task::block_in_place(move || { + let guard = sandbox.setup(key); + let res = f(); + drop(guard); + + res + }) +} + /// A trait for setting up a thread sandboxing environment pub trait Sandboxing { /// The type of the guard that is returned by the setup function diff --git a/src/timing.rs b/src/timing.rs index e5f247e..38b02c2 100644 --- a/src/timing.rs +++ b/src/timing.rs @@ -10,13 +10,13 @@ pub struct Instant(std::time::Instant); impl Instant { /// Create a new `Instant` from the current time #[cfg(not(target_arch = "wasm32"))] - pub fn now() -> Self { + #[must_use] pub fn now() -> Self { Self(std::time::Instant::now()) } /// Get the elapsed time since the instant was created #[cfg(not(target_arch = "wasm32"))] - pub fn elapsed(&self) -> std::time::Duration { + #[must_use] pub fn elapsed(&self) -> std::time::Duration { self.0.elapsed() } }