Refactor stack logic
Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
parent
205958b0cc
commit
ab15f3e3ff
13 changed files with 727 additions and 132 deletions
113
Cargo.lock
generated
113
Cargo.lock
generated
|
@ -32,6 +32,21 @@ version = "0.2.18"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
|
||||
|
||||
[[package]]
|
||||
name = "android-tzdata"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.15"
|
||||
|
@ -304,6 +319,21 @@ version = "1.0.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.38"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401"
|
||||
dependencies = [
|
||||
"android-tzdata",
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clang-sys"
|
||||
version = "1.8.1"
|
||||
|
@ -386,6 +416,15 @@ version = "0.8.7"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.20"
|
||||
|
@ -479,18 +518,31 @@ dependencies = [
|
|||
"async-trait",
|
||||
"axum",
|
||||
"axum-server",
|
||||
"chrono",
|
||||
"clap",
|
||||
"dashmap",
|
||||
"env_logger",
|
||||
"flate2",
|
||||
"futures",
|
||||
"log",
|
||||
"lru",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.0.34"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
|
@ -827,6 +879,29 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.61"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220"
|
||||
dependencies = [
|
||||
"android_system_properties",
|
||||
"core-foundation-sys",
|
||||
"iana-time-zone-haiku",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"windows-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone-haiku"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "0.5.0"
|
||||
|
@ -1029,6 +1104,15 @@ dependencies = [
|
|||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.36.5"
|
||||
|
@ -1565,6 +1649,26 @@ dependencies = [
|
|||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.64"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.64"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.8.0"
|
||||
|
@ -1878,6 +1982,15 @@ dependencies = [
|
|||
"rustix",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-registry"
|
||||
version = "0.2.0"
|
||||
|
|
|
@ -3,17 +3,23 @@ name = "fedivet"
|
|||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
tls = ["axum-server/tls-rustls", "axum-server/rustls-pemfile", "axum-server/tokio-rustls"]
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.83"
|
||||
axum = "0.7.7"
|
||||
axum-server = { version = "0.7.1", features = ["tokio-rustls", "rustls-pemfile", "tls-rustls"] }
|
||||
axum-server = { version = "0.7.1" }
|
||||
chrono = { version = "0.4.38", features = ["serde"] }
|
||||
clap = { version = "4.5.20", features = ["derive"] }
|
||||
dashmap = "6.1.0"
|
||||
env_logger = "0.11.5"
|
||||
flate2 = "1.0.34"
|
||||
futures = "0.3.31"
|
||||
log = "0.4.22"
|
||||
lru = "0.12.5"
|
||||
reqwest = { version = "0.12.8", features = ["stream"] }
|
||||
serde = { version = "1.0.210", features = ["derive"] }
|
||||
serde_json = "1.0.128"
|
||||
thiserror = "1.0.64"
|
||||
tokio = { version = "1.40.0", features = ["rt", "rt-multi-thread", "macros", "net", "sync", "fs", "signal", "time"] }
|
||||
|
|
244
src/evaluate/chain/audit.rs
Normal file
244
src/evaluate/chain/audit.rs
Normal file
|
@ -0,0 +1,244 @@
|
|||
use axum::response::IntoResponse;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashSet;
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::sync::atomic::AtomicU32;
|
||||
use std::{collections::HashMap, fmt::Debug, ops::DerefMut, path::PathBuf, sync::Arc};
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
|
||||
use crate::{
|
||||
delegate,
|
||||
evaluate::{Disposition, Evaluator},
|
||||
HasAppState,
|
||||
};
|
||||
|
||||
/// Save audit logs to a file
|
||||
pub struct AuditOptions {
|
||||
/// The path to save the audit logs
|
||||
pub output: PathBuf,
|
||||
/// The maximum size of the audit log file before it is rotated
|
||||
pub rotate_size: Option<u64>,
|
||||
/// The number of days to keep audit logs before vacuuming
|
||||
pub vacuum_days: Option<u64>,
|
||||
}
|
||||
|
||||
impl AuditOptions {
|
||||
/// Create a new set of audit options
|
||||
pub fn new(output: PathBuf) -> Self {
|
||||
Self {
|
||||
output,
|
||||
rotate_size: None,
|
||||
vacuum_days: None,
|
||||
}
|
||||
}
|
||||
/// Set the maximum size of the audit log file before it is rotated
|
||||
pub fn rotate_size(mut self, rotate_size: u64) -> Self {
|
||||
self.rotate_size = Some(rotate_size);
|
||||
self
|
||||
}
|
||||
/// Set the number of days to keep audit logs before vacuuming
|
||||
pub fn vacuum_days(mut self, vacuum_days: u64) -> Self {
|
||||
self.vacuum_days = Some(vacuum_days);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
/// Errors that can occur while writing audit logs
|
||||
#[non_exhaustive]
|
||||
#[allow(missing_docs)]
|
||||
pub enum AuditError {
|
||||
#[error("Failed to write audit log: {0}")]
|
||||
WriteError(#[from] std::io::Error),
|
||||
|
||||
#[error("Failed to serialize audit log: {0}")]
|
||||
SerializeError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub struct AuditState {
|
||||
options: AuditOptions,
|
||||
cur_file: Option<RwLock<HashMap<String, Mutex<File>>>>,
|
||||
vacuum_counter: AtomicU32,
|
||||
}
|
||||
|
||||
impl AuditState {
|
||||
pub fn new(options: AuditOptions) -> Self {
|
||||
if !options.output.exists() {
|
||||
std::fs::create_dir_all(&options.output).expect("Failed to create audit log directory");
|
||||
}
|
||||
Self {
|
||||
options,
|
||||
cur_file: None,
|
||||
vacuum_counter: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn vacuum(&self) -> Result<(), AuditError> {
|
||||
let read = self.cur_file.as_ref().unwrap().read().await;
|
||||
let in_use_files = read.keys().cloned().collect::<HashSet<_>>();
|
||||
|
||||
let files = std::fs::read_dir(&self.options.output)?
|
||||
.filter_map(|f| f.ok())
|
||||
.filter(|f| {
|
||||
f.file_type().map(|t| t.is_file()).unwrap_or(false)
|
||||
&& f.file_name().to_string_lossy().ends_with(".json")
|
||||
})
|
||||
.filter(|f| !in_use_files.contains(&f.file_name().to_string_lossy().to_string()));
|
||||
|
||||
for file in files {
|
||||
let meta = file.metadata()?;
|
||||
|
||||
if let Some(days) = self.options.vacuum_days {
|
||||
let duration = meta
|
||||
.modified()?
|
||||
.elapsed()
|
||||
.unwrap_or(std::time::Duration::new(0, 0));
|
||||
|
||||
if duration.as_secs() >= days * 24 * 60 * 60 {
|
||||
std::fs::remove_file(file.path())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_new_file(&self, name: &str) -> Result<(), AuditError> {
|
||||
let time_str = chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string();
|
||||
|
||||
let mut write = self.cur_file.as_ref().unwrap().write().await;
|
||||
|
||||
let full_name = format!("{}_{}.json", name, time_str);
|
||||
|
||||
let file = OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.append(true)
|
||||
.open(&self.options.output.join(&full_name))?;
|
||||
|
||||
write.insert(name.to_string(), Mutex::new(file));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn write<E: IntoResponse + Serialize>(
|
||||
&self,
|
||||
name: &str,
|
||||
item: &AuditItem<'_, E>,
|
||||
) -> Result<(), AuditError> {
|
||||
if self
|
||||
.vacuum_counter
|
||||
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
|
||||
>= 512
|
||||
{
|
||||
self.vacuum().await?;
|
||||
self.vacuum_counter
|
||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
let read = self.cur_file.as_ref().unwrap().read().await;
|
||||
|
||||
if let Some(file) = read.get(name) {
|
||||
let mut f = file.lock().await;
|
||||
|
||||
serde_json::to_writer(f.deref_mut(), &item)?;
|
||||
|
||||
let meta = f.metadata()?;
|
||||
|
||||
if let Some(size) = self.options.rotate_size {
|
||||
if meta.len() >= size {
|
||||
drop(f);
|
||||
drop(read);
|
||||
|
||||
self.create_new_file(name).await?;
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut write = self.cur_file.as_ref().unwrap().write().await;
|
||||
|
||||
let file = File::create(&self.options.output.join(name))?;
|
||||
|
||||
write.insert(name.to_string(), Mutex::new(file));
|
||||
|
||||
// this is deliberately out of order to make sure we don't create endless files if serialization fails
|
||||
serde_json::to_writer(write.get(name).unwrap().lock().await.deref_mut(), &item)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Audit<E: IntoResponse + 'static, I: HasAppState<E>> {
|
||||
inner: I,
|
||||
state: Arc<AuditState>,
|
||||
_marker: std::marker::PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: IntoResponse + 'static, I: HasAppState<E>> Audit<E, I> {
|
||||
pub fn new(inner: I, options: AuditOptions) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
state: Arc::new(AuditState::new(options)),
|
||||
_marker: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: IntoResponse + 'static, I: HasAppState<E>> Clone for Audit<E, I> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
inner: self.inner.clone(),
|
||||
state: self.state.clone(),
|
||||
_marker: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delegate!(state Audit::<E, I>.inner);
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct AuditItem<'r, E: IntoResponse + Serialize> {
|
||||
info: &'r crate::APRequestInfo<'r>,
|
||||
ctx: &'r Option<serde_json::Value>,
|
||||
disposition: &'r Disposition<E>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<
|
||||
E: IntoResponse + Serialize + Send + Sync + 'static,
|
||||
I: HasAppState<E> + Evaluator<E> + Sync,
|
||||
> Evaluator<E> for Audit<E, I>
|
||||
{
|
||||
fn name() -> &'static str {
|
||||
"Audit"
|
||||
}
|
||||
async fn evaluate<'r>(
|
||||
&self,
|
||||
ctx: Option<serde_json::Value>,
|
||||
info: &crate::APRequestInfo<'r>,
|
||||
) -> (Disposition<E>, Option<serde_json::Value>) {
|
||||
let (disp, ctx) = self.inner.evaluate(ctx, info).await;
|
||||
|
||||
if ctx
|
||||
.as_ref()
|
||||
.map(|c| c.get("skip_audit").is_some())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return (disp, ctx);
|
||||
}
|
||||
|
||||
let item = AuditItem {
|
||||
info,
|
||||
ctx: &ctx,
|
||||
disposition: &disp,
|
||||
};
|
||||
|
||||
let state = self.state.clone();
|
||||
|
||||
state.write(Self::name(), &item).await.ok();
|
||||
|
||||
(disp, ctx)
|
||||
}
|
||||
}
|
52
src/evaluate/chain/meta.rs
Normal file
52
src/evaluate/chain/meta.rs
Normal file
|
@ -0,0 +1,52 @@
|
|||
use axum::response::IntoResponse;
|
||||
use reqwest::Url;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
delegate,
|
||||
model::ap::{AnyObject, NoteObject},
|
||||
APRequestInfo, HasAppState,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Meta<E: IntoResponse + 'static, I: HasAppState<E>> {
|
||||
inner: I,
|
||||
_marker: std::marker::PhantomData<E>,
|
||||
}
|
||||
|
||||
pub fn extract_host(act: &APRequestInfo) -> Option<String> {
|
||||
act.activity
|
||||
.as_ref()
|
||||
.ok()?
|
||||
.actor
|
||||
.as_ref()
|
||||
.and_then(|s| Url::parse(s).ok()?.host_str().map(|s| s.to_string()))
|
||||
}
|
||||
|
||||
pub fn extract_attributed_to(act: &APRequestInfo) -> Option<String> {
|
||||
match &act.activity.as_ref().ok()?.object {
|
||||
Some(arr) => match arr.clone().into_iter().next() {
|
||||
Some(AnyObject::NoteObject(NoteObject {
|
||||
attributed_to: Some(s),
|
||||
..
|
||||
})) => Some(s),
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_meta(act: &APRequestInfo) -> Option<MetaItem> {
|
||||
Some(MetaItem {
|
||||
instance_host: extract_host(act),
|
||||
attributed_to: extract_attributed_to(act),
|
||||
})
|
||||
}
|
||||
|
||||
delegate!(state Meta::<E, I>.inner);
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MetaItem {
|
||||
instance_host: Option<String>,
|
||||
attributed_to: Option<String>,
|
||||
}
|
2
src/evaluate/chain/mod.rs
Normal file
2
src/evaluate/chain/mod.rs
Normal file
|
@ -0,0 +1,2 @@
|
|||
pub mod audit;
|
||||
pub mod meta;
|
|
@ -1,5 +1,24 @@
|
|||
use axum::response::IntoResponse;
|
||||
use reqwest::StatusCode;
|
||||
use serde::Serialize;
|
||||
|
||||
pub mod chain;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! delegate {
|
||||
(state $str:ident::<E, I $(,$t:ty),*>.$field:ident) => {
|
||||
use crate::{BaseAppState, ClientCache};
|
||||
impl<E: IntoResponse + Clone + 'static, I: HasAppState<E>> HasAppState<E> for $str::<E, I, $($t),*> {
|
||||
fn app_state(&self) -> &BaseAppState<E> {
|
||||
self.inner.app_state()
|
||||
}
|
||||
|
||||
fn client_pool_ref(&self) -> &ClientCache {
|
||||
self.inner.client_pool_ref()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
use crate::{model::error::MisskeyError, APRequestInfo};
|
||||
|
||||
|
@ -10,14 +29,50 @@ pub const ERROR_DENIED: MisskeyError = MisskeyError::new_const(
|
|||
"This server cannot accept this activity.",
|
||||
);
|
||||
|
||||
pub const ERR_BAD_REQUEST: MisskeyError = MisskeyError::new_const(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"UNPARSABLE_REQUEST",
|
||||
"659c1254-1392-458e-aa77-557444031da8",
|
||||
"The request is not HTTP compliant.",
|
||||
);
|
||||
|
||||
pub const ERR_INTERNAL_SERVER_ERROR: MisskeyError = MisskeyError::new_const(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"INTERNAL_SERVER_ERROR",
|
||||
"3a19659e-89e9-4c37-ada8-bbf620f2fc1a",
|
||||
"An internal server error occurred.",
|
||||
);
|
||||
|
||||
pub const ERR_SERVICE_TEMPORARILY_UNAVAILABLE: MisskeyError = MisskeyError::new_const(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"SERVICE_TEMPORARILY_UNAVAILABLE",
|
||||
"39992bed-f58f-484d-bb67-8c8db6f0b224",
|
||||
"The service is temporarily unavailable.",
|
||||
);
|
||||
|
||||
pub const ERR_PAYLOAD_TOO_LARGE: MisskeyError = MisskeyError::new_const(
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
"PAYLOAD_TOO_LARGE",
|
||||
"8007d1b7-0eab-41b2-bb17-95f06926ba2b",
|
||||
"The request payload is too large.",
|
||||
);
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub enum Disposition<E: IntoResponse> {
|
||||
Allow,
|
||||
Next,
|
||||
Intercept(E),
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Evaluator<E: IntoResponse> {
|
||||
async fn evaluate<'r>(&self, info: &APRequestInfo<'r>) -> Disposition<E>;
|
||||
fn name() -> &'static str;
|
||||
|
||||
async fn evaluate<'r>(
|
||||
&self,
|
||||
ctx: Option<serde_json::Value>,
|
||||
info: &APRequestInfo<'r>,
|
||||
) -> (Disposition<E>, Option<serde_json::Value>);
|
||||
}
|
||||
|
||||
impl<E: IntoResponse> Into<Disposition<E>> for Result<(), E> {
|
||||
|
@ -37,7 +92,14 @@ impl<
|
|||
F: Fn(&APRequestInfo<'_>) -> Fut + Send + Sync,
|
||||
> Evaluator<E> for F
|
||||
{
|
||||
async fn evaluate<'r>(&self, info: &APRequestInfo<'r>) -> Disposition<E> {
|
||||
(*self)(info).await.into()
|
||||
fn name() -> &'static str {
|
||||
"Closure"
|
||||
}
|
||||
async fn evaluate<'r>(
|
||||
&self,
|
||||
ctx: Option<serde_json::Value>,
|
||||
info: &APRequestInfo<'r>,
|
||||
) -> (Disposition<E>, Option<serde_json::Value>) {
|
||||
(self(info).await.into(), ctx)
|
||||
}
|
||||
}
|
||||
|
|
187
src/lib.rs
187
src/lib.rs
|
@ -1,9 +1,11 @@
|
|||
#![doc = include_str!("../README.md")]
|
||||
#![feature(ip)]
|
||||
#![feature(async_closure)]
|
||||
#![warn(clippy::unwrap_used, clippy::expect_used)]
|
||||
#![warn(unsafe_code)]
|
||||
#![warn(missing_docs)]
|
||||
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
use std::{fmt::Display, net::SocketAddr};
|
||||
|
||||
|
@ -14,28 +16,65 @@ use axum::{
|
|||
response::{IntoResponse, Response},
|
||||
};
|
||||
use client::ClientCache;
|
||||
use evaluate::{Disposition, Evaluator};
|
||||
use evaluate::chain::audit::{Audit, AuditOptions};
|
||||
use evaluate::{
|
||||
Disposition, Evaluator, ERR_BAD_REQUEST, ERR_INTERNAL_SERVER_ERROR, ERR_PAYLOAD_TOO_LARGE,
|
||||
ERR_SERVICE_TEMPORARILY_UNAVAILABLE,
|
||||
};
|
||||
use futures::TryStreamExt;
|
||||
use model::ap::Object;
|
||||
use model::ap::AnyObject;
|
||||
use model::{ap, error::MisskeyError};
|
||||
use network::stream::LimitedStream;
|
||||
use network::Either;
|
||||
use reqwest::{self, StatusCode};
|
||||
use reqwest::{self};
|
||||
use serde::ser::{SerializeMap, SerializeStruct};
|
||||
use serde::Serialize;
|
||||
|
||||
pub(crate) mod client;
|
||||
/// Evaluation framework
|
||||
pub mod evaluate;
|
||||
/// Data models
|
||||
pub mod model;
|
||||
pub(crate) mod network;
|
||||
/// Server implementation
|
||||
pub mod serve;
|
||||
/// Data sources
|
||||
pub mod source;
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Information about an incoming ActivityPub request.
|
||||
#[allow(missing_docs)]
|
||||
pub struct APRequestInfo<'r> {
|
||||
pub method: &'r Method,
|
||||
pub uri: &'r Uri,
|
||||
pub header: &'r HeaderMap,
|
||||
pub connect: &'r SocketAddr,
|
||||
pub activity: &'r Result<ap::Activity<Object>, (serde_json::Value, serde_json::Error)>,
|
||||
pub activity: &'r Result<ap::Activity<AnyObject>, (serde_json::Value, serde_json::Error)>,
|
||||
}
|
||||
|
||||
impl Serialize for APRequestInfo<'_> {
|
||||
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
let mut state = serializer.serialize_struct("APRequestInfo", 5)?;
|
||||
state.serialize_field("method", &self.method.as_str())?;
|
||||
state.serialize_field("uri", &self.uri.to_string())?;
|
||||
pub struct SerializeHeader<'a>(&'a HeaderMap);
|
||||
impl<'a> Serialize for SerializeHeader<'a> {
|
||||
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
let mut header = serializer.serialize_map(Some(self.0.len()))?;
|
||||
for (key, value) in self.0.iter() {
|
||||
header.serialize_entry(key.as_str(), &value.to_str().unwrap_or_default())?;
|
||||
}
|
||||
header.end()
|
||||
}
|
||||
}
|
||||
state.serialize_field("header", &SerializeHeader(self.header))?;
|
||||
state.serialize_field("connect", &self.connect)?;
|
||||
state.serialize_field(
|
||||
"activity",
|
||||
&self.activity.as_ref().map_err(|(v, e)| (v, e.to_string())),
|
||||
)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for APRequestInfo<'_> {
|
||||
|
@ -48,80 +87,108 @@ impl Display for APRequestInfo<'_> {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct AppState<E: IntoResponse + 'static> {
|
||||
backend: reqwest::Url,
|
||||
clients: ClientCache,
|
||||
inbox_stack: Vec<Box<dyn Evaluator<E> + Send + Sync>>,
|
||||
/// Trait for accessing the application state.
|
||||
pub trait HasAppState<E: IntoResponse + 'static>: Clone {
|
||||
/// Get a reference to the application state.
|
||||
fn app_state(&self) -> &BaseAppState<E>;
|
||||
|
||||
/// Get a reference to the client pool.
|
||||
fn client_pool_ref(&self) -> &ClientCache;
|
||||
|
||||
/// Wrap the evaluator in an audit chain.
|
||||
fn audited(self, opts: AuditOptions) -> Audit<E, Self>
|
||||
where
|
||||
Self: Sized + Send + Sync + Evaluator<E>,
|
||||
E: Send + Sync,
|
||||
{
|
||||
Audit::new(self, opts)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: IntoResponse> AppState<E> {
|
||||
pub fn new(backend: reqwest::Url) -> Self {
|
||||
Self {
|
||||
backend,
|
||||
clients: ClientCache::new(),
|
||||
inbox_stack: Vec::new(),
|
||||
}
|
||||
/// Application state.
|
||||
pub struct BaseAppState<E: IntoResponse + 'static> {
|
||||
backend: reqwest::Url,
|
||||
clients: ClientCache,
|
||||
ctx_template: Option<serde_json::Value>,
|
||||
_marker: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: IntoResponse> HasAppState<E> for Arc<BaseAppState<E>> {
|
||||
fn app_state(&self) -> &BaseAppState<E> {
|
||||
&self
|
||||
}
|
||||
pub fn push_evaluator(&mut self, evaluator: Box<dyn Evaluator<E> + Send + Sync>) -> &mut Self {
|
||||
self.inbox_stack.push(evaluator);
|
||||
self
|
||||
}
|
||||
pub fn client_pool_ref(&self) -> &ClientCache {
|
||||
|
||||
fn client_pool_ref(&self) -> &ClientCache {
|
||||
&self.clients
|
||||
}
|
||||
}
|
||||
|
||||
const ERR_BAD_REQUEST: MisskeyError = MisskeyError::new_const(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"UNPARSABLE_REQUEST",
|
||||
"659c1254-1392-458e-aa77-557444031da8",
|
||||
"The request is not HTTP compliant.",
|
||||
);
|
||||
#[async_trait::async_trait]
|
||||
impl<E: IntoResponse + Send + Sync> Evaluator<E> for Arc<BaseAppState<E>> {
|
||||
fn name() -> &'static str {
|
||||
"BaseAllow"
|
||||
}
|
||||
async fn evaluate<'r>(
|
||||
&self,
|
||||
ctx: Option<serde_json::Value>,
|
||||
_info: &APRequestInfo<'r>,
|
||||
) -> (Disposition<E>, Option<serde_json::Value>) {
|
||||
log::trace!("Evaluator fell through, accepting request");
|
||||
(Disposition::Allow, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
const ERR_INTERNAL_SERVER_ERROR: MisskeyError = MisskeyError::new_const(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"INTERNAL_SERVER_ERROR",
|
||||
"3a19659e-89e9-4c37-ada8-bbf620f2fc1a",
|
||||
"An internal server error occurred.",
|
||||
);
|
||||
impl<E: IntoResponse> BaseAppState<E> {
|
||||
/// Create a new application state.
|
||||
pub fn new(backend: reqwest::Url) -> Self {
|
||||
Self {
|
||||
backend,
|
||||
ctx_template: None,
|
||||
clients: ClientCache::new(),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
/// Set the context template.
|
||||
pub fn with_ctx_template(mut self, ctx: serde_json::Value) -> Self {
|
||||
self.ctx_template = Some(ctx);
|
||||
self
|
||||
}
|
||||
/// With an empty context template.
|
||||
pub fn with_empty_ctx(self) -> Self {
|
||||
self.with_ctx_template(serde_json::Value::Object(Default::default()))
|
||||
}
|
||||
}
|
||||
|
||||
const ERR_SERVICE_TEMPORARILY_UNAVAILABLE: MisskeyError = MisskeyError::new_const(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"SERVICE_TEMPORARILY_UNAVAILABLE",
|
||||
"39992bed-f58f-484d-bb67-8c8db6f0b224",
|
||||
"The service is temporarily unavailable.",
|
||||
);
|
||||
|
||||
const ERR_PAYLOAD_TOO_LARGE: MisskeyError = MisskeyError::new_const(
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
"PAYLOAD_TOO_LARGE",
|
||||
"8007d1b7-0eab-41b2-bb17-95f06926ba2b",
|
||||
"The request payload is too large.",
|
||||
);
|
||||
|
||||
pub struct ProxyApp<E>(std::marker::PhantomData<E>);
|
||||
/// The main application.
|
||||
pub struct ProxyApp<S, E>(std::marker::PhantomData<(S, E)>);
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
impl<E: IntoResponse + 'static> ProxyApp<E> {
|
||||
impl<
|
||||
E: IntoResponse + 'static + Send + Sync,
|
||||
S: HasAppState<E> + Evaluator<E> + Send + Sync + 'static,
|
||||
> ProxyApp<S, E>
|
||||
{
|
||||
/// Pass through the request to the backend with basic error handling.
|
||||
pub async fn pass_through(
|
||||
method: Method,
|
||||
State(app): State<Arc<AppState<E>>>,
|
||||
State(app): State<S>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
OriginalUri(uri): OriginalUri,
|
||||
header: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<impl IntoResponse, MisskeyError> {
|
||||
let path_and_query = uri.path_and_query().ok_or(ERR_BAD_REQUEST)?;
|
||||
let state = app.app_state();
|
||||
|
||||
app.clone()
|
||||
app.app_state()
|
||||
.clients
|
||||
.with_client(&addr, |client| {
|
||||
Box::pin(async move {
|
||||
let req = client
|
||||
.request(
|
||||
method.clone(),
|
||||
app.backend
|
||||
state
|
||||
.backend
|
||||
.join(&path_and_query.to_string())
|
||||
.map_err(|_| ERR_INTERNAL_SERVER_ERROR)?,
|
||||
)
|
||||
|
@ -165,8 +232,9 @@ impl<E: IntoResponse + 'static> ProxyApp<E> {
|
|||
.await
|
||||
}
|
||||
|
||||
/// Handle incoming ActivityPub requests.
|
||||
pub async fn inbox_handler(
|
||||
State(app): State<Arc<AppState<E>>>,
|
||||
State(app): State<S>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
OriginalUri(uri): OriginalUri,
|
||||
header: HeaderMap,
|
||||
|
@ -196,7 +264,7 @@ impl<E: IntoResponse + 'static> ProxyApp<E> {
|
|||
})?;
|
||||
|
||||
let decode = {
|
||||
let activity_decode = serde_json::from_slice::<ap::Activity<Object>>(&body);
|
||||
let activity_decode = serde_json::from_slice::<ap::Activity<AnyObject>>(&body);
|
||||
|
||||
match activity_decode {
|
||||
Ok(activity) => Ok(activity),
|
||||
|
@ -215,21 +283,22 @@ impl<E: IntoResponse + 'static> ProxyApp<E> {
|
|||
activity: &decode,
|
||||
};
|
||||
|
||||
for evaluator in &app.inbox_stack {
|
||||
match evaluator.evaluate(&info).await {
|
||||
Disposition::Next => {}
|
||||
Disposition::Intercept(e) => return Err(Either::B(e)),
|
||||
}
|
||||
let ctx = app.app_state().ctx_template.clone();
|
||||
match app.evaluate(ctx, &info).await.0 {
|
||||
Disposition::Intercept(e) => return Err(Either::B(e)),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
app.clone()
|
||||
.app_state()
|
||||
.clients
|
||||
.with_client(&addr, |client| {
|
||||
Box::pin(async move {
|
||||
let req = client
|
||||
.request(
|
||||
Method::POST,
|
||||
app.backend
|
||||
app.app_state()
|
||||
.backend
|
||||
.join(&path_and_query.to_string())
|
||||
.map_err(|_| Either::A(ERR_INTERNAL_SERVER_ERROR))?,
|
||||
)
|
||||
|
|
85
src/main.rs
85
src/main.rs
|
@ -1,14 +1,15 @@
|
|||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::response::IntoResponse;
|
||||
use clap::Parser;
|
||||
use fedivet::evaluate::ERROR_DENIED;
|
||||
use fedivet::evaluate::chain::audit::AuditOptions;
|
||||
use fedivet::evaluate::Evaluator;
|
||||
use fedivet::model::error::MisskeyError;
|
||||
use fedivet::serve;
|
||||
use fedivet::source::LruData;
|
||||
use fedivet::APRequestInfo;
|
||||
use fedivet::AppState;
|
||||
use reqwest::Url;
|
||||
use fedivet::BaseAppState;
|
||||
use fedivet::HasAppState;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Parser)]
|
||||
pub struct Args {
|
||||
|
@ -23,53 +24,11 @@ pub struct Args {
|
|||
}
|
||||
|
||||
#[allow(clippy::unused_async)]
|
||||
async fn build_state(args: &Args) -> AppState<MisskeyError> {
|
||||
let mut state = AppState::new(args.backend.parse().expect("Invalid backend URL"));
|
||||
|
||||
let instance_history = Arc::new(LruData::sized(
|
||||
&|host| async move { Ok::<_, ()>("Todo") },
|
||||
512.try_into().unwrap(),
|
||||
Some(Duration::from_secs(600)),
|
||||
));
|
||||
|
||||
state.push_evaluator(Box::new(move |info: &APRequestInfo<'_>| {
|
||||
let act = info.activity.as_ref().map_err(|_| ERROR_DENIED).cloned();
|
||||
let instance_history = Arc::clone(&instance_history);
|
||||
|
||||
async move {
|
||||
let act = act?;
|
||||
let host = act
|
||||
.actor
|
||||
.as_ref()
|
||||
.and_then(|s| Url::parse(s).ok())
|
||||
.ok_or(ERROR_DENIED)?;
|
||||
|
||||
let instance = instance_history
|
||||
.query(host.host_str().unwrap().to_owned())
|
||||
.await;
|
||||
|
||||
match instance {
|
||||
Ok(i) => {
|
||||
log::info!("Instance history: {:?}", i);
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => Err::<(), _>(ERROR_DENIED),
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
// let user_history = Arc::new(LruData::sized( ... ));
|
||||
|
||||
state.push_evaluator(Box::new(|info: &APRequestInfo<'_>| {
|
||||
let act = info.activity.as_ref().map_err(|_| ERROR_DENIED).cloned();
|
||||
async move {
|
||||
let act = act?;
|
||||
log::debug!("Activity: {:?}", act);
|
||||
Ok(())
|
||||
}
|
||||
}));
|
||||
|
||||
state
|
||||
async fn build_state<E: IntoResponse + Clone + Serialize + Send + Sync + 'static>(
|
||||
base: Arc<BaseAppState<E>>,
|
||||
_args: &Args,
|
||||
) -> impl HasAppState<E> + Evaluator<E> {
|
||||
base.audited(AuditOptions::new(PathBuf::from("inbox_audit")))
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
|
@ -81,7 +40,13 @@ async fn main() {
|
|||
|
||||
let args = Args::parse();
|
||||
|
||||
let state = Arc::new(build_state(&args).await);
|
||||
let state = build_state::<MisskeyError>(
|
||||
Arc::new(BaseAppState::new(
|
||||
args.backend.parse().expect("Invalid backend URL"),
|
||||
)),
|
||||
&args,
|
||||
)
|
||||
.await;
|
||||
|
||||
let (jh, handle) = serve::serve(
|
||||
state.clone(),
|
||||
|
@ -97,10 +62,16 @@ async fn main() {
|
|||
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("Failed to register SIGTERM handler");
|
||||
|
||||
tokio::select! {
|
||||
_ = gc_ticker.tick() => {
|
||||
state.client_pool_ref().gc(std::time::Duration::from_secs(120));
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
gc_ticker.tick().await;
|
||||
state
|
||||
.client_pool_ref()
|
||||
.gc(std::time::Duration::from_secs(120));
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
res = jh => {
|
||||
if let Err(e) = res {
|
||||
log::error!("Server error: {}", e);
|
||||
|
|
|
@ -6,6 +6,8 @@ use serde::{Deserialize, Serialize};
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
/// Serialize as either A or B.
|
||||
#[allow(missing_docs)]
|
||||
pub enum Either<A, B> {
|
||||
A(A),
|
||||
B(B),
|
||||
|
@ -13,14 +15,16 @@ pub enum Either<A, B> {
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
/// Represents a single value or multiple values.
|
||||
#[allow(missing_docs)]
|
||||
pub enum FlatArray<T> {
|
||||
Single(T),
|
||||
Multiple(Vec<T>),
|
||||
}
|
||||
|
||||
impl IntoIterator for FlatArray<Object> {
|
||||
type Item = Object;
|
||||
type IntoIter = std::vec::IntoIter<Object>;
|
||||
impl<O> IntoIterator for FlatArray<O> {
|
||||
type Item = O;
|
||||
type IntoIter = std::vec::IntoIter<O>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
match self {
|
||||
|
@ -31,6 +35,8 @@ impl IntoIterator for FlatArray<Object> {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
/// Represents an ActivityStreams object.
|
||||
#[allow(missing_docs)]
|
||||
pub struct Object {
|
||||
#[serde(rename = "@context")]
|
||||
pub context: FlatArray<Either<String, serde_json::Value>>,
|
||||
|
@ -38,14 +44,18 @@ pub struct Object {
|
|||
#[serde(rename = "type")]
|
||||
pub ty_: String,
|
||||
pub name: Option<String>,
|
||||
pub published: Option<chrono::DateTime<chrono::Utc>>,
|
||||
pub to: Option<FlatArray<String>>,
|
||||
pub cc: Option<FlatArray<String>>,
|
||||
pub bcc: Option<FlatArray<String>>,
|
||||
pub url: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub rest: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
/// Represents an ActivityStreams activity.
|
||||
#[allow(missing_docs)]
|
||||
pub struct Activity<O> {
|
||||
pub actor: Option<String>,
|
||||
pub object: Option<FlatArray<O>>,
|
||||
|
@ -54,10 +64,42 @@ pub struct Activity<O> {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
/// Represents an ActivityStreams note object.
|
||||
#[allow(missing_docs)]
|
||||
pub struct NoteObject {
|
||||
pub summary: Option<String>,
|
||||
#[serde(rename = "inReplyTo")]
|
||||
pub in_reply_to: Option<String>,
|
||||
#[serde(rename = "attributedTo")]
|
||||
pub attributed_to: Option<String>,
|
||||
|
||||
pub sensitive: Option<bool>,
|
||||
pub content: Option<String>,
|
||||
#[serde(rename = "contentMap")]
|
||||
pub content_map: Option<HashMap<String, String>>,
|
||||
|
||||
pub attachment: Option<FlatArray<Attachment>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
/// Represents an ActivityStreams attachment.
|
||||
#[allow(missing_docs)]
|
||||
pub struct Attachment {
|
||||
#[serde(rename = "type")]
|
||||
pub ty_: String,
|
||||
pub media_type: Option<String>,
|
||||
pub url: String,
|
||||
pub name: Option<String>,
|
||||
pub blurhash: Option<String>,
|
||||
pub width: Option<u32>,
|
||||
pub height: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
/// Represents any parsable ActivityStreams object.
|
||||
#[allow(missing_docs)]
|
||||
pub enum AnyObject {
|
||||
NoteObject(NoteObject),
|
||||
Object(Object),
|
||||
}
|
||||
|
|
|
@ -10,14 +10,17 @@ use axum::{
|
|||
use reqwest::StatusCode;
|
||||
use serde::Serialize;
|
||||
|
||||
/// A trait for responses that can be converted into an axum response.
|
||||
pub trait APResponse: Debug + Display + Send + IntoResponse {}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// A response that should be ignored.
|
||||
pub struct Ignore {
|
||||
reason: Cow<'static, str>,
|
||||
}
|
||||
|
||||
impl Ignore {
|
||||
/// Create a new Ignore response.
|
||||
pub fn new(reason: impl Into<Cow<'static, str>>) -> Self {
|
||||
Self {
|
||||
reason: reason.into(),
|
||||
|
@ -40,6 +43,7 @@ impl IntoResponse for Ignore {
|
|||
impl APResponse for Ignore {}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// Errors that look like Misskey's error response.
|
||||
pub struct MisskeyError {
|
||||
#[serde(skip)]
|
||||
status: StatusCode,
|
||||
|
@ -49,6 +53,7 @@ pub struct MisskeyError {
|
|||
}
|
||||
|
||||
impl MisskeyError {
|
||||
/// Create a new MisskeyError using const fn.
|
||||
pub const fn new_const(
|
||||
status: StatusCode,
|
||||
code: &'static str,
|
||||
|
@ -62,6 +67,7 @@ impl MisskeyError {
|
|||
message: Cow::Borrowed(message),
|
||||
}
|
||||
}
|
||||
/// Create a new MisskeyError.
|
||||
pub fn new(
|
||||
status: StatusCode,
|
||||
code: &'static str,
|
||||
|
@ -76,6 +82,7 @@ impl MisskeyError {
|
|||
}
|
||||
}
|
||||
|
||||
/// HTTP 413 Payload Too Large
|
||||
pub fn s_413(
|
||||
code: &'static str,
|
||||
id: impl Into<Cow<'static, str>>,
|
||||
|
@ -84,6 +91,7 @@ impl MisskeyError {
|
|||
Self::new(StatusCode::PAYLOAD_TOO_LARGE, code, id, message)
|
||||
}
|
||||
|
||||
/// HTTP 500 Internal Server Error
|
||||
pub fn s_500(
|
||||
code: &'static str,
|
||||
id: impl Into<Cow<'static, str>>,
|
||||
|
@ -92,6 +100,7 @@ impl MisskeyError {
|
|||
Self::new(StatusCode::INTERNAL_SERVER_ERROR, code, id, message)
|
||||
}
|
||||
|
||||
/// HTTP 400 Bad Request
|
||||
pub fn s_400(
|
||||
code: &'static str,
|
||||
id: impl Into<Cow<'static, str>>,
|
||||
|
|
|
@ -1,2 +1,4 @@
|
|||
/// ActivityPub types
|
||||
pub mod ap;
|
||||
/// Error handling
|
||||
pub mod error;
|
||||
|
|
30
src/serve.rs
30
src/serve.rs
|
@ -1,14 +1,17 @@
|
|||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use axum::{
|
||||
response::IntoResponse,
|
||||
routing::{any, post},
|
||||
Router,
|
||||
};
|
||||
use axum_server::{tls_rustls::RustlsConfig, Handle};
|
||||
use axum_server::Handle;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::{APRequestInfo, AppState, ProxyApp};
|
||||
use crate::{evaluate::Evaluator, HasAppState, ProxyApp};
|
||||
|
||||
/// Flag indicating whether the TLS feature is enabled
|
||||
pub const HAS_TLS_FEATURE: bool = cfg!(feature = "tls");
|
||||
|
||||
#[allow(clippy::panic)]
|
||||
#[allow(clippy::unwrap_used)]
|
||||
|
@ -21,8 +24,11 @@ use crate::{APRequestInfo, AppState, ProxyApp};
|
|||
/// - `tls_key`: The path to the TLS key file
|
||||
///
|
||||
/// Use [`tokio::select`] to listen on multiple addresses
|
||||
pub async fn serve<E: IntoResponse + 'static>(
|
||||
state: Arc<AppState<E>>,
|
||||
pub async fn serve<
|
||||
S: HasAppState<E> + Evaluator<E> + Send + Sync + 'static,
|
||||
E: IntoResponse + Send + Sync + 'static,
|
||||
>(
|
||||
state: S,
|
||||
listen: &str,
|
||||
tls_cert: Option<&str>,
|
||||
tls_key: Option<&str>,
|
||||
|
@ -30,11 +36,11 @@ pub async fn serve<E: IntoResponse + 'static>(
|
|||
let app = Router::new()
|
||||
.route(
|
||||
"/inbox",
|
||||
post(ProxyApp::inbox_handler)
|
||||
.put(ProxyApp::inbox_handler)
|
||||
.patch(ProxyApp::inbox_handler),
|
||||
post(ProxyApp::<S, E>::inbox_handler)
|
||||
.put(ProxyApp::<S, E>::inbox_handler)
|
||||
.patch(ProxyApp::<S, E>::inbox_handler),
|
||||
)
|
||||
.fallback(any(ProxyApp::pass_through));
|
||||
.fallback(any(ProxyApp::<S, E>::pass_through));
|
||||
|
||||
let ms = app
|
||||
.with_state(state)
|
||||
|
@ -42,7 +48,9 @@ pub async fn serve<E: IntoResponse + 'static>(
|
|||
|
||||
let handle = Handle::new();
|
||||
match (tls_cert, tls_key) {
|
||||
#[cfg(feature = "tls")]
|
||||
(Some(cert), Some(key)) => {
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
let tls_config = RustlsConfig::from_pem_file(cert, key)
|
||||
.await
|
||||
.expect("Failed to load TLS certificate and key");
|
||||
|
@ -54,6 +62,10 @@ pub async fn serve<E: IntoResponse + 'static>(
|
|||
let jh = tokio::spawn(server);
|
||||
(jh, handle)
|
||||
}
|
||||
#[cfg(not(feature = "tls"))]
|
||||
(Some(_), Some(_)) => {
|
||||
panic!("TLS support is not enabled")
|
||||
}
|
||||
(None, None) => {
|
||||
log::info!("Listening on http://{}", listen);
|
||||
let server = axum_server::bind(listen.parse().expect("invalid listen addr"))
|
||||
|
|
|
@ -10,9 +10,12 @@ use std::{
|
|||
|
||||
use lru::LruCache;
|
||||
|
||||
/// Friend instance data source
|
||||
pub mod friend;
|
||||
/// Nodeinfo data source
|
||||
pub mod nodeinfo;
|
||||
|
||||
/// Lazy future that caches its result
|
||||
pub struct LazyFuture<T: Clone, E, F: Future<Output = Result<T, E>>> {
|
||||
future: Mutex<Pin<Box<F>>>,
|
||||
ttl: Option<Duration>,
|
||||
|
@ -20,6 +23,7 @@ pub struct LazyFuture<T: Clone, E, F: Future<Output = Result<T, E>>> {
|
|||
}
|
||||
|
||||
impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> LazyFuture<T, E, F> {
|
||||
/// Create a new lazy future
|
||||
pub fn new(future: F, ttl: Option<Duration>) -> Self {
|
||||
Self {
|
||||
future: Mutex::new(Box::pin(future)),
|
||||
|
@ -73,6 +77,8 @@ struct LazyFutureHandle<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> {
|
|||
}
|
||||
|
||||
impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> LazyFutureHandle<T, E, F> {
|
||||
/// Get a reference to the inner future
|
||||
#[allow(dead_code)]
|
||||
pub fn inner_ref(&self) -> &LazyFuture<T, E, F> {
|
||||
&self.inner
|
||||
}
|
||||
|
@ -94,6 +100,7 @@ impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> Future for LazyFuture
|
|||
}
|
||||
}
|
||||
|
||||
/// LRU cache for futures
|
||||
pub struct LruData<
|
||||
'a,
|
||||
K: Hash,
|
||||
|
@ -117,6 +124,7 @@ impl<
|
|||
FF: Fn(&K) -> F,
|
||||
> LruData<'a, K, T, E, F, FF>
|
||||
{
|
||||
/// Create a new LRU cache with a given size and factory
|
||||
pub fn sized(factory: &'a FF, size: NonZeroUsize, ttl: Option<Duration>) -> Self {
|
||||
Self {
|
||||
factory,
|
||||
|
@ -126,6 +134,7 @@ impl<
|
|||
}
|
||||
}
|
||||
|
||||
/// Evict all expired entries from the cache
|
||||
pub fn evict_expired(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
|
||||
|
@ -152,6 +161,7 @@ impl<
|
|||
}
|
||||
}
|
||||
|
||||
/// Get or insert a future into the cache
|
||||
pub fn get_or_insert(&self, key: K) -> Arc<LazyFuture<T, E, F>> {
|
||||
let factory = self.factory;
|
||||
let entry = self.inner.lock().unwrap().get(&key).cloned();
|
||||
|
@ -166,6 +176,7 @@ impl<
|
|||
}
|
||||
}
|
||||
|
||||
/// Query the cache for a given key or insert it if it doesn't exist
|
||||
pub fn query(&self, key: K) -> impl Future<Output = Result<T, E>> {
|
||||
self.maybe_evict();
|
||||
let future = self.get_or_insert(key);
|
||||
|
|
Loading…
Reference in a new issue