fix reverse proxying

Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-10-16 20:18:00 -05:00
parent b6dea3fd06
commit 92abe7b368
No known key found for this signature in database
7 changed files with 158 additions and 119 deletions

1
.gitignore vendored
View file

@ -1 +1,2 @@
/target /target
/inbox_audit

View file

@ -11,22 +11,22 @@ use crate::network::new_safe_client;
#[derive(Default, Clone)] #[derive(Default, Clone)]
pub struct ClientCache { pub struct ClientCache {
clients: DashMap<IpAddr, (Instant, reqwest::Client)>, safe_clients: DashMap<IpAddr, (Instant, reqwest::Client)>,
} }
impl ClientCache { impl ClientCache {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
clients: DashMap::new(), safe_clients: DashMap::new(),
} }
} }
pub async fn with_client<'a, F, R>(&self, addr: &SocketAddr, f: F) -> R pub async fn with_safe_client<'a, F, R>(&self, addr: &SocketAddr, f: F) -> R
where where
F: FnOnce(reqwest::Client) -> Pin<Box<dyn std::future::Future<Output = R> + Send + 'a>> F: FnOnce(reqwest::Client) -> Pin<Box<dyn std::future::Future<Output = R> + Send + 'a>>
+ 'a, + 'a,
{ {
let client = self.clients.entry(addr.ip()).or_insert_with(|| { let client = self.safe_clients.entry(addr.ip()).or_insert_with(|| {
( (
Instant::now(), Instant::now(),
#[allow(clippy::expect_used)] #[allow(clippy::expect_used)]
@ -39,7 +39,7 @@ impl ClientCache {
pub fn gc(&self, dur: std::time::Duration) { pub fn gc(&self, dur: std::time::Duration) {
let now = Instant::now(); let now = Instant::now();
self.clients self.safe_clients
.retain(|_, (created, _)| now.duration_since(*created) < dur); .retain(|_, (created, _)| now.duration_since(*created) < dur);
} }
} }

View file

@ -1,6 +1,7 @@
use axum::response::IntoResponse; use axum::response::IntoResponse;
use flate2::write::GzEncoder; use flate2::write::GzEncoder;
use flate2::{Compression, GzBuilder}; use flate2::{Compression, GzBuilder};
use reqwest::Url;
use serde::Serialize; use serde::Serialize;
use std::collections::HashSet; use std::collections::HashSet;
use std::fs::{File, OpenOptions}; use std::fs::{File, OpenOptions};
@ -60,7 +61,7 @@ pub enum AuditError {
pub struct AuditState { pub struct AuditState {
options: AuditOptions, options: AuditOptions,
cur_file: Option<RwLock<HashMap<String, Mutex<GzEncoder<File>>>>>, cur_file: RwLock<HashMap<String, Mutex<GzEncoder<File>>>>,
vacuum_counter: AtomicU32, vacuum_counter: AtomicU32,
} }
@ -71,13 +72,13 @@ impl AuditState {
} }
Self { Self {
options, options,
cur_file: None, cur_file: RwLock::new(HashMap::new()),
vacuum_counter: AtomicU32::new(0), vacuum_counter: AtomicU32::new(0),
} }
} }
pub async fn vacuum(&self) -> Result<(), AuditError> { pub async fn vacuum(&self) -> Result<(), AuditError> {
let read = self.cur_file.as_ref().unwrap().read().await; let read = self.cur_file.read().await;
let in_use_files = read.keys().cloned().collect::<HashSet<_>>(); let in_use_files = read.keys().cloned().collect::<HashSet<_>>();
let files = std::fs::read_dir(&self.options.output)? let files = std::fs::read_dir(&self.options.output)?
@ -109,9 +110,11 @@ impl AuditState {
pub async fn create_new_file(&self, name: &str) -> Result<(), AuditError> { pub async fn create_new_file(&self, name: &str) -> Result<(), AuditError> {
let time_str = chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string(); let time_str = chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string();
let mut write = self.cur_file.as_ref().unwrap().write().await; let mut write = self.cur_file.write().await;
let full_name = format!("{}_{}.json.gz", name, time_str); let sanitized_name = name.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
let full_name = format!("{}_{}.json.gz", sanitized_name, time_str);
let file = OpenOptions::new() let file = OpenOptions::new()
.create(true) .create(true)
@ -141,7 +144,7 @@ impl AuditState {
.store(0, std::sync::atomic::Ordering::Relaxed); .store(0, std::sync::atomic::Ordering::Relaxed);
} }
let read = self.cur_file.as_ref().unwrap().read().await; let read = self.cur_file.read().await;
if let Some(file) = read.get(name) { if let Some(file) = read.get(name) {
let mut f = file.lock().await; let mut f = file.lock().await;
@ -162,18 +165,16 @@ impl AuditState {
return Ok(()); return Ok(());
} }
let mut write = self.cur_file.as_ref().unwrap().write().await; drop(read);
{
self.create_new_file(name).await?;
let file = File::create(&self.options.output.join(name))?; let read = self.cur_file.read().await;
let file = GzBuilder::new() let file = read.get(name).unwrap();
.filename(name)
.write(file, Compression::default());
write.insert(name.to_string(), Mutex::new(file)); serde_json::to_writer(file.lock().await.deref_mut(), &item)?;
}
// this is deliberately out of order to make sure we don't create endless files if serialization fails
serde_json::to_writer(write.get(name).unwrap().lock().await.deref_mut(), &item)?;
Ok(()) Ok(())
} }
@ -230,6 +231,7 @@ impl<
) -> (Disposition<E>, Option<serde_json::Value>) { ) -> (Disposition<E>, Option<serde_json::Value>) {
let (disp, ctx) = self.inner.evaluate(ctx, info).await; let (disp, ctx) = self.inner.evaluate(ctx, info).await;
log::debug!("Audit: ctx = {:?}", ctx);
if ctx if ctx
.as_ref() .as_ref()
.map(|c| c.get("skip_audit").is_some()) .map(|c| c.get("skip_audit").is_some())
@ -246,7 +248,19 @@ impl<
let state = self.state.clone(); let state = self.state.clone();
state.write(Self::name(), &item).await.ok(); let key = info
.activity
.as_ref()
.map(|a| {
a.actor
.as_ref()
.and_then(|s| Url::parse(s).ok()?.host_str().map(|s| s.to_string()))
})
.ok()
.flatten()
.unwrap_or_else(|| format!("tcp_{}", info.connect.ip()));
state.write(&key, &item).await.ok();
(disp, ctx) (disp, ctx)
} }

View file

@ -26,8 +26,8 @@ use futures::TryStreamExt;
use model::ap::AnyObject; use model::ap::AnyObject;
use model::{ap, error::MisskeyError}; use model::{ap, error::MisskeyError};
use network::stream::LimitedStream; use network::stream::LimitedStream;
use network::Either; use network::{new_backend_client, Either};
use reqwest::{self}; use reqwest;
use serde::ser::{SerializeMap, SerializeStruct}; use serde::ser::{SerializeMap, SerializeStruct};
use serde::Serialize; use serde::Serialize;
@ -118,6 +118,7 @@ pub trait HasAppState<E: IntoResponse + 'static>: Clone {
/// Application state. /// Application state.
pub struct BaseAppState<E: IntoResponse + 'static> { pub struct BaseAppState<E: IntoResponse + 'static> {
backend: reqwest::Url, backend: reqwest::Url,
backend_client: reqwest::Client,
clients: ClientCache, clients: ClientCache,
ctx_template: Option<serde_json::Value>, ctx_template: Option<serde_json::Value>,
_marker: PhantomData<E>, _marker: PhantomData<E>,
@ -143,7 +144,6 @@ impl<E: IntoResponse + Send + Sync> Evaluator<E> for Arc<BaseAppState<E>> {
ctx: Option<serde_json::Value>, ctx: Option<serde_json::Value>,
_info: &APRequestInfo<'r>, _info: &APRequestInfo<'r>,
) -> (Disposition<E>, Option<serde_json::Value>) { ) -> (Disposition<E>, Option<serde_json::Value>) {
log::trace!("Evaluator fell through, accepting request");
(Disposition::Allow, ctx) (Disposition::Allow, ctx)
} }
} }
@ -155,6 +155,7 @@ impl<E: IntoResponse> BaseAppState<E> {
backend, backend,
ctx_template: None, ctx_template: None,
clients: ClientCache::new(), clients: ClientCache::new(),
backend_client: new_backend_client().expect("Failed to create backend client"),
_marker: PhantomData, _marker: PhantomData,
} }
} }
@ -184,80 +185,92 @@ impl<
State(app): State<S>, State(app): State<S>,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
OriginalUri(uri): OriginalUri, OriginalUri(uri): OriginalUri,
header: HeaderMap, mut header: HeaderMap,
body: Body, body: Body,
) -> Result<impl IntoResponse, MisskeyError> { ) -> Result<impl IntoResponse, MisskeyError> {
log::debug!(
"Received request from {} to {:?}",
addr,
uri.path_and_query()
);
let path_and_query = uri.path_and_query().ok_or(ERR_BAD_REQUEST)?; let path_and_query = uri.path_and_query().ok_or(ERR_BAD_REQUEST)?;
let state = app.app_state(); let state = app.app_state();
app.app_state() header.remove("host");
.clients
.with_client(&addr, |client| {
Box::pin(async move {
let req = client
.request(
method.clone(),
state
.backend
.join(&path_and_query.to_string())
.map_err(|_| ERR_INTERNAL_SERVER_ERROR)?,
)
.headers(header)
.body(reqwest::Body::wrap_stream(Box::pin(
body.into_data_stream(),
)))
.build()
.map_err(|_| ERR_INTERNAL_SERVER_ERROR)?;
let url_clone = req.url().clone(); let req = state
.backend_client
.request(
method.clone(),
state
.backend
.join(&path_and_query.to_string())
.map_err(|_| ERR_INTERNAL_SERVER_ERROR)?,
)
.headers(header)
.body(reqwest::Body::wrap_stream(Box::pin(
body.into_data_stream(),
)))
.build()
.map_err(|_| ERR_INTERNAL_SERVER_ERROR)?;
let resp = client.execute(req).await.map_err(|e: reqwest::Error| { let url_clone = req.url().clone();
log::error!(
"Failed to execute request: {} ({} {})",
e,
method,
url_clone
);
ERR_SERVICE_TEMPORARILY_UNAVAILABLE
})?;
let mut resp_builder = Response::builder().status(resp.status()); let resp = state
.backend_client
resp_builder .execute(req)
.headers_mut()
.ok_or(ERR_INTERNAL_SERVER_ERROR)?
.extend(resp.headers().clone());
resp_builder
.header("X-Forwarded-For", addr.to_string())
.body(Body::from_stream(resp.bytes_stream().inspect_err(|e| {
log::error!("Failed to read response: {}", e);
})))
.map_err(|e| {
log::error!("Failed to build response: {}", e);
ERR_SERVICE_TEMPORARILY_UNAVAILABLE
})
})
})
.await .await
.map_err(|e: reqwest::Error| {
log::error!(
"Failed to execute request: {} ({} {})",
e,
method,
url_clone
);
ERR_SERVICE_TEMPORARILY_UNAVAILABLE
})?;
let mut resp_builder = Response::builder().status(resp.status());
resp_builder
.headers_mut()
.ok_or(ERR_INTERNAL_SERVER_ERROR)?
.extend(resp.headers().clone());
resp_builder
.header("X-Forwarded-For", addr.to_string())
.body(Body::from_stream(resp.bytes_stream().inspect_err(|e| {
log::error!("Failed to read response: {}", e);
})))
.map_err(|e| {
log::error!("Failed to build response: {}", e);
ERR_SERVICE_TEMPORARILY_UNAVAILABLE
})
} }
/// Handle incoming ActivityPub requests. /// Handle incoming ActivityPub requests.
pub async fn inbox_handler( pub async fn inbox_handler(
method: Method,
State(app): State<S>, State(app): State<S>,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
OriginalUri(uri): OriginalUri, OriginalUri(uri): OriginalUri,
header: HeaderMap, mut header: HeaderMap,
body: Body, body: Body,
) -> Result<impl IntoResponse, Either<MisskeyError, E>> { ) -> Result<impl IntoResponse, Either<MisskeyError, E>> {
log::debug!(
"Received request from {} to {:?}",
addr,
uri.path_and_query()
);
let path_and_query = uri.path_and_query().ok_or(Either::A(ERR_BAD_REQUEST))?; let path_and_query = uri.path_and_query().ok_or(Either::A(ERR_BAD_REQUEST))?;
let restricted_body_stream = LimitedStream::new(body.into_data_stream(), 32 << 20); let restricted_body_stream = LimitedStream::new(body.into_data_stream(), 1 << 20);
let mut body = Vec::new(); let mut body = Vec::new();
restricted_body_stream restricted_body_stream
.try_for_each(|chunk| { .try_for_each(|chunk| {
log::debug!("Received chunk of size {}", chunk.len());
body.extend_from_slice(&chunk); body.extend_from_slice(&chunk);
futures::future::ready(Ok(())) futures::future::ready(Ok(()))
}) })
@ -299,54 +312,55 @@ impl<
_ => {} _ => {}
} }
app.clone() header.remove("host");
let req = app
.app_state() .app_state()
.clients .backend_client
.with_client(&addr, |client| { .request(
Box::pin(async move { method.clone(),
let req = client app.app_state()
.request( .backend
Method::POST, .join(&path_and_query.to_string())
app.app_state() .map_err(|_| Either::A(ERR_INTERNAL_SERVER_ERROR))?,
.backend )
.join(&path_and_query.to_string()) .headers(header)
.map_err(|_| Either::A(ERR_INTERNAL_SERVER_ERROR))?, .body(body)
) .build()
.headers(header) .map_err(|_| Either::A(ERR_INTERNAL_SERVER_ERROR))?;
.body(body)
.build()
.map_err(|_| Either::A(ERR_INTERNAL_SERVER_ERROR))?;
let url_clone = req.url().clone(); let url_clone = req.url().clone();
let resp = client.execute(req).await.map_err(|e: reqwest::Error| { let resp =
log::error!( app.app_state()
"Failed to execute request: {} ({} {})", .backend_client
e, .execute(req)
Method::POST, .await
url_clone .map_err(|e: reqwest::Error| {
); log::error!(
Either::A(ERR_SERVICE_TEMPORARILY_UNAVAILABLE) "Failed to execute request: {} ({} {})",
})?; e,
Method::POST,
url_clone
);
Either::A(ERR_SERVICE_TEMPORARILY_UNAVAILABLE)
})?;
let mut resp_builder = Response::builder().status(resp.status()); let mut resp_builder = Response::builder().status(resp.status());
resp_builder resp_builder
.headers_mut() .headers_mut()
.ok_or(Either::A(ERR_INTERNAL_SERVER_ERROR))? .ok_or(Either::A(ERR_INTERNAL_SERVER_ERROR))?
.extend(resp.headers().clone()); .extend(resp.headers().clone());
Ok(resp_builder Ok(resp_builder
.header("X-Forwarded-For", addr.to_string()) .header("X-Forwarded-For", addr.to_string())
.body(Body::from_stream(resp.bytes_stream().inspect_err(|e| { .body(Body::from_stream(resp.bytes_stream().inspect_err(|e| {
log::error!("Failed to read response: {}", e); log::error!("Failed to read response: {}", e);
}))) })))
.map_err(|e| { .map_err(|e| {
log::error!("Failed to build response: {}", e); log::error!("Failed to build response: {}", e);
Either::A(ERR_SERVICE_TEMPORARILY_UNAVAILABLE) Either::A(ERR_SERVICE_TEMPORARILY_UNAVAILABLE)
})?) })?)
})
})
.await
} }
} }

View file

@ -15,7 +15,7 @@ use serde::Serialize;
pub struct Args { pub struct Args {
#[clap(short, long, default_value = "127.0.0.1:3001")] #[clap(short, long, default_value = "127.0.0.1:3001")]
pub listen: String, pub listen: String,
#[clap(short, long, default_value = "http://web:3000")] #[clap(short, long, default_value = "https://echo.free.beeceptor.com")]
pub backend: String, pub backend: String,
#[clap(long)] #[clap(long)]
pub tls_cert: Option<String>, pub tls_cert: Option<String>,
@ -28,8 +28,8 @@ async fn build_state<E: IntoResponse + Clone + Serialize + Send + Sync + 'static
base: Arc<BaseAppState<E>>, base: Arc<BaseAppState<E>>,
_args: &Args, _args: &Args,
) -> impl HasAppState<E> + Evaluator<E> { ) -> impl HasAppState<E> + Evaluator<E> {
base.audited(AuditOptions::new(PathBuf::from("inbox_audit"))) base.extract_meta()
.extract_meta() .audited(AuditOptions::new(PathBuf::from("inbox_audit")))
} }
#[tokio::main] #[tokio::main]

View file

@ -61,3 +61,12 @@ pub(crate) fn new_safe_client(addr: &SocketAddr) -> reqwest::Result<reqwest::Cli
.tcp_keepalive(Some(std::time::Duration::from_secs(20))) .tcp_keepalive(Some(std::time::Duration::from_secs(20)))
.build() .build()
} }
pub(crate) fn new_backend_client() -> reqwest::Result<reqwest::Client> {
reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(10))
.timeout(std::time::Duration::from_secs(20))
.redirect(Policy::limited(4))
.tcp_keepalive(Some(std::time::Duration::from_secs(20)))
.build()
}

View file

@ -52,6 +52,7 @@ where
let p = self.stream.poll_next_unpin(cx); let p = self.stream.poll_next_unpin(cx);
match p { match p {
Poll::Ready(Some(Ok(mut data))) => { Poll::Ready(Some(Ok(mut data))) => {
log::trace!("Received {} bytes", data.len());
if data.len() > remaining_len { if data.len() > remaining_len {
self.limit.store(0, Ordering::Relaxed); self.limit.store(0, Ordering::Relaxed);
Poll::Ready(Some(Ok(data.split_to(remaining_len)))) Poll::Ready(Some(Ok(data.split_to(remaining_len))))