fix reverse proxying

Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-10-16 20:18:00 -05:00
parent b6dea3fd06
commit 92abe7b368
No known key found for this signature in database
7 changed files with 158 additions and 119 deletions

1
.gitignore vendored
View file

@ -1 +1,2 @@
/target
/inbox_audit

View file

@ -11,22 +11,22 @@ use crate::network::new_safe_client;
#[derive(Default, Clone)]
pub struct ClientCache {
clients: DashMap<IpAddr, (Instant, reqwest::Client)>,
safe_clients: DashMap<IpAddr, (Instant, reqwest::Client)>,
}
impl ClientCache {
pub fn new() -> Self {
Self {
clients: DashMap::new(),
safe_clients: DashMap::new(),
}
}
pub async fn with_client<'a, F, R>(&self, addr: &SocketAddr, f: F) -> R
pub async fn with_safe_client<'a, F, R>(&self, addr: &SocketAddr, f: F) -> R
where
F: FnOnce(reqwest::Client) -> Pin<Box<dyn std::future::Future<Output = R> + Send + 'a>>
+ 'a,
{
let client = self.clients.entry(addr.ip()).or_insert_with(|| {
let client = self.safe_clients.entry(addr.ip()).or_insert_with(|| {
(
Instant::now(),
#[allow(clippy::expect_used)]
@ -39,7 +39,7 @@ impl ClientCache {
pub fn gc(&self, dur: std::time::Duration) {
let now = Instant::now();
self.clients
self.safe_clients
.retain(|_, (created, _)| now.duration_since(*created) < dur);
}
}

View file

@ -1,6 +1,7 @@
use axum::response::IntoResponse;
use flate2::write::GzEncoder;
use flate2::{Compression, GzBuilder};
use reqwest::Url;
use serde::Serialize;
use std::collections::HashSet;
use std::fs::{File, OpenOptions};
@ -60,7 +61,7 @@ pub enum AuditError {
pub struct AuditState {
options: AuditOptions,
cur_file: Option<RwLock<HashMap<String, Mutex<GzEncoder<File>>>>>,
cur_file: RwLock<HashMap<String, Mutex<GzEncoder<File>>>>,
vacuum_counter: AtomicU32,
}
@ -71,13 +72,13 @@ impl AuditState {
}
Self {
options,
cur_file: None,
cur_file: RwLock::new(HashMap::new()),
vacuum_counter: AtomicU32::new(0),
}
}
pub async fn vacuum(&self) -> Result<(), AuditError> {
let read = self.cur_file.as_ref().unwrap().read().await;
let read = self.cur_file.read().await;
let in_use_files = read.keys().cloned().collect::<HashSet<_>>();
let files = std::fs::read_dir(&self.options.output)?
@ -109,9 +110,11 @@ impl AuditState {
pub async fn create_new_file(&self, name: &str) -> Result<(), AuditError> {
let time_str = chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string();
let mut write = self.cur_file.as_ref().unwrap().write().await;
let mut write = self.cur_file.write().await;
let full_name = format!("{}_{}.json.gz", name, time_str);
let sanitized_name = name.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
let full_name = format!("{}_{}.json.gz", sanitized_name, time_str);
let file = OpenOptions::new()
.create(true)
@ -141,7 +144,7 @@ impl AuditState {
.store(0, std::sync::atomic::Ordering::Relaxed);
}
let read = self.cur_file.as_ref().unwrap().read().await;
let read = self.cur_file.read().await;
if let Some(file) = read.get(name) {
let mut f = file.lock().await;
@ -162,18 +165,16 @@ impl AuditState {
return Ok(());
}
let mut write = self.cur_file.as_ref().unwrap().write().await;
drop(read);
{
self.create_new_file(name).await?;
let file = File::create(&self.options.output.join(name))?;
let read = self.cur_file.read().await;
let file = GzBuilder::new()
.filename(name)
.write(file, Compression::default());
let file = read.get(name).unwrap();
write.insert(name.to_string(), Mutex::new(file));
// this is deliberately out of order to make sure we don't create endless files if serialization fails
serde_json::to_writer(write.get(name).unwrap().lock().await.deref_mut(), &item)?;
serde_json::to_writer(file.lock().await.deref_mut(), &item)?;
}
Ok(())
}
@ -230,6 +231,7 @@ impl<
) -> (Disposition<E>, Option<serde_json::Value>) {
let (disp, ctx) = self.inner.evaluate(ctx, info).await;
log::debug!("Audit: ctx = {:?}", ctx);
if ctx
.as_ref()
.map(|c| c.get("skip_audit").is_some())
@ -246,7 +248,19 @@ impl<
let state = self.state.clone();
state.write(Self::name(), &item).await.ok();
let key = info
.activity
.as_ref()
.map(|a| {
a.actor
.as_ref()
.and_then(|s| Url::parse(s).ok()?.host_str().map(|s| s.to_string()))
})
.ok()
.flatten()
.unwrap_or_else(|| format!("tcp_{}", info.connect.ip()));
state.write(&key, &item).await.ok();
(disp, ctx)
}

View file

@ -26,8 +26,8 @@ use futures::TryStreamExt;
use model::ap::AnyObject;
use model::{ap, error::MisskeyError};
use network::stream::LimitedStream;
use network::Either;
use reqwest::{self};
use network::{new_backend_client, Either};
use reqwest;
use serde::ser::{SerializeMap, SerializeStruct};
use serde::Serialize;
@ -118,6 +118,7 @@ pub trait HasAppState<E: IntoResponse + 'static>: Clone {
/// Application state.
pub struct BaseAppState<E: IntoResponse + 'static> {
backend: reqwest::Url,
backend_client: reqwest::Client,
clients: ClientCache,
ctx_template: Option<serde_json::Value>,
_marker: PhantomData<E>,
@ -143,7 +144,6 @@ impl<E: IntoResponse + Send + Sync> Evaluator<E> for Arc<BaseAppState<E>> {
ctx: Option<serde_json::Value>,
_info: &APRequestInfo<'r>,
) -> (Disposition<E>, Option<serde_json::Value>) {
log::trace!("Evaluator fell through, accepting request");
(Disposition::Allow, ctx)
}
}
@ -155,6 +155,7 @@ impl<E: IntoResponse> BaseAppState<E> {
backend,
ctx_template: None,
clients: ClientCache::new(),
backend_client: new_backend_client().expect("Failed to create backend client"),
_marker: PhantomData,
}
}
@ -184,17 +185,21 @@ impl<
State(app): State<S>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
OriginalUri(uri): OriginalUri,
header: HeaderMap,
mut header: HeaderMap,
body: Body,
) -> Result<impl IntoResponse, MisskeyError> {
log::debug!(
"Received request from {} to {:?}",
addr,
uri.path_and_query()
);
let path_and_query = uri.path_and_query().ok_or(ERR_BAD_REQUEST)?;
let state = app.app_state();
app.app_state()
.clients
.with_client(&addr, |client| {
Box::pin(async move {
let req = client
header.remove("host");
let req = state
.backend_client
.request(
method.clone(),
state
@ -211,7 +216,11 @@ impl<
let url_clone = req.url().clone();
let resp = client.execute(req).await.map_err(|e: reqwest::Error| {
let resp = state
.backend_client
.execute(req)
.await
.map_err(|e: reqwest::Error| {
log::error!(
"Failed to execute request: {} ({} {})",
e,
@ -237,27 +246,31 @@ impl<
log::error!("Failed to build response: {}", e);
ERR_SERVICE_TEMPORARILY_UNAVAILABLE
})
})
})
.await
}
/// Handle incoming ActivityPub requests.
pub async fn inbox_handler(
method: Method,
State(app): State<S>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
OriginalUri(uri): OriginalUri,
header: HeaderMap,
mut header: HeaderMap,
body: Body,
) -> Result<impl IntoResponse, Either<MisskeyError, E>> {
log::debug!(
"Received request from {} to {:?}",
addr,
uri.path_and_query()
);
let path_and_query = uri.path_and_query().ok_or(Either::A(ERR_BAD_REQUEST))?;
let restricted_body_stream = LimitedStream::new(body.into_data_stream(), 32 << 20);
let restricted_body_stream = LimitedStream::new(body.into_data_stream(), 1 << 20);
let mut body = Vec::new();
restricted_body_stream
.try_for_each(|chunk| {
log::debug!("Received chunk of size {}", chunk.len());
body.extend_from_slice(&chunk);
futures::future::ready(Ok(()))
})
@ -299,14 +312,13 @@ impl<
_ => {}
}
app.clone()
header.remove("host");
let req = app
.app_state()
.clients
.with_client(&addr, |client| {
Box::pin(async move {
let req = client
.backend_client
.request(
Method::POST,
method.clone(),
app.app_state()
.backend
.join(&path_and_query.to_string())
@ -319,7 +331,12 @@ impl<
let url_clone = req.url().clone();
let resp = client.execute(req).await.map_err(|e: reqwest::Error| {
let resp =
app.app_state()
.backend_client
.execute(req)
.await
.map_err(|e: reqwest::Error| {
log::error!(
"Failed to execute request: {} ({} {})",
e,
@ -345,8 +362,5 @@ impl<
log::error!("Failed to build response: {}", e);
Either::A(ERR_SERVICE_TEMPORARILY_UNAVAILABLE)
})?)
})
})
.await
}
}

View file

@ -15,7 +15,7 @@ use serde::Serialize;
pub struct Args {
#[clap(short, long, default_value = "127.0.0.1:3001")]
pub listen: String,
#[clap(short, long, default_value = "http://web:3000")]
#[clap(short, long, default_value = "https://echo.free.beeceptor.com")]
pub backend: String,
#[clap(long)]
pub tls_cert: Option<String>,
@ -28,8 +28,8 @@ async fn build_state<E: IntoResponse + Clone + Serialize + Send + Sync + 'static
base: Arc<BaseAppState<E>>,
_args: &Args,
) -> impl HasAppState<E> + Evaluator<E> {
base.audited(AuditOptions::new(PathBuf::from("inbox_audit")))
.extract_meta()
base.extract_meta()
.audited(AuditOptions::new(PathBuf::from("inbox_audit")))
}
#[tokio::main]

View file

@ -61,3 +61,12 @@ pub(crate) fn new_safe_client(addr: &SocketAddr) -> reqwest::Result<reqwest::Cli
.tcp_keepalive(Some(std::time::Duration::from_secs(20)))
.build()
}
pub(crate) fn new_backend_client() -> reqwest::Result<reqwest::Client> {
reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(10))
.timeout(std::time::Duration::from_secs(20))
.redirect(Policy::limited(4))
.tcp_keepalive(Some(std::time::Duration::from_secs(20)))
.build()
}

View file

@ -52,6 +52,7 @@ where
let p = self.stream.poll_next_unpin(cx);
match p {
Poll::Ready(Some(Ok(mut data))) => {
log::trace!("Received {} bytes", data.len());
if data.len() > remaining_len {
self.limit.store(0, Ordering::Relaxed);
Poll::Ready(Some(Ok(data.split_to(remaining_len))))