Refactor stack logic
Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
parent
205958b0cc
commit
ab15f3e3ff
13 changed files with 727 additions and 132 deletions
113
Cargo.lock
generated
113
Cargo.lock
generated
|
@ -32,6 +32,21 @@ version = "0.2.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
|
checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "android-tzdata"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "android_system_properties"
|
||||||
|
version = "0.1.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anstream"
|
name = "anstream"
|
||||||
version = "0.6.15"
|
version = "0.6.15"
|
||||||
|
@ -304,6 +319,21 @@ version = "1.0.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "chrono"
|
||||||
|
version = "0.4.38"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401"
|
||||||
|
dependencies = [
|
||||||
|
"android-tzdata",
|
||||||
|
"iana-time-zone",
|
||||||
|
"js-sys",
|
||||||
|
"num-traits",
|
||||||
|
"serde",
|
||||||
|
"wasm-bindgen",
|
||||||
|
"windows-targets",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clang-sys"
|
name = "clang-sys"
|
||||||
version = "1.8.1"
|
version = "1.8.1"
|
||||||
|
@ -386,6 +416,15 @@ version = "0.8.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crc32fast"
|
||||||
|
version = "1.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-utils"
|
name = "crossbeam-utils"
|
||||||
version = "0.8.20"
|
version = "0.8.20"
|
||||||
|
@ -479,18 +518,31 @@ dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum",
|
"axum",
|
||||||
"axum-server",
|
"axum-server",
|
||||||
|
"chrono",
|
||||||
"clap",
|
"clap",
|
||||||
"dashmap",
|
"dashmap",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
|
"flate2",
|
||||||
"futures",
|
"futures",
|
||||||
"log",
|
"log",
|
||||||
"lru",
|
"lru",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "flate2"
|
||||||
|
version = "1.0.34"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0"
|
||||||
|
dependencies = [
|
||||||
|
"crc32fast",
|
||||||
|
"miniz_oxide",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fnv"
|
name = "fnv"
|
||||||
version = "1.0.7"
|
version = "1.0.7"
|
||||||
|
@ -827,6 +879,29 @@ dependencies = [
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iana-time-zone"
|
||||||
|
version = "0.1.61"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220"
|
||||||
|
dependencies = [
|
||||||
|
"android_system_properties",
|
||||||
|
"core-foundation-sys",
|
||||||
|
"iana-time-zone-haiku",
|
||||||
|
"js-sys",
|
||||||
|
"wasm-bindgen",
|
||||||
|
"windows-core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iana-time-zone-haiku"
|
||||||
|
version = "0.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "idna"
|
name = "idna"
|
||||||
version = "0.5.0"
|
version = "0.5.0"
|
||||||
|
@ -1029,6 +1104,15 @@ dependencies = [
|
||||||
"minimal-lexical",
|
"minimal-lexical",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-traits"
|
||||||
|
version = "0.2.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "object"
|
name = "object"
|
||||||
version = "0.36.5"
|
version = "0.36.5"
|
||||||
|
@ -1565,6 +1649,26 @@ dependencies = [
|
||||||
"windows-sys 0.59.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror"
|
||||||
|
version = "1.0.64"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84"
|
||||||
|
dependencies = [
|
||||||
|
"thiserror-impl",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror-impl"
|
||||||
|
version = "1.0.64"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tinyvec"
|
name = "tinyvec"
|
||||||
version = "1.8.0"
|
version = "1.8.0"
|
||||||
|
@ -1878,6 +1982,15 @@ dependencies = [
|
||||||
"rustix",
|
"rustix",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-core"
|
||||||
|
version = "0.52.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
|
||||||
|
dependencies = [
|
||||||
|
"windows-targets",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows-registry"
|
name = "windows-registry"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
|
|
|
@ -3,17 +3,23 @@ name = "fedivet"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
tls = ["axum-server/tls-rustls", "axum-server/rustls-pemfile", "axum-server/tokio-rustls"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1.83"
|
async-trait = "0.1.83"
|
||||||
axum = "0.7.7"
|
axum = "0.7.7"
|
||||||
axum-server = { version = "0.7.1", features = ["tokio-rustls", "rustls-pemfile", "tls-rustls"] }
|
axum-server = { version = "0.7.1" }
|
||||||
|
chrono = { version = "0.4.38", features = ["serde"] }
|
||||||
clap = { version = "4.5.20", features = ["derive"] }
|
clap = { version = "4.5.20", features = ["derive"] }
|
||||||
dashmap = "6.1.0"
|
dashmap = "6.1.0"
|
||||||
env_logger = "0.11.5"
|
env_logger = "0.11.5"
|
||||||
|
flate2 = "1.0.34"
|
||||||
futures = "0.3.31"
|
futures = "0.3.31"
|
||||||
log = "0.4.22"
|
log = "0.4.22"
|
||||||
lru = "0.12.5"
|
lru = "0.12.5"
|
||||||
reqwest = { version = "0.12.8", features = ["stream"] }
|
reqwest = { version = "0.12.8", features = ["stream"] }
|
||||||
serde = { version = "1.0.210", features = ["derive"] }
|
serde = { version = "1.0.210", features = ["derive"] }
|
||||||
serde_json = "1.0.128"
|
serde_json = "1.0.128"
|
||||||
|
thiserror = "1.0.64"
|
||||||
tokio = { version = "1.40.0", features = ["rt", "rt-multi-thread", "macros", "net", "sync", "fs", "signal", "time"] }
|
tokio = { version = "1.40.0", features = ["rt", "rt-multi-thread", "macros", "net", "sync", "fs", "signal", "time"] }
|
||||||
|
|
244
src/evaluate/chain/audit.rs
Normal file
244
src/evaluate/chain/audit.rs
Normal file
|
@ -0,0 +1,244 @@
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use serde::Serialize;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::fs::{File, OpenOptions};
|
||||||
|
use std::sync::atomic::AtomicU32;
|
||||||
|
use std::{collections::HashMap, fmt::Debug, ops::DerefMut, path::PathBuf, sync::Arc};
|
||||||
|
use tokio::sync::{Mutex, RwLock};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
delegate,
|
||||||
|
evaluate::{Disposition, Evaluator},
|
||||||
|
HasAppState,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Save audit logs to a file
|
||||||
|
pub struct AuditOptions {
|
||||||
|
/// The path to save the audit logs
|
||||||
|
pub output: PathBuf,
|
||||||
|
/// The maximum size of the audit log file before it is rotated
|
||||||
|
pub rotate_size: Option<u64>,
|
||||||
|
/// The number of days to keep audit logs before vacuuming
|
||||||
|
pub vacuum_days: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuditOptions {
|
||||||
|
/// Create a new set of audit options
|
||||||
|
pub fn new(output: PathBuf) -> Self {
|
||||||
|
Self {
|
||||||
|
output,
|
||||||
|
rotate_size: None,
|
||||||
|
vacuum_days: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Set the maximum size of the audit log file before it is rotated
|
||||||
|
pub fn rotate_size(mut self, rotate_size: u64) -> Self {
|
||||||
|
self.rotate_size = Some(rotate_size);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Set the number of days to keep audit logs before vacuuming
|
||||||
|
pub fn vacuum_days(mut self, vacuum_days: u64) -> Self {
|
||||||
|
self.vacuum_days = Some(vacuum_days);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
/// Errors that can occur while writing audit logs
|
||||||
|
#[non_exhaustive]
|
||||||
|
#[allow(missing_docs)]
|
||||||
|
pub enum AuditError {
|
||||||
|
#[error("Failed to write audit log: {0}")]
|
||||||
|
WriteError(#[from] std::io::Error),
|
||||||
|
|
||||||
|
#[error("Failed to serialize audit log: {0}")]
|
||||||
|
SerializeError(#[from] serde_json::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AuditState {
|
||||||
|
options: AuditOptions,
|
||||||
|
cur_file: Option<RwLock<HashMap<String, Mutex<File>>>>,
|
||||||
|
vacuum_counter: AtomicU32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuditState {
|
||||||
|
pub fn new(options: AuditOptions) -> Self {
|
||||||
|
if !options.output.exists() {
|
||||||
|
std::fs::create_dir_all(&options.output).expect("Failed to create audit log directory");
|
||||||
|
}
|
||||||
|
Self {
|
||||||
|
options,
|
||||||
|
cur_file: None,
|
||||||
|
vacuum_counter: AtomicU32::new(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn vacuum(&self) -> Result<(), AuditError> {
|
||||||
|
let read = self.cur_file.as_ref().unwrap().read().await;
|
||||||
|
let in_use_files = read.keys().cloned().collect::<HashSet<_>>();
|
||||||
|
|
||||||
|
let files = std::fs::read_dir(&self.options.output)?
|
||||||
|
.filter_map(|f| f.ok())
|
||||||
|
.filter(|f| {
|
||||||
|
f.file_type().map(|t| t.is_file()).unwrap_or(false)
|
||||||
|
&& f.file_name().to_string_lossy().ends_with(".json")
|
||||||
|
})
|
||||||
|
.filter(|f| !in_use_files.contains(&f.file_name().to_string_lossy().to_string()));
|
||||||
|
|
||||||
|
for file in files {
|
||||||
|
let meta = file.metadata()?;
|
||||||
|
|
||||||
|
if let Some(days) = self.options.vacuum_days {
|
||||||
|
let duration = meta
|
||||||
|
.modified()?
|
||||||
|
.elapsed()
|
||||||
|
.unwrap_or(std::time::Duration::new(0, 0));
|
||||||
|
|
||||||
|
if duration.as_secs() >= days * 24 * 60 * 60 {
|
||||||
|
std::fs::remove_file(file.path())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create_new_file(&self, name: &str) -> Result<(), AuditError> {
|
||||||
|
let time_str = chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string();
|
||||||
|
|
||||||
|
let mut write = self.cur_file.as_ref().unwrap().write().await;
|
||||||
|
|
||||||
|
let full_name = format!("{}_{}.json", name, time_str);
|
||||||
|
|
||||||
|
let file = OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.write(true)
|
||||||
|
.append(true)
|
||||||
|
.open(&self.options.output.join(&full_name))?;
|
||||||
|
|
||||||
|
write.insert(name.to_string(), Mutex::new(file));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn write<E: IntoResponse + Serialize>(
|
||||||
|
&self,
|
||||||
|
name: &str,
|
||||||
|
item: &AuditItem<'_, E>,
|
||||||
|
) -> Result<(), AuditError> {
|
||||||
|
if self
|
||||||
|
.vacuum_counter
|
||||||
|
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
|
||||||
|
>= 512
|
||||||
|
{
|
||||||
|
self.vacuum().await?;
|
||||||
|
self.vacuum_counter
|
||||||
|
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
let read = self.cur_file.as_ref().unwrap().read().await;
|
||||||
|
|
||||||
|
if let Some(file) = read.get(name) {
|
||||||
|
let mut f = file.lock().await;
|
||||||
|
|
||||||
|
serde_json::to_writer(f.deref_mut(), &item)?;
|
||||||
|
|
||||||
|
let meta = f.metadata()?;
|
||||||
|
|
||||||
|
if let Some(size) = self.options.rotate_size {
|
||||||
|
if meta.len() >= size {
|
||||||
|
drop(f);
|
||||||
|
drop(read);
|
||||||
|
|
||||||
|
self.create_new_file(name).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut write = self.cur_file.as_ref().unwrap().write().await;
|
||||||
|
|
||||||
|
let file = File::create(&self.options.output.join(name))?;
|
||||||
|
|
||||||
|
write.insert(name.to_string(), Mutex::new(file));
|
||||||
|
|
||||||
|
// this is deliberately out of order to make sure we don't create endless files if serialization fails
|
||||||
|
serde_json::to_writer(write.get(name).unwrap().lock().await.deref_mut(), &item)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Audit<E: IntoResponse + 'static, I: HasAppState<E>> {
|
||||||
|
inner: I,
|
||||||
|
state: Arc<AuditState>,
|
||||||
|
_marker: std::marker::PhantomData<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: IntoResponse + 'static, I: HasAppState<E>> Audit<E, I> {
|
||||||
|
pub fn new(inner: I, options: AuditOptions) -> Self {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
state: Arc::new(AuditState::new(options)),
|
||||||
|
_marker: std::marker::PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: IntoResponse + 'static, I: HasAppState<E>> Clone for Audit<E, I> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
inner: self.inner.clone(),
|
||||||
|
state: self.state.clone(),
|
||||||
|
_marker: std::marker::PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
delegate!(state Audit::<E, I>.inner);
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct AuditItem<'r, E: IntoResponse + Serialize> {
|
||||||
|
info: &'r crate::APRequestInfo<'r>,
|
||||||
|
ctx: &'r Option<serde_json::Value>,
|
||||||
|
disposition: &'r Disposition<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl<
|
||||||
|
E: IntoResponse + Serialize + Send + Sync + 'static,
|
||||||
|
I: HasAppState<E> + Evaluator<E> + Sync,
|
||||||
|
> Evaluator<E> for Audit<E, I>
|
||||||
|
{
|
||||||
|
fn name() -> &'static str {
|
||||||
|
"Audit"
|
||||||
|
}
|
||||||
|
async fn evaluate<'r>(
|
||||||
|
&self,
|
||||||
|
ctx: Option<serde_json::Value>,
|
||||||
|
info: &crate::APRequestInfo<'r>,
|
||||||
|
) -> (Disposition<E>, Option<serde_json::Value>) {
|
||||||
|
let (disp, ctx) = self.inner.evaluate(ctx, info).await;
|
||||||
|
|
||||||
|
if ctx
|
||||||
|
.as_ref()
|
||||||
|
.map(|c| c.get("skip_audit").is_some())
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
return (disp, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
let item = AuditItem {
|
||||||
|
info,
|
||||||
|
ctx: &ctx,
|
||||||
|
disposition: &disp,
|
||||||
|
};
|
||||||
|
|
||||||
|
let state = self.state.clone();
|
||||||
|
|
||||||
|
state.write(Self::name(), &item).await.ok();
|
||||||
|
|
||||||
|
(disp, ctx)
|
||||||
|
}
|
||||||
|
}
|
52
src/evaluate/chain/meta.rs
Normal file
52
src/evaluate/chain/meta.rs
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use reqwest::Url;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
delegate,
|
||||||
|
model::ap::{AnyObject, NoteObject},
|
||||||
|
APRequestInfo, HasAppState,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Meta<E: IntoResponse + 'static, I: HasAppState<E>> {
|
||||||
|
inner: I,
|
||||||
|
_marker: std::marker::PhantomData<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_host(act: &APRequestInfo) -> Option<String> {
|
||||||
|
act.activity
|
||||||
|
.as_ref()
|
||||||
|
.ok()?
|
||||||
|
.actor
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|s| Url::parse(s).ok()?.host_str().map(|s| s.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_attributed_to(act: &APRequestInfo) -> Option<String> {
|
||||||
|
match &act.activity.as_ref().ok()?.object {
|
||||||
|
Some(arr) => match arr.clone().into_iter().next() {
|
||||||
|
Some(AnyObject::NoteObject(NoteObject {
|
||||||
|
attributed_to: Some(s),
|
||||||
|
..
|
||||||
|
})) => Some(s),
|
||||||
|
_ => None,
|
||||||
|
},
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_meta(act: &APRequestInfo) -> Option<MetaItem> {
|
||||||
|
Some(MetaItem {
|
||||||
|
instance_host: extract_host(act),
|
||||||
|
attributed_to: extract_attributed_to(act),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
delegate!(state Meta::<E, I>.inner);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct MetaItem {
|
||||||
|
instance_host: Option<String>,
|
||||||
|
attributed_to: Option<String>,
|
||||||
|
}
|
2
src/evaluate/chain/mod.rs
Normal file
2
src/evaluate/chain/mod.rs
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
pub mod audit;
|
||||||
|
pub mod meta;
|
|
@ -1,5 +1,24 @@
|
||||||
use axum::response::IntoResponse;
|
use axum::response::IntoResponse;
|
||||||
use reqwest::StatusCode;
|
use reqwest::StatusCode;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
pub mod chain;
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! delegate {
|
||||||
|
(state $str:ident::<E, I $(,$t:ty),*>.$field:ident) => {
|
||||||
|
use crate::{BaseAppState, ClientCache};
|
||||||
|
impl<E: IntoResponse + Clone + 'static, I: HasAppState<E>> HasAppState<E> for $str::<E, I, $($t),*> {
|
||||||
|
fn app_state(&self) -> &BaseAppState<E> {
|
||||||
|
self.inner.app_state()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn client_pool_ref(&self) -> &ClientCache {
|
||||||
|
self.inner.client_pool_ref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
use crate::{model::error::MisskeyError, APRequestInfo};
|
use crate::{model::error::MisskeyError, APRequestInfo};
|
||||||
|
|
||||||
|
@ -10,14 +29,50 @@ pub const ERROR_DENIED: MisskeyError = MisskeyError::new_const(
|
||||||
"This server cannot accept this activity.",
|
"This server cannot accept this activity.",
|
||||||
);
|
);
|
||||||
|
|
||||||
|
pub const ERR_BAD_REQUEST: MisskeyError = MisskeyError::new_const(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"UNPARSABLE_REQUEST",
|
||||||
|
"659c1254-1392-458e-aa77-557444031da8",
|
||||||
|
"The request is not HTTP compliant.",
|
||||||
|
);
|
||||||
|
|
||||||
|
pub const ERR_INTERNAL_SERVER_ERROR: MisskeyError = MisskeyError::new_const(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"INTERNAL_SERVER_ERROR",
|
||||||
|
"3a19659e-89e9-4c37-ada8-bbf620f2fc1a",
|
||||||
|
"An internal server error occurred.",
|
||||||
|
);
|
||||||
|
|
||||||
|
pub const ERR_SERVICE_TEMPORARILY_UNAVAILABLE: MisskeyError = MisskeyError::new_const(
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
"SERVICE_TEMPORARILY_UNAVAILABLE",
|
||||||
|
"39992bed-f58f-484d-bb67-8c8db6f0b224",
|
||||||
|
"The service is temporarily unavailable.",
|
||||||
|
);
|
||||||
|
|
||||||
|
pub const ERR_PAYLOAD_TOO_LARGE: MisskeyError = MisskeyError::new_const(
|
||||||
|
StatusCode::PAYLOAD_TOO_LARGE,
|
||||||
|
"PAYLOAD_TOO_LARGE",
|
||||||
|
"8007d1b7-0eab-41b2-bb17-95f06926ba2b",
|
||||||
|
"The request payload is too large.",
|
||||||
|
);
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
pub enum Disposition<E: IntoResponse> {
|
pub enum Disposition<E: IntoResponse> {
|
||||||
|
Allow,
|
||||||
Next,
|
Next,
|
||||||
Intercept(E),
|
Intercept(E),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
pub trait Evaluator<E: IntoResponse> {
|
pub trait Evaluator<E: IntoResponse> {
|
||||||
async fn evaluate<'r>(&self, info: &APRequestInfo<'r>) -> Disposition<E>;
|
fn name() -> &'static str;
|
||||||
|
|
||||||
|
async fn evaluate<'r>(
|
||||||
|
&self,
|
||||||
|
ctx: Option<serde_json::Value>,
|
||||||
|
info: &APRequestInfo<'r>,
|
||||||
|
) -> (Disposition<E>, Option<serde_json::Value>);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E: IntoResponse> Into<Disposition<E>> for Result<(), E> {
|
impl<E: IntoResponse> Into<Disposition<E>> for Result<(), E> {
|
||||||
|
@ -37,7 +92,14 @@ impl<
|
||||||
F: Fn(&APRequestInfo<'_>) -> Fut + Send + Sync,
|
F: Fn(&APRequestInfo<'_>) -> Fut + Send + Sync,
|
||||||
> Evaluator<E> for F
|
> Evaluator<E> for F
|
||||||
{
|
{
|
||||||
async fn evaluate<'r>(&self, info: &APRequestInfo<'r>) -> Disposition<E> {
|
fn name() -> &'static str {
|
||||||
(*self)(info).await.into()
|
"Closure"
|
||||||
|
}
|
||||||
|
async fn evaluate<'r>(
|
||||||
|
&self,
|
||||||
|
ctx: Option<serde_json::Value>,
|
||||||
|
info: &APRequestInfo<'r>,
|
||||||
|
) -> (Disposition<E>, Option<serde_json::Value>) {
|
||||||
|
(self(info).await.into(), ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
187
src/lib.rs
187
src/lib.rs
|
@ -1,9 +1,11 @@
|
||||||
|
#![doc = include_str!("../README.md")]
|
||||||
#![feature(ip)]
|
#![feature(ip)]
|
||||||
#![feature(async_closure)]
|
#![feature(async_closure)]
|
||||||
#![warn(clippy::unwrap_used, clippy::expect_used)]
|
#![warn(clippy::unwrap_used, clippy::expect_used)]
|
||||||
#![warn(unsafe_code)]
|
#![warn(unsafe_code)]
|
||||||
#![warn(missing_docs)]
|
#![warn(missing_docs)]
|
||||||
|
|
||||||
|
use std::marker::PhantomData;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::{fmt::Display, net::SocketAddr};
|
use std::{fmt::Display, net::SocketAddr};
|
||||||
|
|
||||||
|
@ -14,28 +16,65 @@ use axum::{
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use client::ClientCache;
|
use client::ClientCache;
|
||||||
use evaluate::{Disposition, Evaluator};
|
use evaluate::chain::audit::{Audit, AuditOptions};
|
||||||
|
use evaluate::{
|
||||||
|
Disposition, Evaluator, ERR_BAD_REQUEST, ERR_INTERNAL_SERVER_ERROR, ERR_PAYLOAD_TOO_LARGE,
|
||||||
|
ERR_SERVICE_TEMPORARILY_UNAVAILABLE,
|
||||||
|
};
|
||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
use model::ap::Object;
|
use model::ap::AnyObject;
|
||||||
use model::{ap, error::MisskeyError};
|
use model::{ap, error::MisskeyError};
|
||||||
use network::stream::LimitedStream;
|
use network::stream::LimitedStream;
|
||||||
use network::Either;
|
use network::Either;
|
||||||
use reqwest::{self, StatusCode};
|
use reqwest::{self};
|
||||||
|
use serde::ser::{SerializeMap, SerializeStruct};
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
pub(crate) mod client;
|
pub(crate) mod client;
|
||||||
|
/// Evaluation framework
|
||||||
pub mod evaluate;
|
pub mod evaluate;
|
||||||
|
/// Data models
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub(crate) mod network;
|
pub(crate) mod network;
|
||||||
|
/// Server implementation
|
||||||
pub mod serve;
|
pub mod serve;
|
||||||
|
/// Data sources
|
||||||
pub mod source;
|
pub mod source;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
/// Information about an incoming ActivityPub request.
|
||||||
|
#[allow(missing_docs)]
|
||||||
pub struct APRequestInfo<'r> {
|
pub struct APRequestInfo<'r> {
|
||||||
pub method: &'r Method,
|
pub method: &'r Method,
|
||||||
pub uri: &'r Uri,
|
pub uri: &'r Uri,
|
||||||
pub header: &'r HeaderMap,
|
pub header: &'r HeaderMap,
|
||||||
pub connect: &'r SocketAddr,
|
pub connect: &'r SocketAddr,
|
||||||
pub activity: &'r Result<ap::Activity<Object>, (serde_json::Value, serde_json::Error)>,
|
pub activity: &'r Result<ap::Activity<AnyObject>, (serde_json::Value, serde_json::Error)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for APRequestInfo<'_> {
|
||||||
|
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||||
|
let mut state = serializer.serialize_struct("APRequestInfo", 5)?;
|
||||||
|
state.serialize_field("method", &self.method.as_str())?;
|
||||||
|
state.serialize_field("uri", &self.uri.to_string())?;
|
||||||
|
pub struct SerializeHeader<'a>(&'a HeaderMap);
|
||||||
|
impl<'a> Serialize for SerializeHeader<'a> {
|
||||||
|
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||||
|
let mut header = serializer.serialize_map(Some(self.0.len()))?;
|
||||||
|
for (key, value) in self.0.iter() {
|
||||||
|
header.serialize_entry(key.as_str(), &value.to_str().unwrap_or_default())?;
|
||||||
|
}
|
||||||
|
header.end()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.serialize_field("header", &SerializeHeader(self.header))?;
|
||||||
|
state.serialize_field("connect", &self.connect)?;
|
||||||
|
state.serialize_field(
|
||||||
|
"activity",
|
||||||
|
&self.activity.as_ref().map_err(|(v, e)| (v, e.to_string())),
|
||||||
|
)?;
|
||||||
|
state.end()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for APRequestInfo<'_> {
|
impl Display for APRequestInfo<'_> {
|
||||||
|
@ -48,80 +87,108 @@ impl Display for APRequestInfo<'_> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct AppState<E: IntoResponse + 'static> {
|
/// Trait for accessing the application state.
|
||||||
backend: reqwest::Url,
|
pub trait HasAppState<E: IntoResponse + 'static>: Clone {
|
||||||
clients: ClientCache,
|
/// Get a reference to the application state.
|
||||||
inbox_stack: Vec<Box<dyn Evaluator<E> + Send + Sync>>,
|
fn app_state(&self) -> &BaseAppState<E>;
|
||||||
|
|
||||||
|
/// Get a reference to the client pool.
|
||||||
|
fn client_pool_ref(&self) -> &ClientCache;
|
||||||
|
|
||||||
|
/// Wrap the evaluator in an audit chain.
|
||||||
|
fn audited(self, opts: AuditOptions) -> Audit<E, Self>
|
||||||
|
where
|
||||||
|
Self: Sized + Send + Sync + Evaluator<E>,
|
||||||
|
E: Send + Sync,
|
||||||
|
{
|
||||||
|
Audit::new(self, opts)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E: IntoResponse> AppState<E> {
|
/// Application state.
|
||||||
pub fn new(backend: reqwest::Url) -> Self {
|
pub struct BaseAppState<E: IntoResponse + 'static> {
|
||||||
Self {
|
backend: reqwest::Url,
|
||||||
backend,
|
clients: ClientCache,
|
||||||
clients: ClientCache::new(),
|
ctx_template: Option<serde_json::Value>,
|
||||||
inbox_stack: Vec::new(),
|
_marker: PhantomData<E>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<E: IntoResponse> HasAppState<E> for Arc<BaseAppState<E>> {
|
||||||
|
fn app_state(&self) -> &BaseAppState<E> {
|
||||||
|
&self
|
||||||
}
|
}
|
||||||
pub fn push_evaluator(&mut self, evaluator: Box<dyn Evaluator<E> + Send + Sync>) -> &mut Self {
|
|
||||||
self.inbox_stack.push(evaluator);
|
fn client_pool_ref(&self) -> &ClientCache {
|
||||||
self
|
|
||||||
}
|
|
||||||
pub fn client_pool_ref(&self) -> &ClientCache {
|
|
||||||
&self.clients
|
&self.clients
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const ERR_BAD_REQUEST: MisskeyError = MisskeyError::new_const(
|
#[async_trait::async_trait]
|
||||||
StatusCode::BAD_REQUEST,
|
impl<E: IntoResponse + Send + Sync> Evaluator<E> for Arc<BaseAppState<E>> {
|
||||||
"UNPARSABLE_REQUEST",
|
fn name() -> &'static str {
|
||||||
"659c1254-1392-458e-aa77-557444031da8",
|
"BaseAllow"
|
||||||
"The request is not HTTP compliant.",
|
}
|
||||||
);
|
async fn evaluate<'r>(
|
||||||
|
&self,
|
||||||
|
ctx: Option<serde_json::Value>,
|
||||||
|
_info: &APRequestInfo<'r>,
|
||||||
|
) -> (Disposition<E>, Option<serde_json::Value>) {
|
||||||
|
log::trace!("Evaluator fell through, accepting request");
|
||||||
|
(Disposition::Allow, ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const ERR_INTERNAL_SERVER_ERROR: MisskeyError = MisskeyError::new_const(
|
impl<E: IntoResponse> BaseAppState<E> {
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
/// Create a new application state.
|
||||||
"INTERNAL_SERVER_ERROR",
|
pub fn new(backend: reqwest::Url) -> Self {
|
||||||
"3a19659e-89e9-4c37-ada8-bbf620f2fc1a",
|
Self {
|
||||||
"An internal server error occurred.",
|
backend,
|
||||||
);
|
ctx_template: None,
|
||||||
|
clients: ClientCache::new(),
|
||||||
|
_marker: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Set the context template.
|
||||||
|
pub fn with_ctx_template(mut self, ctx: serde_json::Value) -> Self {
|
||||||
|
self.ctx_template = Some(ctx);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// With an empty context template.
|
||||||
|
pub fn with_empty_ctx(self) -> Self {
|
||||||
|
self.with_ctx_template(serde_json::Value::Object(Default::default()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const ERR_SERVICE_TEMPORARILY_UNAVAILABLE: MisskeyError = MisskeyError::new_const(
|
/// The main application.
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
pub struct ProxyApp<S, E>(std::marker::PhantomData<(S, E)>);
|
||||||
"SERVICE_TEMPORARILY_UNAVAILABLE",
|
|
||||||
"39992bed-f58f-484d-bb67-8c8db6f0b224",
|
|
||||||
"The service is temporarily unavailable.",
|
|
||||||
);
|
|
||||||
|
|
||||||
const ERR_PAYLOAD_TOO_LARGE: MisskeyError = MisskeyError::new_const(
|
|
||||||
StatusCode::PAYLOAD_TOO_LARGE,
|
|
||||||
"PAYLOAD_TOO_LARGE",
|
|
||||||
"8007d1b7-0eab-41b2-bb17-95f06926ba2b",
|
|
||||||
"The request payload is too large.",
|
|
||||||
);
|
|
||||||
|
|
||||||
pub struct ProxyApp<E>(std::marker::PhantomData<E>);
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
impl<E: IntoResponse + 'static> ProxyApp<E> {
|
impl<
|
||||||
|
E: IntoResponse + 'static + Send + Sync,
|
||||||
|
S: HasAppState<E> + Evaluator<E> + Send + Sync + 'static,
|
||||||
|
> ProxyApp<S, E>
|
||||||
|
{
|
||||||
/// Pass through the request to the backend with basic error handling.
|
/// Pass through the request to the backend with basic error handling.
|
||||||
pub async fn pass_through(
|
pub async fn pass_through(
|
||||||
method: Method,
|
method: Method,
|
||||||
State(app): State<Arc<AppState<E>>>,
|
State(app): State<S>,
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
OriginalUri(uri): OriginalUri,
|
OriginalUri(uri): OriginalUri,
|
||||||
header: HeaderMap,
|
header: HeaderMap,
|
||||||
body: Body,
|
body: Body,
|
||||||
) -> Result<impl IntoResponse, MisskeyError> {
|
) -> Result<impl IntoResponse, MisskeyError> {
|
||||||
let path_and_query = uri.path_and_query().ok_or(ERR_BAD_REQUEST)?;
|
let path_and_query = uri.path_and_query().ok_or(ERR_BAD_REQUEST)?;
|
||||||
|
let state = app.app_state();
|
||||||
|
|
||||||
app.clone()
|
app.app_state()
|
||||||
.clients
|
.clients
|
||||||
.with_client(&addr, |client| {
|
.with_client(&addr, |client| {
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let req = client
|
let req = client
|
||||||
.request(
|
.request(
|
||||||
method.clone(),
|
method.clone(),
|
||||||
app.backend
|
state
|
||||||
|
.backend
|
||||||
.join(&path_and_query.to_string())
|
.join(&path_and_query.to_string())
|
||||||
.map_err(|_| ERR_INTERNAL_SERVER_ERROR)?,
|
.map_err(|_| ERR_INTERNAL_SERVER_ERROR)?,
|
||||||
)
|
)
|
||||||
|
@ -165,8 +232,9 @@ impl<E: IntoResponse + 'static> ProxyApp<E> {
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Handle incoming ActivityPub requests.
|
||||||
pub async fn inbox_handler(
|
pub async fn inbox_handler(
|
||||||
State(app): State<Arc<AppState<E>>>,
|
State(app): State<S>,
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
OriginalUri(uri): OriginalUri,
|
OriginalUri(uri): OriginalUri,
|
||||||
header: HeaderMap,
|
header: HeaderMap,
|
||||||
|
@ -196,7 +264,7 @@ impl<E: IntoResponse + 'static> ProxyApp<E> {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let decode = {
|
let decode = {
|
||||||
let activity_decode = serde_json::from_slice::<ap::Activity<Object>>(&body);
|
let activity_decode = serde_json::from_slice::<ap::Activity<AnyObject>>(&body);
|
||||||
|
|
||||||
match activity_decode {
|
match activity_decode {
|
||||||
Ok(activity) => Ok(activity),
|
Ok(activity) => Ok(activity),
|
||||||
|
@ -215,21 +283,22 @@ impl<E: IntoResponse + 'static> ProxyApp<E> {
|
||||||
activity: &decode,
|
activity: &decode,
|
||||||
};
|
};
|
||||||
|
|
||||||
for evaluator in &app.inbox_stack {
|
let ctx = app.app_state().ctx_template.clone();
|
||||||
match evaluator.evaluate(&info).await {
|
match app.evaluate(ctx, &info).await.0 {
|
||||||
Disposition::Next => {}
|
Disposition::Intercept(e) => return Err(Either::B(e)),
|
||||||
Disposition::Intercept(e) => return Err(Either::B(e)),
|
_ => {}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
app.clone()
|
app.clone()
|
||||||
|
.app_state()
|
||||||
.clients
|
.clients
|
||||||
.with_client(&addr, |client| {
|
.with_client(&addr, |client| {
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let req = client
|
let req = client
|
||||||
.request(
|
.request(
|
||||||
Method::POST,
|
Method::POST,
|
||||||
app.backend
|
app.app_state()
|
||||||
|
.backend
|
||||||
.join(&path_and_query.to_string())
|
.join(&path_and_query.to_string())
|
||||||
.map_err(|_| Either::A(ERR_INTERNAL_SERVER_ERROR))?,
|
.map_err(|_| Either::A(ERR_INTERNAL_SERVER_ERROR))?,
|
||||||
)
|
)
|
||||||
|
|
85
src/main.rs
85
src/main.rs
|
@ -1,14 +1,15 @@
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
|
use axum::response::IntoResponse;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use fedivet::evaluate::ERROR_DENIED;
|
use fedivet::evaluate::chain::audit::AuditOptions;
|
||||||
|
use fedivet::evaluate::Evaluator;
|
||||||
use fedivet::model::error::MisskeyError;
|
use fedivet::model::error::MisskeyError;
|
||||||
use fedivet::serve;
|
use fedivet::serve;
|
||||||
use fedivet::source::LruData;
|
use fedivet::BaseAppState;
|
||||||
use fedivet::APRequestInfo;
|
use fedivet::HasAppState;
|
||||||
use fedivet::AppState;
|
use serde::Serialize;
|
||||||
use reqwest::Url;
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
pub struct Args {
|
pub struct Args {
|
||||||
|
@ -23,53 +24,11 @@ pub struct Args {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::unused_async)]
|
#[allow(clippy::unused_async)]
|
||||||
async fn build_state(args: &Args) -> AppState<MisskeyError> {
|
async fn build_state<E: IntoResponse + Clone + Serialize + Send + Sync + 'static>(
|
||||||
let mut state = AppState::new(args.backend.parse().expect("Invalid backend URL"));
|
base: Arc<BaseAppState<E>>,
|
||||||
|
_args: &Args,
|
||||||
let instance_history = Arc::new(LruData::sized(
|
) -> impl HasAppState<E> + Evaluator<E> {
|
||||||
&|host| async move { Ok::<_, ()>("Todo") },
|
base.audited(AuditOptions::new(PathBuf::from("inbox_audit")))
|
||||||
512.try_into().unwrap(),
|
|
||||||
Some(Duration::from_secs(600)),
|
|
||||||
));
|
|
||||||
|
|
||||||
state.push_evaluator(Box::new(move |info: &APRequestInfo<'_>| {
|
|
||||||
let act = info.activity.as_ref().map_err(|_| ERROR_DENIED).cloned();
|
|
||||||
let instance_history = Arc::clone(&instance_history);
|
|
||||||
|
|
||||||
async move {
|
|
||||||
let act = act?;
|
|
||||||
let host = act
|
|
||||||
.actor
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|s| Url::parse(s).ok())
|
|
||||||
.ok_or(ERROR_DENIED)?;
|
|
||||||
|
|
||||||
let instance = instance_history
|
|
||||||
.query(host.host_str().unwrap().to_owned())
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match instance {
|
|
||||||
Ok(i) => {
|
|
||||||
log::info!("Instance history: {:?}", i);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Err(_) => Err::<(), _>(ERROR_DENIED),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
|
|
||||||
// let user_history = Arc::new(LruData::sized( ... ));
|
|
||||||
|
|
||||||
state.push_evaluator(Box::new(|info: &APRequestInfo<'_>| {
|
|
||||||
let act = info.activity.as_ref().map_err(|_| ERROR_DENIED).cloned();
|
|
||||||
async move {
|
|
||||||
let act = act?;
|
|
||||||
log::debug!("Activity: {:?}", act);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
|
|
||||||
state
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
@ -81,7 +40,13 @@ async fn main() {
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let state = Arc::new(build_state(&args).await);
|
let state = build_state::<MisskeyError>(
|
||||||
|
Arc::new(BaseAppState::new(
|
||||||
|
args.backend.parse().expect("Invalid backend URL"),
|
||||||
|
)),
|
||||||
|
&args,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
let (jh, handle) = serve::serve(
|
let (jh, handle) = serve::serve(
|
||||||
state.clone(),
|
state.clone(),
|
||||||
|
@ -97,10 +62,16 @@ async fn main() {
|
||||||
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||||
.expect("Failed to register SIGTERM handler");
|
.expect("Failed to register SIGTERM handler");
|
||||||
|
|
||||||
tokio::select! {
|
tokio::spawn(async move {
|
||||||
_ = gc_ticker.tick() => {
|
loop {
|
||||||
state.client_pool_ref().gc(std::time::Duration::from_secs(120));
|
gc_ticker.tick().await;
|
||||||
|
state
|
||||||
|
.client_pool_ref()
|
||||||
|
.gc(std::time::Duration::from_secs(120));
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
res = jh => {
|
res = jh => {
|
||||||
if let Err(e) = res {
|
if let Err(e) = res {
|
||||||
log::error!("Server error: {}", e);
|
log::error!("Server error: {}", e);
|
||||||
|
|
|
@ -6,6 +6,8 @@ use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
|
/// Serialize as either A or B.
|
||||||
|
#[allow(missing_docs)]
|
||||||
pub enum Either<A, B> {
|
pub enum Either<A, B> {
|
||||||
A(A),
|
A(A),
|
||||||
B(B),
|
B(B),
|
||||||
|
@ -13,14 +15,16 @@ pub enum Either<A, B> {
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
|
/// Represents a single value or multiple values.
|
||||||
|
#[allow(missing_docs)]
|
||||||
pub enum FlatArray<T> {
|
pub enum FlatArray<T> {
|
||||||
Single(T),
|
Single(T),
|
||||||
Multiple(Vec<T>),
|
Multiple(Vec<T>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoIterator for FlatArray<Object> {
|
impl<O> IntoIterator for FlatArray<O> {
|
||||||
type Item = Object;
|
type Item = O;
|
||||||
type IntoIter = std::vec::IntoIter<Object>;
|
type IntoIter = std::vec::IntoIter<O>;
|
||||||
|
|
||||||
fn into_iter(self) -> Self::IntoIter {
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
match self {
|
match self {
|
||||||
|
@ -31,6 +35,8 @@ impl IntoIterator for FlatArray<Object> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
/// Represents an ActivityStreams object.
|
||||||
|
#[allow(missing_docs)]
|
||||||
pub struct Object {
|
pub struct Object {
|
||||||
#[serde(rename = "@context")]
|
#[serde(rename = "@context")]
|
||||||
pub context: FlatArray<Either<String, serde_json::Value>>,
|
pub context: FlatArray<Either<String, serde_json::Value>>,
|
||||||
|
@ -38,14 +44,18 @@ pub struct Object {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
pub ty_: String,
|
pub ty_: String,
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
|
pub published: Option<chrono::DateTime<chrono::Utc>>,
|
||||||
pub to: Option<FlatArray<String>>,
|
pub to: Option<FlatArray<String>>,
|
||||||
pub cc: Option<FlatArray<String>>,
|
pub cc: Option<FlatArray<String>>,
|
||||||
pub bcc: Option<FlatArray<String>>,
|
pub bcc: Option<FlatArray<String>>,
|
||||||
|
pub url: Option<String>,
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub rest: HashMap<String, serde_json::Value>,
|
pub rest: HashMap<String, serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
/// Represents an ActivityStreams activity.
|
||||||
|
#[allow(missing_docs)]
|
||||||
pub struct Activity<O> {
|
pub struct Activity<O> {
|
||||||
pub actor: Option<String>,
|
pub actor: Option<String>,
|
||||||
pub object: Option<FlatArray<O>>,
|
pub object: Option<FlatArray<O>>,
|
||||||
|
@ -54,10 +64,42 @@ pub struct Activity<O> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
/// Represents an ActivityStreams note object.
|
||||||
|
#[allow(missing_docs)]
|
||||||
pub struct NoteObject {
|
pub struct NoteObject {
|
||||||
pub summary: Option<String>,
|
pub summary: Option<String>,
|
||||||
#[serde(rename = "inReplyTo")]
|
#[serde(rename = "inReplyTo")]
|
||||||
pub in_reply_to: Option<String>,
|
pub in_reply_to: Option<String>,
|
||||||
#[serde(rename = "attributedTo")]
|
#[serde(rename = "attributedTo")]
|
||||||
pub attributed_to: Option<String>,
|
pub attributed_to: Option<String>,
|
||||||
|
|
||||||
|
pub sensitive: Option<bool>,
|
||||||
|
pub content: Option<String>,
|
||||||
|
#[serde(rename = "contentMap")]
|
||||||
|
pub content_map: Option<HashMap<String, String>>,
|
||||||
|
|
||||||
|
pub attachment: Option<FlatArray<Attachment>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
/// Represents an ActivityStreams attachment.
|
||||||
|
#[allow(missing_docs)]
|
||||||
|
pub struct Attachment {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub ty_: String,
|
||||||
|
pub media_type: Option<String>,
|
||||||
|
pub url: String,
|
||||||
|
pub name: Option<String>,
|
||||||
|
pub blurhash: Option<String>,
|
||||||
|
pub width: Option<u32>,
|
||||||
|
pub height: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
/// Represents any parsable ActivityStreams object.
|
||||||
|
#[allow(missing_docs)]
|
||||||
|
pub enum AnyObject {
|
||||||
|
NoteObject(NoteObject),
|
||||||
|
Object(Object),
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,14 +10,17 @@ use axum::{
|
||||||
use reqwest::StatusCode;
|
use reqwest::StatusCode;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
|
/// A trait for responses that can be converted into an axum response.
|
||||||
pub trait APResponse: Debug + Display + Send + IntoResponse {}
|
pub trait APResponse: Debug + Display + Send + IntoResponse {}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
/// A response that should be ignored.
|
||||||
pub struct Ignore {
|
pub struct Ignore {
|
||||||
reason: Cow<'static, str>,
|
reason: Cow<'static, str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Ignore {
|
impl Ignore {
|
||||||
|
/// Create a new Ignore response.
|
||||||
pub fn new(reason: impl Into<Cow<'static, str>>) -> Self {
|
pub fn new(reason: impl Into<Cow<'static, str>>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
reason: reason.into(),
|
reason: reason.into(),
|
||||||
|
@ -40,6 +43,7 @@ impl IntoResponse for Ignore {
|
||||||
impl APResponse for Ignore {}
|
impl APResponse for Ignore {}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
/// Errors that look like Misskey's error response.
|
||||||
pub struct MisskeyError {
|
pub struct MisskeyError {
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
status: StatusCode,
|
status: StatusCode,
|
||||||
|
@ -49,6 +53,7 @@ pub struct MisskeyError {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MisskeyError {
|
impl MisskeyError {
|
||||||
|
/// Create a new MisskeyError using const fn.
|
||||||
pub const fn new_const(
|
pub const fn new_const(
|
||||||
status: StatusCode,
|
status: StatusCode,
|
||||||
code: &'static str,
|
code: &'static str,
|
||||||
|
@ -62,6 +67,7 @@ impl MisskeyError {
|
||||||
message: Cow::Borrowed(message),
|
message: Cow::Borrowed(message),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/// Create a new MisskeyError.
|
||||||
pub fn new(
|
pub fn new(
|
||||||
status: StatusCode,
|
status: StatusCode,
|
||||||
code: &'static str,
|
code: &'static str,
|
||||||
|
@ -76,6 +82,7 @@ impl MisskeyError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// HTTP 413 Payload Too Large
|
||||||
pub fn s_413(
|
pub fn s_413(
|
||||||
code: &'static str,
|
code: &'static str,
|
||||||
id: impl Into<Cow<'static, str>>,
|
id: impl Into<Cow<'static, str>>,
|
||||||
|
@ -84,6 +91,7 @@ impl MisskeyError {
|
||||||
Self::new(StatusCode::PAYLOAD_TOO_LARGE, code, id, message)
|
Self::new(StatusCode::PAYLOAD_TOO_LARGE, code, id, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// HTTP 500 Internal Server Error
|
||||||
pub fn s_500(
|
pub fn s_500(
|
||||||
code: &'static str,
|
code: &'static str,
|
||||||
id: impl Into<Cow<'static, str>>,
|
id: impl Into<Cow<'static, str>>,
|
||||||
|
@ -92,6 +100,7 @@ impl MisskeyError {
|
||||||
Self::new(StatusCode::INTERNAL_SERVER_ERROR, code, id, message)
|
Self::new(StatusCode::INTERNAL_SERVER_ERROR, code, id, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// HTTP 400 Bad Request
|
||||||
pub fn s_400(
|
pub fn s_400(
|
||||||
code: &'static str,
|
code: &'static str,
|
||||||
id: impl Into<Cow<'static, str>>,
|
id: impl Into<Cow<'static, str>>,
|
||||||
|
|
|
@ -1,2 +1,4 @@
|
||||||
|
/// ActivityPub types
|
||||||
pub mod ap;
|
pub mod ap;
|
||||||
|
/// Error handling
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
|
30
src/serve.rs
30
src/serve.rs
|
@ -1,14 +1,17 @@
|
||||||
use std::{net::SocketAddr, sync::Arc};
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
routing::{any, post},
|
routing::{any, post},
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use axum_server::{tls_rustls::RustlsConfig, Handle};
|
use axum_server::Handle;
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
use crate::{APRequestInfo, AppState, ProxyApp};
|
use crate::{evaluate::Evaluator, HasAppState, ProxyApp};
|
||||||
|
|
||||||
|
/// Flag indicating whether the TLS feature is enabled
|
||||||
|
pub const HAS_TLS_FEATURE: bool = cfg!(feature = "tls");
|
||||||
|
|
||||||
#[allow(clippy::panic)]
|
#[allow(clippy::panic)]
|
||||||
#[allow(clippy::unwrap_used)]
|
#[allow(clippy::unwrap_used)]
|
||||||
|
@ -21,8 +24,11 @@ use crate::{APRequestInfo, AppState, ProxyApp};
|
||||||
/// - `tls_key`: The path to the TLS key file
|
/// - `tls_key`: The path to the TLS key file
|
||||||
///
|
///
|
||||||
/// Use [`tokio::select`] to listen on multiple addresses
|
/// Use [`tokio::select`] to listen on multiple addresses
|
||||||
pub async fn serve<E: IntoResponse + 'static>(
|
pub async fn serve<
|
||||||
state: Arc<AppState<E>>,
|
S: HasAppState<E> + Evaluator<E> + Send + Sync + 'static,
|
||||||
|
E: IntoResponse + Send + Sync + 'static,
|
||||||
|
>(
|
||||||
|
state: S,
|
||||||
listen: &str,
|
listen: &str,
|
||||||
tls_cert: Option<&str>,
|
tls_cert: Option<&str>,
|
||||||
tls_key: Option<&str>,
|
tls_key: Option<&str>,
|
||||||
|
@ -30,11 +36,11 @@ pub async fn serve<E: IntoResponse + 'static>(
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route(
|
.route(
|
||||||
"/inbox",
|
"/inbox",
|
||||||
post(ProxyApp::inbox_handler)
|
post(ProxyApp::<S, E>::inbox_handler)
|
||||||
.put(ProxyApp::inbox_handler)
|
.put(ProxyApp::<S, E>::inbox_handler)
|
||||||
.patch(ProxyApp::inbox_handler),
|
.patch(ProxyApp::<S, E>::inbox_handler),
|
||||||
)
|
)
|
||||||
.fallback(any(ProxyApp::pass_through));
|
.fallback(any(ProxyApp::<S, E>::pass_through));
|
||||||
|
|
||||||
let ms = app
|
let ms = app
|
||||||
.with_state(state)
|
.with_state(state)
|
||||||
|
@ -42,7 +48,9 @@ pub async fn serve<E: IntoResponse + 'static>(
|
||||||
|
|
||||||
let handle = Handle::new();
|
let handle = Handle::new();
|
||||||
match (tls_cert, tls_key) {
|
match (tls_cert, tls_key) {
|
||||||
|
#[cfg(feature = "tls")]
|
||||||
(Some(cert), Some(key)) => {
|
(Some(cert), Some(key)) => {
|
||||||
|
use axum_server::tls_rustls::RustlsConfig;
|
||||||
let tls_config = RustlsConfig::from_pem_file(cert, key)
|
let tls_config = RustlsConfig::from_pem_file(cert, key)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to load TLS certificate and key");
|
.expect("Failed to load TLS certificate and key");
|
||||||
|
@ -54,6 +62,10 @@ pub async fn serve<E: IntoResponse + 'static>(
|
||||||
let jh = tokio::spawn(server);
|
let jh = tokio::spawn(server);
|
||||||
(jh, handle)
|
(jh, handle)
|
||||||
}
|
}
|
||||||
|
#[cfg(not(feature = "tls"))]
|
||||||
|
(Some(_), Some(_)) => {
|
||||||
|
panic!("TLS support is not enabled")
|
||||||
|
}
|
||||||
(None, None) => {
|
(None, None) => {
|
||||||
log::info!("Listening on http://{}", listen);
|
log::info!("Listening on http://{}", listen);
|
||||||
let server = axum_server::bind(listen.parse().expect("invalid listen addr"))
|
let server = axum_server::bind(listen.parse().expect("invalid listen addr"))
|
||||||
|
|
|
@ -10,9 +10,12 @@ use std::{
|
||||||
|
|
||||||
use lru::LruCache;
|
use lru::LruCache;
|
||||||
|
|
||||||
|
/// Friend instance data source
|
||||||
pub mod friend;
|
pub mod friend;
|
||||||
|
/// Nodeinfo data source
|
||||||
pub mod nodeinfo;
|
pub mod nodeinfo;
|
||||||
|
|
||||||
|
/// Lazy future that caches its result
|
||||||
pub struct LazyFuture<T: Clone, E, F: Future<Output = Result<T, E>>> {
|
pub struct LazyFuture<T: Clone, E, F: Future<Output = Result<T, E>>> {
|
||||||
future: Mutex<Pin<Box<F>>>,
|
future: Mutex<Pin<Box<F>>>,
|
||||||
ttl: Option<Duration>,
|
ttl: Option<Duration>,
|
||||||
|
@ -20,6 +23,7 @@ pub struct LazyFuture<T: Clone, E, F: Future<Output = Result<T, E>>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> LazyFuture<T, E, F> {
|
impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> LazyFuture<T, E, F> {
|
||||||
|
/// Create a new lazy future
|
||||||
pub fn new(future: F, ttl: Option<Duration>) -> Self {
|
pub fn new(future: F, ttl: Option<Duration>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
future: Mutex::new(Box::pin(future)),
|
future: Mutex::new(Box::pin(future)),
|
||||||
|
@ -73,6 +77,8 @@ struct LazyFutureHandle<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> LazyFutureHandle<T, E, F> {
|
impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> LazyFutureHandle<T, E, F> {
|
||||||
|
/// Get a reference to the inner future
|
||||||
|
#[allow(dead_code)]
|
||||||
pub fn inner_ref(&self) -> &LazyFuture<T, E, F> {
|
pub fn inner_ref(&self) -> &LazyFuture<T, E, F> {
|
||||||
&self.inner
|
&self.inner
|
||||||
}
|
}
|
||||||
|
@ -94,6 +100,7 @@ impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> Future for LazyFuture
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// LRU cache for futures
|
||||||
pub struct LruData<
|
pub struct LruData<
|
||||||
'a,
|
'a,
|
||||||
K: Hash,
|
K: Hash,
|
||||||
|
@ -117,6 +124,7 @@ impl<
|
||||||
FF: Fn(&K) -> F,
|
FF: Fn(&K) -> F,
|
||||||
> LruData<'a, K, T, E, F, FF>
|
> LruData<'a, K, T, E, F, FF>
|
||||||
{
|
{
|
||||||
|
/// Create a new LRU cache with a given size and factory
|
||||||
pub fn sized(factory: &'a FF, size: NonZeroUsize, ttl: Option<Duration>) -> Self {
|
pub fn sized(factory: &'a FF, size: NonZeroUsize, ttl: Option<Duration>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
factory,
|
factory,
|
||||||
|
@ -126,6 +134,7 @@ impl<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Evict all expired entries from the cache
|
||||||
pub fn evict_expired(&self) {
|
pub fn evict_expired(&self) {
|
||||||
let mut inner = self.inner.lock().unwrap();
|
let mut inner = self.inner.lock().unwrap();
|
||||||
|
|
||||||
|
@ -152,6 +161,7 @@ impl<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get or insert a future into the cache
|
||||||
pub fn get_or_insert(&self, key: K) -> Arc<LazyFuture<T, E, F>> {
|
pub fn get_or_insert(&self, key: K) -> Arc<LazyFuture<T, E, F>> {
|
||||||
let factory = self.factory;
|
let factory = self.factory;
|
||||||
let entry = self.inner.lock().unwrap().get(&key).cloned();
|
let entry = self.inner.lock().unwrap().get(&key).cloned();
|
||||||
|
@ -166,6 +176,7 @@ impl<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Query the cache for a given key or insert it if it doesn't exist
|
||||||
pub fn query(&self, key: K) -> impl Future<Output = Result<T, E>> {
|
pub fn query(&self, key: K) -> impl Future<Output = Result<T, E>> {
|
||||||
self.maybe_evict();
|
self.maybe_evict();
|
||||||
let future = self.get_or_insert(key);
|
let future = self.get_or_insert(key);
|
||||||
|
|
Loading…
Reference in a new issue