diff --git a/.vscode/settings.json b/.vscode/settings.json index a14aea1..74cebdb 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,7 @@ { "rust-analyzer.cargo.features": [ - //"cf-worker" - "env-local" + //"cf-worker", + "env-local", + "url-summary" ], } \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 8d1df1b..381f772 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -213,6 +213,7 @@ checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core", + "axum-macros", "bytes", "futures-util", "http", @@ -258,6 +259,17 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "axum-server" version = "0.7.1" @@ -514,6 +526,29 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "cssparser" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c66d1cd8ed61bf80b38432613a7a2f09401ab8d0501110655f8b341484a3e3" +dependencies = [ + "cssparser-macros", + "dtoa-short", + "itoa", + "phf", + "smallvec", +] + +[[package]] +name = "cssparser-macros" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13b588ba4ac1a99f7f2964d24b3d896ddc6bf847ee3855dbd4366f058cfcd331" +dependencies = [ + "quote", + "syn", +] + [[package]] name = "dashmap" version = "6.1.0" @@ -528,6 +563,17 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "derive_more" +version = "0.99.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -539,6 +585,27 @@ dependencies = [ "syn", ] +[[package]] +name = "dtoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcbb2bf8e87535c23f7a8a321e364ce21462d0ff10cb6407820e8e96dfff6653" + +[[package]] +name = "dtoa-short" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd1511a7b6a56299bd043a9c167a6d2bfb37bf84a6dfceaba651168adfb43c87" +dependencies = [ + "dtoa", +] + +[[package]] +name = "ego-tree" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2972feb8dffe7bc8c5463b1dacda1b0dfbed3710e50f977d965429692d74cd8" + [[package]] name = "either" version = "1.13.0" @@ -654,6 +721,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df420e2e84819663797d1ec6544b13c5be84629e7bb00dc960d6917db2987843" +dependencies = [ + "mac", + "new_debug_unreachable", +] + [[package]] name = "futures" version = "0.3.31" @@ -737,6 +814,24 @@ dependencies = [ "slab", ] +[[package]] +name = "fxhash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" +dependencies = [ + "byteorder", +] + +[[package]] +name = "getopts" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5" +dependencies = [ + "unicode-width", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -839,6 +934,20 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "html5ever" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e15626aaf9c351bc696217cbe29cb9b5e86c43f8a46b5e2f5c6c5cf7cb904ce" +dependencies = [ + "log", + "mac", + "markup5ever", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "http" version = "1.2.0" @@ -1283,6 +1392,26 @@ dependencies = [ "hashbrown 0.15.2", ] +[[package]] +name = "mac" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" + +[[package]] +name = "markup5ever" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82c88c6129bd24319e62a0359cb6b958fa7e8be6e19bb1663bc396b90883aca5" +dependencies = [ + "log", + "phf", + "phf_codegen", + "string_cache", + "string_cache_codegen", + "tendril", +] + [[package]] name = "matchit" version = "0.7.3" @@ -1532,6 +1661,77 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_macros", + "phf_shared 0.11.2", +] + +[[package]] +name = "phf_codegen" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +dependencies = [ + "phf_generator 0.11.2", + "phf_shared 0.11.2", +] + +[[package]] +name = "phf_generator" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5285893bb5eb82e6aaf5d59ee909a06a16737a8970984dd7746ba9283498d6" +dependencies = [ + "phf_shared 0.10.0", + "rand", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared 0.11.2", + "rand", +] + +[[package]] +name = "phf_macros" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b" +dependencies = [ + "phf_generator 0.11.2", + "phf_shared 0.11.2", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "phf_shared" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" +dependencies = [ + "siphasher 0.3.11", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher 0.3.11", +] + [[package]] name = "pin-project" version = "1.1.7" @@ -1598,6 +1798,12 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "precomputed-hash" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" + [[package]] name = "proc-macro2" version = "1.0.92" @@ -1964,6 +2170,21 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scraper" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc3d051b884f40e309de6c149734eab57aa8cc1347992710dc80bcc1c2194c15" +dependencies = [ + "cssparser", + "ego-tree", + "getopts", + "html5ever", + "precomputed-hash", + "selectors", + "tendril", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -1987,6 +2208,25 @@ dependencies = [ "libc", ] +[[package]] +name = "selectors" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd568a4c9bb598e291a08244a5c1f5a8a6650bee243b5b0f8dbb3d9cc1d87fe8" +dependencies = [ + "bitflags 2.6.0", + "cssparser", + "derive_more", + "fxhash", + "log", + "new_debug_unreachable", + "phf", + "phf_codegen", + "precomputed-hash", + "servo_arc", + "smallvec", +] + [[package]] name = "serde" version = "1.0.216" @@ -2072,6 +2312,15 @@ dependencies = [ "serde", ] +[[package]] +name = "servo_arc" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae65c4249478a2647db249fb43e23cec56a2c8974a427e7bd8cb5a1d0964921a" +dependencies = [ + "stable_deref_trait", +] + [[package]] name = "shlex" version = "1.3.0" @@ -2102,6 +2351,12 @@ dependencies = [ "quote", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "siphasher" version = "1.0.1" @@ -2167,6 +2422,32 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "string_cache" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" +dependencies = [ + "new_debug_unreachable", + "once_cell", + "parking_lot", + "phf_shared 0.10.0", + "precomputed-hash", + "serde", +] + +[[package]] +name = "string_cache_codegen" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bb30289b722be4ff74a408c3cc27edeaad656e06cb1fe8fa9231fa59c728988" +dependencies = [ + "phf_generator 0.10.0", + "phf_shared 0.10.0", + "proc-macro2", + "quote", +] + [[package]] name = "strsim" version = "0.11.1" @@ -2263,6 +2544,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "tendril" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d24a120c5fc464a3458240ee02c299ebcb9d67b5249c8848b09d639dca8d7bb0" +dependencies = [ + "futf", + "mac", + "utf-8", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -2501,6 +2793,12 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + [[package]] name = "untrusted" version = "0.9.0" @@ -2516,8 +2814,15 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf16_iter" version = "1.0.5" @@ -2965,20 +3270,24 @@ dependencies = [ "clap", "console_error_panic_hook", "dashmap", + "ego-tree", "env_logger", "futures", "getrandom", "governor", + "html5ever", "image", + "lazy_static", "libc", "log", "lru", "prometheus", "quote", "reqwest", + "scraper", "serde", "serde_json", - "siphasher", + "siphasher 1.0.1", "thiserror 2.0.8", "tokio", "toml", diff --git a/Cargo.toml b/Cargo.toml index 984de3d..1344027 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,9 @@ env-local = ["axum/http1", "axum/http2", "lossy-webp", "tower-http", "metrics", + "lazy_static", ] +url-summary = ["dep:html5ever", "dep:scraper", "dep:url", "url/serde", "dep:ego-tree"] cf-worker = ["dep:worker", "dep:worker-macros", "dep:wasm-bindgen", "image/ico", "panic-console-error"] # Observability and tracing features @@ -59,11 +61,13 @@ reqwest = ["dep:reqwest", "dep:url"] # Sandbox features apparmor = ["dep:siphasher", "dep:libc"] +lazy_static = ["dep:lazy_static"] [dependencies] worker = { version="0.4.2", features=['http', 'axum'], optional = true } worker-macros = { version="0.4.2", features=['http'], optional = true } -axum = { version = "0.7", default-features = false, features = ["query", "json"] } +axum = { version = "0.7", default-features = false, features = ["query", "json", "macros"] } +ego-tree = { version = "0.10", optional = true } tower-service = "0.3" console_error_panic_hook = { version = "0.1.1", optional = true } serde = { version = "1", features = ["derive"] } @@ -90,6 +94,9 @@ dashmap = "6.1.0" lru = "0.12.5" prometheus = { version = "0.13.4", optional = true } xml = "0.8.20" +html5ever = { version = "0.29", optional = true } +scraper = { version = "0.22.0", optional = true } +lazy_static = { version = "1.5.0", optional = true } [build-dependencies] chumsky = "0.9.3" diff --git a/README.md b/README.md index 98ece62..225d541 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Currently to do: - [ ] Handle all possible panics reported by Clippy - [X] Sandboxing the image rendering - [X] Prometheus-format metrics +- [X] Experimental URL summarization replacement (local only for now) ## Spec Compliance diff --git a/build.rs b/build.rs index acd560e..c12fc3d 100644 --- a/build.rs +++ b/build.rs @@ -77,6 +77,15 @@ fn static_signatures() -> Vec { mask: vec![0xff, 0xff, 0xff, 0xff], }], }, + MIMEAssociation { + mime: "image/vnd.microsoft.icon".to_string().into(), + ext: vec![".ico".to_string()], + safe: true, + signatures: vec![FlattenedFileSignature { + test: vec![0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00], + mask: vec![0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff], + }], + }, ] } diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 08020ac..b27e42d 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -30,7 +30,7 @@ const fn http_version_to_via(v: axum::http::Version) -> &'static str { } /// Trait for HTTP responses -pub trait HTTPResponse { +pub trait HTTPResponse: Send + 'static { /// Type of the byte buffer type Bytes: Into> + AsRef<[u8]> + Into + Send + 'static; /// Type of body stream @@ -51,10 +51,12 @@ pub trait HTTPResponse { } /// Information about the incoming request +#[derive(Debug, Clone)] pub struct IncomingInfo { version: axum::http::Version, user_agent: String, via: String, + accept_language: Option, } impl IncomingInfo { @@ -123,12 +125,40 @@ impl FromRequestParts for IncomingInfo { acc.push_str(v.to_str().unwrap_or_default()); acc }), + accept_language: parts + .headers + .get("accept-language") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.split(',').next()) + .map(|s| s.chars().filter(|c| !c.is_control()).collect()), }) } } +#[derive(Debug, Clone, PartialEq, Copy)] +/// Expected response type +pub enum ExpectType { + /// Multimedia content, such as images, videos, and audio + Media, + /// Markup content, such as HTML and XML + Markup, + /// Json content + Json, +} + +impl ExpectType { + /// Get the string representation + pub fn as_str(&self) -> &'static str { + match self { + Self::Media => "media", + Self::Markup => "markup", + Self::Json => "json", + } + } +} + /// Trait for upstream clients -pub trait UpstreamClient { +pub trait UpstreamClient: Send + Sync + 'static { /// Type of the response type Response: HTTPResponse; /// Create a new client @@ -137,11 +167,12 @@ pub trait UpstreamClient { fn request_upstream( &self, info: &IncomingInfo, + expect: ExpectType, url: &str, polish: bool, secure: bool, remaining: usize, - ) -> impl std::future::Future>; + ) -> impl std::future::Future> + Send; } /// Reqwest client @@ -157,7 +188,7 @@ pub mod reqwest { use axum::body::Bytes; use futures::TryStreamExt; use reqwest::dns::Resolve; - use std::{net::SocketAddrV4, sync::Arc, time::Duration}; + use std::{future::Future, net::SocketAddrV4, sync::Arc, time::Duration}; use url::Host; /// A Safe DNS resolver that only resolves to global addresses unless the requester itself is local. @@ -303,98 +334,129 @@ pub mod reqwest { .expect("Failed to create reqwest client"), } } - async fn request_upstream( + fn request_upstream( &self, info: &super::IncomingInfo, + expect: super::ExpectType, url: &str, polish: bool, mut secure: bool, remaining: usize, - ) -> Result { - if remaining == 0 { - return Err(ErrorResponse::too_many_redirects()); - } - - if info.looping(&self.via_ident) { - return Err(ErrorResponse::loop_detected()); - } - - let url_parsed = Url::parse(url).map_err(|_| ErrorResponse::bad_url())?; - if !url_parsed.host().map_or(false, |h| match h { - Host::Domain(_) => true, - _ => false, - }) { - return Err(ErrorResponse::non_dns_name()); - } - - secure &= url_parsed.scheme().eq_ignore_ascii_case("https"); - if self.https_only && !secure { - return Err(ErrorResponse::insecure_request()); - } - - let begin = crate::timing::Instant::now(); - - let resp = self - .client - .get(url_parsed) - .header( - "via", - format!( - "{}, {} {}", - info.via, - http_version_to_via(info.version), - self.via_ident - ), - ) - .send() - .await?; - - if resp.status().is_redirection() { - if let Some(location) = resp.headers().get("location").and_then(|l| l.to_str().ok()) - { - return Box::pin(self.request_upstream( - info, - location, - polish, - secure, - remaining - 1, - )) - .await; + ) -> impl Future> + Send { + async move { + if remaining == 0 { + return Err(ErrorResponse::too_many_redirects()); } - return Err(ErrorResponse::missing_location()); - } - if !resp.status().is_success() { - return Err(ErrorResponse::unexpected_status( - url, - resp.status().as_u16(), - )); - } - - let content_length = resp.headers().get("content-length"); - if let Some(content_length) = content_length.and_then(|c| c.to_str().ok()) { - if content_length.parse::().unwrap_or(0) > MAX_SIZE { - return Err(ErrorResponse::payload_too_large()); + if info.looping(&self.via_ident) { + return Err(ErrorResponse::loop_detected()); } - } - let content_type = resp.headers().get("content-type"); - if let Some(content_type) = content_type.and_then(|c| c.to_str().ok()) { - if !["image/", "video/", "audio/", "application/octet-stream"] - .iter() - .any(|prefix| { - content_type[..prefix.len().min(content_type.len())] - .eq_ignore_ascii_case(prefix) - }) - { - return Err(ErrorResponse::not_media()); + let url_parsed = Url::parse(url).map_err(|_| ErrorResponse::bad_url())?; + if !url_parsed.host().map_or(false, |h| match h { + Host::Domain(_) => true, + _ => false, + }) { + return Err(ErrorResponse::non_dns_name()); } - } - Ok(ReqwestResponse { - time_to_body: begin.elapsed(), - resp, - }) + secure &= url_parsed.scheme().eq_ignore_ascii_case("https"); + if self.https_only && !secure { + return Err(ErrorResponse::insecure_request()); + } + + let begin = crate::timing::Instant::now(); + + let resp = self + .client + .get(url_parsed) + .header( + "via", + format!( + "{}, {} {}", + info.via, + http_version_to_via(info.version), + self.via_ident + ), + ) + .header( + "accept", + match expect { + super::ExpectType::Media => { + "image/*, video/*, audio/*, application/octet-stream" + } + super::ExpectType::Markup => "text/html, application/xhtml+xml", + super::ExpectType::Json => "application/json", + }, + ) + .header( + "accept-language", + info.accept_language.as_deref().unwrap_or("en"), + ) + .send() + .await?; + + if resp.status().is_redirection() { + if let Some(location) = + resp.headers().get("location").and_then(|l| l.to_str().ok()) + { + return Box::pin(self.request_upstream( + info, + expect, + location, + polish, + secure, + remaining - 1, + )) + .await; + } + return Err(ErrorResponse::missing_location()); + } + + if !resp.status().is_success() { + return Err(ErrorResponse::unexpected_status( + url, + resp.status().as_u16(), + )); + } + + let content_length = resp.headers().get("content-length"); + if let Some(content_length) = content_length.and_then(|c| c.to_str().ok()) { + if content_length.parse::().unwrap_or(0) > MAX_SIZE { + return Err(ErrorResponse::payload_too_large()); + } + } + + let content_type = resp.headers().get("content-type"); + if let Some(content_type) = content_type.and_then(|c| c.to_str().ok()) { + if !match expect { + super::ExpectType::Media => { + ["image/", "video/", "audio/", "application/octet-stream"] + .iter() + .any(|prefix| { + content_type[..prefix.len().min(content_type.len())] + .eq_ignore_ascii_case(prefix) + }) + } + super::ExpectType::Markup => { + ["text/html", "application/xhtml+xml"].iter().any(|prefix| { + content_type[..prefix.len().min(content_type.len())] + .eq_ignore_ascii_case(prefix) + }) + } + super::ExpectType::Json => content_type + [.."application/json".len().min(content_type.len())] + .eq_ignore_ascii_case("application/json"), + } { + return Err(ErrorResponse::not_media(expect)); + } + } + + Ok(ReqwestResponse { + time_to_body: begin.elapsed(), + resp, + }) + } } } } @@ -406,7 +468,7 @@ pub mod reqwest { deprecated = "You should use reqwest instead when not on Cloudflare Workers" )] pub mod cf_worker { - use std::time::Duration; + use std::{future::Future, time::Duration}; use super::{ http_version_to_via, Cow, ErrorResponse, HTTPResponse, HeaderMap, Pin, RequestCtx, @@ -463,6 +525,22 @@ pub mod cf_worker { #[allow(unsafe_code, reason = "this is never used concurrently")] unsafe impl Sync for CfBodyStreamWrapper {} + struct SendFuture(Pin>); + + #[allow(unsafe_code, reason = "this is never used concurrently")] + unsafe impl Send for SendFuture {} + + impl Future for SendFuture { + type Output = T::Output; + + fn poll( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + Future::poll(self.get_mut().0.as_mut(), cx) + } + } + /// Response from Cloudflare Workers pub struct CfWorkerResponse { time_to_body: std::time::Duration, @@ -470,6 +548,9 @@ pub mod cf_worker { url: Url, } + #[allow(unsafe_code, reason = "this is never used concurrently")] + unsafe impl Send for CfWorkerResponse {} + impl HTTPResponse for CfWorkerResponse { type Bytes = Vec; type BodyStream = CfBodyStreamWrapper; @@ -539,115 +620,150 @@ pub mod cf_worker { } } - async fn request_upstream( + fn request_upstream( &self, info: &super::IncomingInfo, + expect: super::ExpectType, url: &str, polish: bool, mut secure: bool, remaining: usize, - ) -> Result { - if remaining == 0 { - return Err(ErrorResponse::too_many_redirects()); - } + ) -> impl Future> + Send { + let url_parsed = Url::parse(url).map_err(|_| ErrorResponse::bad_url()); + let info = info.clone(); + SendFuture(Box::pin(async move { + let url_parsed = url_parsed?; - if info.looping(&self.via_ident) { - return Err(ErrorResponse::loop_detected()); - } - - let mut headers = Headers::new(); - - headers.set("user-agent", &self.user_agent)?; - headers.set( - "via", - &format!( - "{}, {} {}", - info.via, - http_version_to_via(info.version), - self.via_ident - ), - )?; - - let mut prop = CfProperties::new(); - if polish { - prop.polish = Some(PolishConfig::Lossless); - } - let mut init = RequestInit::new(); - let init = init - .with_method(Method::Get) - .with_headers(headers) - .with_cf_properties(prop) - .with_redirect(RequestRedirect::Manual); - - let url_parsed = Url::parse(url).map_err(|_| ErrorResponse::bad_url())?; - - if self.https_only && !url_parsed.scheme().eq_ignore_ascii_case("https") { - return Err(ErrorResponse::insecure_request()); - } - - secure &= url_parsed.scheme().eq_ignore_ascii_case("http"); - - let begin = crate::timing::Instant::now(); - - let req = Request::new_with_init(url, init)?; - - let abc = AbortController::default(); - let abs = abc.signal(); - let req = Fetch::Request(req); - - worker::wasm_bindgen_futures::spawn_local(async move { - worker::Delay::from(Duration::from_secs(5)).await; - abc.abort(); - }); - - let resp = std::pin::pin!(req - .send_with_signal(&abs) - .map_err(ErrorResponse::worker_fetch_error)) - .await?; - - if resp.status_code() >= 300 && resp.status_code() < 400 { - if let Ok(Some(location)) = resp.headers().get("location") { - return Box::pin(self.request_upstream( - info, - &location, - polish, - secure, - remaining - 1, - )) - .await; + if remaining == 0 { + return Err(ErrorResponse::too_many_redirects()); } - return Err(ErrorResponse::missing_location()); - } - if resp.status_code() < 200 || resp.status_code() >= 300 { - return Err(ErrorResponse::unexpected_status(url, resp.status_code())); - } - - let content_length = resp.headers().get("content-length").unwrap_or_default(); - if let Some(content_length) = content_length { - if content_length.parse::().unwrap_or(0) > MAX_SIZE { - return Err(ErrorResponse::payload_too_large()); + if info.looping(&self.via_ident) { + return Err(ErrorResponse::loop_detected()); } - } - let content_type = resp.headers().get("content-type").unwrap_or_default(); - if let Some(content_type) = content_type { - if !["image/", "video/", "audio/", "application/octet-stream"] - .iter() - .any(|prefix| { - content_type.as_str()[..prefix.len().min(content_type.len())] - .eq_ignore_ascii_case(prefix) - }) - { - return Err(ErrorResponse::not_media()); + let mut headers = Headers::new(); + + headers.set("user-agent", &self.user_agent)?; + headers.set( + "via", + &format!( + "{}, {} {}", + info.via, + http_version_to_via(info.version), + self.via_ident + ), + )?; + headers.set( + "accept", + match expect { + super::ExpectType::Media => { + "image/*, video/*, audio/*, application/octet-stream" + } + super::ExpectType::Markup => "text/html, application/xhtml+xml", + super::ExpectType::Json => "application/json", + }, + )?; + headers.set( + "accept-language", + info.accept_language.as_deref().unwrap_or("en"), + )?; + + let mut prop = CfProperties::new(); + if polish { + prop.polish = Some(PolishConfig::Lossless); } - } + let mut init = RequestInit::new(); + let init = init + .with_method(Method::Get) + .with_headers(headers) + .with_cf_properties(prop) + .with_redirect(RequestRedirect::Manual); - Ok(CfWorkerResponse { - time_to_body: begin.elapsed(), - resp, - url: url_parsed, - }) + if self.https_only && !url_parsed.scheme().eq_ignore_ascii_case("https") { + return Err(ErrorResponse::insecure_request()); + } + + secure &= url_parsed.scheme().eq_ignore_ascii_case("http"); + + let begin = crate::timing::Instant::now(); + + let req = Request::new_with_init(url_parsed.as_str(), init)?; + + let abc = AbortController::default(); + let abs = abc.signal(); + let req = Fetch::Request(req); + + worker::wasm_bindgen_futures::spawn_local(async move { + worker::Delay::from(Duration::from_secs(5)).await; + abc.abort(); + }); + + let resp = std::pin::pin!(req + .send_with_signal(&abs) + .map_err(ErrorResponse::worker_fetch_error)) + .await?; + + if resp.status_code() >= 300 && resp.status_code() < 400 { + if let Ok(Some(location)) = resp.headers().get("location") { + return Box::pin(self.request_upstream( + &info, + expect, + &location, + polish, + secure, + remaining - 1, + )) + .await; + } + return Err(ErrorResponse::missing_location()); + } + + if resp.status_code() < 200 || resp.status_code() >= 300 { + return Err(ErrorResponse::unexpected_status( + url_parsed.as_str(), + resp.status_code(), + )); + } + + let content_length = resp.headers().get("content-length").unwrap_or_default(); + if let Some(content_length) = content_length { + if content_length.parse::().unwrap_or(0) > MAX_SIZE { + return Err(ErrorResponse::payload_too_large()); + } + } + + let content_type = resp.headers().get("content-type").unwrap_or_default(); + if let Some(content_type) = content_type { + if !match expect { + super::ExpectType::Media => { + ["image/", "video/", "audio/", "application/octet-stream"] + .iter() + .any(|prefix| { + content_type[..prefix.len().min(content_type.len())] + .eq_ignore_ascii_case(prefix) + }) + } + super::ExpectType::Markup => { + ["text/html", "application/xhtml+xml"].iter().any(|prefix| { + content_type[..prefix.len().min(content_type.len())] + .eq_ignore_ascii_case(prefix) + }) + } + super::ExpectType::Json => content_type + [.."application/json".len().min(content_type.len())] + .eq_ignore_ascii_case("application/json"), + } { + return Err(ErrorResponse::not_media(expect)); + } + } + + Ok(CfWorkerResponse { + time_to_body: begin.elapsed(), + resp, + url: url_parsed, + }) + })) } } } diff --git a/src/lib.rs b/src/lib.rs index c59d848..411e8ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ #[cfg(feature = "governor")] use std::net::SocketAddr; +#[cfg_attr(feature = "cf-worker", allow(unused_imports))] use std::{ borrow::Cow, fmt::Display, @@ -23,12 +24,13 @@ use axum::{ routing::get, Json, Router, }; -use fetch::{HTTPResponse, IncomingInfo, UpstreamClient, DEFAULT_MAX_REDIRECTS}; +use fetch::{ExpectType, HTTPResponse, IncomingInfo, UpstreamClient, DEFAULT_MAX_REDIRECTS}; #[cfg(feature = "governor")] use governor::{ clock::SystemClock, middleware::StateInformationMiddleware, state::keyed::DashMapStateStore, RateLimiter, }; +#[cfg(feature = "governor")] use lru::LruCache; use post_process::{CompressionLevel, MediaResponse}; use sandbox::Sandboxing; @@ -58,6 +60,10 @@ pub mod config; /// Cross platform timing utilities pub mod timing; +#[cfg(feature = "url-summary")] +/// URL summarization utilities +pub mod url_summary; + /// Utilities for Cloudflare Workers #[cfg(feature = "cf-worker")] mod cf_utils; @@ -98,31 +104,25 @@ async fn fetch( #[cfg(feature = "panic-console-error")] console_error_panic_hook::set_once(); - Ok(router::(config) + Ok(router::(default_state(config)) .call(req) .await?) } -#[cfg(any(feature = "cf-worker", feature = "reqwest"))] -/// Application Router -pub fn router( +/// Create default Application state +pub fn default_state( config: Config, -) -> Router -where - <::Response as HTTPResponse>::BodyStream: Unpin, -{ - use axum::middleware; +) -> Arc> { #[cfg(feature = "governor")] use governor::{ clock::SystemClock, middleware::StateInformationMiddleware, Quota, RateLimiter, }; #[cfg(feature = "governor")] use std::time::Duration; + #[allow(unused_imports)] use std::{num::NonZero, sync::RwLock}; - #[cfg(not(feature = "cf-worker"))] - use tower_http::{catch_panic::CatchPanicLayer, timeout::TimeoutLayer}; - let state = AppState { + Arc::new(AppState { #[cfg(feature = "governor")] limiters: config .rate_limit @@ -144,21 +144,39 @@ where client: Upstream::new(&config.fetch), sandbox: S::new(&config.sandbox), config, - }; + }) +} - let state = Arc::new(state); +#[cfg(any(feature = "cf-worker", feature = "reqwest"))] +/// Application Router +pub fn router( + state: Arc>, +) -> Router +where + <::Response as HTTPResponse>::BodyStream: Unpin, +{ + use axum::middleware; + #[cfg(not(feature = "cf-worker"))] + use tower_http::{catch_panic::CatchPanicLayer, timeout::TimeoutLayer}; #[allow(unused_mut)] let mut router = Router::new() .route("/", get(App::::index)) + .route( + "/url", + get(App::::url_summary).route_layer(middleware::from_fn_with_state( + state.config.enable_cache, + set_cache_control::, + )), + ) .route( "/proxy", get(App::::proxy_without_filename) .head(App::::proxy_without_filename) .options(App::::proxy_options) .route_layer(middleware::from_fn_with_state( - state.clone(), - set_cache_control, + state.config.enable_cache, + set_cache_control::, )) .fallback(|| async { ErrorResponse::method_not_allowed() }), ) @@ -168,8 +186,8 @@ where .head(App::::proxy_without_filename) .options(App::::proxy_options) .route_layer(middleware::from_fn_with_state( - state.clone(), - set_cache_control, + state.config.enable_cache, + set_cache_control::, )) .fallback(|| async { ErrorResponse::method_not_allowed() }), ) @@ -179,8 +197,8 @@ where .head(App::::proxy_with_filename) .options(App::::proxy_options) .route_layer(middleware::from_fn_with_state( - state.clone(), - set_cache_control, + state.config.enable_cache, + set_cache_control::, )) .fallback(|| async { ErrorResponse::method_not_allowed() }), ) @@ -189,6 +207,8 @@ where #[cfg(not(feature = "cf-worker"))] { + use std::time::Duration; + router = router .layer(CatchPanicLayer::custom(|err| { log::error!("Panic in request: {:?}", err); @@ -199,6 +219,8 @@ where #[cfg(feature = "governor")] { + use std::time::Duration; + let state_gc = Arc::clone(&state); std::thread::spawn(move || loop { std::thread::sleep(Duration::from_secs(300)); @@ -209,20 +231,26 @@ where } #[cfg(feature = "governor")] - return router.route_layer(middleware::from_fn_with_state(state, rate_limit_middleware)); + return router.route_layer(middleware::from_fn_with_state( + state, + rate_limit_middleware::, + )); #[cfg(not(feature = "governor"))] router } /// Set the Cache-Control header #[cfg_attr(feature = "cf-worker", worker::send)] -pub async fn set_cache_control( - State(state): State>>, +pub async fn set_cache_control< + C: UpstreamClient + 'static, + S: Sandboxing + Send + Sync + 'static, +>( + State(enabled): State, request: axum::extract::Request, next: axum::middleware::Next, ) -> Response { let mut resp = next.run(request).await; - if state.config.enable_cache { + if enabled { if resp.status() == StatusCode::OK { let headers = resp.headers_mut(); headers.insert( @@ -262,6 +290,7 @@ pub async fn common_security_headers( resp } +#[cfg_attr(not(feature = "governor"), allow(unused))] fn atomic_u64_saturating_dec(credits: &AtomicU64) -> bool { loop { let current = credits.load(std::sync::atomic::Ordering::Relaxed); @@ -283,8 +312,11 @@ fn atomic_u64_saturating_dec(credits: &AtomicU64) -> bool { /// Middleware for rate limiting #[cfg(feature = "governor")] #[cfg_attr(feature = "cf-worker", worker::send)] -pub async fn rate_limit_middleware( - State(state): State>>, +pub async fn rate_limit_middleware< + C: UpstreamClient + 'static, + S: Sandboxing + Send + Sync + 'static, +>( + State(state): State>>, ConnectInfo(addr): ConnectInfo, request: axum::extract::Request, next: axum::middleware::Next, @@ -686,6 +718,15 @@ impl ErrorResponse { message: Cow::Borrowed("Bad URL"), } } + /// Bad Host header + #[cfg(feature = "url-summary")] + #[must_use] + pub fn bad_host(host: &str) -> Self { + Self { + status: StatusCode::BAD_REQUEST, + message: format!("Bad host header: {host}").into(), + } + } /// Upstream sent invalid HTTP response #[must_use] pub fn upstream_protocol_error() -> Self { @@ -764,10 +805,31 @@ impl ErrorResponse { } /// Requested media is not a media file #[must_use] - pub const fn not_media() -> Self { + pub const fn not_media(expect: ExpectType) -> Self { Self { status: StatusCode::BAD_REQUEST, - message: Cow::Borrowed("Not a media file"), + message: Cow::Borrowed(match expect { + ExpectType::Media => "Not a media", + ExpectType::Markup => "Not a markup", + ExpectType::Json => "Not a JSON", + }), + } + } + /// Feature is not enabled + #[must_use] + pub const fn feature_disabled(msg: &'static str) -> Self { + Self { + status: StatusCode::NOT_IMPLEMENTED, + message: Cow::Borrowed(msg), + } + } + /// Unsupported encoding + #[cfg(feature = "url-summary")] + #[must_use] + pub const fn unsupported_encoding() -> Self { + Self { + status: StatusCode::NOT_IMPLEMENTED, + message: Cow::Borrowed("Unsupported encoding received"), } } } @@ -800,7 +862,7 @@ impl IntoResponse for ErrorResponse { /// Application state #[allow(unused)] -pub struct AppState { +pub struct AppState { #[cfg(feature = "governor")] limiters: Box< [( @@ -871,12 +933,21 @@ pub fn register_cancel_handler() { } } +/// An [`axum::response::IntoResponse`] that has no valid value +pub enum NeverResponse {} + +impl IntoResponse for NeverResponse { + fn into_response(self) -> axum::response::Response { + unreachable!() + } +} + #[cfg(any(feature = "cf-worker", feature = "reqwest"))] #[allow(clippy::unused_async)] impl App { /// Root endpoint #[cfg_attr(feature = "cf-worker", worker::send)] - pub async fn index(State(state): State>>) -> Response { + pub async fn index(State(state): State>>) -> Response { match &state.clone().config.index_redirect { &IndexConfig::Redirect { ref permanent, @@ -892,11 +963,200 @@ impl App Result { + Err(ErrorResponse::feature_disabled( + "URL summary feature is not enabled, build with `--features url-summary`", + )) + } + + #[cfg(feature = "url-summary")] + #[cfg_attr(feature = "cf-worker", worker::send)] + async fn url_summary( + State(state): State>>, + Query(query): Query, + axum::extract::Host(host): axum::extract::Host, + info: IncomingInfo, + ) -> Result + where + <::Response as HTTPResponse>::BodyStream: Unpin, + { + use axum::http::{HeaderName, HeaderValue}; + use futures::{TryFutureExt, TryStreamExt}; + use html5ever::tendril::{Tendril, TendrilSink}; + use stream::LimitedStream; + use url::Url; + use url_summary::{summaly, Extractor, ReadOnlyScraperHtml, SummaryResponse}; + + let original_url = query.url.parse().map_err(|_| ErrorResponse::bad_url())?; + + let resp = state + .client + .request_upstream( + &info, + ExpectType::Markup, + &query.url, + false, + true, + DEFAULT_MAX_REDIRECTS, + ) + .await?; + + let final_url = resp + .request() + .url + .parse() + .map_err(|_| ErrorResponse::bad_url())?; + + let resp_header = [( + HeaderName::from_static("x-forwarded-proto"), + HeaderValue::from_static(if resp.request().secure { + "https" + } else { + "http" + }), + )]; + + // we can guarantee here that we have exclusive access to all the node references + // so we will force Send on the parser + struct SendCell(T); + + #[allow(unsafe_code)] + unsafe impl Send for SendCell {} + + #[cfg(not(feature = "cf-worker"))] + let dom = { + use futures::StreamExt; + use std::sync::atomic::AtomicBool; + + let stop_signal = AtomicBool::new(false); + let should_stop = || stop_signal.load(std::sync::atomic::Ordering::Acquire); + let set_stop = || stop_signal.store(true, std::sync::atomic::Ordering::Release); + + let parser = LimitedStream::new( + resp.body().map_ok(|b| { + let mut x = vec![0; 4]; + x.extend::>(b.into()); + x + }), + 3 << 20, + ) + .map_ok(Some) + .take_while(|r| { + let res = if let Err(None) = r { + false + } else { + !should_stop() + }; + + futures::future::ready(res) + }) + .map_ok(|r| r.unwrap()) + .map_err(Option::unwrap) + .try_fold( + (SendCell(url_summary::parser(set_stop)), [0u8; 4], 0usize), + |(mut parser, mut resid, resid_len), mut chunk| async move { + chunk[..4].copy_from_slice(&resid); // copy the padding from the last chunk + let start_offset = 4 - resid_len; // start offset for the valid bytes in the chunk + + let resid_len = match std::str::from_utf8(&chunk[start_offset..]) { + Ok(s) => { + parser.0.process( + #[allow(unsafe_code)] + unsafe { + Tendril::from_byte_slice_without_validating(s.as_bytes()) + }, + ); + 0 + } + Err(e) => { + if e.error_len().is_some() { + // if we cannot decode a character in the middle + return Err(ErrorResponse::unsupported_encoding()); + } + + let valid_len = e.valid_up_to(); + + if valid_len == 0 { + // if we received less than 1 whole character + return Err(ErrorResponse::unsupported_encoding()); + } + + parser.0.process( + #[allow(unsafe_code)] + unsafe { + Tendril::from_byte_slice_without_validating( + &chunk[start_offset..(start_offset + valid_len)], + ) + }, + ); + + // compute how many bytes are left in the chunk + chunk.len() - valid_len - start_offset + } + }; + + // fair yield + tokio::task::yield_now().await; + + // this is guaranteed to be inbounds as we already padded 4 bytes + resid[..4].copy_from_slice(&chunk[chunk.len() - 4..]); + Ok((parser, resid, resid_len)) + }, + ) + .await? + .0; + + parser.0.finish() + }; + + #[cfg(feature = "cf-worker")] + let dom = { + let parser = LimitedStream::new(resp.body().map_ok(|b| b.into()), 3 << 20) + .or_else(|e| async { + match e { + None => Ok(Vec::new()), + Some(e) => Err(e), + } + }) + .try_fold( + SendCell(url_summary::parser(|| {})), + |mut parser, chunk| async move { + parser.0.process(Tendril::from_slice(chunk.as_slice())); + Ok(parser) + }, + ) + .await?; + parser.0.finish() + }; + + let proxy_url: Url = format!("https://{}/proxy/", host) + .parse() + .map_err(|_| ErrorResponse::bad_host(&host))?; + let resp = summaly::extractor(&info, &state.client) + .try_extract( + None, + &original_url, + &final_url, + &ReadOnlyScraperHtml::from(dom), + ) + .map_err(|_| ErrorResponse::unsupported_media()) + .await? + .transform_assets(|url| { + let mut ret = proxy_url.clone(); + ret.query_pairs_mut().append_pair("url", url.as_str()); + ret + }); + + Ok((resp_header, Json(resp))) + } + #[cfg_attr(feature = "cf-worker", worker::send)] async fn proxy_impl<'a>( method: http::Method, filename: Option<&str>, - State(state): State>>, + State(state): State>>, Query(query): Query, info: IncomingInfo, ) -> Result @@ -934,7 +1194,14 @@ impl App( @@ -961,7 +1228,7 @@ impl App, - State(state): State>>, + State(state): State>>, info: IncomingInfo, ) -> Result where @@ -987,7 +1254,7 @@ impl App, - State(state): State>>, + State(state): State>>, Query(query): Query, info: IncomingInfo, ) -> Result diff --git a/src/main.rs b/src/main.rs index 76290a3..ff47a4a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,7 +15,7 @@ use tokio::sync::mpsc; use tower_service::Service; use yumechi_no_kuni_proxy_worker::{ config::{Config, SandboxConfig}, - router, + default_state, router, sandbox::NoSandbox, Upstream, }; @@ -84,9 +84,9 @@ fn main() { if label.is_empty() || label == "unconfined" { panic!("Refusing to start in unconfined AppArmor profile when AppArmor is enabled"); } - router::(config) + router::(default_state(config)) } - SandboxConfig::NoSandbox => router::(config), + SandboxConfig::NoSandbox => router::(default_state(config)), _ => panic!("Unsupported sandbox configuration, did you forget to enable the feature?"), }; diff --git a/src/post_process/sniff.rs b/src/post_process/sniff.rs index 0aef5f2..53754d0 100644 --- a/src/post_process/sniff.rs +++ b/src/post_process/sniff.rs @@ -119,6 +119,7 @@ impl SniffingStream { .iter() .any(|sig| sig.matches(self.sniff_buffer.get_ref())) }); + let mut all_safe = true; let mut best_match = None; diff --git a/src/sandbox.rs b/src/sandbox.rs index 2ca2ec7..d5de768 100644 --- a/src/sandbox.rs +++ b/src/sandbox.rs @@ -19,6 +19,7 @@ pub(crate) mod pthread { /// Run a function with explicit immediate cancellation #[allow(unsafe_code)] +#[cfg(target_os = "linux")] pub(crate) fn pthread_cancelable R, R>(f: F) -> R { unsafe { let mut oldstate = 0; diff --git a/src/stream.rs b/src/stream.rs index e82c106..6f2e077 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -9,7 +9,7 @@ type Bytes = Vec; use futures::{Stream, StreamExt}; /// A stream that limits the amount of data it can receive -pub struct LimitedStream +pub struct LimitedStream where T: Stream> + 'static, { @@ -18,7 +18,7 @@ where limit: AtomicUsize, } -impl LimitedStream +impl LimitedStream where T: Stream> + 'static, { @@ -32,7 +32,7 @@ where } } -impl Stream for LimitedStream +impl Stream for LimitedStream where T: Stream> + 'static, { diff --git a/src/url_summary/combinator.rs b/src/url_summary/combinator.rs new file mode 100644 index 0000000..dca25df --- /dev/null +++ b/src/url_summary/combinator.rs @@ -0,0 +1,205 @@ +#![allow(unused)] + +use futures::TryFutureExt; + +use super::{Extractor, ExtractorError, ReadOnlyScraperHtml, SummaryResponse}; + +/// A combinator that maps the error of an extractor. +pub(crate) struct ExtractorMapError< + R: SummaryResponse, + A: Extractor, + F: Fn(A::Error) -> E + Send, + E, +> { + a: A, + f: F, + _r: std::marker::PhantomData, +} + +impl, F: Fn(A::Error) -> E + Send, E> + ExtractorMapError +{ + pub fn new(a: A, f: F) -> Self { + Self { + a, + f, + _r: std::marker::PhantomData, + } + } +} + +impl, F: Fn(A::Error) -> E + Send + Sync, E> Extractor + for ExtractorMapError +where + E: std::error::Error + Send + 'static, +{ + type Error = E; + + async fn try_extract( + &mut self, + prev: Option, + original_url: &url::Url, + final_url: &url::Url, + dom: &ReadOnlyScraperHtml, + ) -> Result> { + self.a + .try_extract(prev, original_url, final_url, dom) + .map_err(|err| match err { + ExtractorError::Unsupported => ExtractorError::Unsupported, + ExtractorError::InternalError(err) => ExtractorError::InternalError((self.f)(err)), + }) + .await + } +} + +/// A combinator that steers the extractor to another extractor based on a predicate. +pub(crate) struct ExtractorSteer< + R: SummaryResponse + serde::Serialize, + E: Send, + A, + B, + F: Fn(&url::Url, &ReadOnlyScraperHtml) -> bool, +> { + a: A, + b: B, + f: F, + _r: std::marker::PhantomData, + _e: std::marker::PhantomData, +} + +impl ExtractorSteer +where + F: Fn(&url::Url, &ReadOnlyScraperHtml) -> bool + Send, +{ + pub fn new(a: A, b: B, f: F) -> Self { + Self { + a, + b, + f, + _r: std::marker::PhantomData, + _e: std::marker::PhantomData, + } + } +} + +impl< + R: SummaryResponse, + E: Send + std::error::Error + 'static, + A: Extractor, + B: Extractor, + F, + > Extractor for ExtractorSteer +where + F: Fn(&url::Url, &ReadOnlyScraperHtml) -> bool + Send, + >::Error: From<>::Error>, +{ + type Error = E; + + async fn try_extract( + &mut self, + prev: Option, + original_url: &url::Url, + final_url: &url::Url, + dom: &ReadOnlyScraperHtml, + ) -> Result> { + let (a, b) = (&mut self.a, &mut self.b); + if (self.f)(final_url, &dom) { + b.try_extract(prev, original_url, final_url, dom).await + } else { + a.try_extract(prev, original_url, final_url, dom).await + } + } +} + +/// A combinator that chains two extractors using an OR operation. +pub(crate) struct ExtractorOr { + a: A, + b: B, + _r: std::marker::PhantomData, + _e: std::marker::PhantomData, +} + +impl ExtractorOr { + pub fn new(a: A, b: B) -> Self { + Self { + a, + b, + _r: std::marker::PhantomData, + _e: std::marker::PhantomData, + } + } +} + +impl Extractor + for ExtractorOr +where + R: Clone, + A: Extractor, + B: Extractor, + >::Error: From<>::Error>, +{ + type Error = E; + + async fn try_extract( + &mut self, + prev: Option, + original_url: &url::Url, + final_url: &url::Url, + dom: &ReadOnlyScraperHtml, + ) -> Result> { + match self + .a + .try_extract(prev.clone(), original_url, final_url, dom) + .await + { + Ok(res) => Ok(res), + Err(ExtractorError::Unsupported) => { + self.b.try_extract(prev, original_url, final_url, dom).await + } + Err(ExtractorError::InternalError(err)) => Err(ExtractorError::InternalError(err)), + } + } +} + +pub(crate) struct ExtractorThen { + a: A, + b: B, + _r: std::marker::PhantomData, + _e: std::marker::PhantomData, +} + +impl ExtractorThen { + pub fn new(a: A, b: B) -> Self { + Self { + a, + b, + _r: std::marker::PhantomData, + _e: std::marker::PhantomData, + } + } +} + +impl Extractor + for ExtractorThen +where + A: Extractor, + B: Extractor, +{ + type Error = E; + + async fn try_extract( + &mut self, + prev: Option, + original_url: &url::Url, + final_url: &url::Url, + dom: &ReadOnlyScraperHtml, + ) -> Result> { + let res = self + .a + .try_extract(prev, original_url, final_url, dom) + .await?; + self.b + .try_extract(Some(res), original_url, final_url, dom) + .await + } +} diff --git a/src/url_summary/mod.rs b/src/url_summary/mod.rs new file mode 100644 index 0000000..4aa21e6 --- /dev/null +++ b/src/url_summary/mod.rs @@ -0,0 +1,392 @@ +use std::{borrow::Cow, future::Future, ops::Deref}; + +use axum::{response::IntoResponse, Json}; +use html5ever::{ + driver, expanded_name, + interface::{ElementFlags, NextParserState, NodeOrText, QuirksMode, TreeSink}, + local_name, namespace_url, ns, + tendril::StrTendril, + tokenizer::TokenizerOpts, + tree_builder::TreeBuilderOpts, + Attribute, Parser, QualName, +}; +use scraper::{Html, HtmlTreeSink}; +use thiserror::Error; + +/// Combinator types for extractors. +pub mod combinator; +/// Extractors for summarizing URLs. +pub mod summaly; + +#[cfg(feature = "lazy_static")] +#[macro_export] +/// Create a CSS selector +macro_rules! selector { + ($content:literal) => {{ + lazy_static::lazy_static! { + static ref SELECTOR: scraper::Selector = scraper::Selector::parse($content).unwrap(); + } + + std::ops::Deref::deref(&SELECTOR) + }}; +} + +#[cfg(not(feature = "lazy_static"))] +#[macro_export] +/// Create a CSS selector +macro_rules! selector { + ($content:literal) => { + &scraper::Selector::parse($content).unwrap() + }; +} + +#[derive(Debug, Clone)] +/// A force Send wrapper for `scraper::Html` to share access after parsing. +pub struct ReadOnlyScraperHtml(scraper::Html); + +impl AsRef for ReadOnlyScraperHtml { + fn as_ref(&self) -> &scraper::Html { + &self.0 + } +} + +impl From for ReadOnlyScraperHtml { + fn from(html: scraper::Html) -> Self { + Self(html) + } +} + +#[allow(unsafe_code)] +unsafe impl Send for ReadOnlyScraperHtml {} + +#[allow(unsafe_code)] +unsafe impl Sync for ReadOnlyScraperHtml {} + +#[derive(Debug, Error)] +/// An error that occurs during extraction. +pub enum ExtractorError { + #[error("unsupported document")] + /// The document is not supported by the extractor. + Unsupported, + #[error("failed to extract summary: {0}")] + /// An internal error occurred during extraction. + InternalError(#[from] E), +} + +/// A trait for extracting summary information from a URL. +pub trait Extractor: Sized + Send { + /// The error type that the extractor may return. + type Error: std::error::Error + Send + 'static; + + /// Extract summary information from a URL. + fn try_extract( + &mut self, + prev: Option, + original_url: &url::Url, + final_url: &url::Url, + dom: &ReadOnlyScraperHtml, + ) -> impl Future>> + Send; + + /// Chain this extractor with another extractor. + fn or_else>( + self, + b: B, + ) -> impl Extractor + where + R: Clone, + { + combinator::ExtractorOr::new(self, b) + } + + /// Steer the extractor to another extractor based on a predicate. + fn steer, F>( + self, + b: B, + f: F, + ) -> impl Extractor + where + F: Fn(&url::Url, &ReadOnlyScraperHtml) -> bool + Send, + { + combinator::ExtractorSteer::new(self, b, f) + } + + /// Map the error of the extractor. + fn map_err(self, f: F) -> impl Extractor + where + E: std::error::Error + Send + 'static, + F: Fn(Self::Error) -> E + Send + Sync, + { + combinator::ExtractorMapError::new(self, f) + } + + /// Execute this extractor and then another extractor. + fn then>( + self, + b: B, + ) -> impl Extractor + where + R: Clone, + { + combinator::ExtractorThen::new(self, b) + } +} + +/// A dummy extractor that always returns an error. +pub struct NilExtractor { + _r: std::marker::PhantomData, + _e: std::marker::PhantomData, +} + +impl Default for NilExtractor { + fn default() -> Self { + Self { + _r: std::marker::PhantomData, + _e: std::marker::PhantomData, + } + } +} + +impl Extractor + for NilExtractor +{ + type Error = E; + + fn try_extract( + &mut self, + _prev: Option, + _original_url: &url::Url, + _final_url: &url::Url, + _dom: &ReadOnlyScraperHtml, + ) -> impl Future>> + Send { + futures::future::ready(Err(ExtractorError::Unsupported)) + } +} + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +/// A query for summarizing a URL. +pub struct SummaryQuery { + /// The URL to summarize. + pub url: String, +} + +/// Trait for types that can be used as a summary response. +pub trait SummaryResponse: serde::Serialize + Send + Sync + 'static +where + Self: Sized, + Json: IntoResponse, +{ + /// Transform assets URLs in the response. + fn transform_assets(self, f: impl FnMut(url::Url) -> url::Url) -> Self; +} + +/// An HTML sink that notifies when a `body` element is encountered. +pub struct FilteringSinkWrapper, F: Fn()> { + signal: F, + sink: T, +} + +impl, F: Fn()> FilteringSinkWrapper { + /// Create a new `MetaOnlySinkWrapper` with the given sink and signal. + /// + /// When the sink receives a `body` element, the signal will be called. + pub fn new(sink: T, signal: F) -> Self { + Self { signal, sink } + } +} + +impl, F: Fn()> TreeSink for FilteringSinkWrapper +where + for<'a> ::ElemName<'a>: Deref, +{ + type Handle = T::Handle; + + /// The overall result of parsing. + /// + /// This should default to Self, but default associated types are not stable yet. + /// [rust-lang/rust#29661](https://github.com/rust-lang/rust/issues/29661) + type Output = T::Output; + + type ElemName<'a> = T::ElemName<'a> where T: 'a , F: 'a; + + fn finish(self) -> Self::Output { + self.sink.finish() + } + + fn parse_error(&self, msg: Cow<'static, str>) { + self.sink.parse_error(msg) + } + + fn get_document(&self) -> Self::Handle { + self.sink.get_document() + } + + fn elem_name<'a>(&'a self, target: &'a Self::Handle) -> Self::ElemName<'a> { + self.sink.elem_name(target) + } + + /// Create an element. + /// + /// When creating a template element (`name.ns.expanded() == expanded_name!(html "template")`), + /// an associated document fragment called the "template contents" should + /// also be created. Later calls to self.get_template_contents() with that + /// given element return it. + /// See [the template element in the whatwg spec][whatwg template]. + /// + /// [whatwg template]: https://html.spec.whatwg.org/multipage/#the-template-element + fn create_element( + &self, + name: QualName, + attrs: Vec, + flags: ElementFlags, + ) -> Self::Handle { + if name.expanded() == expanded_name!(html "body") || name.local == local_name!("body") { + (self.signal)(); + }; + let handle = self.sink.create_element(name, attrs, flags); + handle + } + + /// Create a comment node. + fn create_comment(&self, _text: StrTendril) -> Self::Handle { + self.sink.create_comment(StrTendril::new()) + } + + /// Create a Processing Instruction node. + fn create_pi(&self, target: StrTendril, data: StrTendril) -> Self::Handle { + self.sink.create_pi(target, data) + } + + /// The child node will not already have a parent. + fn append(&self, parent: &Self::Handle, child: NodeOrText) { + // get rid of some blocks of text that are not useful + if let NodeOrText::AppendText(_) = &child { + let name = self.elem_name(parent); + if name.deref().expanded() == expanded_name!(html "script") + || name.deref().expanded() == expanded_name!(html "style") + || name.deref().expanded() == expanded_name!(html "noscript") + || name.deref().local == local_name!("script") + || name.deref().local == local_name!("style") + || name.deref().local == local_name!("noscript") + { + return; + } + } + self.sink.append(parent, child) + } + + fn append_based_on_parent_node( + &self, + element: &Self::Handle, + prev_element: &Self::Handle, + child: NodeOrText, + ) { + self.sink + .append_based_on_parent_node(element, prev_element, child) + } + + /// Append a `DOCTYPE` element to the `Document` node. + fn append_doctype_to_document( + &self, + name: StrTendril, + public_id: StrTendril, + system_id: StrTendril, + ) { + self.sink + .append_doctype_to_document(name, public_id, system_id) + } + + /// Mark a HTML `