Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-11-12 22:22:48 -06:00
commit 146e317849
No known key found for this signature in database
20 changed files with 6448 additions and 0 deletions

5
.gitignore vendored Normal file
View file

@ -0,0 +1,5 @@
target
node_modules
.wrangler
/bundled
/build

3
.gitmodules vendored Normal file
View file

@ -0,0 +1,3 @@
[submodule "submodules/file"]
path = submodules/file
url = https://github.com/file/file

6
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,6 @@
{
"rust-analyzer.cargo.features": [
//"cf-worker"
"env-local"
],
}

3286
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

66
Cargo.toml Normal file
View file

@ -0,0 +1,66 @@
[package]
name = "yumechi-no-kuni-proxy-worker"
version = "0.1.0"
edition = "2021"
authors = [ "eternal-flame-AD <yume@yumechi.jp>" ]
[package.metadata.release]
release = false
[lib]
crate-type = ["cdylib", "rlib"]
build = "build.rs"
[profile.release]
lto = true
strip = true
opt-level = "z"
codegen-units = 1
[profile.release-local]
inherits = "release"
opt-level = 3
strip = false
[features]
default = []
env-local = ["axum/tokio", "axum/http1", "axum/http2", "reqwest", "tokio", "env_logger", "governor", "clap", "toml", "image/rayon"]
cf-worker = ["dep:worker", "dep:worker-macros", "dep:console_error_panic_hook"]
apparmor = ["dep:rand_core", "dep:siphasher"]
reqwest = ["dep:reqwest"]
svg-text = ["resvg/text"]
tokio = ["dep:tokio"]
env_logger = ["dep:env_logger"]
governor = ["dep:governor"]
[dependencies]
worker = { version="0.4.2", features=['http', 'axum'], optional = true }
worker-macros = { version="0.4.2", features=['http'], optional = true }
axum = { version = "0.7", default-features = false, features = ["query", "json"] }
tower-service = "0.3.2"
console_error_panic_hook = { version = "0.1.1", optional = true }
serde = { version = "1.0.214", features = ["derive"] }
futures = "0.3.31"
image = { version = "0.25.5", default-features = false, features = ["avif", "bmp", "gif", "ico", "jpeg", "png", "webp"] }
reqwest = { version = "0.12.9", features = ["brotli", "gzip", "stream", "zstd"], optional = true }
rand_core = { version = "0.6.4", features = ["getrandom"], optional = true }
siphasher = { version = "1.0.1", optional = true }
tokio = { version = "1.41.1", features = ["rt", "rt-multi-thread", "macros"], optional = true }
clap = { version = "4.5.20", features = ["derive"], optional = true }
toml = { version = "0.8", optional = true }
log = "0.4.22"
env_logger = { version = "0.11.5", optional = true }
governor = { version = "0.7.0", features = ["dashmap"], optional = true }
resvg = { version = "0.44.0", default-features = false, features = ["gif", "image-webp"] }
thiserror = "2.0.3"
[build-dependencies]
chumsky = "0.9.3"
quote = "1.0.37"
serde = { version = "1.0.214", features = ["derive"] }
serde_json = "1.0.132"
[[bin]]
name = "yumechi-no-kuni-proxy-worker"
path = "src/main.rs"
required-features = ["env-local"]

73
LICENSE Normal file
View file

@ -0,0 +1,73 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
Copyright 2024 Yumechi <yume@yumechi.jp>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

38
README.md Normal file
View file

@ -0,0 +1,38 @@
# Yumechi-no-kuni-proxy-worker
This is a misskey proxy worker for ゆめちのくに (Yumechi-no-kuni) instance. Runs natively on both local and Cloudflare Workers environments!
Work in progress! Currently to do:
- [X] Content-Type sniffing
- [X] SVG rendering
- [ ] Font rendering (likely will not run on Cloudflare Workers Free plan)
- [X] Preset image resizing
- [X] Opportunistic Redirection on large video files
- [X] RFC9110 compliant proxy loop detection with defensive programming against known vulnerable proxies
- [X] HTTPs only mode and X-Forwarded-Proto reflection
- [X] Cache-Control header
- [X] Rate-limiting on local deployment (untested)
## Demo
### Avatar resizing
Preview at:
[https://yumechi-no-kuni-proxy-worker.eternal-flame-ad.workers.dev/proxy/avatar.webp?url=https://media.misskeyusercontent.com/io/274cc4f7-4674-4db1-9439-9fac08a66aa1.png](https://yumechi-no-kuni-proxy-worker.eternal-flame-ad.workers.dev/proxy/avatar.webp?url=https://media.misskeyusercontent.com/io/274cc4f7-4674-4db1-9439-9fac08a66aa1.png)
Image:
![Syuilo Avatar resived.png](https://yumechi-no-kuni-proxy-worker.eternal-flame-ad.workers.dev/proxy/avatar.webp?url=https://media.misskeyusercontent.com/io/274cc4f7-4674-4db1-9439-9fac08a66aa1.png)
### SVG rendering
(font rendering disabled due to size restrictions)
[https://yumechi-no-kuni-proxy-worker.eternal-flame-ad.workers.dev/proxy/static.webp?url=https://upload.wikimedia.org/wikipedia/commons/a/ad/AES-AddRoundKey.svg](https://yumechi-no-kuni-proxy-worker.eternal-flame-ad.workers.dev/proxy/static.webp?url=https://upload.wikimedia.org/wikipedia/commons/a/ad/AES-AddRoundKey.svg)
![AES-AddRoundKey.svg](https://yumechi-no-kuni-proxy-worker.eternal-flame-ad.workers.dev/proxy/static.webp?url=https://upload.wikimedia.org/wikipedia/commons/a/ad/AES-AddRoundKey.svg)

556
build.rs Normal file
View file

@ -0,0 +1,556 @@
use std::{env, ops::Range, vec};
use chumsky::prelude::*;
use quote::ToTokens;
use text::whitespace;
const FILES_TO_PARSE: &[&str] = &[
"jpeg",
"images",
"sgml",
"riff",
"animation",
"audio",
"matroska",
"vorbis",
"audio",
"msdos",
"webassembly",
"elf",
"mach",
];
const BLACKLISTED: &[&str] = &["fuji-", "canon-", "corel-", "dicom", "garmin"];
// remove extension entries not in this list from safe entries
const SAFE_EXTENSIONS: &[&str] = &[
".png", ".gif", ".jpeg", ".webp", ".avif", ".apng", ".bmp", ".tiff", ".x-icon", ".opus",
".ogg", ".mp4", ".m4v", ".3gpp", ".mpeg", ".webm", ".aac", ".flac", ".wav",
];
const SAFE_WHITELISTED: &[&str] = &[
"png", "gif", "jpeg", "webp", "avif", "apng", "bmp", "tiff", "x-icon", "opus", "ogg", "mp4",
"m4v", "3gpp", "mpeg", "webm", "aac", "flac", "wav", "svg", "rss",
];
// we want to have signatures for these to be able to detect them
const UNSAFE_WHITELISTED: &[&str] = &[
".exe",
".wasm",
"elf",
"mach",
"javascript",
"bios",
"firmware",
"driver",
"mpegurl",
];
fn static_signatures() -> Vec<MIMEAssociation> {
vec![
MIMEAssociation {
mime: "application/x-elf-executable".to_string().into(),
ext: vec![".pie".to_string(), ".elf".to_string(), ".so".to_string()],
safe: false,
signatures: vec![FlattenedFileSignature {
test: vec![0x7f, b'E', b'L', b'F'],
mask: vec![0xff, 0xff, 0xff, 0xff],
}],
},
MIMEAssociation {
mime: "application/x-mach-binary".to_string().into(),
ext: vec![".dylib".to_string(), ".bundle".to_string()],
safe: false,
signatures: vec![FlattenedFileSignature {
test: vec![0xfe, 0xed, 0xfa, 0xce],
mask: vec![0xff, 0xff, 0xff, 0xff],
}],
},
MIMEAssociation {
mime: "application/vnd.microsoft.portable-executable"
.to_string()
.into(),
ext: vec![".exe".to_string(), ".dll".to_string(), ".sys".to_string()],
safe: false,
signatures: vec![FlattenedFileSignature {
test: b"PE\0\0".to_vec(),
mask: vec![0xff, 0xff, 0xff, 0xff],
}],
},
]
}
#[derive(Debug, Clone)]
pub enum MagicFileLine {
Nop,
Unknown,
Magic {
indent: u8,
offset: u64,
ty: MagicType,
},
AssignAttr {
attr: String,
value: String,
},
}
#[derive(Debug, Clone)]
pub enum MagicType {
Unknown(String),
Belong {
test: Vec<u8>,
mask: Option<Vec<u8>>,
},
String {
test: Vec<u8>,
},
}
pub fn parse_string_repr() -> impl Parser<char, Vec<u8>, Error = Simple<char>> {
just('\\')
.ignore_then(choice((
just('\\').to(b'\\'),
just('n').to(b'\n'),
just('r').to(b'\r'),
just('t').to(b'\t'),
just('x').ignore_then(
one_of("0123456789abcdefABCDEF")
.repeated()
.exactly(2)
.map(|s| u8::from_str_radix(&s.iter().collect::<String>(), 16).unwrap()),
),
)))
.or(none_of("\\").map(|c| c as u8))
.repeated()
.at_least(1)
.map(|s| s.to_vec())
.then_ignore(end())
}
pub fn parse_hex_repr() -> impl Parser<char, Vec<u8>, Error = Simple<char>> {
just("0x")
.ignore_then(
one_of("0123456789abcdef")
.repeated()
.exactly(2)
.map(|s| u8::from_str_radix(&s.iter().collect::<String>(), 16).unwrap())
.repeated()
.at_least(1),
)
.map(|s| s.to_vec())
.then_ignore(end())
}
pub fn parse_magic_line() -> impl Parser<char, MagicFileLine, Error = Simple<char>> {
choice((
just('#')
.then_ignore(any().repeated())
.to(MagicFileLine::Nop),
just('>')
.repeated()
.map(|i| i.len() as u8)
.then(
one_of("0123456789")
.repeated()
.at_least(1)
.try_map(|s, span| {
s.iter()
.collect::<String>()
.parse::<u64>()
.map_err(|_| Simple::custom(span, "Failed to parse number"))
})
.or(just("0x").ignore_then(
one_of("0123456789abcdefABCDEF")
.repeated()
.at_least(1)
.try_map(|s, span| {
u64::from_str_radix(&s.iter().collect::<String>(), 16)
.map_err(|_| Simple::custom(span, "Failed to parse number"))
}),
)),
)
.then_ignore(whitespace().at_least(1))
.then(
none_of(" \t\n")
.repeated()
.at_least(1)
.map(String::from_iter),
)
.then_ignore(whitespace().at_least(1))
.then(
none_of(" \t\n")
.repeated()
.at_least(1)
.map(String::from_iter),
)
.try_map(|(((indent, offset), ty), rep), span: Range<usize>| {
Ok(MagicFileLine::Magic {
indent,
offset,
ty: match ty.as_str() {
"string" => MagicType::String {
test: parse_string_repr().parse(rep).map_err(|_| {
Simple::custom(span, "Failed to parse string pattern")
})?,
},
"belong" => MagicType::Belong {
test: parse_hex_repr()
.parse(rep)
.map_err(|_| Simple::custom(span, "Failed to parse hex pattern"))?,
mask: None,
},
s if s.starts_with("belong&") => {
let mask = &s["belong&".len()..];
let span_clone = span.clone();
MagicType::Belong {
test: parse_hex_repr().parse(rep).map_err(|_| {
Simple::custom(span, "Failed to parse hex pattern")
})?,
mask: Some(parse_hex_repr().parse(mask).map_err(|_| {
Simple::custom(span_clone, "Failed to parse hex pattern")
})?),
}
}
_ => MagicType::Unknown(ty),
},
})
})
.then_ignore(any().repeated()),
just("!:")
.ignore_then(
one_of("abcdefghijklmnopqrstuvwxyz")
.repeated()
.at_least(1)
.map(|s| s.iter().collect()),
)
.then_ignore(whitespace().at_least(1))
.then(any().repeated().map(String::from_iter))
.map(|(attr, value)| MagicFileLine::AssignAttr { attr, value }),
))
.then_ignore(whitespace())
.then_ignore(end())
}
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
pub struct FileSignature {
pub offset: u64,
pub test: Vec<u8>,
pub mask: Option<Vec<u8>>,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
pub struct FlattenedFileSignature {
pub test: Vec<u8>,
pub mask: Vec<u8>,
}
impl FlattenedFileSignature {
fn codegen(&self) -> impl ToTokens {
let data = self
.test
.iter()
.copied()
.zip(self.mask.iter().copied())
.map(|(t, m)| {
quote::quote! {
(#t, #m)
}
});
quote::quote! {
FlattenedFileSignature(&[#(#data),*])
}
}
}
impl From<FileSignature> for FlattenedFileSignature {
fn from(sig: FileSignature) -> Self {
let len = sig.test.len();
FlattenedFileSignature {
test: std::iter::repeat(0)
.take(sig.offset as usize)
.chain(sig.test)
.collect(),
mask: sig.mask.unwrap_or_else(|| {
std::iter::repeat(0)
.take(sig.offset as usize)
.chain(std::iter::repeat(!0).take(len))
.collect()
}),
}
}
}
impl std::ops::BitAnd<FlattenedFileSignature> for FlattenedFileSignature {
type Output = FlattenedFileSignature;
fn bitand(mut self, mut rhs: FlattenedFileSignature) -> Self::Output {
if self.test.len() < rhs.test.len() {
std::mem::swap(&mut self, &mut rhs);
}
let test = self
.test
.iter()
.zip(
rhs.test
.iter()
.chain(std::iter::repeat(&0).take(self.test.len() - rhs.test.len())),
)
.map(|(a, b)| a | b)
.collect();
let mask = self
.mask
.iter()
.zip(
rhs.mask
.iter()
.chain(std::iter::repeat(&0).take(self.test.len() - rhs.test.len())),
)
.map(|(a, b)| a | b)
.collect();
FlattenedFileSignature { test, mask }
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
pub struct MIMEAssociation {
pub mime: Option<String>,
pub ext: Vec<String>,
pub safe: bool,
pub signatures: Vec<FlattenedFileSignature>,
}
impl MIMEAssociation {
fn codegen(&self) -> impl ToTokens {
let mime = self.mime.as_deref().unwrap_or("");
let ext = self.ext.first().map(|s| s.as_str()).unwrap_or("");
let safe = self.safe;
let signatures = self.signatures.iter().map(|s| s.codegen());
quote::quote! {
MIMEAssociation {
mime: #mime,
ext: #ext,
safe: #safe,
signatures: &[#(#signatures),*],
}
}
}
fn build_vec(lines: Vec<MagicFileLine>) -> Vec<MIMEAssociation> {
let mut stack = Vec::new();
let mut out: Vec<MIMEAssociation> = Vec::new();
for line in lines {
match line {
MagicFileLine::Magic { ty, offset, indent } => match ty {
MagicType::Belong { test, mask } => {
stack.truncate(indent as usize);
stack.push(FileSignature { offset, test, mask });
}
MagicType::String { test } => {
stack.truncate(indent as usize);
stack.push(FileSignature {
offset,
test,
mask: None,
});
}
_ => {}
},
MagicFileLine::AssignAttr { attr, value } => match attr.as_str() {
"mime" if !stack.is_empty() => {
let mime = value;
let flattened = stack.iter().map(|sig| sig.clone().into()).fold(
FlattenedFileSignature {
test: Vec::new(),
mask: Vec::new(),
},
|a, b| a & b,
);
if flattened.test.len() > 64 {
eprintln!("Signature too long: {:?}", flattened.test.len());
continue;
}
if let Some(existing) = out
.iter_mut()
.find(|m| m.mime.as_deref().map(|m| m == mime).unwrap_or(false))
{
existing.signatures.push(flattened);
} else {
out.push(MIMEAssociation {
mime: Some(mime),
safe: false,
ext: vec![],
signatures: vec![flattened],
});
}
}
"ext" if !stack.is_empty() => {
let ext = value;
let flattened = stack.iter().map(|sig| sig.clone().into()).fold(
FlattenedFileSignature {
test: Vec::new(),
mask: Vec::new(),
},
|a, b| a & b,
);
if flattened.test.len() > 64 {
eprintln!("Signature too long: {:?}", flattened.test.len());
continue;
}
if let Some(existing) =
out.iter_mut().find(|m| m.signatures.contains(&flattened))
{
existing
.ext
.extend(ext.split('/').map(|s| format!(".{}", s)))
} else {
out.push(MIMEAssociation {
mime: None,
safe: false,
ext: ext.split('/').map(|s| format!(".{}", s)).collect(),
signatures: vec![flattened],
});
}
}
_ => {}
},
_ => {}
}
}
out.iter_mut().for_each(|m| {
m.ext.sort();
m.ext.dedup();
m.signatures.sort_by(|a, b| a.test.cmp(&b.test));
m.signatures.dedup();
});
out.dedup();
out
}
}
const BASE_DIR: &str = "submodules/file/magic/Magdir/";
fn main() {
let signatures = static_signatures()
.into_iter()
.chain(FILES_TO_PARSE.iter().flat_map(|file| {
println!("cargo:rerun-if-changed={}", file);
eprintln!("Using file: {}", file);
let path = format!("{}{}", BASE_DIR, file);
let content = std::fs::read(&path)
.map(|v| String::from_utf8_lossy(&v).to_string())
.unwrap();
let lines = content
.lines()
.filter(|line| !line.is_empty())
.map(|line| {
parse_magic_line().parse(line).unwrap_or_else(|e| {
eprintln!("Failed to parse line: {:?}", line);
eprintln!("Error: {:?}", e);
MagicFileLine::Unknown
})
})
.collect::<Vec<_>>();
MIMEAssociation::build_vec(lines)
.into_iter()
.map(|mut m| {
if m.mime
.as_ref()
.map(|m| UNSAFE_WHITELISTED.iter().any(|u| m.contains(u)))
.unwrap_or(false)
{
m.safe = false;
return m;
}
if m.ext
.iter()
.any(|ext| UNSAFE_WHITELISTED.iter().any(|u| ext.contains(u)))
{
m.safe = false;
return m;
}
if m.mime
.as_ref()
.map(|m| SAFE_WHITELISTED.iter().any(|w| m.contains(w)))
.unwrap_or(false)
{
m.safe = true;
}
if m.ext
.iter()
.any(|ext| SAFE_WHITELISTED.iter().any(|w| ext.contains(w)))
{
m.safe = true;
}
if m.safe {
m.ext
.retain(|ext| SAFE_EXTENSIONS.iter().any(|s| ext.contains(s)));
}
m
})
.filter(|m| {
if let Some(incoming) = &m.mime {
let mime = incoming.to_lowercase();
if BLACKLISTED.iter().any(|b| mime.contains(b)) {
return false;
}
if SAFE_WHITELISTED.iter().any(|w| mime.contains(w))
|| UNSAFE_WHITELISTED.iter().any(|u| mime.contains(u))
{
return true;
}
}
if m.ext
.iter()
.any(|ext| BLACKLISTED.iter().any(|b| ext.contains(b)))
{
return false;
}
if let Some(incoming) = &m.mime {
let mime = incoming.to_lowercase();
if SAFE_WHITELISTED.iter().all(|w| mime.contains(w))
|| UNSAFE_WHITELISTED.iter().any(|u| mime.contains(u))
{
return true;
}
}
if m.ext.iter().any(|ext| {
SAFE_WHITELISTED.iter().any(|w| ext.contains(w))
|| UNSAFE_WHITELISTED.iter().any(|u| ext.contains(u))
}) {
return true;
}
false
})
}))
.collect::<Vec<_>>();
let max_size = signatures
.iter()
.map(|s| s.signatures.iter().map(|s| s.test.len()).max().unwrap())
.max()
.unwrap();
if max_size > 128 {
panic!("Max signature size is too large: {}", max_size);
}
std::fs::write(env::var("OUT_DIR").unwrap() + "/magic.rs", {
let signatures = signatures.iter().map(|s| s.codegen());
quote::quote! {
/// Maximum size of a signature
pub const SNIFF_SIZE: usize = #max_size;
#[allow(clippy::all)]
const MAGICS: &[MIMEAssociation] = &[#(#signatures),*];
}
.into_token_stream()
.to_string()
})
.unwrap();
}

1
rust-toolchain Normal file
View file

@ -0,0 +1 @@
1.81

1
src/deliver/mod.rs Normal file
View file

@ -0,0 +1 @@

597
src/fetch/mod.rs Normal file
View file

@ -0,0 +1,597 @@
use crate::{ErrorResponse, FetchConfig, MAX_SIZE};
use axum::{
body::Bytes,
extract::FromRequestParts,
http::{request::Parts, HeaderMap},
};
use futures::stream::TryStreamExt;
use std::{borrow::Cow, collections::HashSet, convert::Infallible, pin::Pin};
/// Default maximum number of redirects to follow
pub const DEFAULT_MAX_REDIRECTS: usize = 6;
/// Some context about the request for writing response and logging
#[allow(missing_docs)]
pub struct RequestCtx<'a> {
pub url: &'a str,
pub secure: bool,
}
const fn http_version_to_via(v: axum::http::Version) -> &'static str {
match v {
axum::http::Version::HTTP_09 => "0.9",
axum::http::Version::HTTP_10 => "1.0",
axum::http::Version::HTTP_11 => "1.1",
axum::http::Version::HTTP_2 => "2.0",
axum::http::Version::HTTP_3 => "3.0",
_ => "1.1",
}
}
/// Trait for HTTP responses
pub trait HTTPResponse {
/// Type of the byte buffer
type Bytes: Into<Vec<u8>> + AsRef<[u8]> + Into<Bytes> + Send + 'static;
/// Type of body stream
type BodyStream: futures::Stream<Item = Result<Self::Bytes, ErrorResponse>> + Send + 'static;
/// Get some context about the request
fn request(&self) -> RequestCtx<'_>;
/// Get the status code
fn status(&self) -> u16;
/// Get a header value
fn header_one<'a>(&'a self, name: &str) -> Result<Option<Cow<'a, str>>, ErrorResponse>;
/// Walk through all headers with a callback
fn header_walk<F: FnMut(&str, &str) -> bool>(&self, f: F);
/// Collect all headers
fn header_collect(&self, out: &mut HeaderMap) -> Result<(), ErrorResponse>;
/// Get the body stream
fn body(self) -> Self::BodyStream;
}
/// Information about the incoming request
pub struct IncomingInfo {
version: axum::http::Version,
user_agent: String,
via: String,
}
impl IncomingInfo {
/// Check if the request is potentially looping
pub fn looping(&self, self_via: &str) -> bool {
if self.user_agent.is_empty() {
return true;
}
if self.via.contains(self_via) {
return true;
}
// defense against upstream
if self.user_agent.contains("Misskey/") ||
// Purposefully typoed
// https://raw.githubusercontent.com/backrunner/misskey-media-proxy-worker/refs/heads/main/wrangler.toml
self.user_agent.contains("Edg/119.0.2109.1")
{
return true;
}
let split = self.via.split(", ");
let mut seen = HashSet::new();
for part in split {
if !seen.insert(part) {
return true;
}
}
false
}
}
#[axum::async_trait]
impl<S> FromRequestParts<S> for IncomingInfo {
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
Ok(Self {
version: parts.version,
user_agent: parts
.headers
.get("user-agent")
.and_then(|v| v.to_str().ok())
.unwrap_or_default()
.to_string(),
via: parts
.headers
.get_all("via")
.into_iter()
.fold(String::new(), |mut acc, v| {
acc.push_str(v.to_str().unwrap_or_default());
acc.push_str(", ");
acc
}),
})
}
}
/// Trait for upstream clients
pub trait UpstreamClient {
/// Type of the response
type Response: HTTPResponse;
/// Create a new client
fn new(config: &FetchConfig) -> Self;
/// Request the upstream
fn request_upstream(
&self,
info: &IncomingInfo,
url: &str,
polish: bool,
secure: bool,
remaining: usize,
) -> impl std::future::Future<Output = Result<Self::Response, ErrorResponse>>;
}
/// Reqwest client
#[cfg(feature = "reqwest")]
pub mod reqwest {
use super::{
http_version_to_via, Cow, ErrorResponse, HTTPResponse, HeaderMap, Pin, RequestCtx,
TryStreamExt, UpstreamClient, MAX_SIZE,
};
use ::reqwest::{redirect::Policy, ClientBuilder, Url};
use axum::body::Bytes;
use reqwest::dns::Resolve;
use std::{sync::Arc, time::Duration};
/// A Safe DNS resolver that only resolves to global addresses unless the requester itself is local.
pub struct SafeResolver();
// pulled from https://doc.rust-lang.org/src/core/net/ip_addr.rs.html#1650
const fn is_unicast_local_v6(ip: &std::net::Ipv6Addr) -> bool {
ip.segments()[0] & 0xfe00 == 0xfc00
}
const fn is_unicast_link_local_v6(ip: &std::net::Ipv6Addr) -> bool {
ip.segments()[0] & 0xffc0 == 0xfe80
}
impl Resolve for SafeResolver {
fn resolve(&self, name: reqwest::dns::Name) -> reqwest::dns::Resolving {
Box::pin(async move {
match tokio::net::lookup_host(format!("{}:80", name.as_str())).await {
Ok(lookup) => Ok(Box::new(lookup.filter(|addr| match addr {
std::net::SocketAddr::V4(a) => {
!a.ip().is_loopback()
&& !a.ip().is_private()
&& !a.ip().is_link_local()
&& !a.ip().is_multicast()
&& !a.ip().is_documentation()
&& !a.ip().is_unspecified()
}
std::net::SocketAddr::V6(a) => {
!a.ip().is_loopback()
&& !a.ip().is_multicast()
&& !a.ip().is_unspecified()
&& is_unicast_local_v6(a.ip())
&& !is_unicast_link_local_v6(a.ip())
&& a.ip().to_ipv4_mapped().is_none()
}
}))
as Box<dyn Iterator<Item = std::net::SocketAddr> + Send>),
Err(e) => {
log::error!("Failed to resolve {}: {}", name.as_str(), e);
Err(e.into())
}
}
})
}
}
/// Reqwest client
pub struct ReqwestClient {
https_only: bool,
via_ident: String,
client: ::reqwest::Client,
}
/// Response from Reqwest
pub struct ReqwestResponse(::reqwest::Response);
impl HTTPResponse for ReqwestResponse {
type Bytes = Bytes;
type BodyStream = Pin<
Box<dyn futures::Stream<Item = Result<Self::Bytes, ErrorResponse>> + Send + 'static>,
>;
fn request(&self) -> RequestCtx<'_> {
RequestCtx {
url: self.0.url().as_str(),
secure: self.0.url().scheme().eq_ignore_ascii_case("https"),
}
}
fn status(&self) -> u16 {
self.0.status().as_u16()
}
fn header_one<'a>(&'a self, name: &str) -> Result<Option<Cow<'a, str>>, ErrorResponse> {
self.0
.headers()
.get(name)
.map(|v| v.to_str().map(Cow::Borrowed))
.transpose()
.map_err(|_| ErrorResponse::upstream_protocol_error())
}
fn header_walk<F: FnMut(&str, &str) -> bool>(&self, mut f: F) {
for (name, value) in self.0.headers() {
if !f(name.as_str(), value.to_str().unwrap_or_default()) {
break;
}
}
}
fn header_collect(&self, out: &mut HeaderMap) -> Result<(), ErrorResponse> {
for (name, value) in self.0.headers() {
out.insert(name, value.clone());
}
Ok(())
}
fn body(self) -> Self::BodyStream {
Box::pin(self.0.bytes_stream().map_err(Into::into))
}
}
impl UpstreamClient for ReqwestClient {
type Response = ReqwestResponse;
fn new(config: &crate::FetchConfig) -> Self {
Self {
https_only: !config.allow_http,
via_ident: config.via.clone(),
client: ClientBuilder::new()
.https_only(!config.allow_http)
.dns_resolver(Arc::new(SafeResolver()))
.brotli(true)
.zstd(true)
.gzip(true)
.redirect(Policy::none())
.connect_timeout(Duration::from_secs(5))
.timeout(Duration::from_secs(15))
.user_agent(config.user_agent.clone())
.build()
.expect("Failed to create reqwest client"),
}
}
async fn request_upstream(
&self,
info: &super::IncomingInfo,
url: &str,
polish: bool,
mut secure: bool,
remaining: usize,
) -> Result<ReqwestResponse, ErrorResponse> {
if remaining == 0 {
return Err(ErrorResponse::too_many_redirects());
}
if info.looping(&self.via_ident) {
return Err(ErrorResponse::loop_detected());
}
let url_parsed = Url::parse(url).map_err(|_| ErrorResponse::bad_url())?;
secure &= url_parsed.scheme().eq_ignore_ascii_case("https");
if self.https_only && !secure {
return Err(ErrorResponse::insecure_request());
}
let resp = self
.client
.get(url_parsed)
.header(
"via",
format!(
"{}, {} {}",
info.via,
http_version_to_via(info.version),
self.via_ident
),
)
.send()
.await?;
if resp.status().is_redirection() {
if let Some(location) = resp.headers().get("location").and_then(|l| l.to_str().ok())
{
return Box::pin(self.request_upstream(
info,
location,
polish,
secure,
remaining - 1,
))
.await;
}
}
if !resp.status().is_success() {
return Err(ErrorResponse::unexpected_status(
url,
resp.status().as_u16(),
));
}
let content_length = resp.headers().get("content-length");
if let Some(content_length) = content_length.and_then(|c| c.to_str().ok()) {
if content_length.parse::<usize>().unwrap_or(0) > MAX_SIZE {
return Err(ErrorResponse::payload_too_large());
}
}
let content_type = resp.headers().get("content-type");
if let Some(content_type) = content_type.and_then(|c| c.to_str().ok()) {
if !["image/", "video/", "audio/", "application/octet-stream"]
.iter()
.any(|prefix| {
content_type[..prefix.len().min(content_type.len())]
.eq_ignore_ascii_case(prefix)
})
{
return Err(ErrorResponse::not_media());
}
}
Ok(ReqwestResponse(resp))
}
}
}
/// Cloudflare Workers client
#[cfg(feature = "cf-worker")]
#[cfg_attr(
not(target_arch = "wasm32"),
deprecated = "You should use reqwest instead when not on Cloudflare Workers"
)]
pub mod cf_worker {
use std::time::Duration;
use super::{
http_version_to_via, Cow, ErrorResponse, HTTPResponse, HeaderMap, Pin, RequestCtx,
UpstreamClient, MAX_SIZE,
};
use axum::http::{HeaderName, HeaderValue};
use futures::{FutureExt, Stream, TryFutureExt};
use worker::{
AbortController, ByteStream, CfProperties, Fetch, Headers, Method, PolishConfig, Request,
RequestInit, RequestRedirect, Url,
};
/// Cloudflare Workers client
pub struct CfWorkerClient {
https_only: bool,
user_agent: String,
via_ident: String,
}
/// Wrapper for the body stream
pub struct CfBodyStreamWrapper {
stream: Option<Result<ByteStream, ErrorResponse>>,
}
impl Stream for CfBodyStreamWrapper {
type Item = Result<Vec<u8>, ErrorResponse>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.get_mut();
match this.stream.as_mut() {
Some(Ok(stream)) => match futures::ready!(std::pin::pin!(stream).poll_next(cx)) {
Some(Ok(chunk)) => std::task::Poll::Ready(Some(Ok(chunk))),
Some(Err(e)) => std::task::Poll::Ready(Some(Err(ErrorResponse::from(e)))),
None => std::task::Poll::Ready(None),
},
Some(Err(e)) => std::task::Poll::Ready(Some(Err(e.clone()))),
None => std::task::Poll::Ready(None),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self.stream {
Some(Ok(ref stream)) => stream.size_hint(),
_ => (0, None),
}
}
}
#[allow(unsafe_code, reason = "this is never used concurrently")]
unsafe impl Send for CfBodyStreamWrapper {}
#[allow(unsafe_code, reason = "this is never used concurrently")]
unsafe impl Sync for CfBodyStreamWrapper {}
/// Response from Cloudflare Workers
pub struct CfWorkerResponse {
resp: worker::Response,
url: Url,
}
impl HTTPResponse for CfWorkerResponse {
type Bytes = Vec<u8>;
type BodyStream = CfBodyStreamWrapper;
fn request(&self) -> RequestCtx<'_> {
RequestCtx {
url: self.url.as_str(),
secure: self.url.scheme().eq_ignore_ascii_case("https"),
}
}
fn status(&self) -> u16 {
self.resp.status_code()
}
fn header_one<'a>(&'a self, name: &str) -> Result<Option<Cow<'a, str>>, ErrorResponse> {
self.resp
.headers()
.get(name)
.map(|v| v.map(|v| Cow::Owned(v.to_string())))
.map_err(|_| ErrorResponse::upstream_protocol_error())
}
fn header_walk<F: FnMut(&str, &str) -> bool>(&self, mut f: F) {
for (name, value) in self.resp.headers().entries() {
if !f(&name, &value) {
break;
}
}
}
fn header_collect(&self, out: &mut HeaderMap) -> Result<(), ErrorResponse> {
for name in self.resp.headers().keys() {
out.insert(
HeaderName::from_bytes(name.as_bytes())
.map_err(|_| ErrorResponse::upstream_protocol_error())?,
self.resp
.headers()
.get(&name)
.map_err(|_| ErrorResponse::upstream_protocol_error())?
.map(HeaderValue::try_from)
.transpose()
.map_err(|_| ErrorResponse::upstream_protocol_error())?
.ok_or(ErrorResponse::upstream_protocol_error())?,
);
}
Ok(())
}
fn body(mut self) -> Self::BodyStream {
let stream = self.resp.stream().map_err(ErrorResponse::from);
CfBodyStreamWrapper {
stream: Some(stream),
}
}
}
impl UpstreamClient for CfWorkerClient {
type Response = CfWorkerResponse;
fn new(config: &crate::FetchConfig) -> Self {
Self {
https_only: !config.allow_http,
user_agent: config.user_agent.clone(),
via_ident: config.via.clone(),
}
}
async fn request_upstream(
&self,
info: &super::IncomingInfo,
url: &str,
polish: bool,
mut secure: bool,
remaining: usize,
) -> Result<Self::Response, ErrorResponse> {
if remaining == 0 {
return Err(ErrorResponse::too_many_redirects());
}
if info.looping(&self.via_ident) {
return Err(ErrorResponse::loop_detected());
}
let mut headers = Headers::new();
headers.set("user-agent", &self.user_agent)?;
headers.set(
"via",
&format!(
"{}, {} {}",
info.via,
http_version_to_via(info.version),
self.via_ident
),
)?;
let mut prop = CfProperties::new();
if polish {
prop.polish = Some(PolishConfig::Lossless);
}
let mut init = RequestInit::new();
let init = init
.with_method(Method::Get)
.with_headers(headers)
.with_cf_properties(prop)
.with_redirect(RequestRedirect::Manual);
let url_parsed = Url::parse(url).map_err(|_| ErrorResponse::bad_url())?;
if self.https_only && !url_parsed.scheme().eq_ignore_ascii_case("https") {
return Err(ErrorResponse::insecure_request());
}
secure &= url_parsed.scheme().eq_ignore_ascii_case("http");
let req = Request::new_with_init(url, init)?;
let abc = AbortController::default();
let abs = abc.signal();
let req = Fetch::Request(req);
let mut resp_fut = std::pin::pin!(req
.send_with_signal(&abs)
.map_err(ErrorResponse::worker_fetch_error));
let timeout = std::pin::pin!(worker::Delay::from(Duration::from_secs(5)));
let resp = futures::select! {
resp = resp_fut => resp?,
_ = timeout.fuse() => return Err(ErrorResponse::upstream_timeout()),
};
if resp.status_code() == 301 || resp.status_code() == 302 {
if let Ok(Some(location)) = resp.headers().get("location") {
return Box::pin(self.request_upstream(
info,
&location,
polish,
secure,
remaining - 1,
))
.await;
}
}
if resp.status_code() < 200 || resp.status_code() >= 300 {
return Err(ErrorResponse::unexpected_status(url, resp.status_code()));
}
let content_length = resp.headers().get("content-length").unwrap_or_default();
if let Some(content_length) = content_length {
if content_length.parse::<usize>().unwrap_or(0) > MAX_SIZE {
return Err(ErrorResponse::payload_too_large());
}
}
let content_type = resp.headers().get("content-type").unwrap_or_default();
if let Some(content_type) = content_type {
if !["image/", "video/", "audio/", "application/octet-stream"]
.iter()
.any(|prefix| {
content_type.as_str()[..prefix.len().min(content_type.len())]
.eq_ignore_ascii_case(prefix)
})
{
return Err(ErrorResponse::not_media());
}
}
Ok(CfWorkerResponse {
resp,
url: url_parsed,
})
}
}
}

755
src/lib.rs Normal file
View file

@ -0,0 +1,755 @@
#![doc = include_str!("../README.md")]
#![warn(clippy::all, clippy::pedantic)]
#![warn(unsafe_code)]
#![warn(missing_docs)]
#![allow(clippy::missing_errors_doc, clippy::module_name_repetitions)]
use std::{
borrow::Cow, fmt::Display, marker::PhantomData, net::SocketAddr, num::NonZero, sync::Arc,
};
#[cfg(feature = "governor")]
use axum::extract::ConnectInfo;
use axum::{
body::Body,
extract::{Path, Query, State},
http::{self, HeaderMap, StatusCode},
response::{IntoResponse, Redirect, Response},
routing::get,
Json, Router,
};
use fetch::{HTTPResponse, IncomingInfo, UpstreamClient, DEFAULT_MAX_REDIRECTS};
#[cfg(feature = "governor")]
use governor::{
clock::SystemClock, middleware::StateInformationMiddleware, state::keyed::DashMapStateStore,
RateLimiter,
};
use post_process::MediaResponse;
use sandbox::{NoSandbox, Sandboxing};
use serde::Deserialize;
#[cfg(feature = "cf-worker")]
use worker::{event, Context, Env, HttpRequest, Result as WorkerResult};
/// Module for delivering the final processed media to the client
pub mod deliver;
/// Module for fetching media from upstream
pub mod fetch;
/// Module for post-processing media
pub mod post_process;
/// Sandbox utilities for processing media
pub mod sandbox;
/// Stream utilities
pub mod stream;
const MAX_SIZE: usize = 32 << 20;
#[cfg(all(not(feature = "cf-worker"), not(feature = "reqwest")))]
compile_error!("At least one of the `cf-worker` or `reqwest` features must be enabled. hint: '--features env-local' enables everything related to local runtime");
#[cfg(feature = "cf-worker")]
/// The upstream client chosen by the build configuration
pub type Upstream = crate::fetch::cf_worker::CfWorkerClient;
#[cfg(all(not(feature = "cf-worker"), feature = "reqwest"))]
/// The upstream client chosen by the build configuration
pub type Upstream = crate::fetch::reqwest::ReqwestClient;
/// Application configuration
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
/// The listen address
pub listen: String,
/// Send Cache-Control headers
pub enable_cache: bool,
/// Index page configuration
pub index_redirect: IndexConfig,
/// Whether to only serve media with a known safe signature
pub strictly_secure: bool,
/// Fetch configuration
pub fetch: FetchConfig,
/// Post-processing configuration
pub post_process: PostProcessConfig,
#[cfg(feature = "governor")]
/// Governor configuration
pub rate_limit: RateLimitConfig,
}
/// Governor configuration
#[cfg(feature = "governor")]
#[derive(Debug, Clone, serde::Deserialize)]
pub struct RateLimitConfig {
/// The rate limit replenish interval in milliseconds
pub replenish_every: u64,
/// The rate limit burst size
pub burst: NonZero<u32>,
}
impl Default for Config {
fn default() -> Self {
Config {
listen: "127.0.0.1:3000".to_string(),
enable_cache: false,
fetch: FetchConfig {
allow_http: false,
via: concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")).to_string(),
user_agent: concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"))
.to_string(),
},
index_redirect: IndexConfig::Message(format!(
"Welcome to {}",
concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")),
)),
strictly_secure: true,
post_process: PostProcessConfig {
enable_redirects: false,
normalization: NormalizationPolicy::Opportunistic,
allow_svg_passthrough: false,
},
#[cfg(feature = "governor")]
rate_limit: RateLimitConfig {
replenish_every: 2000,
burst: NonZero::new(32).unwrap(),
},
}
}
}
/// Fetch configuration
#[derive(Debug, Clone, serde::Deserialize)]
pub struct FetchConfig {
/// Whether to allow HTTP requests
pub allow_http: bool,
/// The via string to use when fetching media
pub via: String,
/// The user agent to use when fetching media
pub user_agent: String,
}
/// Post-processing configuration
#[derive(Debug, Clone, serde::Deserialize)]
pub struct PostProcessConfig {
/// Opportunistically redirect to the original URL if the media is not modified
///
/// Potentially leaks the user's IP address and other metadata
pub enable_redirects: bool,
/// Whether to normalize media files when the request specifically asks for a format
pub normalization: NormalizationPolicy,
/// Whether to allow SVG passthrough
///
/// This opens up the possibility of SVG-based attacks
pub allow_svg_passthrough: bool,
}
/// Normalization policy
#[derive(Copy, Debug, Clone, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum NormalizationPolicy {
/// Only return the requested format and fail if it can't be provided
Enforce,
/// Always make an attempt to return the requested format
Always,
/// Only return the requested format if the conversion does not result in significant changes
///
/// This is the default
Opportunistic,
/// Ignore the requested format and return the original
Never,
}
impl Default for NormalizationPolicy {
fn default() -> Self {
Self::Opportunistic
}
}
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(untagged)]
/// Index page configuration
pub enum IndexConfig {
/// Redirect to a URL
#[allow(missing_docs)]
Redirect { permanent: bool, url: String },
/// Display a message
Message(String),
}
#[cfg(any(feature = "cf-worker", feature = "reqwest"))]
/// Application Router
pub fn router<C: UpstreamClient + 'static, S: Sandboxing + 'static>(config: Config) -> Router
where
<<C as UpstreamClient>::Response as HTTPResponse>::BodyStream: Unpin,
{
use std::time::Duration;
use axum::middleware;
#[cfg(feature = "governor")]
use governor::{
clock::SystemClock, middleware::StateInformationMiddleware, Quota, RateLimiter,
};
let state = AppState {
#[cfg(feature = "governor")]
limiter: RateLimiter::dashmap_with_clock(
Quota::with_period(Duration::from_millis(config.rate_limit.replenish_every))
.unwrap()
.allow_burst(config.rate_limit.burst),
SystemClock::default(),
)
.with_middleware::<StateInformationMiddleware>(),
client: Upstream::new(&config.fetch),
sandbox: NoSandbox,
config,
};
let state = Arc::new(state);
let router = Router::new()
.route("/", get(App::<C, S>::index))
.route(
"/proxy",
get(App::<C, S>::proxy_without_filename)
.head(App::<C, S>::proxy_without_filename)
.options(App::<C, S>::proxy_options)
.route_layer(middleware::from_fn_with_state(
state.clone(),
set_cache_control,
))
.fallback(|| async { ErrorResponse::method_not_allowed() }),
)
.route(
"/proxy/:filename",
get(App::<C, S>::proxy_with_filename)
.head(App::<C, S>::proxy_with_filename)
.options(App::<C, S>::proxy_options)
.route_layer(middleware::from_fn_with_state(
state.clone(),
set_cache_control,
))
.fallback(|| async { ErrorResponse::method_not_allowed() }),
)
.with_state(Arc::clone(&state));
#[cfg(feature = "governor")]
return router.route_layer(middleware::from_fn_with_state(state, rate_limit_middleware));
#[cfg(not(feature = "governor"))]
router
}
/// Set the Cache-Control header
#[cfg_attr(feature = "cf-worker", worker::send)]
pub async fn set_cache_control(
State(state): State<Arc<AppState<Upstream, NoSandbox>>>,
request: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
let mut resp = next.run(request).await;
if state.config.enable_cache {
if resp.status() == StatusCode::OK {
let headers = resp.headers_mut();
headers.insert(
"Cache-Control",
"public, max-age=31536000, immutable".parse().unwrap(),
);
} else {
let headers = resp.headers_mut();
headers.insert("Cache-Control", "max-age=300".parse().unwrap());
}
} else {
let headers = resp.headers_mut();
headers.insert("Cache-Control", "no-store".parse().unwrap());
}
resp
}
/// Middleware for rate limiting
#[cfg(feature = "governor")]
#[cfg_attr(feature = "cf-worker", worker::send)]
pub async fn rate_limit_middleware(
State(state): State<Arc<AppState<Upstream, NoSandbox>>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
request: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
use std::time::SystemTime;
match state.limiter.check_key(&addr) {
Ok(ok) => {
let mut resp = next.run(request).await;
let headers = resp.headers_mut();
headers.insert(
"X-RateLimit-Limit",
#[allow(clippy::unwrap_used)]
ok.quota().burst_size().to_string().parse().unwrap(),
);
headers.insert(
"X-RateLimit-Replenish-Interval",
state
.config
.rate_limit
.replenish_every
.to_string()
.parse()
.unwrap(),
);
headers.insert(
"X-RateLimit-Remaining",
ok.remaining_burst_capacity().to_string().parse().unwrap(),
);
resp
}
Err(err) => {
log::warn!("Rate limit exceeded for {}: {}", addr, err);
let mut resp = ErrorResponse::rate_limit_exceeded().into_response();
let headers = resp.headers_mut();
headers.insert(
"X-RateLimit-Limit",
#[allow(clippy::unwrap_used)]
err.quota().burst_size().to_string().parse().unwrap(),
);
headers.insert(
"X-RateLimit-Replenish-Interval",
state
.config
.rate_limit
.replenish_every
.to_string()
.parse()
.unwrap(),
);
headers.insert("X-RateLimit-Remaining", "0".parse().unwrap());
headers.insert(
"Retry-After",
err.wait_time_from(SystemTime::now().into())
.as_secs()
.to_string()
.parse()
.unwrap(),
);
resp
}
}
}
#[cfg(feature = "cf-worker")]
#[event(fetch)]
async fn fetch(
req: HttpRequest,
_env: Env,
_ctx: Context,
) -> WorkerResult<axum::http::Response<axum::body::Body>> {
use fetch::cf_worker::CfWorkerClient;
use tower_service::Service;
#[cfg(all(feature = "cf-worker", target_arch = "wasm32"))]
console_error_panic_hook::set_once();
Ok(router::<CfWorkerClient, NoSandbox>(Default::default())
.call(req)
.await?)
}
/// Query parameters for the proxy endpoint
#[derive(Debug, Clone, serde::Deserialize)]
#[allow(missing_docs)]
pub struct ProxyQuery {
pub url: String,
#[serde(flatten)]
pub image_options: ImageOptions,
}
fn deserialize_query_bool<'de, D>(deserializer: D) -> Result<Option<bool>, D::Error>
where
D: serde::Deserializer<'de>,
{
Option::<String>::deserialize(deserializer)?.map_or(Ok(None), |s| match s.as_str() {
"true" | "True" | "TRUE" | "1" | "Y" | "yes" | "Yes" | "YES" => Ok(Some(true)),
"false" | "False" | "FALSE" | "0" | "N" | "no" | "No" | "NO" => Ok(Some(false)),
_ => Err(serde::de::Error::custom("expected 'true' or 'false'")),
})
}
/// Query options for the proxy endpoint
#[derive(Debug, Clone, serde::Deserialize)]
pub struct ImageOptions {
/// If set to true, always proxy the image instead of redirecting
#[serde(default, deserialize_with = "deserialize_query_bool")]
pub origin: Option<bool>,
/// See upstream specification
#[serde(default, deserialize_with = "deserialize_query_bool")]
pub avatar: Option<bool>,
/// See upstream specification
#[serde(default, deserialize_with = "deserialize_query_bool")]
pub static_: Option<bool>,
/// See upstream specification
#[serde(default, deserialize_with = "deserialize_query_bool")]
pub preview: Option<bool>,
/// See upstream specification
#[serde(default, deserialize_with = "deserialize_query_bool")]
pub badge: Option<bool>,
/// See upstream specification
#[serde(default, deserialize_with = "deserialize_query_bool")]
pub emoji: Option<bool>,
/// Set the preferred format, see [`NormalizationPolicy`] for more information
pub format: Option<String>,
}
impl ImageOptions {
/// Whether post-processing is requested
pub fn requested_postprocess(&self) -> bool {
self.format.is_some()
|| self.avatar.is_some()
|| self.static_.is_some()
|| self.preview.is_some()
|| self.badge.is_some()
|| self.emoji.is_some()
}
/// Apply preferred format and image type from filename
pub fn apply_filename(&mut self, filename: &str) {
let mut split = filename.split('.');
let stem = split.next().unwrap_or_default();
let ext = split.last().unwrap_or_default();
match ext {
"png" => self.format = Some("png".to_string()),
"jpg" | "jpeg" => self.format = Some("jpeg".to_string()),
"webp" => self.format = Some("webp".to_string()),
"tiff" => self.format = Some("tiff".to_string()),
"gif" => self.format = Some("gif".to_string()),
"bmp" => self.format = Some("bmp".to_string()),
_ => {}
}
match stem {
"avatar" => self.avatar = Some(true),
"static" => self.static_ = Some(true),
"preview" => self.preview = Some(true),
"badge" => self.badge = Some(true),
"emoji" => self.emoji = Some(true),
_ => {}
}
}
}
#[allow(
clippy::trivially_copy_pass_by_ref,
reason = "Serde requires references"
)]
fn serialize_status<S>(status: &StatusCode, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u16(status.as_u16())
}
#[derive(Debug, Clone, serde::Serialize)]
/// Error response
#[allow(missing_docs)]
pub struct ErrorResponse {
#[serde(serialize_with = "serialize_status")]
pub status: StatusCode,
pub message: Cow<'static, str>,
}
impl Display for ErrorResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.status, self.message)
}
}
impl std::error::Error for ErrorResponse {}
impl ErrorResponse {
/// Method not allowed
pub const fn method_not_allowed() -> Self {
Self {
status: StatusCode::METHOD_NOT_ALLOWED,
message: Cow::Borrowed("Method not allowed"),
}
}
/// Upstream request timed out
#[must_use]
pub const fn upstream_timeout() -> Self {
Self {
status: StatusCode::GATEWAY_TIMEOUT,
message: Cow::Borrowed("Upstream request timed out"),
}
}
/// Unexpected status code
#[must_use]
pub fn unexpected_status(url: &str, status: u16) -> Self {
Self {
status: StatusCode::BAD_GATEWAY,
message: format!("Unexpected status code when accessing {}: {}", url, status).into(),
}
}
/// Insecure request
#[must_use]
pub const fn insecure_request() -> Self {
Self {
status: StatusCode::FORBIDDEN,
message: Cow::Borrowed("HTTP requests are disabled"),
}
}
/// Rate limit exceeded
#[must_use]
pub const fn rate_limit_exceeded() -> Self {
Self {
status: StatusCode::TOO_MANY_REQUESTS,
message: Cow::Borrowed("Rate limit exceeded"),
}
}
/// Bad URL
#[must_use]
pub const fn bad_url() -> Self {
Self {
status: StatusCode::BAD_REQUEST,
message: Cow::Borrowed("Bad URL"),
}
}
/// Upstream sent invalid HTTP response
#[must_use]
pub fn upstream_protocol_error() -> Self {
Self {
status: StatusCode::BAD_GATEWAY,
message: Cow::Borrowed("Upstream protocol error"),
}
}
/// Too many redirects
#[must_use]
pub const fn too_many_redirects() -> Self {
Self {
status: StatusCode::TOO_MANY_REQUESTS,
message: Cow::Borrowed("Too many redirects"),
}
}
/// Cloudflare worker reported an error
#[cfg(feature = "cf-worker")]
#[must_use]
#[allow(clippy::needless_pass_by_value)]
pub fn worker_fetch_error(e: worker::Error) -> Self {
Self {
status: StatusCode::BAD_GATEWAY,
message: Cow::Owned(e.to_string()),
}
}
/// Requested media is too large
#[must_use]
pub const fn loop_detected() -> Self {
Self {
status: StatusCode::LOOP_DETECTED,
message: Cow::Borrowed(
"Loop detected, please make sure your User-Agent and Via headers are not being stripped",
),
}
}
/// Received more data than allowed in one receive
#[must_use]
pub const fn mtu_buffer_overflow() -> Self {
Self {
status: StatusCode::PAYLOAD_TOO_LARGE,
message: Cow::Borrowed("MTU buffer overflow"),
}
}
/// Requested media is too large
#[must_use]
pub const fn payload_too_large() -> Self {
Self {
status: StatusCode::PAYLOAD_TOO_LARGE,
message: Cow::Borrowed("Payload too large"),
}
}
/// Post-processing failed
#[must_use]
pub const fn postprocess_failed(msg: Cow<'static, str>) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
message: msg,
}
}
/// Requested media is unsafe
#[must_use]
pub const fn unsafe_media() -> Self {
Self {
status: StatusCode::FORBIDDEN,
message: Cow::Borrowed("Unsafe media type"),
}
}
/// Requested media can not be processed
#[must_use]
pub const fn unsupported_media() -> Self {
Self {
status: StatusCode::UNSUPPORTED_MEDIA_TYPE,
message: Cow::Borrowed("Unsupported media type"),
}
}
/// Requested media is not a media file
#[must_use]
pub const fn not_media() -> Self {
Self {
status: StatusCode::BAD_REQUEST,
message: Cow::Borrowed("Not a media file"),
}
}
}
#[cfg(feature = "cf-worker")]
impl From<worker::Error> for ErrorResponse {
fn from(e: worker::Error) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
message: Cow::Owned(e.to_string()),
}
}
}
#[cfg(feature = "reqwest")]
impl From<reqwest::Error> for ErrorResponse {
fn from(e: reqwest::Error) -> Self {
Self {
status: StatusCode::BAD_GATEWAY,
message: Cow::Owned(e.to_string()),
}
}
}
impl IntoResponse for ErrorResponse {
fn into_response(self) -> axum::response::Response {
(self.status, Json(self)).into_response()
}
}
/// Application state
#[allow(unused)]
pub struct AppState<C: UpstreamClient, S: Sandboxing> {
#[cfg(feature = "governor")]
limiter: RateLimiter<
SocketAddr,
DashMapStateStore<SocketAddr>,
SystemClock,
StateInformationMiddleware,
>,
config: Config,
client: C,
sandbox: S,
}
/// App routes
pub struct App<C: UpstreamClient, S: Sandboxing> {
_marker: PhantomData<(C, S)>,
}
#[cfg(any(feature = "cf-worker", feature = "reqwest"))]
#[allow(clippy::unused_async)]
impl<C: UpstreamClient + 'static, S: Sandboxing + 'static> App<C, S> {
/// Root endpoint
#[cfg_attr(feature = "cf-worker", worker::send)]
pub async fn index(State(state): State<Arc<AppState<Upstream, NoSandbox>>>) -> Response {
match &state.clone().config.index_redirect {
IndexConfig::Redirect { permanent, ref url } => {
if *permanent {
Redirect::permanent(url).into_response()
} else {
Redirect::temporary(url).into_response()
}
}
IndexConfig::Message(msg) => (StatusCode::OK, msg.to_string()).into_response(),
}
}
#[cfg_attr(feature = "cf-worker", worker::send)]
async fn proxy_impl<'a>(
method: http::Method,
filename: Option<&str>,
State(state): State<Arc<AppState<Upstream, NoSandbox>>>,
Query(query): Query<ProxyQuery>,
info: IncomingInfo,
) -> Result<Response, ErrorResponse>
where
<<C as UpstreamClient>::Response as HTTPResponse>::BodyStream: Unpin,
{
let mut options = query.image_options;
if let Some(filename) = filename {
options.apply_filename(&filename);
}
match method {
http::Method::GET => {}
http::Method::HEAD => {
let mut resp = Response::new(Body::empty());
resp.headers_mut().insert(
"Content-Type",
match options.format.as_deref() {
Some("png") => "image/png",
Some("jpeg") | Some("jpg") => "image/jpeg",
Some("webp") => "image/webp",
_ => "image/webp",
}
.parse()
.unwrap(),
);
return Ok(resp);
}
_ => {
return Err(ErrorResponse::method_not_allowed());
}
}
let resp = state
.client
.request_upstream(&info, &query.url, false, true, DEFAULT_MAX_REDIRECTS)
.await?;
let media =
MediaResponse::from_upstream_response(resp, &state.config.post_process, options)
.await?;
Ok(media.into_response())
}
/// Proxy endpoint without filename
#[cfg_attr(feature = "cf-worker", worker::send)]
pub async fn proxy_without_filename(
method: http::Method,
Query(query): Query<ProxyQuery>,
State(state): State<Arc<AppState<Upstream, NoSandbox>>>,
info: IncomingInfo,
) -> Result<Response, ErrorResponse>
where
<<C as UpstreamClient>::Response as HTTPResponse>::BodyStream: Unpin,
{
Self::proxy_impl(method, None, State(state), Query(query), info).await
}
/// Proxy OPTIONS endpoint
#[cfg_attr(feature = "cf-worker", worker::send)]
pub async fn proxy_options() -> HeaderMap {
let mut hm = HeaderMap::new();
hm.insert(
"Access-Control-Allow-Methods",
"GET, OPTIONS".parse().unwrap(),
);
hm
}
/// Proxy endpoint with filename
#[cfg_attr(feature = "cf-worker", worker::send)]
pub async fn proxy_with_filename(
method: http::Method,
Path(filename): Path<String>,
State(state): State<Arc<AppState<Upstream, NoSandbox>>>,
Query(query): Query<ProxyQuery>,
info: IncomingInfo,
) -> Result<Response, ErrorResponse>
where
<<C as UpstreamClient>::Response as HTTPResponse>::BodyStream: Unpin,
{
Self::proxy_impl(method, Some(&filename), State(state), Query(query), info).await
}
}

33
src/main.rs Normal file
View file

@ -0,0 +1,33 @@
use std::net::SocketAddr;
use clap::Parser;
use yumechi_no_kuni_proxy_worker::{router, sandbox::NoSandbox, Config, Upstream};
#[derive(Parser)]
struct Cli {
#[clap(short, long)]
config: String,
}
#[tokio::main]
async fn main() {
env_logger::init();
let cli = Cli::parse();
let config_bytes = tokio::fs::read_to_string(&cli.config)
.await
.expect("Failed to read config file");
let config: Config = toml::from_str(&config_bytes).expect("Failed to parse config file");
let listen = config.listen.clone();
let router = router::<Upstream, NoSandbox>(config);
let ms = router.into_make_service_with_connect_info::<SocketAddr>();
let listener = tokio::net::TcpListener::bind(listen)
.await
.expect("Failed to bind listener");
axum::serve(listener, ms).await.expect("Failed to serve");
}

View file

@ -0,0 +1,156 @@
use std::io::Cursor;
use image::{
codecs::{png::PngDecoder, webp::WebPDecoder},
AnimationDecoder, DynamicImage, GenericImageView, ImageResult,
};
use crate::ImageOptions;
pub const fn clamp_width(input: (u32, u32), max_width: u32) -> (u32, u32) {
if input.0 > max_width {
(max_width, input.1 * max_width / input.0)
} else {
input
}
}
pub const fn clamp_height(input: (u32, u32), max_height: u32) -> (u32, u32) {
if input.1 > max_height {
(input.0 * max_height / input.1, max_height)
} else {
input
}
}
pub const fn clamp_dimensions(input: (u32, u32), max_width: u32, max_height: u32) -> (u32, u32) {
clamp_height(clamp_width(input, max_width), max_height)
}
// All constants are following https://github.com/misskey-dev/media-proxy/blob/master/SPECIFICATION.md
pub fn postprocess_webp_image(
data: &[u8],
opt: &ImageOptions,
) -> ImageResult<Option<DynamicImage>> {
let dec = WebPDecoder::new(Cursor::new(data))?;
if !dec.has_animation() {
return Ok(Some(postprocess_static_image(data, &opt)?));
}
if opt.static_ == Some(true) {
let first_frame = match dec.into_frames().next() {
Some(Ok(frame)) => frame,
_ => return Ok(None),
};
return Ok(Some(process_static_image_impl(
first_frame.into_buffer().into(),
opt,
)));
}
Ok(None)
}
pub fn postprocess_png_image(data: &[u8], opt: &ImageOptions) -> ImageResult<Option<DynamicImage>> {
let dec = PngDecoder::new(Cursor::new(data))?;
if dec.is_apng()? {
return Ok(Some(postprocess_static_image(data, &opt)?));
}
if opt.static_ == Some(true) {
let first_frame = match dec.apng()?.into_frames().next() {
Some(Ok(frame)) => frame,
_ => return Ok(None),
};
return Ok(Some(process_static_image_impl(
first_frame.into_buffer().into(),
opt,
)));
}
Ok(None)
}
#[derive(Debug, thiserror::Error)]
pub enum SvgPostprocessError {
#[error("Image error: {0}")]
Image(#[from] image::ImageError),
#[error("SVG error: {0}")]
Svg(#[from] resvg::usvg::Error),
}
pub fn postprocess_svg_image(
data: &[u8],
opt: &ImageOptions,
) -> Result<DynamicImage, SvgPostprocessError> {
use resvg::{
tiny_skia::Pixmap,
usvg::{self, Transform},
};
let svg = usvg::Tree::from_data(
data,
&usvg::Options {
default_size: usvg::Size::from_wh(256., 256.).unwrap(),
..Default::default()
},
)?;
let size = svg.size();
let clamped = clamp_dimensions((size.width() as u32, size.height() as u32), 800, 800);
let transform = Transform::from_scale(
clamped.0 as f32 / size.width() as f32,
clamped.1 as f32 / size.height() as f32,
);
let mut pm = Pixmap::new(clamped.0 as _, clamped.1 as _).unwrap();
resvg::render(&svg, transform, &mut pm.as_mut());
let image =
image::RgbaImage::from_vec(pm.width() as _, pm.height() as _, pm.data().to_vec()).unwrap();
Ok(process_static_image_impl(
DynamicImage::ImageRgba8(image),
&opt,
))
}
pub fn postprocess_static_image(data: &[u8], opt: &ImageOptions) -> ImageResult<DynamicImage> {
Ok(process_static_image_impl(
image::load_from_memory(data)?,
&opt,
))
}
fn process_static_image_impl(mut image: DynamicImage, opt: &ImageOptions) -> DynamicImage {
let mut out_dim = image.dimensions();
if opt.badge == Some(true) {
return image.resize_exact(96, 96, image::imageops::FilterType::Nearest);
}
if opt.emoji == Some(true) {
out_dim = clamp_height(out_dim, 128);
} else if opt.avatar == Some(true) {
out_dim = clamp_height(out_dim, 320);
}
if opt.preview == Some(true) {
out_dim = clamp_dimensions(out_dim, 200, 200);
}
if opt.static_ == Some(true) {
out_dim = clamp_dimensions(out_dim, 498, 422);
}
if out_dim != image.dimensions() {
image = image.resize_exact(out_dim.0, out_dim.1, image::imageops::FilterType::Lanczos3);
}
image
}

469
src/post_process/mod.rs Normal file
View file

@ -0,0 +1,469 @@
use std::{
io::{BufWriter, Cursor, Write},
pin::Pin,
};
use axum::{
body::{Body, Bytes},
http::HeaderValue,
response::{IntoResponse, Redirect},
};
use futures::StreamExt;
use image::{
codecs::gif::{GifEncoder, Repeat},
Frames, ImageFormat,
};
use sniff::SniffingStream;
use crate::{fetch::HTTPResponse, ErrorResponse, ImageOptions, PostProcessConfig};
const SLURP_LIMIT: usize = 32 << 20;
const MTU_BUFFER_SIZE: usize = 8192;
/// Module for sniffing media types and filtering out unsafe media
pub mod sniff;
/// Module for processing images
pub mod image_processing;
/// The kind of data passed to the client
pub enum MediaResponse<'a, R: HTTPResponse + 'a>
where
<R as HTTPResponse>::BodyStream: Unpin,
{
/// Send a redirect
Redirect(String),
/// A buffer of data
#[allow(missing_docs)]
Buffer {
data: Vec<u8>,
content_type: Option<String>,
},
/// A static image
ProcessedStaticImage(StaticImage),
/// An animated image
ProcessedAnimatedImage(AnimatedImage<'a>),
/// Pipe the response through (video, audio, etc.)0
PassThru(PassThru<R>),
}
impl<'a, R: HTTPResponse + 'a> MediaResponse<'a, R>
where
<R as HTTPResponse>::BodyStream: Unpin,
{
/// Create a new media response from a redirect
pub async fn from_upstream_response(
response: R,
config: &PostProcessConfig,
options: ImageOptions,
) -> Result<Self, ErrorResponse> {
let content_length = response
.header_one("content-length")
.ok()
.flatten()
.and_then(|cl| cl.parse::<usize>().ok());
let claimed_ct = response.header_one("content-type").ok().flatten();
// svg need special handling so we deal with it first
let is_svg = claimed_ct
.as_deref()
.map(|ct| ct.starts_with("image/svg"))
.unwrap_or(false);
// first if the media type is not something we can handle
if !is_svg
&& (!options.requested_postprocess()
|| claimed_ct
.map(|ct| ct.starts_with("video/") || ct.starts_with("audio/"))
.unwrap_or(false))
{
if config.enable_redirects
&& options.origin != Some(true)
&& content_length.map_or(false, |cl| cl > 1 << 20)
{
return Ok(MediaResponse::Redirect(response.request().url.to_string()));
} else {
return Ok(MediaResponse::probe_then_through(response).await?);
}
}
let is_https = response.request().secure;
let mut sniffer = Box::pin(SniffingStream::new(response));
let mut header = Cursor::new([0; MTU_BUFFER_SIZE]);
let mut overflow = false;
while futures::future::poll_fn(|cx| {
sniffer.as_mut().poll_sniff(cx, |buf| {
if header.write_all(buf).is_err() {
overflow = true;
}
})
})
.await
.is_some()
{}
if overflow {
return Err(ErrorResponse::mtu_buffer_overflow());
}
let result = sniffer.result_ref().cloned();
let mut remaining_body = Pin::into_inner(sniffer).into_remaining_body();
match result {
_ if is_svg => {
let header_len = header.position();
let header = header.into_inner();
let mut buf = Vec::with_capacity(header_len as usize);
buf.extend_from_slice(&header[..header_len as usize]);
while let Some(Ok(bytes)) = remaining_body.next().await {
if buf.len() + bytes.as_ref().len() > 1 << 20 {
return Err(ErrorResponse::payload_too_large());
}
buf.extend_from_slice(bytes.as_ref());
}
let img = image_processing::postprocess_svg_image(&buf, &options)
.map_err(|e| ErrorResponse::postprocess_failed(e.to_string().into()))?;
Ok(MediaResponse::ProcessedStaticImage(StaticImage {
data: img,
format: ImageFormat::WebP,
is_https,
}))
}
Some(rs) => {
if rs.maybe_unsafe {
return Err(ErrorResponse::unsafe_media());
}
match rs.sniffed_mime {
Some(mime) => {
if mime.starts_with("image/") {
// slurp it up
let header_len = header.position();
let header = header.into_inner();
let mut buf = if let Some(cl) = content_length {
let mut ret = Vec::with_capacity(cl);
ret.extend_from_slice(&header[..header_len as usize]);
ret
} else {
header[..header_len as usize].to_vec()
};
while let Some(Ok(bytes)) = remaining_body.next().await {
if buf.len() + bytes.as_ref().len() > SLURP_LIMIT {
return Err(ErrorResponse::payload_too_large());
}
buf.extend_from_slice(bytes.as_ref());
}
let output_static_format = if options
.format
.as_deref()
.map_or(false, |f| f.eq_ignore_ascii_case("png"))
{
ImageFormat::Png
} else if options
.format
.as_deref()
.map_or(false, |f| f.eq_ignore_ascii_case("jpeg"))
{
ImageFormat::Jpeg
} else {
ImageFormat::WebP
};
if options.format.is_none() {
if mime.starts_with("image/png") || mime.starts_with("image/apng") {
let result =
image_processing::postprocess_png_image(&buf, &options)
.map_err(|e| {
ErrorResponse::postprocess_failed(
e.to_string().into(),
)
})?;
return match result {
Some(img) => {
Ok(MediaResponse::ProcessedStaticImage(StaticImage {
data: img,
format: output_static_format,
is_https,
}))
}
None => Ok(MediaResponse::Buffer {
data: buf,
content_type: Some("image/png".into()),
}),
};
}
if mime.starts_with("image/webp") {
let result =
image_processing::postprocess_webp_image(&buf, &options)
.map_err(|e| {
ErrorResponse::postprocess_failed(
e.to_string().into(),
)
})?;
return match result {
Some(img) => {
Ok(MediaResponse::ProcessedStaticImage(StaticImage {
data: img,
format: output_static_format,
is_https,
}))
}
None => Ok(MediaResponse::Buffer {
data: buf,
content_type: Some("image/webp".into()),
}),
};
}
}
let result = image_processing::postprocess_static_image(&buf, &options)
.map_err(|e| {
ErrorResponse::postprocess_failed(e.to_string().into())
})?;
Ok(MediaResponse::ProcessedStaticImage(StaticImage {
data: result,
format: output_static_format,
is_https,
}))
} else {
Ok(MediaResponse::PassThru(PassThru {
header_len: header.position() as _,
header: header.into_inner(),
remaining_body,
content_type: Some(mime.to_string()),
is_https,
}))
}
}
None => Ok(MediaResponse::PassThru(PassThru {
header_len: header.position() as _,
header: header.into_inner(),
remaining_body,
content_type: None,
is_https,
})),
}
}
None => Err(ErrorResponse::unsupported_media()),
}
}
}
impl<'a, R: HTTPResponse + 'a> IntoResponse for MediaResponse<'a, R>
where
<R as HTTPResponse>::BodyStream: Unpin,
{
fn into_response(self) -> axum::response::Response {
match self {
MediaResponse::Buffer { data, content_type } => {
let mut resp = axum::http::Response::new(Body::from(data));
if let Some(ct) = content_type {
resp.headers_mut()
.insert("content-type", HeaderValue::from_str(&ct).unwrap());
}
resp.into_response()
}
MediaResponse::Redirect(ret) => Redirect::permanent(&ret).into_response(),
MediaResponse::ProcessedStaticImage(img) => img.into_response(),
MediaResponse::ProcessedAnimatedImage(img) => img.into_response(),
MediaResponse::PassThru(pt) => pt.into_response(),
}
}
}
impl<'a, R: HTTPResponse + 'a> MediaResponse<'a, R>
where
<R as HTTPResponse>::BodyStream: Unpin,
{
/// Probe the response for media type and then pass through
///
/// Mainly used for the /proxy/?url= endpoint
pub async fn probe_then_through(response: R) -> Result<Self, ErrorResponse> {
let content_type = response
.header_one("content-type")
.ok()
.flatten()
.map(|ct| ct.to_string());
let is_https = response.request().secure;
let mut sniffed = Box::pin(sniff::SniffingStream::new(response));
let mut header = Cursor::new([0; MTU_BUFFER_SIZE]);
let mut overflow = false;
while futures::future::poll_fn(|cx| {
sniffed.as_mut().poll_sniff(cx, |buf| {
if header.write_all(buf).is_err() {
overflow = true;
}
})
})
.await
.is_some()
{}
if overflow {
return Err(ErrorResponse::mtu_buffer_overflow());
}
let result = sniffed.result_ref().cloned();
let body = Pin::into_inner(sniffed).into_remaining_body();
match result {
Some(rs) => {
if rs.maybe_unsafe {
return Err(ErrorResponse::unsafe_media());
}
Ok(MediaResponse::PassThru(PassThru {
header_len: header
.position()
.try_into()
.map_err(|_| ErrorResponse::payload_too_large())?,
header: header.into_inner(),
remaining_body: body,
content_type,
is_https,
}))
}
None => Err(ErrorResponse::unsupported_media()),
}
}
}
/// Pass through the response
pub struct PassThru<R: HTTPResponse> {
header: [u8; MTU_BUFFER_SIZE],
header_len: usize,
content_type: Option<String>,
is_https: bool,
remaining_body: <R as HTTPResponse>::BodyStream,
}
impl<R: HTTPResponse> IntoResponse for PassThru<R> {
fn into_response(self) -> axum::response::Response {
if self
.content_type
.as_deref()
.map_or(false, |ct| ct.starts_with("image/svg"))
{
// reject svg pass through
return ErrorResponse::unsupported_media().into_response();
}
let content_type = HeaderValue::from_str(
self.content_type
.as_deref()
.unwrap_or("application/octet-stream"),
);
let proto = if self.is_https {
HeaderValue::from_static("https")
} else {
HeaderValue::from_static("http")
};
let header = Bytes::from(self.header[..self.header_len].to_vec());
let header_stream = futures::stream::once(async { Ok(header) });
let mut resp = axum::http::Response::new(Body::from_stream(header_stream.chain(
self.remaining_body.map(|res| match res {
Ok(bytes) => Ok(bytes.into()),
Err(e) => Err(e.message),
}),
)));
resp.headers_mut().insert("x-forwarded-proto", proto);
if let Ok(ct) = content_type {
resp.headers_mut().insert("content-type", ct);
}
resp.into_response()
}
}
/// Processed static image
pub struct StaticImage {
data: image::DynamicImage,
format: ImageFormat,
is_https: bool,
}
impl IntoResponse for StaticImage {
fn into_response(self) -> axum::response::Response {
let mut buf = BufWriter::new(Cursor::new(Vec::new()));
self.data.write_to(&mut buf, self.format).unwrap();
let mut resp =
axum::http::Response::new(Body::from(buf.into_inner().unwrap().into_inner()));
let mime = match self.format {
ImageFormat::Png => "image/png",
ImageFormat::Jpeg => "image/jpeg",
ImageFormat::Gif => "image/gif",
ImageFormat::WebP => "image/webp",
ImageFormat::Tiff => "image/tiff",
ImageFormat::Bmp => "image/bmp",
ImageFormat::Avif => "image/avif",
ImageFormat::OpenExr => "image/openexr",
ImageFormat::Ico => "image/vnd.microsoft.icon",
_ => "application/octet-stream",
};
resp.headers_mut().insert(
"x-forwarded-proto",
HeaderValue::from_static(if self.is_https { "https" } else { "http" }),
);
resp.headers_mut()
.insert("content-type", HeaderValue::from_static(mime));
resp.into_response()
}
}
/// Processed animated image
pub struct AnimatedImage<'a> {
frames: Frames<'a>,
is_https: bool,
}
impl<'a> IntoResponse for AnimatedImage<'a> {
fn into_response(self) -> axum::response::Response {
let mut buf = BufWriter::new(Cursor::new(Vec::new()));
let mut gif = GifEncoder::new(&mut buf);
gif.set_repeat(Repeat::Infinite).unwrap();
if let Err(e) = gif.try_encode_frames(self.frames) {
return ErrorResponse::postprocess_failed(format!("gif encoding failed: {e}").into())
.into_response();
}
drop(gif);
let mut resp =
axum::http::Response::new(Body::from(buf.into_inner().unwrap().into_inner()));
resp.headers_mut().insert(
"x-forwarded-proto",
HeaderValue::from_static(if self.is_https { "https" } else { "http" }),
);
resp.headers_mut()
.insert("content-type", HeaderValue::from_static("image/gif"));
resp.into_response()
}
}

221
src/post_process/sniff.rs Normal file
View file

@ -0,0 +1,221 @@
use std::{
io::{Cursor, Write},
pin::Pin,
task::Poll,
};
use futures::{Stream, StreamExt};
use crate::{fetch::HTTPResponse, ErrorResponse};
// MIME sniffing data
include!(concat!(env!("OUT_DIR"), "/magic.rs"));
/// An association between a MIME type and a file extension
#[derive(Clone)]
pub struct MIMEAssociation {
/// The MIME type
pub mime: &'static str,
/// The file extension
pub ext: &'static str,
/// Whether the file is safe to display
pub safe: bool,
/// The file signatures
signatures: &'static [FlattenedFileSignature],
}
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
struct FlattenedFileSignature(&'static [(u8, u8)]);
impl FlattenedFileSignature {
#[inline]
fn matches(&self, test: &[u8]) -> bool {
if self.0.len() > test.len() {
return false;
}
self.0
.iter()
.zip(test.iter())
.all(|((sig, mask), byte)| sig & mask == *byte & mask)
}
}
/// A stream that sniffs the MIME type of the data it receives
pub struct SniffingStream<R: HTTPResponse> {
body: <R as HTTPResponse>::BodyStream,
sniff_buffer: Cursor<[u8; SNIFF_SIZE]>,
sniffed: Option<SniffResult>,
}
/// The result of MIME sniffing
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
pub struct SniffResult {
/// The MIME type that was sniffed
pub sniffed_mime: Option<&'static str>,
/// Whether the file may be unsafe
pub maybe_unsafe: bool,
}
impl<R: HTTPResponse> SniffingStream<R> {
/// Create a new `SniffingStream` from a response
pub fn new(response: R) -> Self {
Self {
body: response.body(),
sniff_buffer: Cursor::new([0; SNIFF_SIZE]),
sniffed: None,
}
}
/// Create a new `SniffingStream` from a body stream
pub fn new_from_body_stream(response: <R as HTTPResponse>::BodyStream) -> Self {
Self {
body: response,
sniff_buffer: Cursor::new([0; SNIFF_SIZE]),
sniffed: None,
}
}
/// Get the result of MIME sniffing
pub fn result_ref(&self) -> Option<&SniffResult> {
self.sniffed.as_ref()
}
/// Get the result of MIME sniffing
pub fn result(self) -> Option<SniffResult> {
self.sniffed
}
/// Get the remaining body stream
pub fn into_remaining_body(self) -> <R as HTTPResponse>::BodyStream {
self.body
}
/// Poll the stream for MIME sniffing, writing any data consumed to a buffer
pub fn poll_sniff<'a>(
mut self: Pin<&'a mut Self>,
cx: &mut std::task::Context<'_>,
mut notify: impl FnMut(&[u8]),
) -> Poll<Option<Result<usize, ErrorResponse>>>
where
<R as HTTPResponse>::BodyStream: Unpin,
{
#[allow(clippy::cast_possible_truncation)]
let remaining_sniff_buffer = SNIFF_SIZE - self.sniff_buffer.position() as usize;
if remaining_sniff_buffer > 0 {
match self.body.poll_next_unpin(cx) {
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(Some(Ok(bytes))) => {
notify(bytes.as_ref());
self.sniff_buffer.write_all(bytes.as_ref()).ok();
if self.sniff_buffer.position() == SNIFF_SIZE as u64 {
let cands = MAGICS.iter().filter(|assoc| {
assoc
.signatures
.iter()
.any(|sig| sig.matches(self.sniff_buffer.get_ref()))
});
let mut all_safe = true;
let mut best_match = None;
for cand in cands {
match best_match {
None => best_match = Some(cand),
Some(assoc) if assoc.signatures.len() < cand.signatures.len() => {
best_match = Some(cand);
}
_ => {}
}
if !cand.safe {
all_safe = false;
break;
}
}
self.sniffed = Some(SniffResult {
sniffed_mime: best_match.map(|assoc| assoc.mime),
maybe_unsafe: !all_safe,
});
self.sniff_buffer.set_position(0);
return Poll::Ready(None);
}
return Poll::Ready(Some(Ok(bytes.as_ref().len())));
}
}
}
Poll::Ready(None)
}
}
impl<R: HTTPResponse> Stream for SniffingStream<R>
where
<R as HTTPResponse>::BodyStream: Unpin,
{
type Item = Result<<R as HTTPResponse>::Bytes, ErrorResponse>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
#[allow(clippy::cast_possible_truncation)]
let remaining_sniff_buffer = SNIFF_SIZE - self.sniff_buffer.position() as usize;
if self.sniffed.is_none() && remaining_sniff_buffer > 0 {
match self.body.poll_next_unpin(cx) {
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(Some(Ok(bytes))) => {
#[allow(clippy::unused_io_amount)]
self.sniff_buffer.write(bytes.as_ref()).ok();
if self.sniff_buffer.position() == SNIFF_SIZE as u64 {
let cands = MAGICS.iter().filter(|assoc| {
assoc
.signatures
.iter()
.any(|sig| sig.matches(self.sniff_buffer.get_ref()))
});
let mut all_safe = true;
let mut best_match = None;
for cand in cands {
if best_match
.map_or(0, |assoc: &MIMEAssociation| assoc.signatures.len())
< cand.signatures.len()
{
best_match = Some(cand);
}
if !cand.safe {
all_safe = false;
break;
}
}
self.sniffed = Some(SniffResult {
sniffed_mime: best_match.map(|assoc| assoc.mime),
maybe_unsafe: !all_safe,
});
self.sniff_buffer.set_position(0);
}
return Poll::Ready(Some(Ok(bytes)));
}
}
}
self.body.poll_next_unpin(cx)
}
}

103
src/sandbox.rs Normal file
View file

@ -0,0 +1,103 @@
/// A trait for setting up a thread sandboxing environment
pub trait Sandboxing {
/// The type of the guard that is returned by the setup function
type Guard;
/// Set up the sandboxing environment
fn setup(&self, key: &[u8]) -> Self::Guard;
}
/// A sandboxing environment that does nothing
#[derive(Default, Clone, Copy)]
pub struct NoSandbox;
impl Sandboxing for NoSandbox {
type Guard = ();
fn setup(&self, _key: &[u8]) -> Self::Guard {}
}
/// A sandboxing environment that uses `AppArmor` hats
#[cfg(feature = "apparmor")]
pub mod apparmor {
use std::{
ffi::{c_int, c_ulong, CString},
hash::Hasher,
};
use rand_core::RngCore;
use siphasher::sip::SipHasher;
use super::Sandboxing;
#[link(name = "apparmor")]
extern "C" {
fn aa_change_hat(profile: *const i8, token: c_ulong) -> c_int;
}
/// An `AppArmor` hat environment
pub struct AppArmorHat {
profile: CString,
hasher: SipHasher,
}
impl AppArmorHat {
/// Create a new `AppArmor` hat environment
///
/// # Panics
/// Panics if the profile name is invalid C string
#[must_use]
pub fn new(profile: &str) -> Self {
let cstr = std::ffi::CString::new(profile).expect("Invalid profile name");
let mut buf = [0; 16];
rand_core::OsRng.fill_bytes(&mut buf);
Self {
profile: cstr,
hasher: SipHasher::new_with_key(&buf),
}
}
}
/// A 'challange' to the less secure task to prove it can still function
/// The key here is hide the token reasonably well against ROP attacks
pub struct AppArmorHandle {
hasher: SipHasher,
hash_1x: Box<u64>,
}
impl Drop for AppArmorHandle {
#[allow(clippy::inline_always, reason = "Intentional")]
#[inline(always)]
fn drop(&mut self) {
self.hasher.write_u64(*self.hash_1x);
let hash_2x = self.hasher.finish();
#[allow(unsafe_code)]
let ret = unsafe { aa_change_hat(std::ptr::null(), hash_2x as c_ulong) };
// This should never happen as aa_change_hat in return mode will kill on any failed call
// This should never happen as aa_change_hat in return mode will kill on any failed call
assert!(ret == 0, "AppArmor hat return failed: {ret}");
}
}
impl Sandboxing for AppArmorHat {
type Guard = AppArmorHandle;
fn setup(&self, key: &[u8]) -> Self::Guard {
let mut hash = self.hasher;
hash.write(key);
let hash_1x = hash.finish();
let mut hash_2 = self.hasher;
hash_2.write_u64(hash_1x);
let hash_2x = hash_2.finish();
#[allow(unsafe_code)]
let ret = unsafe { aa_change_hat(self.profile.as_ptr(), hash_2x as c_ulong) };
assert!(ret == 0, "AppArmor hat change failed: {ret}");
AppArmorHandle {
hasher: hash,
hash_1x: Box::new(hash_1x),
}
}
}
}

72
src/stream.rs Normal file
View file

@ -0,0 +1,72 @@
use std::{
pin::Pin,
sync::atomic::{AtomicUsize, Ordering},
task::Poll,
};
type Bytes = Vec<u8>;
use futures::{Stream, StreamExt};
/// A stream that limits the amount of data it can receive
pub struct LimitedStream<T, E>
where
T: Stream<Item = Result<Bytes, E>> + 'static,
{
ended: bool,
stream: Pin<Box<T>>,
limit: AtomicUsize,
}
impl<T, E> LimitedStream<T, E>
where
T: Stream<Item = Result<Bytes, E>> + 'static,
{
/// Create a new `LimitedStream` with a given stream and limit
pub fn new(stream: T, limit: usize) -> Self {
Self {
ended: false,
stream: Box::pin(stream),
limit: AtomicUsize::new(limit),
}
}
}
impl<T, E> Stream for LimitedStream<T, E>
where
T: Stream<Item = Result<Bytes, E>> + 'static,
{
type Item = Result<Bytes, Option<E>>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.limit.load(Ordering::SeqCst) {
0 => {
if self.ended {
Poll::Ready(None)
} else {
self.ended = true;
Poll::Ready(Some(Err(None)))
}
}
remaining_len => {
let p = self.stream.poll_next_unpin(cx);
match p {
Poll::Ready(Some(Ok(mut data))) => {
if data.len() > remaining_len {
self.limit.store(0, Ordering::SeqCst);
data.truncate(remaining_len);
Poll::Ready(Some(Ok(data)))
} else {
self.limit.fetch_sub(data.len(), Ordering::SeqCst);
Poll::Ready(Some(Ok(data)))
}
}
_ => p.map_err(|e| Some(e)),
}
}
}
}
}

1
submodules/file Submodule

@ -0,0 +1 @@
Subproject commit 87ed2d47d61450bd1ee3f25d568c26116810437c

6
wrangler.toml Normal file
View file

@ -0,0 +1,6 @@
name = "yumechi-no-kuni-proxy-worker"
main = "build/worker/shim.mjs"
compatibility_date = "2024-11-11"
[build]
command = "cargo install -q worker-build && worker-build --release --features cf-worker"