Refactor stack logic

Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-10-16 18:09:32 -05:00
parent 205958b0cc
commit ab15f3e3ff
No known key found for this signature in database
13 changed files with 727 additions and 132 deletions

113
Cargo.lock generated
View file

@ -32,6 +32,21 @@ version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.15" version = "0.6.15"
@ -304,6 +319,21 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.38"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-targets",
]
[[package]] [[package]]
name = "clang-sys" name = "clang-sys"
version = "1.8.1" version = "1.8.1"
@ -386,6 +416,15 @@ version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "crc32fast"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3"
dependencies = [
"cfg-if",
]
[[package]] [[package]]
name = "crossbeam-utils" name = "crossbeam-utils"
version = "0.8.20" version = "0.8.20"
@ -479,18 +518,31 @@ dependencies = [
"async-trait", "async-trait",
"axum", "axum",
"axum-server", "axum-server",
"chrono",
"clap", "clap",
"dashmap", "dashmap",
"env_logger", "env_logger",
"flate2",
"futures", "futures",
"log", "log",
"lru", "lru",
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",
"thiserror",
"tokio", "tokio",
] ]
[[package]]
name = "flate2"
version = "1.0.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]] [[package]]
name = "fnv" name = "fnv"
version = "1.0.7" version = "1.0.7"
@ -827,6 +879,29 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "iana-time-zone"
version = "0.1.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "idna" name = "idna"
version = "0.5.0" version = "0.5.0"
@ -1029,6 +1104,15 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "object" name = "object"
version = "0.36.5" version = "0.36.5"
@ -1565,6 +1649,26 @@ dependencies = [
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
[[package]]
name = "thiserror"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "tinyvec" name = "tinyvec"
version = "1.8.0" version = "1.8.0"
@ -1878,6 +1982,15 @@ dependencies = [
"rustix", "rustix",
] ]
[[package]]
name = "windows-core"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
dependencies = [
"windows-targets",
]
[[package]] [[package]]
name = "windows-registry" name = "windows-registry"
version = "0.2.0" version = "0.2.0"

View file

@ -3,17 +3,23 @@ name = "fedivet"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[features]
tls = ["axum-server/tls-rustls", "axum-server/rustls-pemfile", "axum-server/tokio-rustls"]
[dependencies] [dependencies]
async-trait = "0.1.83" async-trait = "0.1.83"
axum = "0.7.7" axum = "0.7.7"
axum-server = { version = "0.7.1", features = ["tokio-rustls", "rustls-pemfile", "tls-rustls"] } axum-server = { version = "0.7.1" }
chrono = { version = "0.4.38", features = ["serde"] }
clap = { version = "4.5.20", features = ["derive"] } clap = { version = "4.5.20", features = ["derive"] }
dashmap = "6.1.0" dashmap = "6.1.0"
env_logger = "0.11.5" env_logger = "0.11.5"
flate2 = "1.0.34"
futures = "0.3.31" futures = "0.3.31"
log = "0.4.22" log = "0.4.22"
lru = "0.12.5" lru = "0.12.5"
reqwest = { version = "0.12.8", features = ["stream"] } reqwest = { version = "0.12.8", features = ["stream"] }
serde = { version = "1.0.210", features = ["derive"] } serde = { version = "1.0.210", features = ["derive"] }
serde_json = "1.0.128" serde_json = "1.0.128"
thiserror = "1.0.64"
tokio = { version = "1.40.0", features = ["rt", "rt-multi-thread", "macros", "net", "sync", "fs", "signal", "time"] } tokio = { version = "1.40.0", features = ["rt", "rt-multi-thread", "macros", "net", "sync", "fs", "signal", "time"] }

244
src/evaluate/chain/audit.rs Normal file
View file

@ -0,0 +1,244 @@
use axum::response::IntoResponse;
use serde::Serialize;
use std::collections::HashSet;
use std::fs::{File, OpenOptions};
use std::sync::atomic::AtomicU32;
use std::{collections::HashMap, fmt::Debug, ops::DerefMut, path::PathBuf, sync::Arc};
use tokio::sync::{Mutex, RwLock};
use crate::{
delegate,
evaluate::{Disposition, Evaluator},
HasAppState,
};
/// Save audit logs to a file
pub struct AuditOptions {
/// The path to save the audit logs
pub output: PathBuf,
/// The maximum size of the audit log file before it is rotated
pub rotate_size: Option<u64>,
/// The number of days to keep audit logs before vacuuming
pub vacuum_days: Option<u64>,
}
impl AuditOptions {
/// Create a new set of audit options
pub fn new(output: PathBuf) -> Self {
Self {
output,
rotate_size: None,
vacuum_days: None,
}
}
/// Set the maximum size of the audit log file before it is rotated
pub fn rotate_size(mut self, rotate_size: u64) -> Self {
self.rotate_size = Some(rotate_size);
self
}
/// Set the number of days to keep audit logs before vacuuming
pub fn vacuum_days(mut self, vacuum_days: u64) -> Self {
self.vacuum_days = Some(vacuum_days);
self
}
}
#[derive(Debug, thiserror::Error)]
/// Errors that can occur while writing audit logs
#[non_exhaustive]
#[allow(missing_docs)]
pub enum AuditError {
#[error("Failed to write audit log: {0}")]
WriteError(#[from] std::io::Error),
#[error("Failed to serialize audit log: {0}")]
SerializeError(#[from] serde_json::Error),
}
pub struct AuditState {
options: AuditOptions,
cur_file: Option<RwLock<HashMap<String, Mutex<File>>>>,
vacuum_counter: AtomicU32,
}
impl AuditState {
pub fn new(options: AuditOptions) -> Self {
if !options.output.exists() {
std::fs::create_dir_all(&options.output).expect("Failed to create audit log directory");
}
Self {
options,
cur_file: None,
vacuum_counter: AtomicU32::new(0),
}
}
pub async fn vacuum(&self) -> Result<(), AuditError> {
let read = self.cur_file.as_ref().unwrap().read().await;
let in_use_files = read.keys().cloned().collect::<HashSet<_>>();
let files = std::fs::read_dir(&self.options.output)?
.filter_map(|f| f.ok())
.filter(|f| {
f.file_type().map(|t| t.is_file()).unwrap_or(false)
&& f.file_name().to_string_lossy().ends_with(".json")
})
.filter(|f| !in_use_files.contains(&f.file_name().to_string_lossy().to_string()));
for file in files {
let meta = file.metadata()?;
if let Some(days) = self.options.vacuum_days {
let duration = meta
.modified()?
.elapsed()
.unwrap_or(std::time::Duration::new(0, 0));
if duration.as_secs() >= days * 24 * 60 * 60 {
std::fs::remove_file(file.path())?;
}
}
}
Ok(())
}
pub async fn create_new_file(&self, name: &str) -> Result<(), AuditError> {
let time_str = chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string();
let mut write = self.cur_file.as_ref().unwrap().write().await;
let full_name = format!("{}_{}.json", name, time_str);
let file = OpenOptions::new()
.create(true)
.write(true)
.append(true)
.open(&self.options.output.join(&full_name))?;
write.insert(name.to_string(), Mutex::new(file));
Ok(())
}
pub async fn write<E: IntoResponse + Serialize>(
&self,
name: &str,
item: &AuditItem<'_, E>,
) -> Result<(), AuditError> {
if self
.vacuum_counter
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
>= 512
{
self.vacuum().await?;
self.vacuum_counter
.store(0, std::sync::atomic::Ordering::Relaxed);
}
let read = self.cur_file.as_ref().unwrap().read().await;
if let Some(file) = read.get(name) {
let mut f = file.lock().await;
serde_json::to_writer(f.deref_mut(), &item)?;
let meta = f.metadata()?;
if let Some(size) = self.options.rotate_size {
if meta.len() >= size {
drop(f);
drop(read);
self.create_new_file(name).await?;
}
}
return Ok(());
}
let mut write = self.cur_file.as_ref().unwrap().write().await;
let file = File::create(&self.options.output.join(name))?;
write.insert(name.to_string(), Mutex::new(file));
// this is deliberately out of order to make sure we don't create endless files if serialization fails
serde_json::to_writer(write.get(name).unwrap().lock().await.deref_mut(), &item)?;
Ok(())
}
}
pub struct Audit<E: IntoResponse + 'static, I: HasAppState<E>> {
inner: I,
state: Arc<AuditState>,
_marker: std::marker::PhantomData<E>,
}
impl<E: IntoResponse + 'static, I: HasAppState<E>> Audit<E, I> {
pub fn new(inner: I, options: AuditOptions) -> Self {
Self {
inner,
state: Arc::new(AuditState::new(options)),
_marker: std::marker::PhantomData,
}
}
}
impl<E: IntoResponse + 'static, I: HasAppState<E>> Clone for Audit<E, I> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
state: self.state.clone(),
_marker: std::marker::PhantomData,
}
}
}
delegate!(state Audit::<E, I>.inner);
#[derive(Debug, Serialize)]
pub struct AuditItem<'r, E: IntoResponse + Serialize> {
info: &'r crate::APRequestInfo<'r>,
ctx: &'r Option<serde_json::Value>,
disposition: &'r Disposition<E>,
}
#[async_trait::async_trait]
impl<
E: IntoResponse + Serialize + Send + Sync + 'static,
I: HasAppState<E> + Evaluator<E> + Sync,
> Evaluator<E> for Audit<E, I>
{
fn name() -> &'static str {
"Audit"
}
async fn evaluate<'r>(
&self,
ctx: Option<serde_json::Value>,
info: &crate::APRequestInfo<'r>,
) -> (Disposition<E>, Option<serde_json::Value>) {
let (disp, ctx) = self.inner.evaluate(ctx, info).await;
if ctx
.as_ref()
.map(|c| c.get("skip_audit").is_some())
.unwrap_or(false)
{
return (disp, ctx);
}
let item = AuditItem {
info,
ctx: &ctx,
disposition: &disp,
};
let state = self.state.clone();
state.write(Self::name(), &item).await.ok();
(disp, ctx)
}
}

View file

@ -0,0 +1,52 @@
use axum::response::IntoResponse;
use reqwest::Url;
use serde::{Deserialize, Serialize};
use crate::{
delegate,
model::ap::{AnyObject, NoteObject},
APRequestInfo, HasAppState,
};
#[derive(Clone)]
pub struct Meta<E: IntoResponse + 'static, I: HasAppState<E>> {
inner: I,
_marker: std::marker::PhantomData<E>,
}
pub fn extract_host(act: &APRequestInfo) -> Option<String> {
act.activity
.as_ref()
.ok()?
.actor
.as_ref()
.and_then(|s| Url::parse(s).ok()?.host_str().map(|s| s.to_string()))
}
pub fn extract_attributed_to(act: &APRequestInfo) -> Option<String> {
match &act.activity.as_ref().ok()?.object {
Some(arr) => match arr.clone().into_iter().next() {
Some(AnyObject::NoteObject(NoteObject {
attributed_to: Some(s),
..
})) => Some(s),
_ => None,
},
_ => None,
}
}
pub fn extract_meta(act: &APRequestInfo) -> Option<MetaItem> {
Some(MetaItem {
instance_host: extract_host(act),
attributed_to: extract_attributed_to(act),
})
}
delegate!(state Meta::<E, I>.inner);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetaItem {
instance_host: Option<String>,
attributed_to: Option<String>,
}

View file

@ -0,0 +1,2 @@
pub mod audit;
pub mod meta;

View file

@ -1,5 +1,24 @@
use axum::response::IntoResponse; use axum::response::IntoResponse;
use reqwest::StatusCode; use reqwest::StatusCode;
use serde::Serialize;
pub mod chain;
#[macro_export]
macro_rules! delegate {
(state $str:ident::<E, I $(,$t:ty),*>.$field:ident) => {
use crate::{BaseAppState, ClientCache};
impl<E: IntoResponse + Clone + 'static, I: HasAppState<E>> HasAppState<E> for $str::<E, I, $($t),*> {
fn app_state(&self) -> &BaseAppState<E> {
self.inner.app_state()
}
fn client_pool_ref(&self) -> &ClientCache {
self.inner.client_pool_ref()
}
}
};
}
use crate::{model::error::MisskeyError, APRequestInfo}; use crate::{model::error::MisskeyError, APRequestInfo};
@ -10,14 +29,50 @@ pub const ERROR_DENIED: MisskeyError = MisskeyError::new_const(
"This server cannot accept this activity.", "This server cannot accept this activity.",
); );
pub const ERR_BAD_REQUEST: MisskeyError = MisskeyError::new_const(
StatusCode::BAD_REQUEST,
"UNPARSABLE_REQUEST",
"659c1254-1392-458e-aa77-557444031da8",
"The request is not HTTP compliant.",
);
pub const ERR_INTERNAL_SERVER_ERROR: MisskeyError = MisskeyError::new_const(
StatusCode::INTERNAL_SERVER_ERROR,
"INTERNAL_SERVER_ERROR",
"3a19659e-89e9-4c37-ada8-bbf620f2fc1a",
"An internal server error occurred.",
);
pub const ERR_SERVICE_TEMPORARILY_UNAVAILABLE: MisskeyError = MisskeyError::new_const(
StatusCode::SERVICE_UNAVAILABLE,
"SERVICE_TEMPORARILY_UNAVAILABLE",
"39992bed-f58f-484d-bb67-8c8db6f0b224",
"The service is temporarily unavailable.",
);
pub const ERR_PAYLOAD_TOO_LARGE: MisskeyError = MisskeyError::new_const(
StatusCode::PAYLOAD_TOO_LARGE,
"PAYLOAD_TOO_LARGE",
"8007d1b7-0eab-41b2-bb17-95f06926ba2b",
"The request payload is too large.",
);
#[derive(Debug, Serialize)]
pub enum Disposition<E: IntoResponse> { pub enum Disposition<E: IntoResponse> {
Allow,
Next, Next,
Intercept(E), Intercept(E),
} }
#[async_trait::async_trait] #[async_trait::async_trait]
pub trait Evaluator<E: IntoResponse> { pub trait Evaluator<E: IntoResponse> {
async fn evaluate<'r>(&self, info: &APRequestInfo<'r>) -> Disposition<E>; fn name() -> &'static str;
async fn evaluate<'r>(
&self,
ctx: Option<serde_json::Value>,
info: &APRequestInfo<'r>,
) -> (Disposition<E>, Option<serde_json::Value>);
} }
impl<E: IntoResponse> Into<Disposition<E>> for Result<(), E> { impl<E: IntoResponse> Into<Disposition<E>> for Result<(), E> {
@ -37,7 +92,14 @@ impl<
F: Fn(&APRequestInfo<'_>) -> Fut + Send + Sync, F: Fn(&APRequestInfo<'_>) -> Fut + Send + Sync,
> Evaluator<E> for F > Evaluator<E> for F
{ {
async fn evaluate<'r>(&self, info: &APRequestInfo<'r>) -> Disposition<E> { fn name() -> &'static str {
(*self)(info).await.into() "Closure"
}
async fn evaluate<'r>(
&self,
ctx: Option<serde_json::Value>,
info: &APRequestInfo<'r>,
) -> (Disposition<E>, Option<serde_json::Value>) {
(self(info).await.into(), ctx)
} }
} }

View file

@ -1,9 +1,11 @@
#![doc = include_str!("../README.md")]
#![feature(ip)] #![feature(ip)]
#![feature(async_closure)] #![feature(async_closure)]
#![warn(clippy::unwrap_used, clippy::expect_used)] #![warn(clippy::unwrap_used, clippy::expect_used)]
#![warn(unsafe_code)] #![warn(unsafe_code)]
#![warn(missing_docs)] #![warn(missing_docs)]
use std::marker::PhantomData;
use std::sync::Arc; use std::sync::Arc;
use std::{fmt::Display, net::SocketAddr}; use std::{fmt::Display, net::SocketAddr};
@ -14,28 +16,65 @@ use axum::{
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use client::ClientCache; use client::ClientCache;
use evaluate::{Disposition, Evaluator}; use evaluate::chain::audit::{Audit, AuditOptions};
use evaluate::{
Disposition, Evaluator, ERR_BAD_REQUEST, ERR_INTERNAL_SERVER_ERROR, ERR_PAYLOAD_TOO_LARGE,
ERR_SERVICE_TEMPORARILY_UNAVAILABLE,
};
use futures::TryStreamExt; use futures::TryStreamExt;
use model::ap::Object; use model::ap::AnyObject;
use model::{ap, error::MisskeyError}; use model::{ap, error::MisskeyError};
use network::stream::LimitedStream; use network::stream::LimitedStream;
use network::Either; use network::Either;
use reqwest::{self, StatusCode}; use reqwest::{self};
use serde::ser::{SerializeMap, SerializeStruct};
use serde::Serialize;
pub(crate) mod client; pub(crate) mod client;
/// Evaluation framework
pub mod evaluate; pub mod evaluate;
/// Data models
pub mod model; pub mod model;
pub(crate) mod network; pub(crate) mod network;
/// Server implementation
pub mod serve; pub mod serve;
/// Data sources
pub mod source; pub mod source;
#[derive(Debug)] #[derive(Debug)]
/// Information about an incoming ActivityPub request.
#[allow(missing_docs)]
pub struct APRequestInfo<'r> { pub struct APRequestInfo<'r> {
pub method: &'r Method, pub method: &'r Method,
pub uri: &'r Uri, pub uri: &'r Uri,
pub header: &'r HeaderMap, pub header: &'r HeaderMap,
pub connect: &'r SocketAddr, pub connect: &'r SocketAddr,
pub activity: &'r Result<ap::Activity<Object>, (serde_json::Value, serde_json::Error)>, pub activity: &'r Result<ap::Activity<AnyObject>, (serde_json::Value, serde_json::Error)>,
}
impl Serialize for APRequestInfo<'_> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut state = serializer.serialize_struct("APRequestInfo", 5)?;
state.serialize_field("method", &self.method.as_str())?;
state.serialize_field("uri", &self.uri.to_string())?;
pub struct SerializeHeader<'a>(&'a HeaderMap);
impl<'a> Serialize for SerializeHeader<'a> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut header = serializer.serialize_map(Some(self.0.len()))?;
for (key, value) in self.0.iter() {
header.serialize_entry(key.as_str(), &value.to_str().unwrap_or_default())?;
}
header.end()
}
}
state.serialize_field("header", &SerializeHeader(self.header))?;
state.serialize_field("connect", &self.connect)?;
state.serialize_field(
"activity",
&self.activity.as_ref().map_err(|(v, e)| (v, e.to_string())),
)?;
state.end()
}
} }
impl Display for APRequestInfo<'_> { impl Display for APRequestInfo<'_> {
@ -48,80 +87,108 @@ impl Display for APRequestInfo<'_> {
} }
} }
pub struct AppState<E: IntoResponse + 'static> { /// Trait for accessing the application state.
backend: reqwest::Url, pub trait HasAppState<E: IntoResponse + 'static>: Clone {
clients: ClientCache, /// Get a reference to the application state.
inbox_stack: Vec<Box<dyn Evaluator<E> + Send + Sync>>, fn app_state(&self) -> &BaseAppState<E>;
/// Get a reference to the client pool.
fn client_pool_ref(&self) -> &ClientCache;
/// Wrap the evaluator in an audit chain.
fn audited(self, opts: AuditOptions) -> Audit<E, Self>
where
Self: Sized + Send + Sync + Evaluator<E>,
E: Send + Sync,
{
Audit::new(self, opts)
}
} }
impl<E: IntoResponse> AppState<E> { /// Application state.
pub fn new(backend: reqwest::Url) -> Self { pub struct BaseAppState<E: IntoResponse + 'static> {
Self { backend: reqwest::Url,
backend, clients: ClientCache,
clients: ClientCache::new(), ctx_template: Option<serde_json::Value>,
inbox_stack: Vec::new(), _marker: PhantomData<E>,
} }
impl<E: IntoResponse> HasAppState<E> for Arc<BaseAppState<E>> {
fn app_state(&self) -> &BaseAppState<E> {
&self
} }
pub fn push_evaluator(&mut self, evaluator: Box<dyn Evaluator<E> + Send + Sync>) -> &mut Self {
self.inbox_stack.push(evaluator); fn client_pool_ref(&self) -> &ClientCache {
self
}
pub fn client_pool_ref(&self) -> &ClientCache {
&self.clients &self.clients
} }
} }
const ERR_BAD_REQUEST: MisskeyError = MisskeyError::new_const( #[async_trait::async_trait]
StatusCode::BAD_REQUEST, impl<E: IntoResponse + Send + Sync> Evaluator<E> for Arc<BaseAppState<E>> {
"UNPARSABLE_REQUEST", fn name() -> &'static str {
"659c1254-1392-458e-aa77-557444031da8", "BaseAllow"
"The request is not HTTP compliant.", }
); async fn evaluate<'r>(
&self,
ctx: Option<serde_json::Value>,
_info: &APRequestInfo<'r>,
) -> (Disposition<E>, Option<serde_json::Value>) {
log::trace!("Evaluator fell through, accepting request");
(Disposition::Allow, ctx)
}
}
const ERR_INTERNAL_SERVER_ERROR: MisskeyError = MisskeyError::new_const( impl<E: IntoResponse> BaseAppState<E> {
StatusCode::INTERNAL_SERVER_ERROR, /// Create a new application state.
"INTERNAL_SERVER_ERROR", pub fn new(backend: reqwest::Url) -> Self {
"3a19659e-89e9-4c37-ada8-bbf620f2fc1a", Self {
"An internal server error occurred.", backend,
); ctx_template: None,
clients: ClientCache::new(),
_marker: PhantomData,
}
}
/// Set the context template.
pub fn with_ctx_template(mut self, ctx: serde_json::Value) -> Self {
self.ctx_template = Some(ctx);
self
}
/// With an empty context template.
pub fn with_empty_ctx(self) -> Self {
self.with_ctx_template(serde_json::Value::Object(Default::default()))
}
}
const ERR_SERVICE_TEMPORARILY_UNAVAILABLE: MisskeyError = MisskeyError::new_const( /// The main application.
StatusCode::SERVICE_UNAVAILABLE, pub struct ProxyApp<S, E>(std::marker::PhantomData<(S, E)>);
"SERVICE_TEMPORARILY_UNAVAILABLE",
"39992bed-f58f-484d-bb67-8c8db6f0b224",
"The service is temporarily unavailable.",
);
const ERR_PAYLOAD_TOO_LARGE: MisskeyError = MisskeyError::new_const(
StatusCode::PAYLOAD_TOO_LARGE,
"PAYLOAD_TOO_LARGE",
"8007d1b7-0eab-41b2-bb17-95f06926ba2b",
"The request payload is too large.",
);
pub struct ProxyApp<E>(std::marker::PhantomData<E>);
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
impl<E: IntoResponse + 'static> ProxyApp<E> { impl<
E: IntoResponse + 'static + Send + Sync,
S: HasAppState<E> + Evaluator<E> + Send + Sync + 'static,
> ProxyApp<S, E>
{
/// Pass through the request to the backend with basic error handling. /// Pass through the request to the backend with basic error handling.
pub async fn pass_through( pub async fn pass_through(
method: Method, method: Method,
State(app): State<Arc<AppState<E>>>, State(app): State<S>,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
OriginalUri(uri): OriginalUri, OriginalUri(uri): OriginalUri,
header: HeaderMap, header: HeaderMap,
body: Body, body: Body,
) -> Result<impl IntoResponse, MisskeyError> { ) -> Result<impl IntoResponse, MisskeyError> {
let path_and_query = uri.path_and_query().ok_or(ERR_BAD_REQUEST)?; let path_and_query = uri.path_and_query().ok_or(ERR_BAD_REQUEST)?;
let state = app.app_state();
app.clone() app.app_state()
.clients .clients
.with_client(&addr, |client| { .with_client(&addr, |client| {
Box::pin(async move { Box::pin(async move {
let req = client let req = client
.request( .request(
method.clone(), method.clone(),
app.backend state
.backend
.join(&path_and_query.to_string()) .join(&path_and_query.to_string())
.map_err(|_| ERR_INTERNAL_SERVER_ERROR)?, .map_err(|_| ERR_INTERNAL_SERVER_ERROR)?,
) )
@ -165,8 +232,9 @@ impl<E: IntoResponse + 'static> ProxyApp<E> {
.await .await
} }
/// Handle incoming ActivityPub requests.
pub async fn inbox_handler( pub async fn inbox_handler(
State(app): State<Arc<AppState<E>>>, State(app): State<S>,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
OriginalUri(uri): OriginalUri, OriginalUri(uri): OriginalUri,
header: HeaderMap, header: HeaderMap,
@ -196,7 +264,7 @@ impl<E: IntoResponse + 'static> ProxyApp<E> {
})?; })?;
let decode = { let decode = {
let activity_decode = serde_json::from_slice::<ap::Activity<Object>>(&body); let activity_decode = serde_json::from_slice::<ap::Activity<AnyObject>>(&body);
match activity_decode { match activity_decode {
Ok(activity) => Ok(activity), Ok(activity) => Ok(activity),
@ -215,21 +283,22 @@ impl<E: IntoResponse + 'static> ProxyApp<E> {
activity: &decode, activity: &decode,
}; };
for evaluator in &app.inbox_stack { let ctx = app.app_state().ctx_template.clone();
match evaluator.evaluate(&info).await { match app.evaluate(ctx, &info).await.0 {
Disposition::Next => {} Disposition::Intercept(e) => return Err(Either::B(e)),
Disposition::Intercept(e) => return Err(Either::B(e)), _ => {}
}
} }
app.clone() app.clone()
.app_state()
.clients .clients
.with_client(&addr, |client| { .with_client(&addr, |client| {
Box::pin(async move { Box::pin(async move {
let req = client let req = client
.request( .request(
Method::POST, Method::POST,
app.backend app.app_state()
.backend
.join(&path_and_query.to_string()) .join(&path_and_query.to_string())
.map_err(|_| Either::A(ERR_INTERNAL_SERVER_ERROR))?, .map_err(|_| Either::A(ERR_INTERNAL_SERVER_ERROR))?,
) )

View file

@ -1,14 +1,15 @@
use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use axum::response::IntoResponse;
use clap::Parser; use clap::Parser;
use fedivet::evaluate::ERROR_DENIED; use fedivet::evaluate::chain::audit::AuditOptions;
use fedivet::evaluate::Evaluator;
use fedivet::model::error::MisskeyError; use fedivet::model::error::MisskeyError;
use fedivet::serve; use fedivet::serve;
use fedivet::source::LruData; use fedivet::BaseAppState;
use fedivet::APRequestInfo; use fedivet::HasAppState;
use fedivet::AppState; use serde::Serialize;
use reqwest::Url;
#[derive(Parser)] #[derive(Parser)]
pub struct Args { pub struct Args {
@ -23,53 +24,11 @@ pub struct Args {
} }
#[allow(clippy::unused_async)] #[allow(clippy::unused_async)]
async fn build_state(args: &Args) -> AppState<MisskeyError> { async fn build_state<E: IntoResponse + Clone + Serialize + Send + Sync + 'static>(
let mut state = AppState::new(args.backend.parse().expect("Invalid backend URL")); base: Arc<BaseAppState<E>>,
_args: &Args,
let instance_history = Arc::new(LruData::sized( ) -> impl HasAppState<E> + Evaluator<E> {
&|host| async move { Ok::<_, ()>("Todo") }, base.audited(AuditOptions::new(PathBuf::from("inbox_audit")))
512.try_into().unwrap(),
Some(Duration::from_secs(600)),
));
state.push_evaluator(Box::new(move |info: &APRequestInfo<'_>| {
let act = info.activity.as_ref().map_err(|_| ERROR_DENIED).cloned();
let instance_history = Arc::clone(&instance_history);
async move {
let act = act?;
let host = act
.actor
.as_ref()
.and_then(|s| Url::parse(s).ok())
.ok_or(ERROR_DENIED)?;
let instance = instance_history
.query(host.host_str().unwrap().to_owned())
.await;
match instance {
Ok(i) => {
log::info!("Instance history: {:?}", i);
Ok(())
}
Err(_) => Err::<(), _>(ERROR_DENIED),
}
}
}));
// let user_history = Arc::new(LruData::sized( ... ));
state.push_evaluator(Box::new(|info: &APRequestInfo<'_>| {
let act = info.activity.as_ref().map_err(|_| ERROR_DENIED).cloned();
async move {
let act = act?;
log::debug!("Activity: {:?}", act);
Ok(())
}
}));
state
} }
#[tokio::main] #[tokio::main]
@ -81,7 +40,13 @@ async fn main() {
let args = Args::parse(); let args = Args::parse();
let state = Arc::new(build_state(&args).await); let state = build_state::<MisskeyError>(
Arc::new(BaseAppState::new(
args.backend.parse().expect("Invalid backend URL"),
)),
&args,
)
.await;
let (jh, handle) = serve::serve( let (jh, handle) = serve::serve(
state.clone(), state.clone(),
@ -97,10 +62,16 @@ async fn main() {
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to register SIGTERM handler"); .expect("Failed to register SIGTERM handler");
tokio::select! { tokio::spawn(async move {
_ = gc_ticker.tick() => { loop {
state.client_pool_ref().gc(std::time::Duration::from_secs(120)); gc_ticker.tick().await;
state
.client_pool_ref()
.gc(std::time::Duration::from_secs(120));
} }
});
tokio::select! {
res = jh => { res = jh => {
if let Err(e) = res { if let Err(e) = res {
log::error!("Server error: {}", e); log::error!("Server error: {}", e);

View file

@ -6,6 +6,8 @@ use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
/// Serialize as either A or B.
#[allow(missing_docs)]
pub enum Either<A, B> { pub enum Either<A, B> {
A(A), A(A),
B(B), B(B),
@ -13,14 +15,16 @@ pub enum Either<A, B> {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
/// Represents a single value or multiple values.
#[allow(missing_docs)]
pub enum FlatArray<T> { pub enum FlatArray<T> {
Single(T), Single(T),
Multiple(Vec<T>), Multiple(Vec<T>),
} }
impl IntoIterator for FlatArray<Object> { impl<O> IntoIterator for FlatArray<O> {
type Item = Object; type Item = O;
type IntoIter = std::vec::IntoIter<Object>; type IntoIter = std::vec::IntoIter<O>;
fn into_iter(self) -> Self::IntoIter { fn into_iter(self) -> Self::IntoIter {
match self { match self {
@ -31,6 +35,8 @@ impl IntoIterator for FlatArray<Object> {
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
/// Represents an ActivityStreams object.
#[allow(missing_docs)]
pub struct Object { pub struct Object {
#[serde(rename = "@context")] #[serde(rename = "@context")]
pub context: FlatArray<Either<String, serde_json::Value>>, pub context: FlatArray<Either<String, serde_json::Value>>,
@ -38,14 +44,18 @@ pub struct Object {
#[serde(rename = "type")] #[serde(rename = "type")]
pub ty_: String, pub ty_: String,
pub name: Option<String>, pub name: Option<String>,
pub published: Option<chrono::DateTime<chrono::Utc>>,
pub to: Option<FlatArray<String>>, pub to: Option<FlatArray<String>>,
pub cc: Option<FlatArray<String>>, pub cc: Option<FlatArray<String>>,
pub bcc: Option<FlatArray<String>>, pub bcc: Option<FlatArray<String>>,
pub url: Option<String>,
#[serde(flatten)] #[serde(flatten)]
pub rest: HashMap<String, serde_json::Value>, pub rest: HashMap<String, serde_json::Value>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
/// Represents an ActivityStreams activity.
#[allow(missing_docs)]
pub struct Activity<O> { pub struct Activity<O> {
pub actor: Option<String>, pub actor: Option<String>,
pub object: Option<FlatArray<O>>, pub object: Option<FlatArray<O>>,
@ -54,10 +64,42 @@ pub struct Activity<O> {
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
/// Represents an ActivityStreams note object.
#[allow(missing_docs)]
pub struct NoteObject { pub struct NoteObject {
pub summary: Option<String>, pub summary: Option<String>,
#[serde(rename = "inReplyTo")] #[serde(rename = "inReplyTo")]
pub in_reply_to: Option<String>, pub in_reply_to: Option<String>,
#[serde(rename = "attributedTo")] #[serde(rename = "attributedTo")]
pub attributed_to: Option<String>, pub attributed_to: Option<String>,
pub sensitive: Option<bool>,
pub content: Option<String>,
#[serde(rename = "contentMap")]
pub content_map: Option<HashMap<String, String>>,
pub attachment: Option<FlatArray<Attachment>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
/// Represents an ActivityStreams attachment.
#[allow(missing_docs)]
pub struct Attachment {
#[serde(rename = "type")]
pub ty_: String,
pub media_type: Option<String>,
pub url: String,
pub name: Option<String>,
pub blurhash: Option<String>,
pub width: Option<u32>,
pub height: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
/// Represents any parsable ActivityStreams object.
#[allow(missing_docs)]
pub enum AnyObject {
NoteObject(NoteObject),
Object(Object),
} }

View file

@ -10,14 +10,17 @@ use axum::{
use reqwest::StatusCode; use reqwest::StatusCode;
use serde::Serialize; use serde::Serialize;
/// A trait for responses that can be converted into an axum response.
pub trait APResponse: Debug + Display + Send + IntoResponse {} pub trait APResponse: Debug + Display + Send + IntoResponse {}
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
/// A response that should be ignored.
pub struct Ignore { pub struct Ignore {
reason: Cow<'static, str>, reason: Cow<'static, str>,
} }
impl Ignore { impl Ignore {
/// Create a new Ignore response.
pub fn new(reason: impl Into<Cow<'static, str>>) -> Self { pub fn new(reason: impl Into<Cow<'static, str>>) -> Self {
Self { Self {
reason: reason.into(), reason: reason.into(),
@ -40,6 +43,7 @@ impl IntoResponse for Ignore {
impl APResponse for Ignore {} impl APResponse for Ignore {}
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
/// Errors that look like Misskey's error response.
pub struct MisskeyError { pub struct MisskeyError {
#[serde(skip)] #[serde(skip)]
status: StatusCode, status: StatusCode,
@ -49,6 +53,7 @@ pub struct MisskeyError {
} }
impl MisskeyError { impl MisskeyError {
/// Create a new MisskeyError using const fn.
pub const fn new_const( pub const fn new_const(
status: StatusCode, status: StatusCode,
code: &'static str, code: &'static str,
@ -62,6 +67,7 @@ impl MisskeyError {
message: Cow::Borrowed(message), message: Cow::Borrowed(message),
} }
} }
/// Create a new MisskeyError.
pub fn new( pub fn new(
status: StatusCode, status: StatusCode,
code: &'static str, code: &'static str,
@ -76,6 +82,7 @@ impl MisskeyError {
} }
} }
/// HTTP 413 Payload Too Large
pub fn s_413( pub fn s_413(
code: &'static str, code: &'static str,
id: impl Into<Cow<'static, str>>, id: impl Into<Cow<'static, str>>,
@ -84,6 +91,7 @@ impl MisskeyError {
Self::new(StatusCode::PAYLOAD_TOO_LARGE, code, id, message) Self::new(StatusCode::PAYLOAD_TOO_LARGE, code, id, message)
} }
/// HTTP 500 Internal Server Error
pub fn s_500( pub fn s_500(
code: &'static str, code: &'static str,
id: impl Into<Cow<'static, str>>, id: impl Into<Cow<'static, str>>,
@ -92,6 +100,7 @@ impl MisskeyError {
Self::new(StatusCode::INTERNAL_SERVER_ERROR, code, id, message) Self::new(StatusCode::INTERNAL_SERVER_ERROR, code, id, message)
} }
/// HTTP 400 Bad Request
pub fn s_400( pub fn s_400(
code: &'static str, code: &'static str,
id: impl Into<Cow<'static, str>>, id: impl Into<Cow<'static, str>>,

View file

@ -1,2 +1,4 @@
/// ActivityPub types
pub mod ap; pub mod ap;
/// Error handling
pub mod error; pub mod error;

View file

@ -1,14 +1,17 @@
use std::{net::SocketAddr, sync::Arc}; use std::net::SocketAddr;
use axum::{ use axum::{
response::IntoResponse, response::IntoResponse,
routing::{any, post}, routing::{any, post},
Router, Router,
}; };
use axum_server::{tls_rustls::RustlsConfig, Handle}; use axum_server::Handle;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use crate::{APRequestInfo, AppState, ProxyApp}; use crate::{evaluate::Evaluator, HasAppState, ProxyApp};
/// Flag indicating whether the TLS feature is enabled
pub const HAS_TLS_FEATURE: bool = cfg!(feature = "tls");
#[allow(clippy::panic)] #[allow(clippy::panic)]
#[allow(clippy::unwrap_used)] #[allow(clippy::unwrap_used)]
@ -21,8 +24,11 @@ use crate::{APRequestInfo, AppState, ProxyApp};
/// - `tls_key`: The path to the TLS key file /// - `tls_key`: The path to the TLS key file
/// ///
/// Use [`tokio::select`] to listen on multiple addresses /// Use [`tokio::select`] to listen on multiple addresses
pub async fn serve<E: IntoResponse + 'static>( pub async fn serve<
state: Arc<AppState<E>>, S: HasAppState<E> + Evaluator<E> + Send + Sync + 'static,
E: IntoResponse + Send + Sync + 'static,
>(
state: S,
listen: &str, listen: &str,
tls_cert: Option<&str>, tls_cert: Option<&str>,
tls_key: Option<&str>, tls_key: Option<&str>,
@ -30,11 +36,11 @@ pub async fn serve<E: IntoResponse + 'static>(
let app = Router::new() let app = Router::new()
.route( .route(
"/inbox", "/inbox",
post(ProxyApp::inbox_handler) post(ProxyApp::<S, E>::inbox_handler)
.put(ProxyApp::inbox_handler) .put(ProxyApp::<S, E>::inbox_handler)
.patch(ProxyApp::inbox_handler), .patch(ProxyApp::<S, E>::inbox_handler),
) )
.fallback(any(ProxyApp::pass_through)); .fallback(any(ProxyApp::<S, E>::pass_through));
let ms = app let ms = app
.with_state(state) .with_state(state)
@ -42,7 +48,9 @@ pub async fn serve<E: IntoResponse + 'static>(
let handle = Handle::new(); let handle = Handle::new();
match (tls_cert, tls_key) { match (tls_cert, tls_key) {
#[cfg(feature = "tls")]
(Some(cert), Some(key)) => { (Some(cert), Some(key)) => {
use axum_server::tls_rustls::RustlsConfig;
let tls_config = RustlsConfig::from_pem_file(cert, key) let tls_config = RustlsConfig::from_pem_file(cert, key)
.await .await
.expect("Failed to load TLS certificate and key"); .expect("Failed to load TLS certificate and key");
@ -54,6 +62,10 @@ pub async fn serve<E: IntoResponse + 'static>(
let jh = tokio::spawn(server); let jh = tokio::spawn(server);
(jh, handle) (jh, handle)
} }
#[cfg(not(feature = "tls"))]
(Some(_), Some(_)) => {
panic!("TLS support is not enabled")
}
(None, None) => { (None, None) => {
log::info!("Listening on http://{}", listen); log::info!("Listening on http://{}", listen);
let server = axum_server::bind(listen.parse().expect("invalid listen addr")) let server = axum_server::bind(listen.parse().expect("invalid listen addr"))

View file

@ -10,9 +10,12 @@ use std::{
use lru::LruCache; use lru::LruCache;
/// Friend instance data source
pub mod friend; pub mod friend;
/// Nodeinfo data source
pub mod nodeinfo; pub mod nodeinfo;
/// Lazy future that caches its result
pub struct LazyFuture<T: Clone, E, F: Future<Output = Result<T, E>>> { pub struct LazyFuture<T: Clone, E, F: Future<Output = Result<T, E>>> {
future: Mutex<Pin<Box<F>>>, future: Mutex<Pin<Box<F>>>,
ttl: Option<Duration>, ttl: Option<Duration>,
@ -20,6 +23,7 @@ pub struct LazyFuture<T: Clone, E, F: Future<Output = Result<T, E>>> {
} }
impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> LazyFuture<T, E, F> { impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> LazyFuture<T, E, F> {
/// Create a new lazy future
pub fn new(future: F, ttl: Option<Duration>) -> Self { pub fn new(future: F, ttl: Option<Duration>) -> Self {
Self { Self {
future: Mutex::new(Box::pin(future)), future: Mutex::new(Box::pin(future)),
@ -73,6 +77,8 @@ struct LazyFutureHandle<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> {
} }
impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> LazyFutureHandle<T, E, F> { impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> LazyFutureHandle<T, E, F> {
/// Get a reference to the inner future
#[allow(dead_code)]
pub fn inner_ref(&self) -> &LazyFuture<T, E, F> { pub fn inner_ref(&self) -> &LazyFuture<T, E, F> {
&self.inner &self.inner
} }
@ -94,6 +100,7 @@ impl<T: Clone, E: Clone, F: Future<Output = Result<T, E>>> Future for LazyFuture
} }
} }
/// LRU cache for futures
pub struct LruData< pub struct LruData<
'a, 'a,
K: Hash, K: Hash,
@ -117,6 +124,7 @@ impl<
FF: Fn(&K) -> F, FF: Fn(&K) -> F,
> LruData<'a, K, T, E, F, FF> > LruData<'a, K, T, E, F, FF>
{ {
/// Create a new LRU cache with a given size and factory
pub fn sized(factory: &'a FF, size: NonZeroUsize, ttl: Option<Duration>) -> Self { pub fn sized(factory: &'a FF, size: NonZeroUsize, ttl: Option<Duration>) -> Self {
Self { Self {
factory, factory,
@ -126,6 +134,7 @@ impl<
} }
} }
/// Evict all expired entries from the cache
pub fn evict_expired(&self) { pub fn evict_expired(&self) {
let mut inner = self.inner.lock().unwrap(); let mut inner = self.inner.lock().unwrap();
@ -152,6 +161,7 @@ impl<
} }
} }
/// Get or insert a future into the cache
pub fn get_or_insert(&self, key: K) -> Arc<LazyFuture<T, E, F>> { pub fn get_or_insert(&self, key: K) -> Arc<LazyFuture<T, E, F>> {
let factory = self.factory; let factory = self.factory;
let entry = self.inner.lock().unwrap().get(&key).cloned(); let entry = self.inner.lock().unwrap().get(&key).cloned();
@ -166,6 +176,7 @@ impl<
} }
} }
/// Query the cache for a given key or insert it if it doesn't exist
pub fn query(&self, key: K) -> impl Future<Output = Result<T, E>> { pub fn query(&self, key: K) -> impl Future<Output = Result<T, E>> {
self.maybe_evict(); self.maybe_evict();
let future = self.get_or_insert(key); let future = self.get_or_insert(key);