fix reverse proxying
Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
parent
b6dea3fd06
commit
92abe7b368
7 changed files with 158 additions and 119 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1 +1,2 @@
|
|||
/target
|
||||
/inbox_audit
|
||||
|
|
|
@ -11,22 +11,22 @@ use crate::network::new_safe_client;
|
|||
#[derive(Default, Clone)]
|
||||
|
||||
pub struct ClientCache {
|
||||
clients: DashMap<IpAddr, (Instant, reqwest::Client)>,
|
||||
safe_clients: DashMap<IpAddr, (Instant, reqwest::Client)>,
|
||||
}
|
||||
|
||||
impl ClientCache {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
clients: DashMap::new(),
|
||||
safe_clients: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn with_client<'a, F, R>(&self, addr: &SocketAddr, f: F) -> R
|
||||
pub async fn with_safe_client<'a, F, R>(&self, addr: &SocketAddr, f: F) -> R
|
||||
where
|
||||
F: FnOnce(reqwest::Client) -> Pin<Box<dyn std::future::Future<Output = R> + Send + 'a>>
|
||||
+ 'a,
|
||||
{
|
||||
let client = self.clients.entry(addr.ip()).or_insert_with(|| {
|
||||
let client = self.safe_clients.entry(addr.ip()).or_insert_with(|| {
|
||||
(
|
||||
Instant::now(),
|
||||
#[allow(clippy::expect_used)]
|
||||
|
@ -39,7 +39,7 @@ impl ClientCache {
|
|||
|
||||
pub fn gc(&self, dur: std::time::Duration) {
|
||||
let now = Instant::now();
|
||||
self.clients
|
||||
self.safe_clients
|
||||
.retain(|_, (created, _)| now.duration_since(*created) < dur);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use axum::response::IntoResponse;
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::{Compression, GzBuilder};
|
||||
use reqwest::Url;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashSet;
|
||||
use std::fs::{File, OpenOptions};
|
||||
|
@ -60,7 +61,7 @@ pub enum AuditError {
|
|||
|
||||
pub struct AuditState {
|
||||
options: AuditOptions,
|
||||
cur_file: Option<RwLock<HashMap<String, Mutex<GzEncoder<File>>>>>,
|
||||
cur_file: RwLock<HashMap<String, Mutex<GzEncoder<File>>>>,
|
||||
vacuum_counter: AtomicU32,
|
||||
}
|
||||
|
||||
|
@ -71,13 +72,13 @@ impl AuditState {
|
|||
}
|
||||
Self {
|
||||
options,
|
||||
cur_file: None,
|
||||
cur_file: RwLock::new(HashMap::new()),
|
||||
vacuum_counter: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn vacuum(&self) -> Result<(), AuditError> {
|
||||
let read = self.cur_file.as_ref().unwrap().read().await;
|
||||
let read = self.cur_file.read().await;
|
||||
let in_use_files = read.keys().cloned().collect::<HashSet<_>>();
|
||||
|
||||
let files = std::fs::read_dir(&self.options.output)?
|
||||
|
@ -109,9 +110,11 @@ impl AuditState {
|
|||
pub async fn create_new_file(&self, name: &str) -> Result<(), AuditError> {
|
||||
let time_str = chrono::Utc::now().format("%Y-%m-%d_%H-%M-%S").to_string();
|
||||
|
||||
let mut write = self.cur_file.as_ref().unwrap().write().await;
|
||||
let mut write = self.cur_file.write().await;
|
||||
|
||||
let full_name = format!("{}_{}.json.gz", name, time_str);
|
||||
let sanitized_name = name.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
|
||||
|
||||
let full_name = format!("{}_{}.json.gz", sanitized_name, time_str);
|
||||
|
||||
let file = OpenOptions::new()
|
||||
.create(true)
|
||||
|
@ -141,7 +144,7 @@ impl AuditState {
|
|||
.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
let read = self.cur_file.as_ref().unwrap().read().await;
|
||||
let read = self.cur_file.read().await;
|
||||
|
||||
if let Some(file) = read.get(name) {
|
||||
let mut f = file.lock().await;
|
||||
|
@ -162,18 +165,16 @@ impl AuditState {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
let mut write = self.cur_file.as_ref().unwrap().write().await;
|
||||
drop(read);
|
||||
{
|
||||
self.create_new_file(name).await?;
|
||||
|
||||
let file = File::create(&self.options.output.join(name))?;
|
||||
let read = self.cur_file.read().await;
|
||||
|
||||
let file = GzBuilder::new()
|
||||
.filename(name)
|
||||
.write(file, Compression::default());
|
||||
let file = read.get(name).unwrap();
|
||||
|
||||
write.insert(name.to_string(), Mutex::new(file));
|
||||
|
||||
// this is deliberately out of order to make sure we don't create endless files if serialization fails
|
||||
serde_json::to_writer(write.get(name).unwrap().lock().await.deref_mut(), &item)?;
|
||||
serde_json::to_writer(file.lock().await.deref_mut(), &item)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -230,6 +231,7 @@ impl<
|
|||
) -> (Disposition<E>, Option<serde_json::Value>) {
|
||||
let (disp, ctx) = self.inner.evaluate(ctx, info).await;
|
||||
|
||||
log::debug!("Audit: ctx = {:?}", ctx);
|
||||
if ctx
|
||||
.as_ref()
|
||||
.map(|c| c.get("skip_audit").is_some())
|
||||
|
@ -246,7 +248,19 @@ impl<
|
|||
|
||||
let state = self.state.clone();
|
||||
|
||||
state.write(Self::name(), &item).await.ok();
|
||||
let key = info
|
||||
.activity
|
||||
.as_ref()
|
||||
.map(|a| {
|
||||
a.actor
|
||||
.as_ref()
|
||||
.and_then(|s| Url::parse(s).ok()?.host_str().map(|s| s.to_string()))
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
.unwrap_or_else(|| format!("tcp_{}", info.connect.ip()));
|
||||
|
||||
state.write(&key, &item).await.ok();
|
||||
|
||||
(disp, ctx)
|
||||
}
|
||||
|
|
64
src/lib.rs
64
src/lib.rs
|
@ -26,8 +26,8 @@ use futures::TryStreamExt;
|
|||
use model::ap::AnyObject;
|
||||
use model::{ap, error::MisskeyError};
|
||||
use network::stream::LimitedStream;
|
||||
use network::Either;
|
||||
use reqwest::{self};
|
||||
use network::{new_backend_client, Either};
|
||||
use reqwest;
|
||||
use serde::ser::{SerializeMap, SerializeStruct};
|
||||
use serde::Serialize;
|
||||
|
||||
|
@ -118,6 +118,7 @@ pub trait HasAppState<E: IntoResponse + 'static>: Clone {
|
|||
/// Application state.
|
||||
pub struct BaseAppState<E: IntoResponse + 'static> {
|
||||
backend: reqwest::Url,
|
||||
backend_client: reqwest::Client,
|
||||
clients: ClientCache,
|
||||
ctx_template: Option<serde_json::Value>,
|
||||
_marker: PhantomData<E>,
|
||||
|
@ -143,7 +144,6 @@ impl<E: IntoResponse + Send + Sync> Evaluator<E> for Arc<BaseAppState<E>> {
|
|||
ctx: Option<serde_json::Value>,
|
||||
_info: &APRequestInfo<'r>,
|
||||
) -> (Disposition<E>, Option<serde_json::Value>) {
|
||||
log::trace!("Evaluator fell through, accepting request");
|
||||
(Disposition::Allow, ctx)
|
||||
}
|
||||
}
|
||||
|
@ -155,6 +155,7 @@ impl<E: IntoResponse> BaseAppState<E> {
|
|||
backend,
|
||||
ctx_template: None,
|
||||
clients: ClientCache::new(),
|
||||
backend_client: new_backend_client().expect("Failed to create backend client"),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -184,17 +185,21 @@ impl<
|
|||
State(app): State<S>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
OriginalUri(uri): OriginalUri,
|
||||
header: HeaderMap,
|
||||
mut header: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<impl IntoResponse, MisskeyError> {
|
||||
log::debug!(
|
||||
"Received request from {} to {:?}",
|
||||
addr,
|
||||
uri.path_and_query()
|
||||
);
|
||||
let path_and_query = uri.path_and_query().ok_or(ERR_BAD_REQUEST)?;
|
||||
let state = app.app_state();
|
||||
|
||||
app.app_state()
|
||||
.clients
|
||||
.with_client(&addr, |client| {
|
||||
Box::pin(async move {
|
||||
let req = client
|
||||
header.remove("host");
|
||||
|
||||
let req = state
|
||||
.backend_client
|
||||
.request(
|
||||
method.clone(),
|
||||
state
|
||||
|
@ -211,7 +216,11 @@ impl<
|
|||
|
||||
let url_clone = req.url().clone();
|
||||
|
||||
let resp = client.execute(req).await.map_err(|e: reqwest::Error| {
|
||||
let resp = state
|
||||
.backend_client
|
||||
.execute(req)
|
||||
.await
|
||||
.map_err(|e: reqwest::Error| {
|
||||
log::error!(
|
||||
"Failed to execute request: {} ({} {})",
|
||||
e,
|
||||
|
@ -237,27 +246,31 @@ impl<
|
|||
log::error!("Failed to build response: {}", e);
|
||||
ERR_SERVICE_TEMPORARILY_UNAVAILABLE
|
||||
})
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Handle incoming ActivityPub requests.
|
||||
pub async fn inbox_handler(
|
||||
method: Method,
|
||||
State(app): State<S>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
OriginalUri(uri): OriginalUri,
|
||||
header: HeaderMap,
|
||||
mut header: HeaderMap,
|
||||
body: Body,
|
||||
) -> Result<impl IntoResponse, Either<MisskeyError, E>> {
|
||||
log::debug!(
|
||||
"Received request from {} to {:?}",
|
||||
addr,
|
||||
uri.path_and_query()
|
||||
);
|
||||
let path_and_query = uri.path_and_query().ok_or(Either::A(ERR_BAD_REQUEST))?;
|
||||
|
||||
let restricted_body_stream = LimitedStream::new(body.into_data_stream(), 32 << 20);
|
||||
let restricted_body_stream = LimitedStream::new(body.into_data_stream(), 1 << 20);
|
||||
|
||||
let mut body = Vec::new();
|
||||
|
||||
restricted_body_stream
|
||||
.try_for_each(|chunk| {
|
||||
log::debug!("Received chunk of size {}", chunk.len());
|
||||
body.extend_from_slice(&chunk);
|
||||
futures::future::ready(Ok(()))
|
||||
})
|
||||
|
@ -299,14 +312,13 @@ impl<
|
|||
_ => {}
|
||||
}
|
||||
|
||||
app.clone()
|
||||
header.remove("host");
|
||||
|
||||
let req = app
|
||||
.app_state()
|
||||
.clients
|
||||
.with_client(&addr, |client| {
|
||||
Box::pin(async move {
|
||||
let req = client
|
||||
.backend_client
|
||||
.request(
|
||||
Method::POST,
|
||||
method.clone(),
|
||||
app.app_state()
|
||||
.backend
|
||||
.join(&path_and_query.to_string())
|
||||
|
@ -319,7 +331,12 @@ impl<
|
|||
|
||||
let url_clone = req.url().clone();
|
||||
|
||||
let resp = client.execute(req).await.map_err(|e: reqwest::Error| {
|
||||
let resp =
|
||||
app.app_state()
|
||||
.backend_client
|
||||
.execute(req)
|
||||
.await
|
||||
.map_err(|e: reqwest::Error| {
|
||||
log::error!(
|
||||
"Failed to execute request: {} ({} {})",
|
||||
e,
|
||||
|
@ -345,8 +362,5 @@ impl<
|
|||
log::error!("Failed to build response: {}", e);
|
||||
Either::A(ERR_SERVICE_TEMPORARILY_UNAVAILABLE)
|
||||
})?)
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ use serde::Serialize;
|
|||
pub struct Args {
|
||||
#[clap(short, long, default_value = "127.0.0.1:3001")]
|
||||
pub listen: String,
|
||||
#[clap(short, long, default_value = "http://web:3000")]
|
||||
#[clap(short, long, default_value = "https://echo.free.beeceptor.com")]
|
||||
pub backend: String,
|
||||
#[clap(long)]
|
||||
pub tls_cert: Option<String>,
|
||||
|
@ -28,8 +28,8 @@ async fn build_state<E: IntoResponse + Clone + Serialize + Send + Sync + 'static
|
|||
base: Arc<BaseAppState<E>>,
|
||||
_args: &Args,
|
||||
) -> impl HasAppState<E> + Evaluator<E> {
|
||||
base.audited(AuditOptions::new(PathBuf::from("inbox_audit")))
|
||||
.extract_meta()
|
||||
base.extract_meta()
|
||||
.audited(AuditOptions::new(PathBuf::from("inbox_audit")))
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
|
|
|
@ -61,3 +61,12 @@ pub(crate) fn new_safe_client(addr: &SocketAddr) -> reqwest::Result<reqwest::Cli
|
|||
.tcp_keepalive(Some(std::time::Duration::from_secs(20)))
|
||||
.build()
|
||||
}
|
||||
|
||||
pub(crate) fn new_backend_client() -> reqwest::Result<reqwest::Client> {
|
||||
reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
.timeout(std::time::Duration::from_secs(20))
|
||||
.redirect(Policy::limited(4))
|
||||
.tcp_keepalive(Some(std::time::Duration::from_secs(20)))
|
||||
.build()
|
||||
}
|
||||
|
|
|
@ -52,6 +52,7 @@ where
|
|||
let p = self.stream.poll_next_unpin(cx);
|
||||
match p {
|
||||
Poll::Ready(Some(Ok(mut data))) => {
|
||||
log::trace!("Received {} bytes", data.len());
|
||||
if data.len() > remaining_len {
|
||||
self.limit.store(0, Ordering::Relaxed);
|
||||
Poll::Ready(Some(Ok(data.split_to(remaining_len))))
|
||||
|
|
Loading…
Reference in a new issue