fix reverse proxying
Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
parent
b6dea3fd06
commit
92abe7b368
7 changed files with 158 additions and 119 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1 +1,2 @@
|
||||||
/target
|
/target
|
||||||
|
/inbox_audit
|
||||||
|
|
|
@ -11,22 +11,22 @@ use crate::network::new_safe_client;
|
||||||
#[derive(Default, Clone)]
|
#[derive(Default, Clone)]
|
||||||
|
|
||||||
pub struct ClientCache {
|
pub struct ClientCache {
|
||||||
clients: DashMap<IpAddr, (Instant, reqwest::Client)>,
|
safe_clients: DashMap<IpAddr, (Instant, reqwest::Client)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClientCache {
|
impl ClientCache {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
clients: DashMap::new(),
|
safe_clients: DashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn with_client<'a, F, R>(&self, addr: &SocketAddr, f: F) -> R
|
pub async fn with_safe_client<'a, F, R>(&self, addr: &SocketAddr, f: F) -> R
|
||||||
where
|
where
|
||||||
F: FnOnce(reqwest::Client) -> Pin<Box<dyn std::future::Future<Output = R> + Send + 'a>>
|
F: FnOnce(reqwest::Client) -> Pin<Box<dyn std::future::Future<Output = R> + Send + 'a>>
|
||||||
+ 'a,
|
+ 'a,
|
||||||
{
|
{
|
||||||
let client = self.clients.entry(addr.ip()).or_insert_with(|| {
|
let client = self.safe_clients.entry(addr.ip()).or_insert_with(|| {
|
||||||
(
|
(
|
||||||
Instant::now(),
|
Instant::now(),
|
||||||
#[allow(clippy::expect_used)]
|
#[allow(clippy::expect_used)]
|
||||||
|
@ -39,7 +39,7 @@ impl ClientCache {
|
||||||
|
|
||||||
pub fn gc(&self, dur: std::time::Duration) {
|
pub fn gc(&self, dur: std::time::Duration) {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
self.clients
|
self.safe_clients
|
||||||
.retain(|_, (created, _)| now.duration_since(*created) < dur);
|
.retain(|_, (created, _)| now.duration_since(*created) < dur);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use axum::response::IntoResponse;
|
use axum::response::IntoResponse;
|
||||||
use flate2::write::GzEncoder;
|
use flate2::write::GzEncoder;
|
||||||
use flate2::{Compression, GzBuilder};
|
use flate2::{Compression, GzBuilder};
|
||||||
|
use reqwest::Url;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::fs::{File, OpenOptions};
|
use std::fs::{File, OpenOptions};
|
||||||
|
@ -60,7 +61,7 @@ pub enum AuditError {
|
||||||
|
|
||||||
pub struct AuditState {
|
pub struct AuditState {
|
||||||
options: AuditOptions,
|
options: AuditOptions,
|
||||||
cur_file: Option<RwLock<HashMap<String, Mutex<GzEncoder<File>>>>>,
|
cur_file: RwLock<HashMap<String, Mutex<GzEncoder<File>>>>,
|
||||||
vacuum_counter: AtomicU32,
|
vacuum_counter: AtomicU32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,13 +72,13 @@ impl AuditState {
|
||||||
}
|
}
|
||||||
Self {
|
Self {
|
||||||
options,
|
options,
|
||||||
cur_file: None,
|
cur_file: RwLock::new(HashMap::new()),
|
||||||
vacuum_counter: AtomicU32::new(0),
|
vacuum_counter: AtomicU32::new(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn vacuum(&self) -> Result<(), AuditError> {
|
pub async fn vacuum(&self) -> Result<(), AuditError> {
|
||||||
let read = self.cur_file.as_ref().unwrap().read().await;
|
let read = self.cur_file.read().await;
|
||||||
let in_use_files = read.keys().cloned().collect::<HashSet<_>>();
|
let in_use_files = read.keys().cloned().collect::<HashSet<_>>();
|
||||||
|
|
||||||
let files = std::fs::read_dir(&self.options.output)?
|
let files = std::fs::read_dir(&self.options.output)?
|
||||||
|
@ -109,9 +110,11 @@ impl AuditState {
|
||||||
pub async fn create_new_file(&self, name: &str) -> Result<(), AuditError> {
|
pub async fn create_new_file(&self, name: &str) -> Result<(), AuditError> {
|
||||||
let time_str = chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string();
|
let time_str = chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string();
|
||||||
|
|
||||||
let mut write = self.cur_file.as_ref().unwrap().write().await;
|
let mut write = self.cur_file.write().await;
|
||||||
|
|
||||||
let full_name = format!("{}_{}.json.gz", name, time_str);
|
let sanitized_name = name.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
|
||||||
|
|
||||||
|
let full_name = format!("{}_{}.json.gz", sanitized_name, time_str);
|
||||||
|
|
||||||
let file = OpenOptions::new()
|
let file = OpenOptions::new()
|
||||||
.create(true)
|
.create(true)
|
||||||
|
@ -141,7 +144,7 @@ impl AuditState {
|
||||||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
let read = self.cur_file.as_ref().unwrap().read().await;
|
let read = self.cur_file.read().await;
|
||||||
|
|
||||||
if let Some(file) = read.get(name) {
|
if let Some(file) = read.get(name) {
|
||||||
let mut f = file.lock().await;
|
let mut f = file.lock().await;
|
||||||
|
@ -162,18 +165,16 @@ impl AuditState {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut write = self.cur_file.as_ref().unwrap().write().await;
|
drop(read);
|
||||||
|
{
|
||||||
|
self.create_new_file(name).await?;
|
||||||
|
|
||||||
let file = File::create(&self.options.output.join(name))?;
|
let read = self.cur_file.read().await;
|
||||||
|
|
||||||
let file = GzBuilder::new()
|
let file = read.get(name).unwrap();
|
||||||
.filename(name)
|
|
||||||
.write(file, Compression::default());
|
|
||||||
|
|
||||||
write.insert(name.to_string(), Mutex::new(file));
|
serde_json::to_writer(file.lock().await.deref_mut(), &item)?;
|
||||||
|
}
|
||||||
// this is deliberately out of order to make sure we don't create endless files if serialization fails
|
|
||||||
serde_json::to_writer(write.get(name).unwrap().lock().await.deref_mut(), &item)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -230,6 +231,7 @@ impl<
|
||||||
) -> (Disposition<E>, Option<serde_json::Value>) {
|
) -> (Disposition<E>, Option<serde_json::Value>) {
|
||||||
let (disp, ctx) = self.inner.evaluate(ctx, info).await;
|
let (disp, ctx) = self.inner.evaluate(ctx, info).await;
|
||||||
|
|
||||||
|
log::debug!("Audit: ctx = {:?}", ctx);
|
||||||
if ctx
|
if ctx
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|c| c.get("skip_audit").is_some())
|
.map(|c| c.get("skip_audit").is_some())
|
||||||
|
@ -246,7 +248,19 @@ impl<
|
||||||
|
|
||||||
let state = self.state.clone();
|
let state = self.state.clone();
|
||||||
|
|
||||||
state.write(Self::name(), &item).await.ok();
|
let key = info
|
||||||
|
.activity
|
||||||
|
.as_ref()
|
||||||
|
.map(|a| {
|
||||||
|
a.actor
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|s| Url::parse(s).ok()?.host_str().map(|s| s.to_string()))
|
||||||
|
})
|
||||||
|
.ok()
|
||||||
|
.flatten()
|
||||||
|
.unwrap_or_else(|| format!("tcp_{}", info.connect.ip()));
|
||||||
|
|
||||||
|
state.write(&key, &item).await.ok();
|
||||||
|
|
||||||
(disp, ctx)
|
(disp, ctx)
|
||||||
}
|
}
|
||||||
|
|
64
src/lib.rs
64
src/lib.rs
|
@ -26,8 +26,8 @@ use futures::TryStreamExt;
|
||||||
use model::ap::AnyObject;
|
use model::ap::AnyObject;
|
||||||
use model::{ap, error::MisskeyError};
|
use model::{ap, error::MisskeyError};
|
||||||
use network::stream::LimitedStream;
|
use network::stream::LimitedStream;
|
||||||
use network::Either;
|
use network::{new_backend_client, Either};
|
||||||
use reqwest::{self};
|
use reqwest;
|
||||||
use serde::ser::{SerializeMap, SerializeStruct};
|
use serde::ser::{SerializeMap, SerializeStruct};
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
|
@ -118,6 +118,7 @@ pub trait HasAppState<E: IntoResponse + 'static>: Clone {
|
||||||
/// Application state.
|
/// Application state.
|
||||||
pub struct BaseAppState<E: IntoResponse + 'static> {
|
pub struct BaseAppState<E: IntoResponse + 'static> {
|
||||||
backend: reqwest::Url,
|
backend: reqwest::Url,
|
||||||
|
backend_client: reqwest::Client,
|
||||||
clients: ClientCache,
|
clients: ClientCache,
|
||||||
ctx_template: Option<serde_json::Value>,
|
ctx_template: Option<serde_json::Value>,
|
||||||
_marker: PhantomData<E>,
|
_marker: PhantomData<E>,
|
||||||
|
@ -143,7 +144,6 @@ impl<E: IntoResponse + Send + Sync> Evaluator<E> for Arc<BaseAppState<E>> {
|
||||||
ctx: Option<serde_json::Value>,
|
ctx: Option<serde_json::Value>,
|
||||||
_info: &APRequestInfo<'r>,
|
_info: &APRequestInfo<'r>,
|
||||||
) -> (Disposition<E>, Option<serde_json::Value>) {
|
) -> (Disposition<E>, Option<serde_json::Value>) {
|
||||||
log::trace!("Evaluator fell through, accepting request");
|
|
||||||
(Disposition::Allow, ctx)
|
(Disposition::Allow, ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -155,6 +155,7 @@ impl<E: IntoResponse> BaseAppState<E> {
|
||||||
backend,
|
backend,
|
||||||
ctx_template: None,
|
ctx_template: None,
|
||||||
clients: ClientCache::new(),
|
clients: ClientCache::new(),
|
||||||
|
backend_client: new_backend_client().expect("Failed to create backend client"),
|
||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -184,17 +185,21 @@ impl<
|
||||||
State(app): State<S>,
|
State(app): State<S>,
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
OriginalUri(uri): OriginalUri,
|
OriginalUri(uri): OriginalUri,
|
||||||
header: HeaderMap,
|
mut header: HeaderMap,
|
||||||
body: Body,
|
body: Body,
|
||||||
) -> Result<impl IntoResponse, MisskeyError> {
|
) -> Result<impl IntoResponse, MisskeyError> {
|
||||||
|
log::debug!(
|
||||||
|
"Received request from {} to {:?}",
|
||||||
|
addr,
|
||||||
|
uri.path_and_query()
|
||||||
|
);
|
||||||
let path_and_query = uri.path_and_query().ok_or(ERR_BAD_REQUEST)?;
|
let path_and_query = uri.path_and_query().ok_or(ERR_BAD_REQUEST)?;
|
||||||
let state = app.app_state();
|
let state = app.app_state();
|
||||||
|
|
||||||
app.app_state()
|
header.remove("host");
|
||||||
.clients
|
|
||||||
.with_client(&addr, |client| {
|
let req = state
|
||||||
Box::pin(async move {
|
.backend_client
|
||||||
let req = client
|
|
||||||
.request(
|
.request(
|
||||||
method.clone(),
|
method.clone(),
|
||||||
state
|
state
|
||||||
|
@ -211,7 +216,11 @@ impl<
|
||||||
|
|
||||||
let url_clone = req.url().clone();
|
let url_clone = req.url().clone();
|
||||||
|
|
||||||
let resp = client.execute(req).await.map_err(|e: reqwest::Error| {
|
let resp = state
|
||||||
|
.backend_client
|
||||||
|
.execute(req)
|
||||||
|
.await
|
||||||
|
.map_err(|e: reqwest::Error| {
|
||||||
log::error!(
|
log::error!(
|
||||||
"Failed to execute request: {} ({} {})",
|
"Failed to execute request: {} ({} {})",
|
||||||
e,
|
e,
|
||||||
|
@ -237,27 +246,31 @@ impl<
|
||||||
log::error!("Failed to build response: {}", e);
|
log::error!("Failed to build response: {}", e);
|
||||||
ERR_SERVICE_TEMPORARILY_UNAVAILABLE
|
ERR_SERVICE_TEMPORARILY_UNAVAILABLE
|
||||||
})
|
})
|
||||||
})
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle incoming ActivityPub requests.
|
/// Handle incoming ActivityPub requests.
|
||||||
pub async fn inbox_handler(
|
pub async fn inbox_handler(
|
||||||
|
method: Method,
|
||||||
State(app): State<S>,
|
State(app): State<S>,
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
OriginalUri(uri): OriginalUri,
|
OriginalUri(uri): OriginalUri,
|
||||||
header: HeaderMap,
|
mut header: HeaderMap,
|
||||||
body: Body,
|
body: Body,
|
||||||
) -> Result<impl IntoResponse, Either<MisskeyError, E>> {
|
) -> Result<impl IntoResponse, Either<MisskeyError, E>> {
|
||||||
|
log::debug!(
|
||||||
|
"Received request from {} to {:?}",
|
||||||
|
addr,
|
||||||
|
uri.path_and_query()
|
||||||
|
);
|
||||||
let path_and_query = uri.path_and_query().ok_or(Either::A(ERR_BAD_REQUEST))?;
|
let path_and_query = uri.path_and_query().ok_or(Either::A(ERR_BAD_REQUEST))?;
|
||||||
|
|
||||||
let restricted_body_stream = LimitedStream::new(body.into_data_stream(), 32 << 20);
|
let restricted_body_stream = LimitedStream::new(body.into_data_stream(), 1 << 20);
|
||||||
|
|
||||||
let mut body = Vec::new();
|
let mut body = Vec::new();
|
||||||
|
|
||||||
restricted_body_stream
|
restricted_body_stream
|
||||||
.try_for_each(|chunk| {
|
.try_for_each(|chunk| {
|
||||||
|
log::debug!("Received chunk of size {}", chunk.len());
|
||||||
body.extend_from_slice(&chunk);
|
body.extend_from_slice(&chunk);
|
||||||
futures::future::ready(Ok(()))
|
futures::future::ready(Ok(()))
|
||||||
})
|
})
|
||||||
|
@ -299,14 +312,13 @@ impl<
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
app.clone()
|
header.remove("host");
|
||||||
|
|
||||||
|
let req = app
|
||||||
.app_state()
|
.app_state()
|
||||||
.clients
|
.backend_client
|
||||||
.with_client(&addr, |client| {
|
|
||||||
Box::pin(async move {
|
|
||||||
let req = client
|
|
||||||
.request(
|
.request(
|
||||||
Method::POST,
|
method.clone(),
|
||||||
app.app_state()
|
app.app_state()
|
||||||
.backend
|
.backend
|
||||||
.join(&path_and_query.to_string())
|
.join(&path_and_query.to_string())
|
||||||
|
@ -319,7 +331,12 @@ impl<
|
||||||
|
|
||||||
let url_clone = req.url().clone();
|
let url_clone = req.url().clone();
|
||||||
|
|
||||||
let resp = client.execute(req).await.map_err(|e: reqwest::Error| {
|
let resp =
|
||||||
|
app.app_state()
|
||||||
|
.backend_client
|
||||||
|
.execute(req)
|
||||||
|
.await
|
||||||
|
.map_err(|e: reqwest::Error| {
|
||||||
log::error!(
|
log::error!(
|
||||||
"Failed to execute request: {} ({} {})",
|
"Failed to execute request: {} ({} {})",
|
||||||
e,
|
e,
|
||||||
|
@ -345,8 +362,5 @@ impl<
|
||||||
log::error!("Failed to build response: {}", e);
|
log::error!("Failed to build response: {}", e);
|
||||||
Either::A(ERR_SERVICE_TEMPORARILY_UNAVAILABLE)
|
Either::A(ERR_SERVICE_TEMPORARILY_UNAVAILABLE)
|
||||||
})?)
|
})?)
|
||||||
})
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ use serde::Serialize;
|
||||||
pub struct Args {
|
pub struct Args {
|
||||||
#[clap(short, long, default_value = "127.0.0.1:3001")]
|
#[clap(short, long, default_value = "127.0.0.1:3001")]
|
||||||
pub listen: String,
|
pub listen: String,
|
||||||
#[clap(short, long, default_value = "http://web:3000")]
|
#[clap(short, long, default_value = "https://echo.free.beeceptor.com")]
|
||||||
pub backend: String,
|
pub backend: String,
|
||||||
#[clap(long)]
|
#[clap(long)]
|
||||||
pub tls_cert: Option<String>,
|
pub tls_cert: Option<String>,
|
||||||
|
@ -28,8 +28,8 @@ async fn build_state<E: IntoResponse + Clone + Serialize + Send + Sync + 'static
|
||||||
base: Arc<BaseAppState<E>>,
|
base: Arc<BaseAppState<E>>,
|
||||||
_args: &Args,
|
_args: &Args,
|
||||||
) -> impl HasAppState<E> + Evaluator<E> {
|
) -> impl HasAppState<E> + Evaluator<E> {
|
||||||
base.audited(AuditOptions::new(PathBuf::from("inbox_audit")))
|
base.extract_meta()
|
||||||
.extract_meta()
|
.audited(AuditOptions::new(PathBuf::from("inbox_audit")))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
|
|
@ -61,3 +61,12 @@ pub(crate) fn new_safe_client(addr: &SocketAddr) -> reqwest::Result<reqwest::Cli
|
||||||
.tcp_keepalive(Some(std::time::Duration::from_secs(20)))
|
.tcp_keepalive(Some(std::time::Duration::from_secs(20)))
|
||||||
.build()
|
.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn new_backend_client() -> reqwest::Result<reqwest::Client> {
|
||||||
|
reqwest::Client::builder()
|
||||||
|
.connect_timeout(std::time::Duration::from_secs(10))
|
||||||
|
.timeout(std::time::Duration::from_secs(20))
|
||||||
|
.redirect(Policy::limited(4))
|
||||||
|
.tcp_keepalive(Some(std::time::Duration::from_secs(20)))
|
||||||
|
.build()
|
||||||
|
}
|
||||||
|
|
|
@ -52,6 +52,7 @@ where
|
||||||
let p = self.stream.poll_next_unpin(cx);
|
let p = self.stream.poll_next_unpin(cx);
|
||||||
match p {
|
match p {
|
||||||
Poll::Ready(Some(Ok(mut data))) => {
|
Poll::Ready(Some(Ok(mut data))) => {
|
||||||
|
log::trace!("Received {} bytes", data.len());
|
||||||
if data.len() > remaining_len {
|
if data.len() > remaining_len {
|
||||||
self.limit.store(0, Ordering::Relaxed);
|
self.limit.store(0, Ordering::Relaxed);
|
||||||
Poll::Ready(Some(Ok(data.split_to(remaining_len))))
|
Poll::Ready(Some(Ok(data.split_to(remaining_len))))
|
||||||
|
|
Loading…
Reference in a new issue