URL summary support

Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-12-21 19:31:26 -06:00
parent 81063c2c5e
commit e6afa180bb
No known key found for this signature in database
15 changed files with 2028 additions and 227 deletions

View file

@ -1,6 +1,7 @@
{ {
"rust-analyzer.cargo.features": [ "rust-analyzer.cargo.features": [
//"cf-worker" //"cf-worker",
"env-local" "env-local",
"url-summary"
], ],
} }

311
Cargo.lock generated
View file

@ -213,6 +213,7 @@ checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
"axum-macros",
"bytes", "bytes",
"futures-util", "futures-util",
"http", "http",
@ -258,6 +259,17 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "axum-macros"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "axum-server" name = "axum-server"
version = "0.7.1" version = "0.7.1"
@ -514,6 +526,29 @@ version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "cssparser"
version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7c66d1cd8ed61bf80b38432613a7a2f09401ab8d0501110655f8b341484a3e3"
dependencies = [
"cssparser-macros",
"dtoa-short",
"itoa",
"phf",
"smallvec",
]
[[package]]
name = "cssparser-macros"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13b588ba4ac1a99f7f2964d24b3d896ddc6bf847ee3855dbd4366f058cfcd331"
dependencies = [
"quote",
"syn",
]
[[package]] [[package]]
name = "dashmap" name = "dashmap"
version = "6.1.0" version = "6.1.0"
@ -528,6 +563,17 @@ dependencies = [
"parking_lot_core", "parking_lot_core",
] ]
[[package]]
name = "derive_more"
version = "0.99.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "displaydoc" name = "displaydoc"
version = "0.2.5" version = "0.2.5"
@ -539,6 +585,27 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "dtoa"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcbb2bf8e87535c23f7a8a321e364ce21462d0ff10cb6407820e8e96dfff6653"
[[package]]
name = "dtoa-short"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd1511a7b6a56299bd043a9c167a6d2bfb37bf84a6dfceaba651168adfb43c87"
dependencies = [
"dtoa",
]
[[package]]
name = "ego-tree"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2972feb8dffe7bc8c5463b1dacda1b0dfbed3710e50f977d965429692d74cd8"
[[package]] [[package]]
name = "either" name = "either"
version = "1.13.0" version = "1.13.0"
@ -654,6 +721,16 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "futf"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df420e2e84819663797d1ec6544b13c5be84629e7bb00dc960d6917db2987843"
dependencies = [
"mac",
"new_debug_unreachable",
]
[[package]] [[package]]
name = "futures" name = "futures"
version = "0.3.31" version = "0.3.31"
@ -737,6 +814,24 @@ dependencies = [
"slab", "slab",
] ]
[[package]]
name = "fxhash"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
dependencies = [
"byteorder",
]
[[package]]
name = "getopts"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5"
dependencies = [
"unicode-width",
]
[[package]] [[package]]
name = "getrandom" name = "getrandom"
version = "0.2.15" version = "0.2.15"
@ -839,6 +934,20 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "html5ever"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e15626aaf9c351bc696217cbe29cb9b5e86c43f8a46b5e2f5c6c5cf7cb904ce"
dependencies = [
"log",
"mac",
"markup5ever",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "http" name = "http"
version = "1.2.0" version = "1.2.0"
@ -1283,6 +1392,26 @@ dependencies = [
"hashbrown 0.15.2", "hashbrown 0.15.2",
] ]
[[package]]
name = "mac"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4"
[[package]]
name = "markup5ever"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82c88c6129bd24319e62a0359cb6b958fa7e8be6e19bb1663bc396b90883aca5"
dependencies = [
"log",
"phf",
"phf_codegen",
"string_cache",
"string_cache_codegen",
"tendril",
]
[[package]] [[package]]
name = "matchit" name = "matchit"
version = "0.7.3" version = "0.7.3"
@ -1532,6 +1661,77 @@ version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "phf"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc"
dependencies = [
"phf_macros",
"phf_shared 0.11.2",
]
[[package]]
name = "phf_codegen"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a"
dependencies = [
"phf_generator 0.11.2",
"phf_shared 0.11.2",
]
[[package]]
name = "phf_generator"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d5285893bb5eb82e6aaf5d59ee909a06a16737a8970984dd7746ba9283498d6"
dependencies = [
"phf_shared 0.10.0",
"rand",
]
[[package]]
name = "phf_generator"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0"
dependencies = [
"phf_shared 0.11.2",
"rand",
]
[[package]]
name = "phf_macros"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b"
dependencies = [
"phf_generator 0.11.2",
"phf_shared 0.11.2",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "phf_shared"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096"
dependencies = [
"siphasher 0.3.11",
]
[[package]]
name = "phf_shared"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b"
dependencies = [
"siphasher 0.3.11",
]
[[package]] [[package]]
name = "pin-project" name = "pin-project"
version = "1.1.7" version = "1.1.7"
@ -1598,6 +1798,12 @@ dependencies = [
"zerocopy", "zerocopy",
] ]
[[package]]
name = "precomputed-hash"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.92" version = "1.0.92"
@ -1964,6 +2170,21 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "scraper"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc3d051b884f40e309de6c149734eab57aa8cc1347992710dc80bcc1c2194c15"
dependencies = [
"cssparser",
"ego-tree",
"getopts",
"html5ever",
"precomputed-hash",
"selectors",
"tendril",
]
[[package]] [[package]]
name = "security-framework" name = "security-framework"
version = "2.11.1" version = "2.11.1"
@ -1987,6 +2208,25 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "selectors"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd568a4c9bb598e291a08244a5c1f5a8a6650bee243b5b0f8dbb3d9cc1d87fe8"
dependencies = [
"bitflags 2.6.0",
"cssparser",
"derive_more",
"fxhash",
"log",
"new_debug_unreachable",
"phf",
"phf_codegen",
"precomputed-hash",
"servo_arc",
"smallvec",
]
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.216" version = "1.0.216"
@ -2072,6 +2312,15 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "servo_arc"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae65c4249478a2647db249fb43e23cec56a2c8974a427e7bd8cb5a1d0964921a"
dependencies = [
"stable_deref_trait",
]
[[package]] [[package]]
name = "shlex" name = "shlex"
version = "1.3.0" version = "1.3.0"
@ -2102,6 +2351,12 @@ dependencies = [
"quote", "quote",
] ]
[[package]]
name = "siphasher"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d"
[[package]] [[package]]
name = "siphasher" name = "siphasher"
version = "1.0.1" version = "1.0.1"
@ -2167,6 +2422,32 @@ dependencies = [
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
[[package]]
name = "string_cache"
version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b"
dependencies = [
"new_debug_unreachable",
"once_cell",
"parking_lot",
"phf_shared 0.10.0",
"precomputed-hash",
"serde",
]
[[package]]
name = "string_cache_codegen"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bb30289b722be4ff74a408c3cc27edeaad656e06cb1fe8fa9231fa59c728988"
dependencies = [
"phf_generator 0.10.0",
"phf_shared 0.10.0",
"proc-macro2",
"quote",
]
[[package]] [[package]]
name = "strsim" name = "strsim"
version = "0.11.1" version = "0.11.1"
@ -2263,6 +2544,17 @@ dependencies = [
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
[[package]]
name = "tendril"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d24a120c5fc464a3458240ee02c299ebcb9d67b5249c8848b09d639dca8d7bb0"
dependencies = [
"futf",
"mac",
"utf-8",
]
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.69" version = "1.0.69"
@ -2501,6 +2793,12 @@ version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83"
[[package]]
name = "unicode-width"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.9.0" version = "0.9.0"
@ -2516,8 +2814,15 @@ dependencies = [
"form_urlencoded", "form_urlencoded",
"idna", "idna",
"percent-encoding", "percent-encoding",
"serde",
] ]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "utf16_iter" name = "utf16_iter"
version = "1.0.5" version = "1.0.5"
@ -2965,20 +3270,24 @@ dependencies = [
"clap", "clap",
"console_error_panic_hook", "console_error_panic_hook",
"dashmap", "dashmap",
"ego-tree",
"env_logger", "env_logger",
"futures", "futures",
"getrandom", "getrandom",
"governor", "governor",
"html5ever",
"image", "image",
"lazy_static",
"libc", "libc",
"log", "log",
"lru", "lru",
"prometheus", "prometheus",
"quote", "quote",
"reqwest", "reqwest",
"scraper",
"serde", "serde",
"serde_json", "serde_json",
"siphasher", "siphasher 1.0.1",
"thiserror 2.0.8", "thiserror 2.0.8",
"tokio", "tokio",
"toml", "toml",

View file

@ -37,7 +37,9 @@ env-local = ["axum/http1", "axum/http2",
"lossy-webp", "lossy-webp",
"tower-http", "tower-http",
"metrics", "metrics",
"lazy_static",
] ]
url-summary = ["dep:html5ever", "dep:scraper", "dep:url", "url/serde", "dep:ego-tree"]
cf-worker = ["dep:worker", "dep:worker-macros", "dep:wasm-bindgen", "image/ico", "panic-console-error"] cf-worker = ["dep:worker", "dep:worker-macros", "dep:wasm-bindgen", "image/ico", "panic-console-error"]
# Observability and tracing features # Observability and tracing features
@ -59,11 +61,13 @@ reqwest = ["dep:reqwest", "dep:url"]
# Sandbox features # Sandbox features
apparmor = ["dep:siphasher", "dep:libc"] apparmor = ["dep:siphasher", "dep:libc"]
lazy_static = ["dep:lazy_static"]
[dependencies] [dependencies]
worker = { version="0.4.2", features=['http', 'axum'], optional = true } worker = { version="0.4.2", features=['http', 'axum'], optional = true }
worker-macros = { version="0.4.2", features=['http'], optional = true } worker-macros = { version="0.4.2", features=['http'], optional = true }
axum = { version = "0.7", default-features = false, features = ["query", "json"] } axum = { version = "0.7", default-features = false, features = ["query", "json", "macros"] }
ego-tree = { version = "0.10", optional = true }
tower-service = "0.3" tower-service = "0.3"
console_error_panic_hook = { version = "0.1.1", optional = true } console_error_panic_hook = { version = "0.1.1", optional = true }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
@ -90,6 +94,9 @@ dashmap = "6.1.0"
lru = "0.12.5" lru = "0.12.5"
prometheus = { version = "0.13.4", optional = true } prometheus = { version = "0.13.4", optional = true }
xml = "0.8.20" xml = "0.8.20"
html5ever = { version = "0.29", optional = true }
scraper = { version = "0.22.0", optional = true }
lazy_static = { version = "1.5.0", optional = true }
[build-dependencies] [build-dependencies]
chumsky = "0.9.3" chumsky = "0.9.3"

View file

@ -21,6 +21,7 @@ Currently to do:
- [ ] Handle all possible panics reported by Clippy - [ ] Handle all possible panics reported by Clippy
- [X] Sandboxing the image rendering - [X] Sandboxing the image rendering
- [X] Prometheus-format metrics - [X] Prometheus-format metrics
- [X] Experimental URL summarization replacement (local only for now)
## Spec Compliance ## Spec Compliance

View file

@ -77,6 +77,15 @@ fn static_signatures() -> Vec<MIMEAssociation> {
mask: vec![0xff, 0xff, 0xff, 0xff], mask: vec![0xff, 0xff, 0xff, 0xff],
}], }],
}, },
MIMEAssociation {
mime: "image/vnd.microsoft.icon".to_string().into(),
ext: vec![".ico".to_string()],
safe: true,
signatures: vec![FlattenedFileSignature {
test: vec![0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
mask: vec![0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff],
}],
},
] ]
} }

View file

@ -30,7 +30,7 @@ const fn http_version_to_via(v: axum::http::Version) -> &'static str {
} }
/// Trait for HTTP responses /// Trait for HTTP responses
pub trait HTTPResponse { pub trait HTTPResponse: Send + 'static {
/// Type of the byte buffer /// Type of the byte buffer
type Bytes: Into<Vec<u8>> + AsRef<[u8]> + Into<Bytes> + Send + 'static; type Bytes: Into<Vec<u8>> + AsRef<[u8]> + Into<Bytes> + Send + 'static;
/// Type of body stream /// Type of body stream
@ -51,10 +51,12 @@ pub trait HTTPResponse {
} }
/// Information about the incoming request /// Information about the incoming request
#[derive(Debug, Clone)]
pub struct IncomingInfo { pub struct IncomingInfo {
version: axum::http::Version, version: axum::http::Version,
user_agent: String, user_agent: String,
via: String, via: String,
accept_language: Option<String>,
} }
impl IncomingInfo { impl IncomingInfo {
@ -123,12 +125,40 @@ impl<S> FromRequestParts<S> for IncomingInfo {
acc.push_str(v.to_str().unwrap_or_default()); acc.push_str(v.to_str().unwrap_or_default());
acc acc
}), }),
accept_language: parts
.headers
.get("accept-language")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.split(',').next())
.map(|s| s.chars().filter(|c| !c.is_control()).collect()),
}) })
} }
} }
#[derive(Debug, Clone, PartialEq, Copy)]
/// Expected response type
pub enum ExpectType {
/// Multimedia content, such as images, videos, and audio
Media,
/// Markup content, such as HTML and XML
Markup,
/// Json content
Json,
}
impl ExpectType {
/// Get the string representation
pub fn as_str(&self) -> &'static str {
match self {
Self::Media => "media",
Self::Markup => "markup",
Self::Json => "json",
}
}
}
/// Trait for upstream clients /// Trait for upstream clients
pub trait UpstreamClient { pub trait UpstreamClient: Send + Sync + 'static {
/// Type of the response /// Type of the response
type Response: HTTPResponse; type Response: HTTPResponse;
/// Create a new client /// Create a new client
@ -137,11 +167,12 @@ pub trait UpstreamClient {
fn request_upstream( fn request_upstream(
&self, &self,
info: &IncomingInfo, info: &IncomingInfo,
expect: ExpectType,
url: &str, url: &str,
polish: bool, polish: bool,
secure: bool, secure: bool,
remaining: usize, remaining: usize,
) -> impl std::future::Future<Output = Result<Self::Response, ErrorResponse>>; ) -> impl std::future::Future<Output = Result<Self::Response, ErrorResponse>> + Send;
} }
/// Reqwest client /// Reqwest client
@ -157,7 +188,7 @@ pub mod reqwest {
use axum::body::Bytes; use axum::body::Bytes;
use futures::TryStreamExt; use futures::TryStreamExt;
use reqwest::dns::Resolve; use reqwest::dns::Resolve;
use std::{net::SocketAddrV4, sync::Arc, time::Duration}; use std::{future::Future, net::SocketAddrV4, sync::Arc, time::Duration};
use url::Host; use url::Host;
/// A Safe DNS resolver that only resolves to global addresses unless the requester itself is local. /// A Safe DNS resolver that only resolves to global addresses unless the requester itself is local.
@ -303,14 +334,16 @@ pub mod reqwest {
.expect("Failed to create reqwest client"), .expect("Failed to create reqwest client"),
} }
} }
async fn request_upstream( fn request_upstream(
&self, &self,
info: &super::IncomingInfo, info: &super::IncomingInfo,
expect: super::ExpectType,
url: &str, url: &str,
polish: bool, polish: bool,
mut secure: bool, mut secure: bool,
remaining: usize, remaining: usize,
) -> Result<ReqwestResponse, ErrorResponse> { ) -> impl Future<Output = Result<ReqwestResponse, ErrorResponse>> + Send {
async move {
if remaining == 0 { if remaining == 0 {
return Err(ErrorResponse::too_many_redirects()); return Err(ErrorResponse::too_many_redirects());
} }
@ -346,14 +379,30 @@ pub mod reqwest {
self.via_ident self.via_ident
), ),
) )
.header(
"accept",
match expect {
super::ExpectType::Media => {
"image/*, video/*, audio/*, application/octet-stream"
}
super::ExpectType::Markup => "text/html, application/xhtml+xml",
super::ExpectType::Json => "application/json",
},
)
.header(
"accept-language",
info.accept_language.as_deref().unwrap_or("en"),
)
.send() .send()
.await?; .await?;
if resp.status().is_redirection() { if resp.status().is_redirection() {
if let Some(location) = resp.headers().get("location").and_then(|l| l.to_str().ok()) if let Some(location) =
resp.headers().get("location").and_then(|l| l.to_str().ok())
{ {
return Box::pin(self.request_upstream( return Box::pin(self.request_upstream(
info, info,
expect,
location, location,
polish, polish,
secure, secure,
@ -380,14 +429,26 @@ pub mod reqwest {
let content_type = resp.headers().get("content-type"); let content_type = resp.headers().get("content-type");
if let Some(content_type) = content_type.and_then(|c| c.to_str().ok()) { if let Some(content_type) = content_type.and_then(|c| c.to_str().ok()) {
if !["image/", "video/", "audio/", "application/octet-stream"] if !match expect {
super::ExpectType::Media => {
["image/", "video/", "audio/", "application/octet-stream"]
.iter() .iter()
.any(|prefix| { .any(|prefix| {
content_type[..prefix.len().min(content_type.len())] content_type[..prefix.len().min(content_type.len())]
.eq_ignore_ascii_case(prefix) .eq_ignore_ascii_case(prefix)
}) })
{ }
return Err(ErrorResponse::not_media()); super::ExpectType::Markup => {
["text/html", "application/xhtml+xml"].iter().any(|prefix| {
content_type[..prefix.len().min(content_type.len())]
.eq_ignore_ascii_case(prefix)
})
}
super::ExpectType::Json => content_type
[.."application/json".len().min(content_type.len())]
.eq_ignore_ascii_case("application/json"),
} {
return Err(ErrorResponse::not_media(expect));
} }
} }
@ -398,6 +459,7 @@ pub mod reqwest {
} }
} }
} }
}
/// Cloudflare Workers client /// Cloudflare Workers client
#[cfg(feature = "cf-worker")] #[cfg(feature = "cf-worker")]
@ -406,7 +468,7 @@ pub mod reqwest {
deprecated = "You should use reqwest instead when not on Cloudflare Workers" deprecated = "You should use reqwest instead when not on Cloudflare Workers"
)] )]
pub mod cf_worker { pub mod cf_worker {
use std::time::Duration; use std::{future::Future, time::Duration};
use super::{ use super::{
http_version_to_via, Cow, ErrorResponse, HTTPResponse, HeaderMap, Pin, RequestCtx, http_version_to_via, Cow, ErrorResponse, HTTPResponse, HeaderMap, Pin, RequestCtx,
@ -463,6 +525,22 @@ pub mod cf_worker {
#[allow(unsafe_code, reason = "this is never used concurrently")] #[allow(unsafe_code, reason = "this is never used concurrently")]
unsafe impl Sync for CfBodyStreamWrapper {} unsafe impl Sync for CfBodyStreamWrapper {}
struct SendFuture<T: Future>(Pin<Box<T>>);
#[allow(unsafe_code, reason = "this is never used concurrently")]
unsafe impl<T: Future> Send for SendFuture<T> {}
impl<T: Future> Future for SendFuture<T> {
type Output = T::Output;
fn poll(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
Future::poll(self.get_mut().0.as_mut(), cx)
}
}
/// Response from Cloudflare Workers /// Response from Cloudflare Workers
pub struct CfWorkerResponse { pub struct CfWorkerResponse {
time_to_body: std::time::Duration, time_to_body: std::time::Duration,
@ -470,6 +548,9 @@ pub mod cf_worker {
url: Url, url: Url,
} }
#[allow(unsafe_code, reason = "this is never used concurrently")]
unsafe impl Send for CfWorkerResponse {}
impl HTTPResponse for CfWorkerResponse { impl HTTPResponse for CfWorkerResponse {
type Bytes = Vec<u8>; type Bytes = Vec<u8>;
type BodyStream = CfBodyStreamWrapper; type BodyStream = CfBodyStreamWrapper;
@ -539,14 +620,20 @@ pub mod cf_worker {
} }
} }
async fn request_upstream( fn request_upstream(
&self, &self,
info: &super::IncomingInfo, info: &super::IncomingInfo,
expect: super::ExpectType,
url: &str, url: &str,
polish: bool, polish: bool,
mut secure: bool, mut secure: bool,
remaining: usize, remaining: usize,
) -> Result<Self::Response, ErrorResponse> { ) -> impl Future<Output = Result<Self::Response, ErrorResponse>> + Send {
let url_parsed = Url::parse(url).map_err(|_| ErrorResponse::bad_url());
let info = info.clone();
SendFuture(Box::pin(async move {
let url_parsed = url_parsed?;
if remaining == 0 { if remaining == 0 {
return Err(ErrorResponse::too_many_redirects()); return Err(ErrorResponse::too_many_redirects());
} }
@ -567,6 +654,20 @@ pub mod cf_worker {
self.via_ident self.via_ident
), ),
)?; )?;
headers.set(
"accept",
match expect {
super::ExpectType::Media => {
"image/*, video/*, audio/*, application/octet-stream"
}
super::ExpectType::Markup => "text/html, application/xhtml+xml",
super::ExpectType::Json => "application/json",
},
)?;
headers.set(
"accept-language",
info.accept_language.as_deref().unwrap_or("en"),
)?;
let mut prop = CfProperties::new(); let mut prop = CfProperties::new();
if polish { if polish {
@ -579,8 +680,6 @@ pub mod cf_worker {
.with_cf_properties(prop) .with_cf_properties(prop)
.with_redirect(RequestRedirect::Manual); .with_redirect(RequestRedirect::Manual);
let url_parsed = Url::parse(url).map_err(|_| ErrorResponse::bad_url())?;
if self.https_only && !url_parsed.scheme().eq_ignore_ascii_case("https") { if self.https_only && !url_parsed.scheme().eq_ignore_ascii_case("https") {
return Err(ErrorResponse::insecure_request()); return Err(ErrorResponse::insecure_request());
} }
@ -589,7 +688,7 @@ pub mod cf_worker {
let begin = crate::timing::Instant::now(); let begin = crate::timing::Instant::now();
let req = Request::new_with_init(url, init)?; let req = Request::new_with_init(url_parsed.as_str(), init)?;
let abc = AbortController::default(); let abc = AbortController::default();
let abs = abc.signal(); let abs = abc.signal();
@ -608,7 +707,8 @@ pub mod cf_worker {
if resp.status_code() >= 300 && resp.status_code() < 400 { if resp.status_code() >= 300 && resp.status_code() < 400 {
if let Ok(Some(location)) = resp.headers().get("location") { if let Ok(Some(location)) = resp.headers().get("location") {
return Box::pin(self.request_upstream( return Box::pin(self.request_upstream(
info, &info,
expect,
&location, &location,
polish, polish,
secure, secure,
@ -620,7 +720,10 @@ pub mod cf_worker {
} }
if resp.status_code() < 200 || resp.status_code() >= 300 { if resp.status_code() < 200 || resp.status_code() >= 300 {
return Err(ErrorResponse::unexpected_status(url, resp.status_code())); return Err(ErrorResponse::unexpected_status(
url_parsed.as_str(),
resp.status_code(),
));
} }
let content_length = resp.headers().get("content-length").unwrap_or_default(); let content_length = resp.headers().get("content-length").unwrap_or_default();
@ -632,14 +735,26 @@ pub mod cf_worker {
let content_type = resp.headers().get("content-type").unwrap_or_default(); let content_type = resp.headers().get("content-type").unwrap_or_default();
if let Some(content_type) = content_type { if let Some(content_type) = content_type {
if !["image/", "video/", "audio/", "application/octet-stream"] if !match expect {
super::ExpectType::Media => {
["image/", "video/", "audio/", "application/octet-stream"]
.iter() .iter()
.any(|prefix| { .any(|prefix| {
content_type.as_str()[..prefix.len().min(content_type.len())] content_type[..prefix.len().min(content_type.len())]
.eq_ignore_ascii_case(prefix) .eq_ignore_ascii_case(prefix)
}) })
{ }
return Err(ErrorResponse::not_media()); super::ExpectType::Markup => {
["text/html", "application/xhtml+xml"].iter().any(|prefix| {
content_type[..prefix.len().min(content_type.len())]
.eq_ignore_ascii_case(prefix)
})
}
super::ExpectType::Json => content_type
[.."application/json".len().min(content_type.len())]
.eq_ignore_ascii_case("application/json"),
} {
return Err(ErrorResponse::not_media(expect));
} }
} }
@ -648,6 +763,7 @@ pub mod cf_worker {
resp, resp,
url: url_parsed, url: url_parsed,
}) })
}))
} }
} }
} }

View file

@ -6,6 +6,7 @@
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
use std::net::SocketAddr; use std::net::SocketAddr;
#[cfg_attr(feature = "cf-worker", allow(unused_imports))]
use std::{ use std::{
borrow::Cow, borrow::Cow,
fmt::Display, fmt::Display,
@ -23,12 +24,13 @@ use axum::{
routing::get, routing::get,
Json, Router, Json, Router,
}; };
use fetch::{HTTPResponse, IncomingInfo, UpstreamClient, DEFAULT_MAX_REDIRECTS}; use fetch::{ExpectType, HTTPResponse, IncomingInfo, UpstreamClient, DEFAULT_MAX_REDIRECTS};
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
use governor::{ use governor::{
clock::SystemClock, middleware::StateInformationMiddleware, state::keyed::DashMapStateStore, clock::SystemClock, middleware::StateInformationMiddleware, state::keyed::DashMapStateStore,
RateLimiter, RateLimiter,
}; };
#[cfg(feature = "governor")]
use lru::LruCache; use lru::LruCache;
use post_process::{CompressionLevel, MediaResponse}; use post_process::{CompressionLevel, MediaResponse};
use sandbox::Sandboxing; use sandbox::Sandboxing;
@ -58,6 +60,10 @@ pub mod config;
/// Cross platform timing utilities /// Cross platform timing utilities
pub mod timing; pub mod timing;
#[cfg(feature = "url-summary")]
/// URL summarization utilities
pub mod url_summary;
/// Utilities for Cloudflare Workers /// Utilities for Cloudflare Workers
#[cfg(feature = "cf-worker")] #[cfg(feature = "cf-worker")]
mod cf_utils; mod cf_utils;
@ -98,31 +104,25 @@ async fn fetch(
#[cfg(feature = "panic-console-error")] #[cfg(feature = "panic-console-error")]
console_error_panic_hook::set_once(); console_error_panic_hook::set_once();
Ok(router::<CfWorkerClient, NoSandbox>(config) Ok(router::<CfWorkerClient, NoSandbox>(default_state(config))
.call(req) .call(req)
.await?) .await?)
} }
#[cfg(any(feature = "cf-worker", feature = "reqwest"))] /// Create default Application state
/// Application Router pub fn default_state<S: Sandboxing + Send + Sync + 'static>(
pub fn router<C: UpstreamClient + 'static, S: Sandboxing + Send + Sync + 'static>(
config: Config, config: Config,
) -> Router ) -> Arc<AppState<Upstream, S>> {
where
<<C as UpstreamClient>::Response as HTTPResponse>::BodyStream: Unpin,
{
use axum::middleware;
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
use governor::{ use governor::{
clock::SystemClock, middleware::StateInformationMiddleware, Quota, RateLimiter, clock::SystemClock, middleware::StateInformationMiddleware, Quota, RateLimiter,
}; };
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
use std::time::Duration; use std::time::Duration;
#[allow(unused_imports)]
use std::{num::NonZero, sync::RwLock}; use std::{num::NonZero, sync::RwLock};
#[cfg(not(feature = "cf-worker"))]
use tower_http::{catch_panic::CatchPanicLayer, timeout::TimeoutLayer};
let state = AppState { Arc::new(AppState {
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
limiters: config limiters: config
.rate_limit .rate_limit
@ -144,21 +144,39 @@ where
client: Upstream::new(&config.fetch), client: Upstream::new(&config.fetch),
sandbox: S::new(&config.sandbox), sandbox: S::new(&config.sandbox),
config, config,
}; })
}
let state = Arc::new(state); #[cfg(any(feature = "cf-worker", feature = "reqwest"))]
/// Application Router
pub fn router<C: UpstreamClient + 'static, S: Sandboxing + Send + Sync + 'static>(
state: Arc<AppState<C, S>>,
) -> Router
where
<<C as UpstreamClient>::Response as HTTPResponse>::BodyStream: Unpin,
{
use axum::middleware;
#[cfg(not(feature = "cf-worker"))]
use tower_http::{catch_panic::CatchPanicLayer, timeout::TimeoutLayer};
#[allow(unused_mut)] #[allow(unused_mut)]
let mut router = Router::new() let mut router = Router::new()
.route("/", get(App::<C, S>::index)) .route("/", get(App::<C, S>::index))
.route(
"/url",
get(App::<C, S>::url_summary).route_layer(middleware::from_fn_with_state(
state.config.enable_cache,
set_cache_control::<C, S>,
)),
)
.route( .route(
"/proxy", "/proxy",
get(App::<C, S>::proxy_without_filename) get(App::<C, S>::proxy_without_filename)
.head(App::<C, S>::proxy_without_filename) .head(App::<C, S>::proxy_without_filename)
.options(App::<C, S>::proxy_options) .options(App::<C, S>::proxy_options)
.route_layer(middleware::from_fn_with_state( .route_layer(middleware::from_fn_with_state(
state.clone(), state.config.enable_cache,
set_cache_control, set_cache_control::<C, S>,
)) ))
.fallback(|| async { ErrorResponse::method_not_allowed() }), .fallback(|| async { ErrorResponse::method_not_allowed() }),
) )
@ -168,8 +186,8 @@ where
.head(App::<C, S>::proxy_without_filename) .head(App::<C, S>::proxy_without_filename)
.options(App::<C, S>::proxy_options) .options(App::<C, S>::proxy_options)
.route_layer(middleware::from_fn_with_state( .route_layer(middleware::from_fn_with_state(
state.clone(), state.config.enable_cache,
set_cache_control, set_cache_control::<C, S>,
)) ))
.fallback(|| async { ErrorResponse::method_not_allowed() }), .fallback(|| async { ErrorResponse::method_not_allowed() }),
) )
@ -179,8 +197,8 @@ where
.head(App::<C, S>::proxy_with_filename) .head(App::<C, S>::proxy_with_filename)
.options(App::<C, S>::proxy_options) .options(App::<C, S>::proxy_options)
.route_layer(middleware::from_fn_with_state( .route_layer(middleware::from_fn_with_state(
state.clone(), state.config.enable_cache,
set_cache_control, set_cache_control::<C, S>,
)) ))
.fallback(|| async { ErrorResponse::method_not_allowed() }), .fallback(|| async { ErrorResponse::method_not_allowed() }),
) )
@ -189,6 +207,8 @@ where
#[cfg(not(feature = "cf-worker"))] #[cfg(not(feature = "cf-worker"))]
{ {
use std::time::Duration;
router = router router = router
.layer(CatchPanicLayer::custom(|err| { .layer(CatchPanicLayer::custom(|err| {
log::error!("Panic in request: {:?}", err); log::error!("Panic in request: {:?}", err);
@ -199,6 +219,8 @@ where
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
{ {
use std::time::Duration;
let state_gc = Arc::clone(&state); let state_gc = Arc::clone(&state);
std::thread::spawn(move || loop { std::thread::spawn(move || loop {
std::thread::sleep(Duration::from_secs(300)); std::thread::sleep(Duration::from_secs(300));
@ -209,20 +231,26 @@ where
} }
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
return router.route_layer(middleware::from_fn_with_state(state, rate_limit_middleware)); return router.route_layer(middleware::from_fn_with_state(
state,
rate_limit_middleware::<C, S>,
));
#[cfg(not(feature = "governor"))] #[cfg(not(feature = "governor"))]
router router
} }
/// Set the Cache-Control header /// Set the Cache-Control header
#[cfg_attr(feature = "cf-worker", worker::send)] #[cfg_attr(feature = "cf-worker", worker::send)]
pub async fn set_cache_control<S: Sandboxing + Send + Sync + 'static>( pub async fn set_cache_control<
State(state): State<Arc<AppState<Upstream, S>>>, C: UpstreamClient + 'static,
S: Sandboxing + Send + Sync + 'static,
>(
State(enabled): State<bool>,
request: axum::extract::Request, request: axum::extract::Request,
next: axum::middleware::Next, next: axum::middleware::Next,
) -> Response { ) -> Response {
let mut resp = next.run(request).await; let mut resp = next.run(request).await;
if state.config.enable_cache { if enabled {
if resp.status() == StatusCode::OK { if resp.status() == StatusCode::OK {
let headers = resp.headers_mut(); let headers = resp.headers_mut();
headers.insert( headers.insert(
@ -262,6 +290,7 @@ pub async fn common_security_headers(
resp resp
} }
#[cfg_attr(not(feature = "governor"), allow(unused))]
fn atomic_u64_saturating_dec(credits: &AtomicU64) -> bool { fn atomic_u64_saturating_dec(credits: &AtomicU64) -> bool {
loop { loop {
let current = credits.load(std::sync::atomic::Ordering::Relaxed); let current = credits.load(std::sync::atomic::Ordering::Relaxed);
@ -283,8 +312,11 @@ fn atomic_u64_saturating_dec(credits: &AtomicU64) -> bool {
/// Middleware for rate limiting /// Middleware for rate limiting
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
#[cfg_attr(feature = "cf-worker", worker::send)] #[cfg_attr(feature = "cf-worker", worker::send)]
pub async fn rate_limit_middleware<S: Sandboxing + Send + Sync + 'static>( pub async fn rate_limit_middleware<
State(state): State<Arc<AppState<Upstream, S>>>, C: UpstreamClient + 'static,
S: Sandboxing + Send + Sync + 'static,
>(
State(state): State<Arc<AppState<C, S>>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
request: axum::extract::Request, request: axum::extract::Request,
next: axum::middleware::Next, next: axum::middleware::Next,
@ -686,6 +718,15 @@ impl ErrorResponse {
message: Cow::Borrowed("Bad URL"), message: Cow::Borrowed("Bad URL"),
} }
} }
/// Bad Host header
#[cfg(feature = "url-summary")]
#[must_use]
pub fn bad_host(host: &str) -> Self {
Self {
status: StatusCode::BAD_REQUEST,
message: format!("Bad host header: {host}").into(),
}
}
/// Upstream sent invalid HTTP response /// Upstream sent invalid HTTP response
#[must_use] #[must_use]
pub fn upstream_protocol_error() -> Self { pub fn upstream_protocol_error() -> Self {
@ -764,10 +805,31 @@ impl ErrorResponse {
} }
/// Requested media is not a media file /// Requested media is not a media file
#[must_use] #[must_use]
pub const fn not_media() -> Self { pub const fn not_media(expect: ExpectType) -> Self {
Self { Self {
status: StatusCode::BAD_REQUEST, status: StatusCode::BAD_REQUEST,
message: Cow::Borrowed("Not a media file"), message: Cow::Borrowed(match expect {
ExpectType::Media => "Not a media",
ExpectType::Markup => "Not a markup",
ExpectType::Json => "Not a JSON",
}),
}
}
/// Feature is not enabled
#[must_use]
pub const fn feature_disabled(msg: &'static str) -> Self {
Self {
status: StatusCode::NOT_IMPLEMENTED,
message: Cow::Borrowed(msg),
}
}
/// Unsupported encoding
#[cfg(feature = "url-summary")]
#[must_use]
pub const fn unsupported_encoding() -> Self {
Self {
status: StatusCode::NOT_IMPLEMENTED,
message: Cow::Borrowed("Unsupported encoding received"),
} }
} }
} }
@ -800,7 +862,7 @@ impl IntoResponse for ErrorResponse {
/// Application state /// Application state
#[allow(unused)] #[allow(unused)]
pub struct AppState<C: UpstreamClient, S: Sandboxing> { pub struct AppState<C: UpstreamClient, S: Sandboxing + Send + Sync + 'static> {
#[cfg(feature = "governor")] #[cfg(feature = "governor")]
limiters: Box< limiters: Box<
[( [(
@ -871,12 +933,21 @@ pub fn register_cancel_handler() {
} }
} }
/// An [`axum::response::IntoResponse`] that has no valid value
pub enum NeverResponse {}
impl IntoResponse for NeverResponse {
fn into_response(self) -> axum::response::Response {
unreachable!()
}
}
#[cfg(any(feature = "cf-worker", feature = "reqwest"))] #[cfg(any(feature = "cf-worker", feature = "reqwest"))]
#[allow(clippy::unused_async)] #[allow(clippy::unused_async)]
impl<C: UpstreamClient + 'static, S: Sandboxing + Send + Sync + 'static> App<C, S> { impl<C: UpstreamClient + 'static, S: Sandboxing + Send + Sync + 'static> App<C, S> {
/// Root endpoint /// Root endpoint
#[cfg_attr(feature = "cf-worker", worker::send)] #[cfg_attr(feature = "cf-worker", worker::send)]
pub async fn index(State(state): State<Arc<AppState<Upstream, S>>>) -> Response { pub async fn index(State(state): State<Arc<AppState<C, S>>>) -> Response {
match &state.clone().config.index_redirect { match &state.clone().config.index_redirect {
&IndexConfig::Redirect { &IndexConfig::Redirect {
ref permanent, ref permanent,
@ -892,11 +963,200 @@ impl<C: UpstreamClient + 'static, S: Sandboxing + Send + Sync + 'static> App<C,
} }
} }
#[cfg(not(feature = "url-summary"))]
#[cfg_attr(feature = "cf-worker", worker::send)]
async fn url_summary() -> Result<NeverResponse, ErrorResponse> {
Err(ErrorResponse::feature_disabled(
"URL summary feature is not enabled, build with `--features url-summary`",
))
}
#[cfg(feature = "url-summary")]
#[cfg_attr(feature = "cf-worker", worker::send)]
async fn url_summary(
State(state): State<Arc<AppState<C, S>>>,
Query(query): Query<url_summary::SummaryQuery>,
axum::extract::Host(host): axum::extract::Host,
info: IncomingInfo,
) -> Result<impl IntoResponse, ErrorResponse>
where
<<C as UpstreamClient>::Response as HTTPResponse>::BodyStream: Unpin,
{
use axum::http::{HeaderName, HeaderValue};
use futures::{TryFutureExt, TryStreamExt};
use html5ever::tendril::{Tendril, TendrilSink};
use stream::LimitedStream;
use url::Url;
use url_summary::{summaly, Extractor, ReadOnlyScraperHtml, SummaryResponse};
let original_url = query.url.parse().map_err(|_| ErrorResponse::bad_url())?;
let resp = state
.client
.request_upstream(
&info,
ExpectType::Markup,
&query.url,
false,
true,
DEFAULT_MAX_REDIRECTS,
)
.await?;
let final_url = resp
.request()
.url
.parse()
.map_err(|_| ErrorResponse::bad_url())?;
let resp_header = [(
HeaderName::from_static("x-forwarded-proto"),
HeaderValue::from_static(if resp.request().secure {
"https"
} else {
"http"
}),
)];
// we can guarantee here that we have exclusive access to all the node references
// so we will force Send on the parser
struct SendCell<T>(T);
#[allow(unsafe_code)]
unsafe impl<T> Send for SendCell<T> {}
#[cfg(not(feature = "cf-worker"))]
let dom = {
use futures::StreamExt;
use std::sync::atomic::AtomicBool;
let stop_signal = AtomicBool::new(false);
let should_stop = || stop_signal.load(std::sync::atomic::Ordering::Acquire);
let set_stop = || stop_signal.store(true, std::sync::atomic::Ordering::Release);
let parser = LimitedStream::new(
resp.body().map_ok(|b| {
let mut x = vec![0; 4];
x.extend::<Vec<_>>(b.into());
x
}),
3 << 20,
)
.map_ok(Some)
.take_while(|r| {
let res = if let Err(None) = r {
false
} else {
!should_stop()
};
futures::future::ready(res)
})
.map_ok(|r| r.unwrap())
.map_err(Option::unwrap)
.try_fold(
(SendCell(url_summary::parser(set_stop)), [0u8; 4], 0usize),
|(mut parser, mut resid, resid_len), mut chunk| async move {
chunk[..4].copy_from_slice(&resid); // copy the padding from the last chunk
let start_offset = 4 - resid_len; // start offset for the valid bytes in the chunk
let resid_len = match std::str::from_utf8(&chunk[start_offset..]) {
Ok(s) => {
parser.0.process(
#[allow(unsafe_code)]
unsafe {
Tendril::from_byte_slice_without_validating(s.as_bytes())
},
);
0
}
Err(e) => {
if e.error_len().is_some() {
// if we cannot decode a character in the middle
return Err(ErrorResponse::unsupported_encoding());
}
let valid_len = e.valid_up_to();
if valid_len == 0 {
// if we received less than 1 whole character
return Err(ErrorResponse::unsupported_encoding());
}
parser.0.process(
#[allow(unsafe_code)]
unsafe {
Tendril::from_byte_slice_without_validating(
&chunk[start_offset..(start_offset + valid_len)],
)
},
);
// compute how many bytes are left in the chunk
chunk.len() - valid_len - start_offset
}
};
// fair yield
tokio::task::yield_now().await;
// this is guaranteed to be inbounds as we already padded 4 bytes
resid[..4].copy_from_slice(&chunk[chunk.len() - 4..]);
Ok((parser, resid, resid_len))
},
)
.await?
.0;
parser.0.finish()
};
#[cfg(feature = "cf-worker")]
let dom = {
let parser = LimitedStream::new(resp.body().map_ok(|b| b.into()), 3 << 20)
.or_else(|e| async {
match e {
None => Ok(Vec::new()),
Some(e) => Err(e),
}
})
.try_fold(
SendCell(url_summary::parser(|| {})),
|mut parser, chunk| async move {
parser.0.process(Tendril::from_slice(chunk.as_slice()));
Ok(parser)
},
)
.await?;
parser.0.finish()
};
let proxy_url: Url = format!("https://{}/proxy/", host)
.parse()
.map_err(|_| ErrorResponse::bad_host(&host))?;
let resp = summaly::extractor(&info, &state.client)
.try_extract(
None,
&original_url,
&final_url,
&ReadOnlyScraperHtml::from(dom),
)
.map_err(|_| ErrorResponse::unsupported_media())
.await?
.transform_assets(|url| {
let mut ret = proxy_url.clone();
ret.query_pairs_mut().append_pair("url", url.as_str());
ret
});
Ok((resp_header, Json(resp)))
}
#[cfg_attr(feature = "cf-worker", worker::send)] #[cfg_attr(feature = "cf-worker", worker::send)]
async fn proxy_impl<'a>( async fn proxy_impl<'a>(
method: http::Method, method: http::Method,
filename: Option<&str>, filename: Option<&str>,
State(state): State<Arc<AppState<Upstream, S>>>, State(state): State<Arc<AppState<C, S>>>,
Query(query): Query<ProxyQuery>, Query(query): Query<ProxyQuery>,
info: IncomingInfo, info: IncomingInfo,
) -> Result<Response, ErrorResponse> ) -> Result<Response, ErrorResponse>
@ -934,7 +1194,14 @@ impl<C: UpstreamClient + 'static, S: Sandboxing + Send + Sync + 'static> App<C,
let resp = state let resp = state
.client .client
.request_upstream(&info, &query.url, false, true, DEFAULT_MAX_REDIRECTS) .request_upstream(
&info,
ExpectType::Media,
&query.url,
false,
true,
DEFAULT_MAX_REDIRECTS,
)
.await?; .await?;
let media = Box::pin(MediaResponse::from_upstream_response::<S>( let media = Box::pin(MediaResponse::from_upstream_response::<S>(
@ -961,7 +1228,7 @@ impl<C: UpstreamClient + 'static, S: Sandboxing + Send + Sync + 'static> App<C,
pub async fn proxy_without_filename( pub async fn proxy_without_filename(
method: http::Method, method: http::Method,
Query(query): Query<ProxyQuery>, Query(query): Query<ProxyQuery>,
State(state): State<Arc<AppState<Upstream, S>>>, State(state): State<Arc<AppState<C, S>>>,
info: IncomingInfo, info: IncomingInfo,
) -> Result<Response, ErrorResponse> ) -> Result<Response, ErrorResponse>
where where
@ -987,7 +1254,7 @@ impl<C: UpstreamClient + 'static, S: Sandboxing + Send + Sync + 'static> App<C,
pub async fn proxy_with_filename( pub async fn proxy_with_filename(
method: http::Method, method: http::Method,
Path(filename): Path<String>, Path(filename): Path<String>,
State(state): State<Arc<AppState<Upstream, S>>>, State(state): State<Arc<AppState<C, S>>>,
Query(query): Query<ProxyQuery>, Query(query): Query<ProxyQuery>,
info: IncomingInfo, info: IncomingInfo,
) -> Result<Response, ErrorResponse> ) -> Result<Response, ErrorResponse>

View file

@ -15,7 +15,7 @@ use tokio::sync::mpsc;
use tower_service::Service; use tower_service::Service;
use yumechi_no_kuni_proxy_worker::{ use yumechi_no_kuni_proxy_worker::{
config::{Config, SandboxConfig}, config::{Config, SandboxConfig},
router, default_state, router,
sandbox::NoSandbox, sandbox::NoSandbox,
Upstream, Upstream,
}; };
@ -84,9 +84,9 @@ fn main() {
if label.is_empty() || label == "unconfined" { if label.is_empty() || label == "unconfined" {
panic!("Refusing to start in unconfined AppArmor profile when AppArmor is enabled"); panic!("Refusing to start in unconfined AppArmor profile when AppArmor is enabled");
} }
router::<Upstream, apparmor::AppArmorHat>(config) router::<Upstream, apparmor::AppArmorHat>(default_state(config))
} }
SandboxConfig::NoSandbox => router::<Upstream, NoSandbox>(config), SandboxConfig::NoSandbox => router::<Upstream, NoSandbox>(default_state(config)),
_ => panic!("Unsupported sandbox configuration, did you forget to enable the feature?"), _ => panic!("Unsupported sandbox configuration, did you forget to enable the feature?"),
}; };

View file

@ -119,6 +119,7 @@ impl<R: HTTPResponse> SniffingStream<R> {
.iter() .iter()
.any(|sig| sig.matches(self.sniff_buffer.get_ref())) .any(|sig| sig.matches(self.sniff_buffer.get_ref()))
}); });
let mut all_safe = true; let mut all_safe = true;
let mut best_match = None; let mut best_match = None;

View file

@ -19,6 +19,7 @@ pub(crate) mod pthread {
/// Run a function with explicit immediate cancellation /// Run a function with explicit immediate cancellation
#[allow(unsafe_code)] #[allow(unsafe_code)]
#[cfg(target_os = "linux")]
pub(crate) fn pthread_cancelable<F: FnOnce() -> R, R>(f: F) -> R { pub(crate) fn pthread_cancelable<F: FnOnce() -> R, R>(f: F) -> R {
unsafe { unsafe {
let mut oldstate = 0; let mut oldstate = 0;

View file

@ -9,7 +9,7 @@ type Bytes = Vec<u8>;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
/// A stream that limits the amount of data it can receive /// A stream that limits the amount of data it can receive
pub struct LimitedStream<T, E> pub struct LimitedStream<T, E: Send>
where where
T: Stream<Item = Result<Bytes, E>> + 'static, T: Stream<Item = Result<Bytes, E>> + 'static,
{ {
@ -18,7 +18,7 @@ where
limit: AtomicUsize, limit: AtomicUsize,
} }
impl<T, E> LimitedStream<T, E> impl<T, E: Send> LimitedStream<T, E>
where where
T: Stream<Item = Result<Bytes, E>> + 'static, T: Stream<Item = Result<Bytes, E>> + 'static,
{ {
@ -32,7 +32,7 @@ where
} }
} }
impl<T, E> Stream for LimitedStream<T, E> impl<T, E: Send> Stream for LimitedStream<T, E>
where where
T: Stream<Item = Result<Bytes, E>> + 'static, T: Stream<Item = Result<Bytes, E>> + 'static,
{ {

View file

@ -0,0 +1,205 @@
#![allow(unused)]
use futures::TryFutureExt;
use super::{Extractor, ExtractorError, ReadOnlyScraperHtml, SummaryResponse};
/// A combinator that maps the error of an extractor.
pub(crate) struct ExtractorMapError<
R: SummaryResponse,
A: Extractor<R>,
F: Fn(A::Error) -> E + Send,
E,
> {
a: A,
f: F,
_r: std::marker::PhantomData<R>,
}
impl<R: SummaryResponse, A: Extractor<R>, F: Fn(A::Error) -> E + Send, E>
ExtractorMapError<R, A, F, E>
{
pub fn new(a: A, f: F) -> Self {
Self {
a,
f,
_r: std::marker::PhantomData,
}
}
}
impl<R: SummaryResponse, A: Extractor<R>, F: Fn(A::Error) -> E + Send + Sync, E> Extractor<R>
for ExtractorMapError<R, A, F, E>
where
E: std::error::Error + Send + 'static,
{
type Error = E;
async fn try_extract(
&mut self,
prev: Option<R>,
original_url: &url::Url,
final_url: &url::Url,
dom: &ReadOnlyScraperHtml,
) -> Result<R, ExtractorError<Self::Error>> {
self.a
.try_extract(prev, original_url, final_url, dom)
.map_err(|err| match err {
ExtractorError::Unsupported => ExtractorError::Unsupported,
ExtractorError::InternalError(err) => ExtractorError::InternalError((self.f)(err)),
})
.await
}
}
/// A combinator that steers the extractor to another extractor based on a predicate.
pub(crate) struct ExtractorSteer<
R: SummaryResponse + serde::Serialize,
E: Send,
A,
B,
F: Fn(&url::Url, &ReadOnlyScraperHtml) -> bool,
> {
a: A,
b: B,
f: F,
_r: std::marker::PhantomData<R>,
_e: std::marker::PhantomData<E>,
}
impl<R: SummaryResponse + serde::Serialize, E: Send, A, B, F> ExtractorSteer<R, E, A, B, F>
where
F: Fn(&url::Url, &ReadOnlyScraperHtml) -> bool + Send,
{
pub fn new(a: A, b: B, f: F) -> Self {
Self {
a,
b,
f,
_r: std::marker::PhantomData,
_e: std::marker::PhantomData,
}
}
}
impl<
R: SummaryResponse,
E: Send + std::error::Error + 'static,
A: Extractor<R, Error = E>,
B: Extractor<R, Error = E>,
F,
> Extractor<R> for ExtractorSteer<R, E, A, B, F>
where
F: Fn(&url::Url, &ReadOnlyScraperHtml) -> bool + Send,
<B as Extractor<R>>::Error: From<<A as Extractor<R>>::Error>,
{
type Error = E;
async fn try_extract(
&mut self,
prev: Option<R>,
original_url: &url::Url,
final_url: &url::Url,
dom: &ReadOnlyScraperHtml,
) -> Result<R, ExtractorError<Self::Error>> {
let (a, b) = (&mut self.a, &mut self.b);
if (self.f)(final_url, &dom) {
b.try_extract(prev, original_url, final_url, dom).await
} else {
a.try_extract(prev, original_url, final_url, dom).await
}
}
}
/// A combinator that chains two extractors using an OR operation.
pub(crate) struct ExtractorOr<R: SummaryResponse, E: std::error::Error + Send + 'static, A, B> {
a: A,
b: B,
_r: std::marker::PhantomData<R>,
_e: std::marker::PhantomData<E>,
}
impl<R: SummaryResponse, E: std::error::Error + Send + 'static, A, B> ExtractorOr<R, E, A, B> {
pub fn new(a: A, b: B) -> Self {
Self {
a,
b,
_r: std::marker::PhantomData,
_e: std::marker::PhantomData,
}
}
}
impl<R: SummaryResponse, E: std::error::Error + Send + 'static, A, B> Extractor<R>
for ExtractorOr<R, E, A, B>
where
R: Clone,
A: Extractor<R, Error = E>,
B: Extractor<R, Error = E>,
<B as Extractor<R>>::Error: From<<A as Extractor<R>>::Error>,
{
type Error = E;
async fn try_extract(
&mut self,
prev: Option<R>,
original_url: &url::Url,
final_url: &url::Url,
dom: &ReadOnlyScraperHtml,
) -> Result<R, ExtractorError<Self::Error>> {
match self
.a
.try_extract(prev.clone(), original_url, final_url, dom)
.await
{
Ok(res) => Ok(res),
Err(ExtractorError::Unsupported) => {
self.b.try_extract(prev, original_url, final_url, dom).await
}
Err(ExtractorError::InternalError(err)) => Err(ExtractorError::InternalError(err)),
}
}
}
pub(crate) struct ExtractorThen<R: SummaryResponse, E: std::error::Error + Send + 'static, A, B> {
a: A,
b: B,
_r: std::marker::PhantomData<R>,
_e: std::marker::PhantomData<E>,
}
impl<R: SummaryResponse, E: std::error::Error + Send + 'static, A, B> ExtractorThen<R, E, A, B> {
pub fn new(a: A, b: B) -> Self {
Self {
a,
b,
_r: std::marker::PhantomData,
_e: std::marker::PhantomData,
}
}
}
impl<R: SummaryResponse, E: std::error::Error + Send + 'static, A, B> Extractor<R>
for ExtractorThen<R, E, A, B>
where
A: Extractor<R, Error = E>,
B: Extractor<R, Error = E>,
{
type Error = E;
async fn try_extract(
&mut self,
prev: Option<R>,
original_url: &url::Url,
final_url: &url::Url,
dom: &ReadOnlyScraperHtml,
) -> Result<R, ExtractorError<Self::Error>> {
let res = self
.a
.try_extract(prev, original_url, final_url, dom)
.await?;
self.b
.try_extract(Some(res), original_url, final_url, dom)
.await
}
}

392
src/url_summary/mod.rs Normal file
View file

@ -0,0 +1,392 @@
use std::{borrow::Cow, future::Future, ops::Deref};
use axum::{response::IntoResponse, Json};
use html5ever::{
driver, expanded_name,
interface::{ElementFlags, NextParserState, NodeOrText, QuirksMode, TreeSink},
local_name, namespace_url, ns,
tendril::StrTendril,
tokenizer::TokenizerOpts,
tree_builder::TreeBuilderOpts,
Attribute, Parser, QualName,
};
use scraper::{Html, HtmlTreeSink};
use thiserror::Error;
/// Combinator types for extractors.
pub mod combinator;
/// Extractors for summarizing URLs.
pub mod summaly;
#[cfg(feature = "lazy_static")]
#[macro_export]
/// Create a CSS selector
macro_rules! selector {
($content:literal) => {{
lazy_static::lazy_static! {
static ref SELECTOR: scraper::Selector = scraper::Selector::parse($content).unwrap();
}
std::ops::Deref::deref(&SELECTOR)
}};
}
#[cfg(not(feature = "lazy_static"))]
#[macro_export]
/// Create a CSS selector
macro_rules! selector {
($content:literal) => {
&scraper::Selector::parse($content).unwrap()
};
}
#[derive(Debug, Clone)]
/// A force Send wrapper for `scraper::Html` to share access after parsing.
pub struct ReadOnlyScraperHtml(scraper::Html);
impl AsRef<scraper::Html> for ReadOnlyScraperHtml {
fn as_ref(&self) -> &scraper::Html {
&self.0
}
}
impl From<scraper::Html> for ReadOnlyScraperHtml {
fn from(html: scraper::Html) -> Self {
Self(html)
}
}
#[allow(unsafe_code)]
unsafe impl Send for ReadOnlyScraperHtml {}
#[allow(unsafe_code)]
unsafe impl Sync for ReadOnlyScraperHtml {}
#[derive(Debug, Error)]
/// An error that occurs during extraction.
pub enum ExtractorError<E: Send + 'static> {
#[error("unsupported document")]
/// The document is not supported by the extractor.
Unsupported,
#[error("failed to extract summary: {0}")]
/// An internal error occurred during extraction.
InternalError(#[from] E),
}
/// A trait for extracting summary information from a URL.
pub trait Extractor<R: SummaryResponse>: Sized + Send {
/// The error type that the extractor may return.
type Error: std::error::Error + Send + 'static;
/// Extract summary information from a URL.
fn try_extract(
&mut self,
prev: Option<R>,
original_url: &url::Url,
final_url: &url::Url,
dom: &ReadOnlyScraperHtml,
) -> impl Future<Output = Result<R, ExtractorError<Self::Error>>> + Send;
/// Chain this extractor with another extractor.
fn or_else<B: Extractor<R, Error = Self::Error>>(
self,
b: B,
) -> impl Extractor<R, Error = Self::Error>
where
R: Clone,
{
combinator::ExtractorOr::new(self, b)
}
/// Steer the extractor to another extractor based on a predicate.
fn steer<B: Extractor<R, Error = Self::Error>, F>(
self,
b: B,
f: F,
) -> impl Extractor<R, Error = Self::Error>
where
F: Fn(&url::Url, &ReadOnlyScraperHtml) -> bool + Send,
{
combinator::ExtractorSteer::new(self, b, f)
}
/// Map the error of the extractor.
fn map_err<F, E>(self, f: F) -> impl Extractor<R, Error = E>
where
E: std::error::Error + Send + 'static,
F: Fn(Self::Error) -> E + Send + Sync,
{
combinator::ExtractorMapError::new(self, f)
}
/// Execute this extractor and then another extractor.
fn then<B: Extractor<R, Error = Self::Error>>(
self,
b: B,
) -> impl Extractor<R, Error = Self::Error>
where
R: Clone,
{
combinator::ExtractorThen::new(self, b)
}
}
/// A dummy extractor that always returns an error.
pub struct NilExtractor<R, E> {
_r: std::marker::PhantomData<R>,
_e: std::marker::PhantomData<E>,
}
impl<R, E> Default for NilExtractor<R, E> {
fn default() -> Self {
Self {
_r: std::marker::PhantomData,
_e: std::marker::PhantomData,
}
}
}
impl<R: SummaryResponse, E: std::error::Error + Send + 'static> Extractor<R>
for NilExtractor<R, E>
{
type Error = E;
fn try_extract(
&mut self,
_prev: Option<R>,
_original_url: &url::Url,
_final_url: &url::Url,
_dom: &ReadOnlyScraperHtml,
) -> impl Future<Output = Result<R, ExtractorError<Self::Error>>> + Send {
futures::future::ready(Err(ExtractorError::Unsupported))
}
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
/// A query for summarizing a URL.
pub struct SummaryQuery {
/// The URL to summarize.
pub url: String,
}
/// Trait for types that can be used as a summary response.
pub trait SummaryResponse: serde::Serialize + Send + Sync + 'static
where
Self: Sized,
Json<Self>: IntoResponse,
{
/// Transform assets URLs in the response.
fn transform_assets(self, f: impl FnMut(url::Url) -> url::Url) -> Self;
}
/// An HTML sink that notifies when a `body` element is encountered.
pub struct FilteringSinkWrapper<T: TreeSink<Handle = ego_tree::NodeId>, F: Fn()> {
signal: F,
sink: T,
}
impl<T: TreeSink<Handle = ego_tree::NodeId>, F: Fn()> FilteringSinkWrapper<T, F> {
/// Create a new `MetaOnlySinkWrapper` with the given sink and signal.
///
/// When the sink receives a `body` element, the signal will be called.
pub fn new(sink: T, signal: F) -> Self {
Self { signal, sink }
}
}
impl<T: TreeSink<Handle = ego_tree::NodeId>, F: Fn()> TreeSink for FilteringSinkWrapper<T, F>
where
for<'a> <T as TreeSink>::ElemName<'a>: Deref<Target = QualName>,
{
type Handle = T::Handle;
/// The overall result of parsing.
///
/// This should default to Self, but default associated types are not stable yet.
/// [rust-lang/rust#29661](https://github.com/rust-lang/rust/issues/29661)
type Output = T::Output;
type ElemName<'a> = T::ElemName<'a> where T: 'a , F: 'a;
fn finish(self) -> Self::Output {
self.sink.finish()
}
fn parse_error(&self, msg: Cow<'static, str>) {
self.sink.parse_error(msg)
}
fn get_document(&self) -> Self::Handle {
self.sink.get_document()
}
fn elem_name<'a>(&'a self, target: &'a Self::Handle) -> Self::ElemName<'a> {
self.sink.elem_name(target)
}
/// Create an element.
///
/// When creating a template element (`name.ns.expanded() == expanded_name!(html "template")`),
/// an associated document fragment called the "template contents" should
/// also be created. Later calls to self.get_template_contents() with that
/// given element return it.
/// See [the template element in the whatwg spec][whatwg template].
///
/// [whatwg template]: https://html.spec.whatwg.org/multipage/#the-template-element
fn create_element(
&self,
name: QualName,
attrs: Vec<Attribute>,
flags: ElementFlags,
) -> Self::Handle {
if name.expanded() == expanded_name!(html "body") || name.local == local_name!("body") {
(self.signal)();
};
let handle = self.sink.create_element(name, attrs, flags);
handle
}
/// Create a comment node.
fn create_comment(&self, _text: StrTendril) -> Self::Handle {
self.sink.create_comment(StrTendril::new())
}
/// Create a Processing Instruction node.
fn create_pi(&self, target: StrTendril, data: StrTendril) -> Self::Handle {
self.sink.create_pi(target, data)
}
/// The child node will not already have a parent.
fn append(&self, parent: &Self::Handle, child: NodeOrText<Self::Handle>) {
// get rid of some blocks of text that are not useful
if let NodeOrText::AppendText(_) = &child {
let name = self.elem_name(parent);
if name.deref().expanded() == expanded_name!(html "script")
|| name.deref().expanded() == expanded_name!(html "style")
|| name.deref().expanded() == expanded_name!(html "noscript")
|| name.deref().local == local_name!("script")
|| name.deref().local == local_name!("style")
|| name.deref().local == local_name!("noscript")
{
return;
}
}
self.sink.append(parent, child)
}
fn append_based_on_parent_node(
&self,
element: &Self::Handle,
prev_element: &Self::Handle,
child: NodeOrText<Self::Handle>,
) {
self.sink
.append_based_on_parent_node(element, prev_element, child)
}
/// Append a `DOCTYPE` element to the `Document` node.
fn append_doctype_to_document(
&self,
name: StrTendril,
public_id: StrTendril,
system_id: StrTendril,
) {
self.sink
.append_doctype_to_document(name, public_id, system_id)
}
/// Mark a HTML `<script>` as "already started".
fn mark_script_already_started(&self, node: &Self::Handle) {
self.sink.mark_script_already_started(node)
}
fn pop(&self, node: &Self::Handle) {
let name = self.elem_name(node);
if name.deref().expanded() == expanded_name!(html "head")
|| name.deref().local == local_name!("head")
{
(self.signal)();
}
self.sink.pop(node)
}
fn get_template_contents(&self, target: &Self::Handle) -> Self::Handle {
self.sink.get_template_contents(target)
}
fn same_node(&self, x: &Self::Handle, y: &Self::Handle) -> bool {
self.sink.same_node(x, y)
}
fn set_quirks_mode(&self, mode: QuirksMode) {
self.sink.set_quirks_mode(mode)
}
fn append_before_sibling(&self, sibling: &Self::Handle, new_node: NodeOrText<Self::Handle>) {
self.sink.append_before_sibling(sibling, new_node)
}
/// Add each attribute to the given element, if no attribute with that name
/// already exists. The tree builder promises this will never be called
/// with something else than an element.
fn add_attrs_if_missing(&self, target: &Self::Handle, attrs: Vec<Attribute>) {
self.sink.add_attrs_if_missing(target, attrs)
}
/// Associate the given form-associatable element with the form element
fn associate_with_form(
&self,
target: &Self::Handle,
form: &Self::Handle,
nodes: (&Self::Handle, Option<&Self::Handle>),
) {
self.sink.associate_with_form(target, form, nodes)
}
/// Detach the given node from its parent.
fn remove_from_parent(&self, target: &Self::Handle) {
self.sink.remove_from_parent(target)
}
/// Remove all the children from node and append them to new_parent.
fn reparent_children(&self, node: &Self::Handle, new_parent: &Self::Handle) {
self.sink.reparent_children(node, new_parent)
}
/// Returns true if the adjusted current node is an HTML integration point
/// and the token is a start tag.
fn is_mathml_annotation_xml_integration_point(&self, handle: &Self::Handle) -> bool {
self.sink.is_mathml_annotation_xml_integration_point(handle)
}
/// Called whenever the line number changes.
fn set_current_line(&self, line_number: u64) {
self.sink.set_current_line(line_number);
}
/// Indicate that a `script` element is complete.
fn complete_script(&self, node: &Self::Handle) -> NextParserState {
self.sink.complete_script(node)
}
}
/// Create an HTML parser with the default configuration.
pub fn parser<F: Fn()>(stop: F) -> Parser<FilteringSinkWrapper<HtmlTreeSink, F>> {
let opts = driver::ParseOpts {
tokenizer: TokenizerOpts {
exact_errors: false,
..Default::default()
},
tree_builder: TreeBuilderOpts {
exact_errors: false,
drop_doctype: false,
..Default::default()
},
..Default::default()
};
driver::parse_document(
FilteringSinkWrapper::new(HtmlTreeSink::new(Html::new_document()), stop),
opts,
)
}

492
src/url_summary/summaly.rs Normal file
View file

@ -0,0 +1,492 @@
use crate::{
fetch::{HTTPResponse, IncomingInfo, UpstreamClient},
selector,
stream::LimitedStream,
ErrorResponse,
};
use futures::{TryFutureExt, TryStreamExt};
use super::{Extractor, SummaryResponse};
const TWITTER_DOMAINS: [&str; 3] = ["twitter.com", "twimg.com", "x.com"];
/// Create a Summaly extractor that behaves similarly to the original implementation.
pub fn extractor<'a, C: UpstreamClient>(
incoming: &'a IncomingInfo,
client: &'a C,
) -> impl Extractor<SummalyFormatResponse> + 'a {
SummalyCommonExtractor
.steer(
SummalyWikipediaExtractor {
incoming: &incoming,
client,
},
|url, _| {
url.host_str().map_or(false, |host| {
host == "wikipedia.org" || host.ends_with(".wikipedia.org")
})
},
)
.steer(SummalyAmazonExtractor, |url, _| {
url.host_str().map_or(false, |host| {
[
".com", ".co.jp", ".co.uk", ".de", ".fr", ".it", ".es", ".nl", ".cn", ".in",
".au",
]
.iter()
.any(|d| {
host.strip_suffix(d)
.map_or(false, |s| s.ends_with(".amazon"))
})
})
})
}
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
/// A response from Summaly.
pub struct SummalyFormatResponse {
_meta: SummalyFormatResponseMeta,
/// The title of the URL.
pub title: Option<String>,
/// The description of the URL.
pub description: Option<String>,
/// The icon of the URL.
pub icon: Option<url::Url>,
/// The thumbnail of the URL.
pub thumbnail: Option<url::Url>,
/// The player of the URL.
pub player: Option<SummalyFormatResponsePlayer>,
/// Whether the URL is sensitive.
pub sensitive: bool,
/// The site name .
pub sitename: String,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
/// Metadata for a Summaly response.
pub struct SummalyFormatResponseMeta {
/// The original URL.
pub original_url: url::Url,
/// The final URL.
pub final_url: url::Url,
/// Whether the common extractor was used.
pub common_extractor_used: bool,
/// Whether the Wikipedia extractor was used.
pub wikipedia_extractor_used: bool,
/// Whether the Amazon extractor was used.
pub amazon_extractor_used: bool,
}
impl SummaryResponse for SummalyFormatResponse {
fn transform_assets(mut self, mut f: impl FnMut(url::Url) -> url::Url) -> Self {
self.icon = self.icon.as_ref().map(|u| f(u.clone()));
self.thumbnail = self.thumbnail.as_ref().map(|u| f(u.clone()));
self
}
}
impl SummalyFormatResponse {
/// Create a new Summaly response.
pub fn new(original_url: url::Url, final_url: url::Url) -> Self {
Self {
title: None,
sitename: final_url.host_str().unwrap_or_default().to_string(),
_meta: SummalyFormatResponseMeta {
original_url,
final_url,
common_extractor_used: false,
wikipedia_extractor_used: false,
amazon_extractor_used: false,
},
description: None,
icon: None,
thumbnail: None,
player: None,
sensitive: false,
}
}
/// Set the sensitivity of the URL.
pub fn set_sensitive(&mut self, new: bool) {
self.sensitive = new;
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// A player for a video.
pub struct SummalyFormatResponsePlayer {
url: url::Url,
width: Option<u32>,
height: Option<u32>,
}
impl SummalyFormatResponsePlayer {
/// Create a new player.
pub fn new(url: url::Url) -> Self {
Self {
url,
width: None,
height: None,
}
}
}
struct SummalyCommonExtractor;
impl Extractor<SummalyFormatResponse> for SummalyCommonExtractor {
type Error = std::convert::Infallible;
async fn try_extract(
&mut self,
prev: Option<SummalyFormatResponse>,
original_url: &url::Url,
final_url: &url::Url,
dom: &super::ReadOnlyScraperHtml,
) -> std::result::Result<SummalyFormatResponse, super::ExtractorError<Self::Error>> {
let mut response = prev
.unwrap_or_else(|| SummalyFormatResponse::new(original_url.clone(), final_url.clone()));
response._meta.common_extractor_used = true;
let dom = dom.as_ref();
let title = dom
.select(selector!(r#"title"#))
.next()
.and_then(|element| element.text().next())
.or_else(|| {
dom.select(selector!(r#"meta[property="og:title"]"#))
.next()
.and_then(|element| element.attr("content"))
})
.or_else(|| {
dom.select(selector!(r#"#title"#))
.next()
.and_then(|element| element.text().next())
})
.map(|s| s.trim().chars().take(100).collect());
let twitter_card = if TWITTER_DOMAINS
.iter()
.any(|d| final_url.domain().map_or(false, |u| u.ends_with(d)))
{
dom.select(selector!(r#"meta[name="twitter:card"]"#))
.next()
.and_then(|element| element.attr("content"))
} else {
None
};
let thumb = [
(
selector!(r#"meta[property="og:image:secure_url"]"#),
"content",
),
(selector!(r#"meta[property="og:image"]"#), "content"),
(selector!(r#"meta[name="twitter:image"]"#), "content"),
(selector!(r#"link[rel="image_src"]"#), "href"),
(selector!(r#"link[rel="apple-touch-icon"]"#), "href"),
(
selector!(r#"link[rel="apple-touch-icon image_src"]"#),
"href",
),
]
.iter()
.find_map(|(selector, attr)| {
dom.select(selector)
.next()
.and_then(|element| element.attr(attr))
});
let player_url = [
selector!(r#"meta[property="og:video:secure_url"]"#),
selector!(r#"meta[property="og:video:url"]"#),
selector!(r#"meta[property="og:video"]"#),
]
.iter()
.find_map(|selector| {
dom.select(selector)
.next()
.and_then(|element| element.attr("content"))
})
.or_else(|| {
if twitter_card.map_or(false, |c| c == "summary_large_image") {
[
selector!(r#"meta[name="twitter:player"]"#),
selector!(r#"meta[property="twitter:player"]"#),
]
.iter()
.find_map(|selector| {
dom.select(selector)
.next()
.and_then(|element| element.attr("content"))
})
} else {
None
}
});
let player_width = [
selector!(r#"meta[name="twitter:player:width"]"#),
selector!(r#"meta[property="twitter:player:width"]"#),
selector!(r#"meta[property="og:video:width"]"#),
]
.iter()
.find_map(|selector| {
dom.select(selector)
.next()
.and_then(|element| element.attr("content"))
.and_then(|width| width.parse().ok())
});
let player_height = [
selector!(r#"meta[name="twitter:player:height"]"#),
selector!(r#"meta[property="twitter:player:height"]"#),
selector!(r#"meta[property="og:video:height"]"#),
]
.iter()
.find_map(|selector| {
dom.select(selector)
.next()
.and_then(|element| element.attr("content"))
.and_then(|height| height.parse().ok())
});
let description = [
selector!(r#"meta[name="description"]"#),
selector!(r#"meta[property="og:description"]"#),
selector!(r#"meta[name="twitter:description"]"#),
]
.iter()
.find_map(|selector| {
dom.select(selector).next().and_then(|element| {
element.attr("content").map(|s| {
s.chars()
.into_iter()
.filter(|c| !c.is_control())
.take(300)
.collect::<String>()
})
})
});
let site_name = [
selector!(r#"meta[property="og:site_name"]"#),
selector!(r#"meta[name="application-name"]"#),
selector!(r#"meta[name="twitter:site"]"#),
selector!(r#"meta[name="twitter:site:id"]"#),
]
.iter()
.find_map(|selector| {
dom.select(selector).next().and_then(|element| {
element.attr("content").map(|s| {
s.chars()
.into_iter()
.filter(|c| !c.is_control())
.take(100)
.collect::<String>()
})
})
});
let favicon = [
selector!(r#"link[rel="icon"]"#),
selector!(r#"link[rel="shortcut icon"]"#),
selector!(r#"link[rel="apple-touch-icon"]"#),
selector!(r#"link[rel="apple-touch-icon-precomposed"]"#),
]
.iter()
.find_map(|selector| {
dom.select(selector)
.next()
.and_then(|element| element.attr("href"))
})
.unwrap_or("/favicon.ico");
response.title = title;
response.description = description;
response.icon = final_url.join(favicon).ok();
response.thumbnail = thumb.and_then(|rel| final_url.join(rel).ok());
player_url
.and_then(|player_url| final_url.join(player_url).ok())
.map(|url| {
response.player = Some(SummalyFormatResponsePlayer {
url,
width: player_width,
height: player_height,
})
});
response.sitename =
site_name.unwrap_or_else(|| final_url.host_str().unwrap_or_default().to_string());
Ok(response)
}
}
struct SummalyAmazonExtractor;
impl Extractor<SummalyFormatResponse> for SummalyAmazonExtractor {
type Error = std::convert::Infallible;
async fn try_extract(
&mut self,
prev: Option<SummalyFormatResponse>,
original_url: &url::Url,
final_url: &url::Url,
dom: &super::ReadOnlyScraperHtml,
) -> std::result::Result<SummalyFormatResponse, super::ExtractorError<Self::Error>> {
let mut response = prev
.unwrap_or_else(|| SummalyFormatResponse::new(original_url.clone(), final_url.clone()));
response._meta.amazon_extractor_used = true;
let dom = dom.as_ref();
let description = dom
.select(selector!(r#"#productDescription"#))
.next()
.and_then(|element| {
element
.text()
.next()
.map(|s| s.trim().chars().take(150).collect::<String>())
})
.or_else(|| {
dom.select(selector!(r#"meta[name="description"]"#))
.next()
.and_then(|element| element.attr("content"))
.map(|s| s.trim().chars().take(150).collect::<String>())
});
let thumbnail = dom
.select(selector!(r#"img#landingImage"#))
.next()
.and_then(|element| element.attr("src"))
.and_then(|rel| final_url.join(rel).ok());
let player_url = dom
.select(selector!(r#"iframe#productVideoSource"#))
.next()
.and_then(|element| element.attr("src"))
.and_then(|rel| final_url.join(rel).ok())
.or_else(|| {
dom.select(selector!(r#"meta[property="twitter:player"]"#))
.next()
.and_then(|element| element.attr("content"))
.and_then(|rel| final_url.join(rel).ok())
});
let player_width = dom
.select(selector!(r#"meta[property="twitter:player:width"]"#))
.next()
.and_then(|element| element.attr("content"))
.and_then(|s| s.parse().ok());
let player_height = dom
.select(selector!(r#"meta[property="twitter:player:height"]"#))
.next()
.and_then(|element| element.attr("content"))
.and_then(|s| s.parse().ok());
let site_name = dom
.select(selector!(r#"meta[property="og:site_name"]"#))
.next()
.and_then(|element| element.attr("content"))
.map(|s| s.chars().take(100).collect::<String>());
response.icon = "https://www.amazon.com/favicon.ico".parse().ok();
response.description = description;
response.thumbnail = thumbnail;
player_url.map(|url| {
response.player = Some(SummalyFormatResponsePlayer {
url,
width: player_width,
height: player_height,
})
});
response.sitename =
site_name.unwrap_or_else(|| final_url.host_str().unwrap_or_default().to_string());
Ok(response)
}
}
struct SummalyWikipediaExtractor<'a, C: UpstreamClient> {
incoming: &'a IncomingInfo,
client: &'a C,
}
impl<'a, C: UpstreamClient> Extractor<SummalyFormatResponse> for SummalyWikipediaExtractor<'a, C> {
type Error = std::convert::Infallible;
async fn try_extract(
&mut self,
prev: Option<SummalyFormatResponse>,
original_url: &url::Url,
final_url: &url::Url,
_dom: &super::ReadOnlyScraperHtml,
) -> std::result::Result<SummalyFormatResponse, super::ExtractorError<Self::Error>> {
let mut response = prev
.unwrap_or_else(|| SummalyFormatResponse::new(original_url.clone(), final_url.clone()));
response._meta.wikipedia_extractor_used = true;
let lang = final_url
.host_str()
.and_then(|host| host.split('.').next())
.filter(|s| s.chars().all(|c| c.is_ascii_alphabetic()))
.unwrap_or("en");
let title = final_url.path().rsplit('/').next().unwrap_or("Wikipedia");
let endpoint = format!("https://{}.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&exintro&explaintext&titles={}", lang, title);
#[derive(serde::Deserialize)]
struct WikipediaPageResponse {
title: String,
extract: Option<String>,
}
let api = self
.client
.request_upstream(
&self.incoming,
crate::fetch::ExpectType::Json,
&endpoint,
false,
true,
2,
)
.and_then(|r| async move {
LimitedStream::new(r.body().map_ok(Into::into), 512 << 10)
.try_fold(Vec::new(), |mut v, chunk| async move {
v.extend_from_slice(&chunk);
Ok(v)
})
.await
.map_err(|e| e.unwrap_or(ErrorResponse::payload_too_large()))
})
.and_then(|v| async move {
serde_json::from_slice::<WikipediaPageResponse>(&v)
.map_err(|_| ErrorResponse::unsupported_encoding())
})
.await
.ok();
response.title = api.as_ref().map(|api| api.title.clone());
response.description =
api.and_then(|api| api.extract.map(|s| s.chars().take(300).collect()));
response.sitename = "Wikipedia".to_string();
response.icon = "https://wikipedia.org/static/favicon/wikipedia.ico"
.parse()
.ok()
.into();
response.player = None;
response.thumbnail = format!(
"https://wikipedia.org/static/images/project-logos/{}wiki.png",
lang
)
.parse()
.ok()
.into();
Ok(response)
}
}

View file

@ -3,7 +3,7 @@ main = "build/worker/shim.mjs"
compatibility_date = "2024-11-11" compatibility_date = "2024-11-11"
[build] [build]
command = "cargo install -q worker-build && worker-build --release --features cf-worker" command = "cargo install -q worker-build && worker-build --release --features \"cf-worker url-summary\""
[observability] [observability]
enabled = true enabled = true