diff --git a/Cargo.lock b/Cargo.lock index fcb2ad2..e2f67de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -73,6 +73,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "allocator-api2" version = "0.2.18" @@ -128,6 +143,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "anyhow" +version = "1.0.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74f37166d7d48a0284b99dd824694c26119c700b53bf0d1540cdb147dbdaaf13" + [[package]] name = "argon2" version = "0.5.3" @@ -185,10 +206,11 @@ version = "0.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" dependencies = [ + "brotli", "futures-core", + "futures-io", "memchr", "pin-project-lite", - "tokio", "zstd", "zstd-safe", ] @@ -309,6 +331,27 @@ dependencies = [ "generic-array", ] +[[package]] +name = "brotli" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -672,6 +715,21 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -688,6 +746,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-intrusive" version = "0.5.0" @@ -705,6 +774,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -723,8 +803,10 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1127,6 +1209,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" +[[package]] +name = "nohash-hasher" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" + [[package]] name = "nom" version = "7.1.3" @@ -1310,6 +1398,26 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pin-project" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.15" @@ -1518,10 +1626,12 @@ name = "replikey" version = "0.1.0" dependencies = [ "aes-gcm", + "anyhow", "argon2", "async-compression", "clap", "env_logger", + "futures", "log", "openssl", "pem-rfc7468", @@ -1538,8 +1648,10 @@ dependencies = [ "time", "tokio", "tokio-rustls", + "tokio-util", "toml", "x509-parser", + "yamux", ] [[package]] @@ -1951,6 +2063,12 @@ dependencies = [ "whoami", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "stringprep" version = "0.1.5" @@ -2122,6 +2240,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +dependencies = [ + "bytes", + "futures-core", + "futures-io", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "0.8.19" @@ -2382,6 +2514,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.26.5" @@ -2618,6 +2760,22 @@ dependencies = [ "time", ] +[[package]] +name = "yamux" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a31b5e376a8b012bee9c423acdbb835fc34d45001cfa3106236a624e4b738028" +dependencies = [ + "futures", + "log", + "nohash-hasher", + "parking_lot", + "pin-project", + "rand", + "static_assertions", + "web-time", +] + [[package]] name = "yasna" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index 9bd9755..46b31eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,11 +26,15 @@ serde = { version = "1.0.214", features = ["derive"], optional = true } sqlx = { version = "0.8.2", optional = true, default-features = false, features = ["tls-none", "postgres"] } tokio = { version = "1.41.0", features = ["rt", "rt-multi-thread", "macros", "net", "io-util", "sync"], optional = true } rustls = { version = "0.23.16", optional = true } -async-compression = { version = "0.4.17", optional = true, features = ["tokio", "zstd"] } +async-compression = { version = "0.4.17", optional = true, features = ["futures-io"] } +yamux = { version = "0.13.3", optional = true } +futures = { version = "0.3.31", optional = true } +tokio-util = { version = "0.7.12", features = ["compat"], optional = true } +anyhow = "1.0.92" [features] -default = ["keygen", "networking", "service", "remote-crl", "setup-postgres"] -asyncio = ["dep:tokio"] +default = ["keygen", "networking", "service", "remote-crl", "setup-postgres", "yamux", "zstd", "brotli"] +asyncio = ["dep:tokio", "dep:futures", "dep:tokio-util"] keygen = ["dep:rcgen", "dep:pem-rfc7468", "dep:rpassword", "dep:argon2", "dep:sha2", "dep:aes-gcm", "dep:time"] networking = ["asyncio", "dep:tokio-rustls", "dep:rustls", "dep:async-compression"] test-crosscheck-openssl = ["dep:openssl"] @@ -41,6 +45,9 @@ setup-postgres = ["dep:sqlx"] stat-service = ["networking", "serde"] rustls = ["dep:rustls"] async-compression = ["dep:async-compression"] +yamux = ["dep:yamux", "networking"] +zstd = ["async-compression/zstd"] +brotli = ["async-compression/brotli"] [[bin]] name = "replikey" diff --git a/src/bin/replikey.rs b/src/bin/replikey.rs index 93ed220..694d2de 100644 --- a/src/bin/replikey.rs +++ b/src/bin/replikey.rs @@ -96,9 +96,12 @@ fn main() { print_feature!("keygen"); print_feature!("asyncio"); print_feature!("networking"); + print_feature!("yamux"); print_feature!("service"); print_feature!("remote-crl"); print_feature!("setup-postgres"); + print_feature!("zstd"); + print_feature!("brotli"); } #[cfg(feature = "keygen")] SubCommand::Cert(cert) => match cert.subcmd { diff --git a/src/lib.rs b/src/lib.rs index 4722e34..df31506 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,8 @@ +#![forbid(unsafe_code)] + #[cfg(feature = "keygen")] pub mod cert; pub mod ops; +pub mod transport; pub mod fs_crypt; diff --git a/src/ops/cert.rs b/src/ops/cert.rs index 7539c5e..38559b3 100644 --- a/src/ops/cert.rs +++ b/src/ops/cert.rs @@ -185,7 +185,7 @@ pub fn dump_certificate_params(params: &CertificateParams) { for b in serial_number.as_ref() { print!("{:02x}", b); } - println!(""); + println!(); } for san in ¶ms.subject_alt_names { match san { diff --git a/src/ops/network.rs b/src/ops/network.rs index c9cd9bf..fafb655 100644 --- a/src/ops/network.rs +++ b/src/ops/network.rs @@ -1,9 +1,16 @@ use std::{io::Cursor, net::ToSocketAddrs, sync::Arc}; use clap::Parser; +use futures::{ + io::{ + AsyncBufRead as FuturesAsyncBufRead, AsyncRead as FuturesAsyncRead, AsyncReadExt, + AsyncWrite as FuturesAsyncWrite, BufReader, + }, + AsyncWriteExt as _, +}; use tokio::{ - io::{AsyncBufRead, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}, - net::TcpStream, + io::AsyncWriteExt as _, + net::{TcpSocket, TcpStream}, }; use tokio_rustls::{ rustls::{ @@ -14,6 +21,9 @@ use tokio_rustls::{ }, TlsAcceptor, TlsConnector, }; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + +use crate::transport::{yamux::YamuxStreamManager, MultiplexedC2S, MultiplexedS2C, StreamManager}; #[derive(Debug, Parser)] pub struct NetworkCommand { @@ -30,22 +40,45 @@ pub enum NetworkSubCommand { ForwardProxy(ForwardProxyCommand), } +#[derive(Debug, Clone)] +pub struct ReverseProxySpec { + pub sni: String, + pub target: String, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, thiserror::Error)] +pub enum ReverseProxySpecParseError { + #[error("Too many parts")] + TooManyParts, + #[error("No SNI provided")] + NoSni, + #[error("No target provided")] + NoTarget, +} + +fn reverse_proxy_spec_parser(input: &str) -> Result { + let mut parts = input.split('/'); + let sni = parts.next().ok_or(ReverseProxySpecParseError::NoSni)?; + let target = parts.next().ok_or(ReverseProxySpecParseError::NoTarget)?; + + if parts.next().is_some() { + return Err(ReverseProxySpecParseError::TooManyParts); + } + + Ok(ReverseProxySpec { + sni: sni.to_string(), + target: target.to_string(), + }) +} + #[derive(Debug, Parser)] pub struct ReverseProxyCommand { #[clap(short, long)] pub listen: String, - #[clap(long)] - pub redis_sni: String, - - #[clap(long)] - pub redis_target: String, - - #[clap(long)] - pub postgres_sni: String, - - #[clap(long)] - pub postgres_target: String, + #[clap(long, help = "your.sni.tld/target:1234", value_parser = reverse_proxy_spec_parser)] + pub target: Vec, #[clap(long, help = "Certificate")] pub cert: String, @@ -58,6 +91,60 @@ pub struct ReverseProxyCommand { #[clap(long, help = "CRLs to use")] pub crl: Vec, + + #[clap(long, help = "Transport to use", default_value = "plain")] + pub transport: Transport, + + #[clap(long, help = "Compression to use", default_value = "none")] + pub compression: Compression, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, Default)] +pub enum Transport { + #[default] + Plain, + #[cfg(feature = "yamux")] + YAMux, +} + +impl std::str::FromStr for Transport { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "plain" => Ok(Transport::Plain), + #[cfg(feature = "yamux")] + "yamux" => Ok(Transport::YAMux), + _ => Err(format!("Unknown transport: {}", s)), + } + } +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, Default)] +pub enum Compression { + #[default] + None, + #[cfg(feature = "brotli")] + Brotli, + #[cfg(feature = "zstd")] + Zstd, +} + +impl std::str::FromStr for Compression { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "none" => Ok(Compression::None), + #[cfg(feature = "brotli")] + "brotli" => Ok(Compression::Brotli), + #[cfg(feature = "zstd")] + "zstd" => Ok(Compression::Zstd), + _ => Err(format!("Unknown compression: {}", s)), + } + } } #[derive(Debug, Parser)] @@ -82,48 +169,77 @@ pub struct ForwardProxyCommand { #[clap(long)] pub crl: Vec, + + #[clap(long, help = "Transport to use", default_value = "plain")] + pub transport: Transport, + + #[clap(long, help = "Compression to use", default_value = "none")] + pub compression: Compression, } -fn compressor_to(w: impl AsyncWrite + Unpin) -> impl AsyncWrite + Unpin { - async_compression::tokio::write::ZstdEncoder::new(w) +fn compressor_to<'s>( + comp: Compression, + w: impl FuturesAsyncWrite + Unpin + Send + 's, +) -> Box { + match comp { + Compression::None => Box::new(w), + #[cfg(feature = "brotli")] + Compression::Brotli => Box::new(async_compression::futures::write::BrotliEncoder::new(w)), + #[cfg(feature = "zstd")] + Compression::Zstd => Box::new(async_compression::futures::write::ZstdEncoder::new(w)), + } } -fn decompressor_from(r: impl AsyncBufRead + Unpin) -> impl AsyncRead + Unpin { - async_compression::tokio::bufread::ZstdDecoder::new(r) +fn decompressor_from<'s>( + comp: Compression, + r: impl FuturesAsyncBufRead + Unpin + Send + 's, +) -> Box { + match comp { + Compression::None => Box::new(BufReader::new(r)), + #[cfg(feature = "brotli")] + Compression::Brotli => Box::new(async_compression::futures::bufread::BrotliDecoder::new(r)), + #[cfg(feature = "zstd")] + Compression::Zstd => Box::new(async_compression::futures::bufread::ZstdDecoder::new(r)), + } } -async fn send_static_string(w: &mut (impl AsyncWrite + Unpin), s: &str) -> tokio::io::Result<()> { - let mut cursor = Cursor::new(s); - - tokio::io::copy(&mut cursor, &mut compressor_to(w)).await?; +async fn send_static_string( + c: Compression, + w: &mut (impl FuturesAsyncWrite + Send + Unpin), + s: &str, +) -> futures::io::Result<()> { + futures::io::copy(&mut Cursor::new(s).compat(), &mut compressor_to(c, w)).await?; Ok(()) } async fn copy_bidirectional_compressed( - local: impl AsyncRead + AsyncWrite + Unpin, - remote: impl AsyncRead + AsyncWrite + Unpin, -) -> tokio::io::Result<(u64, u64)> { - let (mut local_rx, mut local_tx) = tokio::io::split(local); - let (remote_rx, remote_tx) = tokio::io::split(remote); + comp: Compression, + local: impl FuturesAsyncRead + FuturesAsyncWrite + Send + Unpin, + remote: impl FuturesAsyncRead + FuturesAsyncWrite + Send + Unpin, +) -> futures::io::Result<(u64, u64)> { + let (mut local_rx, mut local_tx) = local.split(); + let (remote_rx, remote_tx) = remote.split(); let remote_rx_buf = BufReader::new(remote_rx); - let mut remote_tx_comp = compressor_to(remote_tx); - let mut remote_rx_decomp = decompressor_from(remote_rx_buf); - - log::info!("Starting transfer"); + let mut remote_tx_comp = compressor_to(comp, remote_tx); + let mut remote_rx_decomp = decompressor_from(comp, remote_rx_buf); let uplink = async move { - let res = tokio::io::copy(&mut local_rx, &mut remote_tx_comp).await; - let shutdown = remote_tx_comp.shutdown().await; + log::info!("Starting transfer uplink"); + let res = futures::io::copy(&mut local_rx, &mut remote_tx_comp).await; + log::info!("Finished uplink"); + let shutdown = remote_tx_comp.close().await; let res = res?; shutdown?; tokio::io::Result::Ok(res) }; let downlink = async move { - let res = tokio::io::copy(&mut remote_rx_decomp, &mut local_tx).await; - let shutdown = local_tx.shutdown().await; + log::info!("Starting transfer downlink"); + let res = futures::io::copy(&mut remote_rx_decomp, &mut local_tx).await; + log::info!("Finished downlink"); + let shutdown = local_tx.close().await; let res = res?; shutdown?; tokio::io::Result::Ok(res) @@ -184,93 +300,190 @@ pub async fn reverse_proxy(opts: ReverseProxyCommand) -> Result<(), Box TcpSocket::new_v4()?, + std::net::SocketAddr::V6(_) => TcpSocket::new_v6()?, + }; + // default options with keepalive + socket.set_reuseaddr(true)?; + socket.set_keepalive(true)?; + socket.bind(listen_addr)?; + let listener = socket.listen(1024)?; log::info!("Listening on: {}", opts.listen); - let (redis_sni, postgres_sni, redis_target, postgres_target) = ( - Arc::new(opts.redis_sni.clone()), - Arc::new(opts.postgres_sni.clone()), - opts.redis_target.clone(), - opts.postgres_target.clone(), + for spec in &opts.target { + if let Err(e) = spec.target.to_socket_addrs() { + eprintln!( + "Failed to resolve target address {}, are you sure it's correct? {}", + spec.target, e + ); + return Ok(()); + } + } + + let sni_to_target = Arc::new( + opts.target + .into_iter() + .map(|spec| (spec.sni, spec.target)) + .collect::>(), ); - if let Err(e) = opts.redis_target.to_socket_addrs() { - eprintln!("Failed to resolve redis target: {}", e); - return Ok(()); - } - if let Err(e) = opts.postgres_target.to_socket_addrs() { - eprintln!("Failed to resolve postgres target: {}", e); - return Ok(()); - } + loop { - let (pt_stream, _) = match listener.accept().await { + let (pt_stream, addr) = match listener.accept().await { Ok(s) => s, Err(e) => { eprintln!("Failed to accept connection: {}", e); continue; } }; + log::info!("Accepted TCP connection from: {}", addr); let acceptor = acceptor.clone(); - let (redis_sni, postgres_sni, redis_target, postgres_target) = ( - redis_sni.clone(), - postgres_sni.clone(), - redis_target.clone(), - postgres_target.clone(), - ); + let sni_to_target = Arc::clone(&sni_to_target); tokio::spawn(async move { match acceptor.accept(pt_stream).await { - Ok(mut tls) => match tls.get_ref().1.server_name().map(|s| s.to_string()) { - Some(sni) if sni == *redis_sni => { - log::info!( - "Accepted Redis connection for {:?}", - tls.get_ref().1.server_name() - ); - match tokio::net::TcpStream::connect(&redis_target).await { - Ok(redis) => { - if let Err(e) = copy_bidirectional_compressed(redis, tls).await { - eprintln!("Failed to copy data: {}", e); + Ok(mut tls) => { + log::info!("Accepted TLS connection from: {}", addr); + + let multiplexer = match opts.transport { + Transport::Plain => None, + #[cfg(feature = "yamux")] + Transport::YAMux => { + log::info!("Using YAMux transport"); + Some(YamuxStreamManager::new()) + } + }; + + match tls.get_ref().1.server_name().map(|s| s.to_string()) { + Some(sni) => match sni_to_target.get(&sni).cloned() { + Some(target) => { + log::info!( + "Accepted connection for {:?}, forwarding to {}", + tls.get_ref().1.server_name(), + target + ); + match multiplexer { + Some(m) => { + let mut inner_listener = + m.wrap_s2c_connection(tls.compat()).await; + tokio::spawn(async move { + loop { + log::info!("Waiting for new stream (YAMux)"); + match inner_listener.accept().await { + Ok(Some(mut tls_multiplexed)) => { + let target = target.clone(); + log::info!("Accepted new stream (YAMux)"); + tokio::spawn(async move { + let target = match TcpStream::connect( + target, + ) + .await + { + Ok(s) => s, + Err(e) => { + log::error!( + "Failed to connect to target: {}", + e + ); + match tls_multiplexed + .close() + .await + { + Ok(_) => {} + Err(e) => { + log::error!( + "Failed to close multiplexed stream: {}", + e + ); + } + } + return; + } + }; + if let Err(e) = + copy_bidirectional_compressed( + opts.compression, + target.compat(), + tls_multiplexed, + ) + .await + { + log::error!( + "Failed to copy data: {}", + e + ); + } + }); + } + Ok(None) => { + log::info!("No more streams to accept"); + break; + } + Err(e) => { + log::error!( + "Failed to accept stream: {}", + e + ); + break; + } + } + } + futures::future::poll_fn(|cx| { + inner_listener.poll_close(cx) + }) + .await + .ok(); + }); + } + None => match tokio::net::TcpStream::connect(target).await { + Ok(target) => { + if let Err(e) = copy_bidirectional_compressed( + opts.compression, + target.compat(), + tls.compat(), + ) + .await + { + log::error!("Failed to copy data: {}", e); + } + } + Err(e) => { + eprintln!("Failed to connect to target: {}", e); + tls.shutdown() + .await + .expect("Failed to shutdown TLS stream"); + } + }, } } - Err(e) => { - eprintln!("Failed to connect to redis: {}", e); - tls.shutdown().await.expect("Failed to shutdown TLS stream"); + None => { + log::warn!("Accepted connection for {:?}, but SNI {} does not match any configured SNI", tls.get_ref().1.server_name(), sni); + let mut compat = tls.compat_write(); + send_static_string( + opts.compression, + &mut compat, + format!("SNI {} does not match any configured SNI", sni) + .as_str(), + ) + .await + .expect("Failed to send static string"); + compat.close().await.expect("Failed to shutdown TLS stream"); } + }, + _ => { + log::error!("No SNI provided"); + let mut compat = tls.compat_write(); + send_static_string(opts.compression, &mut compat, "No SNI provided") + .await + .expect("Failed to send static string"); + compat.close().await.expect("Failed to shutdown TLS stream"); } } - Some(sni) if sni == *postgres_sni => { - log::info!( - "Accepted Postgres connection for {:?}", - tls.get_ref().1.server_name() - ); - match tokio::net::TcpStream::connect(&postgres_target).await { - Ok(postgres) => { - if let Err(e) = copy_bidirectional_compressed(postgres, tls).await { - eprintln!("Failed to copy data: {}", e); - } - } - Err(e) => { - eprintln!("Failed to connect to postgres: {}", e); - tls.shutdown().await.expect("Failed to shutdown TLS stream"); - } - } - } - Some(sni) => { - log::warn!("Accepted connection for {:?}, but SNI {} does not match any configured SNI", tls.get_ref().1.server_name(), sni); - send_static_string( - &mut tls, - format!("SNI {} does not match any configured SNI", sni).as_str(), - ) - .await - .expect("Failed to send static string"); - tls.shutdown().await.expect("Failed to shutdown TLS stream"); - } - _ => { - send_static_string(&mut tls, "No SNI provided") - .await - .expect("Failed to send static string"); - eprintln!("No SNI provided"); - tls.shutdown().await.expect("Failed to shutdown TLS stream"); - } - }, + } Err(e) => { eprintln!("Failed to accept connection: {}", e); } @@ -279,7 +492,7 @@ pub async fn reverse_proxy(opts: ReverseProxyCommand) -> Result<(), Box Result<(), Box> { +pub async fn forward_proxy(opts: ForwardProxyCommand) -> anyhow::Result<()> { let (_, ca_pem) = x509_parser::pem::parse_x509_pem(&std::fs::read(&opts.ca)?)?; let (_, ca_cert) = x509_parser::parse_x509_certificate(&ca_pem.contents)?; let mut cert_store = RootCertStore::empty(); @@ -329,38 +542,199 @@ pub async fn forward_proxy(opts: ForwardProxyCommand) -> Result<(), Box s.as_str(), - None => opts.target.as_str(), + Some(ref s) => s.to_string(), + None => opts.target.to_string(), }) .expect("Failed to parse SNI"); - loop { - let (pt_stream, _) = match listener.accept().await { - Ok(s) => s, - Err(e) => { - eprintln!("Failed to accept connection: {}", e); - continue; - } - }; - let connector = connector.clone(); - let tls_stream = match TcpStream::connect(&opts.target).await { - Ok(s) => s, - Err(e) => { - eprintln!("Failed to connect to target: {}", e); - continue; - } - }; - let sni = sni.to_owned(); - tokio::spawn(async move { - match connector.connect(sni, tls_stream).await { - Ok(tls) => { - if let Err(e) = copy_bidirectional_compressed(pt_stream, tls).await { - eprintln!("Failed to copy data: {}", e); + + match opts.transport { + Transport::Plain => loop { + let (local_pt, _) = match listener.accept().await { + Ok(s) => s, + Err(e) => { + eprintln!("Failed to accept connection: {}", e); + continue; + } + }; + let target = opts.target.clone(); + let connector = connector.clone(); + let sni = sni.clone(); + tokio::spawn(async move { + let addrs = match target.to_socket_addrs() { + Ok(a) => a, + Err(e) => { + eprintln!("Failed to resolve target address: {}", e); + return; + } + }; + let mut last_err = None; + for addr in addrs { + log::info!("Trying to connect to: {}", addr); + let socket = match match addr { + std::net::SocketAddr::V4(_) => TcpSocket::new_v4(), + std::net::SocketAddr::V6(_) => TcpSocket::new_v6(), + } { + Ok(s) => s, + Err(e) => { + eprintln!("Failed to create socket: {}", e); + return; + } + }; + match socket.set_keepalive(true) { + Ok(_) => {} + Err(e) => { + log::warn!("Failed to set keepalive: {}", e); + } + } + + match socket.connect(addr).await { + Ok(target_pt) => { + let sni = sni.to_owned(); + + match connector.connect(sni, target_pt).await { + Ok(tls) => { + if let Err(e) = copy_bidirectional_compressed( + opts.compression, + local_pt.compat(), + tls.compat(), + ) + .await + { + eprintln!("Failed to copy data: {}", e); + } + } + Err(e) => { + eprintln!("Failed to connect to target: {}", e); + } + } + last_err = None; + break; + } + Err(e) => { + last_err = Some(e); + } } } - Err(e) => { - eprintln!("Failed to connect to target: {}", e); + if let Some(e) = last_err { + log::error!("None of the target addresses worked: {}", e); } + }); + }, + #[cfg(feature = "yamux")] + Transport::YAMux => { + let mut retries = 5; + loop { + if retries == 0 { + eprintln!("Too many retries, giving up"); + break; + } + let conn = TcpStream::connect(&opts.target).await?; + log::info!("Connected to target"); + let tls_stream = match connector.connect(sni.clone(), conn).await { + Ok(s) => s, + Err(e) => { + eprintln!("Failed to connect to target: {}", e); + retries -= 1; + continue; + } + }; + let mut inner_listener = YamuxStreamManager::new() + .wrap_c2s_connection(tls_stream.compat()) + .await; + + loop { + log::info!("Waiting for new TCP connection (YAMux)"); + let new_conn = tokio::select! { + _ = futures::future::poll_fn(|cx| inner_listener.poll_next_inbound(cx)) => { + None + } + conn = listener.accept() => { + match conn { + Ok((pt_stream, _)) => { + log::info!("Starting new C2S connection (YAMux)"); + match inner_listener.start_c2s().await { + Ok(s) => Some((pt_stream, s)), + Err(e) => { + log::error!("Failed to start C2S connection: {}", e); + continue; + } + } + } + Err(e) => { + log::error!("Failed to accept connection: {}", e); + continue; + } + } + } + }; + + if let Some((pt_stream, tls_multiplexed)) = new_conn { + log::info!("Obtained new connection (YAMux)"); + tokio::spawn(async move { + if let Err(e) = copy_bidirectional_compressed( + opts.compression, + pt_stream.compat(), + tls_multiplexed, + ) + .await + { + eprintln!("Failed to copy data: {}", e); + } + }); + } else { + log::warn!("No more streams to accept, next connection will be retried"); + break; + } + } + + futures::future::poll_fn(|cx| inner_listener.poll_close(cx)) + .await + .ok(); } - }); - } + + anyhow::bail!("Too many retries"); + } + }; +} + +#[cfg(test)] +mod tests { + use super::*; + + async fn test_compression_method(algo: Compression) { + let (r, w) = tokio::io::duplex(1024); + let data = "Hello, world!".repeat(1024); + + let mut w = compressor_to(algo, w.compat()); + + let mut r = decompressor_from(algo, BufReader::new(r.compat())); + + tokio::join!( + async { + w.write_all(data.as_bytes()).await.unwrap(); + w.close().await.unwrap(); + }, + async { + let mut output = Vec::new(); + r.read_to_end(&mut output).await.unwrap(); + assert_eq!(output.len(), data.len()); + assert_eq!(output, data.as_bytes()); + } + ); + } + + macro_rules! test_compression { + ($name:ident, $algo:expr) => { + #[tokio::test] + async fn $name() { + test_compression_method($algo).await; + } + }; + } + + test_compression!(test_compression_none, Compression::None); + #[cfg(feature = "brotli")] + test_compression!(test_compression_brotli, Compression::Brotli); + #[cfg(feature = "zstd")] + test_compression!(test_compression_zstd, Compression::Zstd); } diff --git a/src/ops/postgres.rs b/src/ops/postgres.rs index cf530e1..357e9fe 100644 --- a/src/ops/postgres.rs +++ b/src/ops/postgres.rs @@ -141,7 +141,7 @@ pub async fn setup_postgres_pub( opts: SetupPublicationCommand, ) -> Result<(), PostgresSetupError> { let mut postgres = PgConnection::connect( - &&connection_string + &connection_string .map(|s| s.to_string()) .or_else(postgres_connection_string_from_env) .ok_or(PostgresSetupError::MissingConnection)?, @@ -196,7 +196,7 @@ pub async fn drop_postgres_pub( opts: DropPublicationCommand, ) -> Result<(), PostgresSetupError> { let mut postgres = PgConnection::connect( - &&connection_string + &connection_string .map(|s| s.to_string()) .or_else(postgres_connection_string_from_env) .ok_or(PostgresSetupError::MissingConnection)?, @@ -220,7 +220,7 @@ pub async fn add_table_to_postgres_pub( opts: AddTableCommand, ) -> Result<(), PostgresSetupError> { let mut postgres = PgConnection::connect( - &&connection_string + &connection_string .map(|s| s.to_string()) .or_else(postgres_connection_string_from_env) .ok_or(PostgresSetupError::MissingConnection)?, @@ -247,7 +247,7 @@ pub async fn drop_table_from_postgres_pub( opts: DropTableCommand, ) -> Result<(), PostgresSetupError> { let mut postgres = PgConnection::connect( - &&connection_string + &connection_string .map(|s| s.to_string()) .or_else(postgres_connection_string_from_env) .ok_or(PostgresSetupError::MissingConnection)?, @@ -274,7 +274,7 @@ pub async fn setup_postgres_sub( opts: SetupSubscriptionCommand, ) -> Result<(), PostgresSetupError> { let mut postgres = PgConnection::connect( - &&connection_string + &connection_string .map(|s| s.to_string()) .or_else(postgres_connection_string_from_env) .ok_or(PostgresSetupError::MissingConnection)?, @@ -324,7 +324,7 @@ pub async fn drop_postgres_sub( opts: DropSubscriptionCommand, ) -> Result<(), PostgresSetupError> { let mut postgres = PgConnection::connect( - &&connection_string + &connection_string .map(|s| s.to_string()) .or_else(postgres_connection_string_from_env) .ok_or(PostgresSetupError::MissingConnection)?, diff --git a/src/ops/service.rs b/src/ops/service.rs index eae09c1..62c1944 100644 --- a/src/ops/service.rs +++ b/src/ops/service.rs @@ -3,7 +3,9 @@ use std::path::Path; use clap::Parser; use serde::Deserialize; -use super::network::ReverseProxyCommand; +use crate::ops::network::{ForwardProxyCommand, ReverseProxySpec}; + +use super::network::{Compression, ReverseProxyCommand, Transport}; const DEF_CONFIG_FILE: &str = "/etc/replikey.toml"; const CA_CERT: &str = "ca.pem"; @@ -47,23 +49,40 @@ pub struct ConnectionConfig { #[derive(Debug, Deserialize)] pub struct MasterConfig { listen: String, - redis_sni: String, - redis_target: String, - postgres_sni: String, - postgres_target: String, + transport: Option, + compression: Option, + redis: MasterServiceSpec, + postgres: MasterServiceSpec, workdir: Option, crl: Vec, } #[derive(Debug, Deserialize)] pub struct SlaveConfig { - listen: String, - redis_sni: String, - postgres_sni: String, + target: String, + transport: Option, + compression: Option, + redis: SlaveServiceSpec, + postgres: SlaveServiceSpec, workdir: Option, crl: Vec, } +#[derive(Debug, Deserialize)] +pub struct MasterServiceSpec { + pub target: String, + pub sni: String, + pub multiplex: Option, +} + +#[derive(Debug, Deserialize)] +pub struct SlaveServiceSpec { + pub target: String, + pub sni: String, + pub listen: String, + pub multiplex: Option, +} + pub fn service_replicate_master(config: String) { let config = std::fs::read_to_string(config).unwrap(); let config: Config = toml::from_str(&config).expect("Failed to parse config"); @@ -80,23 +99,38 @@ pub fn service_replicate_master(config: String) { let master_conf = config.connection.master.as_ref().unwrap(); let cmd = ReverseProxyCommand { listen: master_conf.listen.clone(), - redis_sni: master_conf.redis_sni.clone(), - redis_target: master_conf.redis_target.clone(), - postgres_sni: master_conf.postgres_sni.clone(), - postgres_target: master_conf.postgres_target.clone(), - + transport: master_conf.transport.unwrap_or_default(), + compression: master_conf.compression.unwrap_or_default(), + target: vec![ + ReverseProxySpec { + sni: master_conf.redis.sni.clone(), + target: master_conf.redis.target.clone(), + }, + ReverseProxySpec { + sni: master_conf.postgres.sni.clone(), + target: master_conf.postgres.target.clone(), + }, + ], cert: Path::new(SERVER_CERT).to_string_lossy().to_string(), key: Path::new(SERVER_KEY).to_string_lossy().to_string(), ca: Path::new(CA_CERT).to_string_lossy().to_string(), crl: config.connection.master.as_ref().unwrap().crl.clone(), }; - tokio::runtime::Runtime::new() + let res = tokio::runtime::Runtime::new() .unwrap() - .block_on(crate::ops::network::reverse_proxy(cmd)) - .unwrap(); + .block_on(crate::ops::network::reverse_proxy(cmd)); - println!("Replication master started"); + match res { + Ok(_) => { + log::error!("Unexpected termination of reverse proxy"); + std::process::exit(1); + } + Err(e) => { + log::error!("Failed to start reverse proxy: {}", e); + std::process::exit(1); + } + } } pub fn service_replicate_slave(config: String) { @@ -113,23 +147,59 @@ pub fn service_replicate_slave(config: String) { } let slave_conf = config.connection.slave.as_ref().unwrap(); - let cmd = ReverseProxyCommand { - listen: slave_conf.listen.clone(), - redis_sni: slave_conf.redis_sni.clone(), - redis_target: slave_conf.redis_sni.clone(), - postgres_sni: slave_conf.postgres_sni.clone(), - postgres_target: slave_conf.postgres_sni.clone(), - + let cmd_redis = ForwardProxyCommand { + listen: slave_conf.redis.listen.clone(), + transport: slave_conf.transport.unwrap_or_default(), + compression: slave_conf.compression.unwrap_or_default(), + sni: Some(slave_conf.redis.sni.clone()), + target: slave_conf.target.clone(), + cert: Path::new(CLIENT_CERT).to_string_lossy().to_string(), + key: Path::new(CLIENT_KEY).to_string_lossy().to_string(), + ca: Path::new(CA_CERT).to_string_lossy().to_string(), + crl: config.connection.slave.as_ref().unwrap().crl.clone(), + }; + let cmd_postgres = ForwardProxyCommand { + listen: slave_conf.postgres.listen.clone(), + transport: slave_conf.transport.unwrap_or_default(), + compression: slave_conf.compression.unwrap_or_default(), + sni: Some(slave_conf.postgres.sni.clone()), + target: slave_conf.target.clone(), cert: Path::new(CLIENT_CERT).to_string_lossy().to_string(), key: Path::new(CLIENT_KEY).to_string_lossy().to_string(), ca: Path::new(CA_CERT).to_string_lossy().to_string(), crl: config.connection.slave.as_ref().unwrap().crl.clone(), }; - tokio::runtime::Runtime::new() - .unwrap() - .block_on(crate::ops::network::reverse_proxy(cmd)) - .unwrap(); + enum ProxyId { + Redis, + Postgres, + } - println!("Replication slave started"); + let res = tokio::runtime::Runtime::new() + .unwrap() + .block_on(async move { + tokio::select! { + r = crate::ops::network::forward_proxy(cmd_redis) => { + (ProxyId::Redis, r) + }, + r = crate::ops::network::forward_proxy(cmd_postgres) => { + (ProxyId::Postgres, r) + }, + } + }); + + match res { + (ProxyId::Redis, Ok(_)) => { + log::error!("Unexpected termination of redis proxy"); + } + (ProxyId::Postgres, Ok(_)) => { + log::error!("Unexpected termination of postgres proxy"); + } + (ProxyId::Redis, Err(e)) => { + log::error!("Failed to start redis proxy: {}", e); + } + (ProxyId::Postgres, Err(e)) => { + log::error!("Failed to start postgres proxy: {}", e); + } + } } diff --git a/src/transport/mod.rs b/src/transport/mod.rs new file mode 100644 index 0000000..b60d7e5 --- /dev/null +++ b/src/transport/mod.rs @@ -0,0 +1,203 @@ +use std::fmt::Debug; + +use futures::{ + future::BoxFuture, + io::{AsyncRead, AsyncWrite}, + FutureExt, +}; + +#[cfg(feature = "yamux")] +pub mod yamux; + +pub trait DynStream: AsyncRead + AsyncWrite + Unpin + Send {} + +impl DynStream for T {} + +pub trait StreamManager<'c, C: 'c>: Send + Sync { + type S2CConn: MultiplexedS2C; + type C2SConn: MultiplexedC2S; + + fn wrap_c2s_connection(&'c self, stream: C) -> BoxFuture<'c, Self::C2SConn>; + fn wrap_s2c_connection(&'c self, stream: C) -> BoxFuture<'c, Self::S2CConn>; + + fn shutdown(&self) -> BoxFuture<'c, ()>; +} + +// hangs rustc for some reason.. +/* + +pub fn manager_to_dyn<'c, C: 'c>(inner: impl StreamManager<'c, C> + 'c) -> DynStreamManager<'c, C> { + DynStreamManager { + inner: Box::new(inner), + } +} + +pub struct DynStreamManager<'c, C: 'c> { + inner: Box< + dyn StreamManager< + 'c, + C, + S2CConn = Box< + dyn MultiplexedS2C> + + 'c, + >, + C2SConn = Box< + dyn MultiplexedC2S> + + 'c, + >, + > + 'c, + >, +} + +impl<'c, C> StreamManager<'c, C> for DynStreamManager<'c, C> +where + C: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'c, + >::Error: Error + 'static, + >::Error: Error + 'static, +{ + type S2CConn = DynMultiplexedS2CConnection<'c, C>; + type C2SConn = DynMultiplexedC2SConnection<'c, C>; + + fn wrap_c2s_connection(&'c self, stream: C) -> BoxFuture<'c, Self::C2SConn> { + self.inner + .wrap_c2s_connection(stream) + .map(|x| DynMultiplexedC2SConnection { + inner: Box::new(x.map(|x| { + x.map(|s| s.map(|s| Box::new(s) as _)) + .map_err(anyhow::Error::new) + }) as _), + }) + .boxed() + } + + fn wrap_s2c_connection(&'c self, stream: C) -> BoxFuture<'c, Self::S2CConn> { + self.inner + .wrap_s2c_connection(stream) + .map(|x| DynMultiplexedS2CConnection { + inner: Box::new(x.map(|x| { + x.map(|s| s.map(|s| Box::new(s) as _)) + .map_err(anyhow::Error::new) as _ + })), + }) + .boxed() + } + + fn shutdown(&self) -> BoxFuture<'c, ()> { + async move { self.inner.shutdown().await }.boxed() + } +} + +*/ +pub trait MultiplexedS2C: Send + Sync { + type Error: Debug + Send + Sync + Sized; + type Stream: DynStream; + + fn accept(&mut self) -> BoxFuture<'_, Result, Self::Error>>; + + fn map(self, map: F) -> S2CMap + where + F: Fn(Result, Self::Error>) -> Result, E> + Send + Sync, + S: DynStream, + Self: Sized, + { + S2CMap { + inner: self, + map, + _phantom: std::marker::PhantomData, + } + } +} + +pub struct S2CMap { + inner: MC, + map: F, + _phantom: std::marker::PhantomData, +} + +impl MultiplexedS2C + for S2CMap +where + MC: MultiplexedS2C, + F: Fn(Result, MC::Error>) -> Result, E> + Send + Sync, +{ + type Error = E; + type Stream = S; + + fn accept(&mut self) -> BoxFuture<'_, Result, Self::Error>> { + let inner = &mut self.inner; + let mapf = &self.map; + inner.accept().map(mapf).boxed() + } +} + +pub trait MultiplexedC2S: Send + Sync { + type Error: Debug + Send + Sync + Sized; + type Stream: DynStream; + + fn start_c2s(&mut self) -> BoxFuture<'_, Result>; + + fn map(self, map: F) -> C2SMap + where + F: Fn(Result) -> Result + Send + Sync, + S: DynStream, + Self: Sized, + { + C2SMap { + inner: self, + map, + _phantom: std::marker::PhantomData, + } + } +} + +pub struct C2SMap { + inner: MC, + map: F, + _phantom: std::marker::PhantomData, +} + +impl MultiplexedC2S + for C2SMap +where + MC: MultiplexedC2S, + F: Fn(Result) -> Result + Send + Sync, +{ + type Error = E; + type Stream = S; + + fn start_c2s(&mut self) -> BoxFuture<'_, Result> { + let inner = &mut self.inner; + let map = &self.map; + inner.start_c2s().map(map).boxed() + } +} + +pub struct DynMultiplexedS2CConnection<'c, C: 'c> { + inner: Box> + 'c>, +} + +impl<'c, C> MultiplexedS2C for DynMultiplexedS2CConnection<'c, C> { + type Error = anyhow::Error; + type Stream = Box; + fn accept(&mut self) -> BoxFuture<'_, Result, Self::Error>> { + async move { + self.inner + .accept() + .await + .map(|x| x.map(|x| Box::new(x) as _)) + } + .boxed() + } +} + +pub struct DynMultiplexedC2SConnection<'c, C: 'c> { + inner: Box> + 'c>, +} + +impl<'c, C> MultiplexedC2S for DynMultiplexedC2SConnection<'c, C> { + type Error = anyhow::Error; + type Stream = Box; + fn start_c2s(&mut self) -> BoxFuture<'_, Result> { + async move { self.inner.start_c2s().await.map(|x| Box::new(x) as _) }.boxed() + } +} diff --git a/src/transport/yamux.rs b/src/transport/yamux.rs new file mode 100644 index 0000000..aba71a0 --- /dev/null +++ b/src/transport/yamux.rs @@ -0,0 +1,80 @@ +use futures::{ + future::BoxFuture, + io::{AsyncRead, AsyncWrite}, + FutureExt, +}; +use yamux::Config; + +pub(crate) fn yamux_config() -> yamux::Config { + let mut conf = Config::default(); + + // we are in mTLS so be generous but still prevent accidental DoS + conf.set_read_after_close(true) + .set_max_num_streams(512) + .set_max_connection_receive_window(Some((256 << 10) * 512)); + + conf +} + +pub struct YamuxStreamManager { + config: yamux::Config, +} + +impl Default for YamuxStreamManager { + fn default() -> Self { + Self::new() + } +} + +impl YamuxStreamManager { + pub fn new() -> Self { + Self { + config: yamux_config(), + } + } +} + +impl<'c, C: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'c> super::StreamManager<'c, C> + for YamuxStreamManager +{ + type S2CConn = yamux::Connection; + type C2SConn = yamux::Connection; + + fn wrap_c2s_connection(&self, stream: C) -> BoxFuture<'c, Self::C2SConn> { + let config = self.config.clone(); + async move { yamux::Connection::new(stream, config, yamux::Mode::Client) }.boxed() + } + + fn wrap_s2c_connection(&self, stream: C) -> BoxFuture<'c, Self::S2CConn> { + let config = self.config.clone(); + async move { yamux::Connection::new(stream, config, yamux::Mode::Server) }.boxed() + } + + fn shutdown(&self) -> BoxFuture<'c, ()> { + async {}.boxed() + } +} + +impl super::MultiplexedS2C + for yamux::Connection +{ + type Error = yamux::ConnectionError; + type Stream = yamux::Stream; + + fn accept(&mut self) -> BoxFuture<'_, Result, Self::Error>> { + futures::future::poll_fn(|cx| self.poll_next_inbound(cx)) + .map(|x| x.transpose()) + .boxed() + } +} + +impl super::MultiplexedC2S + for yamux::Connection +{ + type Error = yamux::ConnectionError; + type Stream = yamux::Stream; + + fn start_c2s(&mut self) -> BoxFuture<'_, Result> { + futures::future::poll_fn(|cx| self.poll_new_outbound(cx)).boxed() + } +} diff --git a/tests/mtls_integration.rs b/tests/mtls_integration.rs new file mode 100644 index 0000000..ed7d76a --- /dev/null +++ b/tests/mtls_integration.rs @@ -0,0 +1,968 @@ +#![allow(clippy::all)] + +use std::{ + path::PathBuf, + sync::{atomic::AtomicU32, Once}, +}; + +use clap::Parser; +use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use replikey::{ + cert::UsageType, + ops::{ + cert::{CertCommand, CertSubCommand}, + network::{NetworkCommand, NetworkSubCommand}, + }, +}; +use rustls::crypto::{aws_lc_rs, CryptoProvider}; +use time::OffsetDateTime; +use tokio::{ + io::{AsyncReadExt as _, AsyncWriteExt as _}, + task::JoinSet, +}; +use tokio_util::compat::TokioAsyncReadCompatExt; + +static NEXT_PORT: AtomicU32 = AtomicU32::new(12311); + +static TEST_SETUP: Once = Once::new(); + +fn setup_once() { + TEST_SETUP.call_once(|| { + if std::env::var("RUST_LOG").is_err() { + std::env::set_var("RUST_LOG", "info"); + } + env_logger::init(); + CryptoProvider::install_default(aws_lc_rs::default_provider()) + .expect("Failed to install crypto provider"); + }); +} + +fn next_port() -> u32 { + NEXT_PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst) +} + +async fn test_stream_sequence< + const WRITE_SIZE: usize, + L: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, + R: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, +>( + left: L, + right: R, +) { + let (mut l_rx, mut l_tx) = left.split(); + + let (mut r_rx, mut r_tx) = right.split(); + + let l_write = async move { + let mut buf = [0u8; WRITE_SIZE]; + for i in 0..100 { + buf.iter_mut() + .enumerate() + .for_each(|(j, x)| *x = ((i + j) ^ 123) as u8); + l_tx.write_all(&buf).await?; + } + l_tx.close().await?; + + Ok::<_, futures::io::Error>(()) + }; + let r_write = async move { + let mut buf = [0u8; WRITE_SIZE]; + for i in 0..100 { + buf.iter_mut() + .enumerate() + .for_each(|(j, x)| *x = ((i + j) ^ 321) as u8); + r_tx.write_all(&buf).await?; + } + r_tx.close().await?; + Ok(()) + }; + let l_read = async move { + let mut buf = [0u8; WRITE_SIZE]; + for i in 0..100 { + l_rx.read_exact(&mut buf).await?; + for j in 0..WRITE_SIZE { + assert_eq!(buf[j], ((i + j) ^ 321) as u8); + } + } + assert_eq!(l_rx.read(&mut buf).await?, 0, "expected eof"); + Ok(()) + }; + let r_read = async move { + let mut buf = [0u8; WRITE_SIZE]; + for i in 0..100 { + r_rx.read_exact(&mut buf).await?; + for j in 0..WRITE_SIZE { + assert_eq!(buf[j], ((i + j) ^ 123) as u8); + } + } + assert_eq!(r_rx.read(&mut buf).await?, 0, "expected eof"); + Ok(()) + }; + + tokio::try_join!(l_write, r_write, l_read, r_read).expect("stream sequence failed"); +} + +#[tokio::test] +async fn in_memory_stream_works() { + let (left, right) = tokio::io::duplex(1024); + test_stream_sequence::<1024, _, _>(left.compat(), right.compat()).await; + let (left, right) = tokio::io::duplex(1024); + test_stream_sequence::<1, _, _>(left.compat(), right.compat()).await; + let (left, right) = tokio::io::duplex(1024); + test_stream_sequence::<3543, _, _>(left.compat(), right.compat()).await; +} + +fn test_ca(dir: &PathBuf) -> (PathBuf, PathBuf) { + let id = OffsetDateTime::now_utc().unix_timestamp_nanos(); + + let path = dir.join(format!("ca-{}/", id)); + + std::fs::create_dir_all(&path).expect("failed to create ca dir"); + + let cmd = [ + "replikey".to_string(), + "create-ca".to_string(), + "--valid-days".to_string(), + "365".to_string(), + "--dn-common-name".to_string(), + "replikey-test-ca".to_string(), + "--output".to_string(), + path.to_string_lossy().to_string(), + ]; + + let parsed = CertCommand::parse_from(&cmd); + + if let CertSubCommand::CreateCa(opts) = parsed.subcmd { + replikey::ops::cert::create_ca(opts, false); + } else { + panic!("failed to parse create-ca"); + } + + let pem = path.join("ca.pem"); + let key = path.join("ca.key"); + + assert!(pem.exists()); + assert!(key.exists()); + + (pem, key) +} + +struct PeerCert { + key: PathBuf, + csr: PathBuf, + self_signed: PathBuf, + signed_cert: Option, +} + +impl PeerCert { + fn signed_by(&self, usage: UsageType, ca_key: &PathBuf) -> PeerCert { + let id = OffsetDateTime::now_utc().unix_timestamp_nanos(); + + let path = ca_key.parent().unwrap().join(format!("signed-{}/", id)); + + std::fs::create_dir_all(&path).expect("failed to create signed dir"); + + let cmd = [ + "replikey".to_string(), + "sign-server-csr".to_string(), + "--valid-days".to_string(), + "365".to_string(), + "--ca-dir".to_string(), + ca_key.parent().unwrap().to_string_lossy().to_string(), + "--input-csr".to_string(), + self.csr.to_string_lossy().to_string(), + "-d".to_string(), + "target1.local".to_string(), + "-d".to_string(), + "target2.local".to_string(), + "-d".to_string(), + "*.targets.local".to_string(), + "--output".to_string(), + format!("{}/server-signed.pem", path.to_string_lossy()), + ]; + + let parsed = CertCommand::parse_from(&cmd); + + if let CertSubCommand::SignServerCSR(opts) = parsed.subcmd { + replikey::ops::cert::sign_csr(opts, usage, false); + } else { + panic!("failed to parse sign-server-csr"); + } + + PeerCert { + key: self.key.clone(), + csr: self.csr.clone(), + self_signed: self.self_signed.clone(), + signed_cert: Some(path.join("server-signed.pem")), + } + } +} + +fn test_server_cert(dir: &PathBuf) -> PeerCert { + let id = OffsetDateTime::now_utc().unix_timestamp_nanos(); + + let path = dir.join(format!("server-{}/", id)); + + std::fs::create_dir_all(&path).expect("failed to create server dir"); + + let cmd = [ + "replikey".to_string(), + "create-server".to_string(), + "--valid-days".to_string(), + "365".to_string(), + "--dn-common-name".to_string(), + "replikey-test-server".to_string(), + "-d".to_string(), + "target1.local".to_string(), + "-d".to_string(), + "target2.local".to_string(), + "-d".to_string(), + "*.targets.local".to_string(), + "--output".to_string(), + path.to_string_lossy().to_string(), + ]; + + let parsed = CertCommand::parse_from(&cmd); + + if let CertSubCommand::CreateServer(opts) = parsed.subcmd { + replikey::ops::cert::create_server(opts); + } else { + panic!("failed to parse create-server"); + } + + let pem = path.join("server.pem"); + let key = path.join("server.key"); + let csr = path.join("server.csr"); + + assert!(pem.exists()); + assert!(key.exists()); + + PeerCert { + key, + csr, + self_signed: pem, + signed_cert: None, + } +} + +fn test_client_cert(dir: &PathBuf) -> PeerCert { + let id = OffsetDateTime::now_utc().unix_timestamp_nanos(); + + let path = dir.join(format!("client-{}/", id)); + + std::fs::create_dir_all(&path).expect("failed to create client dir"); + + let cmd = [ + "replikey".to_string(), + "create-client".to_string(), + "--valid-days".to_string(), + "365".to_string(), + "--dn-common-name".to_string(), + "replikey-test-client".to_string(), + "--output".to_string(), + path.to_string_lossy().to_string(), + ]; + + let parsed = CertCommand::parse_from(&cmd); + + if let CertSubCommand::CreateClient(opts) = parsed.subcmd { + replikey::ops::cert::create_client(opts); + } else { + panic!("failed to parse create-client"); + } + + let pem = path.join("client.pem"); + let key = path.join("client.key"); + let csr = path.join("client.csr"); + + assert!(pem.exists()); + assert!(key.exists()); + + PeerCert { + key, + csr, + self_signed: pem, + signed_cert: None, + } +} + +async fn start_test_target_1() -> u32 { + let port = next_port(); + let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) + .await + .expect("failed to bind listener"); + + tokio::spawn(async move { + loop { + let (mut stream, _) = listener.accept().await.expect("failed to accept"); + log::info!("accepted connection on target1"); + tokio::spawn(async move { + let (mut rx, mut tx) = stream.split(); + let n = tokio::io::copy(&mut rx, &mut tx) + .await + .expect("failed to copy"); + log::info!("copied {} bytes", n); + tx.shutdown().await.expect("failed to close"); + }); + } + }); + + port +} + +async fn start_test_target_2() -> u32 { + let port = next_port(); + let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) + .await + .expect("failed to bind listener"); + + tokio::spawn(async move { + loop { + let (mut stream, _) = listener.accept().await.expect("failed to accept"); + log::info!("accepted connection on target2"); + tokio::spawn(async move { + let (mut rx, mut tx) = stream.split(); + let mut buf = [0u8; 1024]; + let mut total = 0; + loop { + let n = rx.read(&mut buf).await.expect("failed to read"); + total += n; + if n == 0 { + break; + } + buf[..n].iter_mut().for_each(|x| *x = !*x); + tx.write_all(&buf[..n]).await.expect("failed to write"); + } + log::info!("copied {} bytes", total); + tx.shutdown().await.expect("failed to close"); + }); + } + }); + + port +} + +async fn start_reverse_proxy( + target1: u32, + target2: u32, + cert: &PathBuf, + key: &PathBuf, + ca_cert: &PathBuf, + transport: &str, +) -> u32 { + let port = next_port(); + + let cmd = [ + "replikey".to_string(), + "reverse-proxy".to_string(), + "--listen".to_string(), + format!("127.0.0.1:{}", port), + "--target".to_string(), + format!("target1.local/127.0.0.1:{}", target1), + "--target".to_string(), + format!("target2.local/127.0.0.1:{}", target2), + "--target".to_string(), + format!("target1.targets.local/127.0.0.1:{}", target1), + "--target".to_string(), + format!("target2.targets.local/127.0.0.1:{}", target2), + "--cert".to_string(), + cert.to_string_lossy().to_string(), + "--key".to_string(), + key.to_string_lossy().to_string(), + "--ca".to_string(), + ca_cert.to_string_lossy().to_string(), + "--transport".to_string(), + transport.to_string(), + ]; + + let parsed = NetworkCommand::parse_from(&cmd); + + if let NetworkSubCommand::ReverseProxy(opts) = parsed.subcmd { + tokio::spawn(async move { + replikey::ops::network::reverse_proxy(opts) + .await + .expect("failed to run reverse proxy"); + }); + + port + } else { + panic!("failed to parse reverse-proxy"); + } +} + +async fn start_forward_proxy( + addr: &str, + sni: &str, + cert: &PathBuf, + key: &PathBuf, + ca_cert: &PathBuf, + transport: &str, +) -> u32 { + let port = next_port(); + + let cmd = [ + "replikey".to_string(), + "forward-proxy".to_string(), + "--listen".to_string(), + format!("127.0.0.1:{}", port), + "--target".to_string(), + addr.to_string(), + "--sni".to_string(), + sni.to_string(), + "--cert".to_string(), + cert.to_string_lossy().to_string(), + "--key".to_string(), + key.to_string_lossy().to_string(), + "--ca".to_string(), + ca_cert.to_string_lossy().to_string(), + "--transport".to_string(), + transport.to_string(), + ]; + + let parsed = NetworkCommand::parse_from(&cmd); + + if let NetworkSubCommand::ForwardProxy(opts) = parsed.subcmd { + tokio::spawn(async move { + replikey::ops::network::forward_proxy(opts) + .await + .expect("failed to run forward proxy"); + }); + + port + } else { + panic!("failed to parse forward-proxy"); + } +} + +async fn wait_for_port(port: u32) { + loop { + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => { + panic!("timeout waiting for port"); + } + res = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) => { + if res.is_ok() { + break; + } + } + } + } +} + +async fn should_not_get_anything(port: u32) { + let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .expect("failed to connect"); + + let res = stream.read(&mut [0u8; 1024]).await; + + match res { + Ok(n) => { + assert_eq!(n, 0); + } + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::ConnectionReset); + } + } +} + +async fn expect_target1_signature(port: u32) { + let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .expect("failed to connect"); + let (mut rx, mut tx) = tokio::io::split(stream); + + for batch in 0..8 { + log::debug!("expect_target1_signature batch {}", batch); + let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes(); + let expect = data.clone(); + + let mut write_size = (1..4096).cycle(); + + let (_, rx_data) = tokio::join!( + async { + while !data.is_empty() { + let n = write_size.next().unwrap().min(data.len()); + tx.write_all(&data[..n]).await.expect("failed to write"); + log::trace!("wrote {} bytes for target1", n); + data = data.split_off(n); + } + tx.flush().await.expect("failed to flush"); + log::info!("wrote {} bytes for target1", expect.len()); + }, + async { + let mut rxed = Vec::new(); + let mut buf = [0u8; 1024]; + loop { + let n = rx.read(&mut buf).await.expect("failed to read"); + log::trace!("read {} bytes for target1", n); + if n == 0 { + break; + } + rxed.extend_from_slice(&buf[..n]); + if rxed.len() >= expect.len() { + break; + } + } + rxed + } + ); + + assert_eq!(rx_data, expect); + } + + tx.shutdown().await.expect("failed to close"); +} + +async fn expect_target1_signature_buffered(port: u32) { + let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .expect("failed to connect"); + + let (mut rx, mut tx) = tokio::io::split(stream); + + let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes(); + let expect = data.clone(); + + let mut write_size = (1..4096).cycle(); + + let (_, rx_data) = tokio::join!( + async { + while !data.is_empty() { + let n = write_size.next().unwrap().min(data.len()); + tx.write_all(&data[..n]).await.expect("failed to write"); + data = data.split_off(n); + } + tx.shutdown().await.expect("failed to close"); + }, + async { + let mut rxed = Vec::new(); + let mut buf = [0u8; 1024]; + loop { + let n = rx.read(&mut buf).await.expect("failed to read"); + if n == 0 { + break; + } + rxed.extend_from_slice(&buf[..n]); + if rxed.len() >= expect.len() { + break; + } + } + rxed + } + ); + + assert_eq!(rx_data, expect); +} + +async fn expect_target2_signature(port: u32) { + let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .expect("failed to connect"); + + let (mut rx, mut tx) = tokio::io::split(stream); + + for _batch in 0..8 { + let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes(); + let expect = data.iter().map(|x| !x).collect::>(); + + let mut write_size = (1..4096).cycle(); + + let (_, rx_data) = tokio::join!( + async { + while !data.is_empty() { + let n = write_size.next().unwrap().min(data.len()); + tx.write_all(&data[..n]).await.expect("failed to write"); + data = data.split_off(n); + } + tx.flush().await.expect("failed to flush"); + }, + async { + let mut rxed = Vec::new(); + let mut buf = [0u8; 1024]; + loop { + let n = rx.read(&mut buf).await.expect("failed to read"); + if n == 0 { + break; + } + rxed.extend_from_slice(&buf[..n]); + if rxed.len() >= expect.len() { + break; + } + } + rxed + } + ); + + assert_eq!(rx_data, expect); + } + + tx.shutdown().await.expect("failed to close"); +} + +async fn expect_target2_signature_buffered(port: u32) { + let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .expect("failed to connect"); + + let (mut rx, mut tx) = tokio::io::split(stream); + + let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes(); + + let expect = data.iter().map(|x| !x).collect::>(); + + let mut write_size = (1..4096).cycle(); + + let (_, rx_data) = tokio::join!( + async { + while !data.is_empty() { + let n = write_size.next().unwrap().min(data.len()); + tx.write_all(&data[..n]).await.expect("failed to write"); + data = data.split_off(n); + } + tx.shutdown().await.expect("failed to close"); + }, + async { + let mut rxed = Vec::new(); + let mut buf = [0u8; 1024]; + loop { + let n = rx.read(&mut buf).await.expect("failed to read"); + if n == 0 { + break; + } + rxed.extend_from_slice(&buf[..n]); + if rxed.len() >= expect.len() { + break; + } + } + rxed + } + ); + + assert_eq!(rx_data, expect); +} + +#[tokio::test] +async fn test_mtls_integrated() { + setup_once(); + let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca")); + let server_cert = test_server_cert(&PathBuf::from("target/test-server")); + let client_cert = test_client_cert(&PathBuf::from("target/test-client")); + + let server_cert = server_cert.signed_by(UsageType::Server, &ca_key); + let client_cert = client_cert.signed_by(UsageType::Client, &ca_key); + + let target1 = start_test_target_1().await; + let target2 = start_test_target_2().await; + log::info!("target1: {}, target2: {}", target1, target2); + wait_for_port(target1).await; + wait_for_port(target2).await; + expect_target1_signature(target1).await; + expect_target2_signature(target2).await; + + let proxy_port_self_signed = start_reverse_proxy( + target1, + target2, + &server_cert.self_signed, + &server_cert.key, + &ca_cert, + "plain", + ) + .await; + wait_for_port(proxy_port_self_signed).await; + log::info!("reverse_proxy_port_self_signed: {}", proxy_port_self_signed); + + let proxy_port = start_reverse_proxy( + target1, + target2, + server_cert.signed_cert.as_ref().unwrap(), + &server_cert.key, + &ca_cert, + "plain", + ) + .await; + wait_for_port(proxy_port).await; + log::info!("reverse_proxy_port: {}", proxy_port); + + for server_signed in [false, true] { + let forward_proxy_target1_self_signed = start_forward_proxy( + &format!( + "127.0.0.1:{}", + if server_signed { + proxy_port + } else { + proxy_port_self_signed + } + ), + "target1.local", + &client_cert.self_signed, + &client_cert.key, + &ca_cert, + "plain", + ) + .await; + log::info!( + "forward_proxy_target1_self_signed: {}", + forward_proxy_target1_self_signed + ); + + let forward_proxy_target1 = start_forward_proxy( + &format!( + "127.0.0.1:{}", + if server_signed { + proxy_port + } else { + proxy_port_self_signed + } + ), + "target1.local", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + "plain", + ) + .await; + log::info!("forward_proxy_target1: {}", forward_proxy_target1); + + let forward_proxy_target2 = start_forward_proxy( + &format!( + "127.0.0.1:{}", + if server_signed { + proxy_port + } else { + proxy_port_self_signed + } + ), + "target2.local", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + "plain", + ) + .await; + log::info!("forward_proxy_target2: {}", forward_proxy_target2); + + let forward_proxy_target1_wildcard = start_forward_proxy( + &format!( + "127.0.0.1:{}", + if server_signed { + proxy_port + } else { + proxy_port_self_signed + } + ), + "target1.targets.local", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + "plain", + ) + .await; + + let forward_proxy_target2_wildcard: u32 = start_forward_proxy( + &format!( + "127.0.0.1:{}", + if server_signed { + proxy_port + } else { + proxy_port_self_signed + } + ), + "target2.targets.local", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + "plain", + ) + .await; + log::info!( + "forward_proxy_target2_wildcard: {}", + forward_proxy_target2_wildcard + ); + + let forward_proxy_unknown_sni = start_forward_proxy( + &format!( + "127.0.0.1:{}", + if server_signed { + proxy_port + } else { + proxy_port_self_signed + } + ), + "some-other.place", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + "plain", + ) + .await; + log::info!("forward_proxy_unknown_sni: {}", forward_proxy_unknown_sni); + + wait_for_port(forward_proxy_target1).await; + wait_for_port(forward_proxy_target2).await; + wait_for_port(forward_proxy_target1_wildcard).await; + wait_for_port(forward_proxy_target2_wildcard).await; + wait_for_port(forward_proxy_unknown_sni).await; + log::info!("Test: self-signed cert"); + should_not_get_anything(forward_proxy_target1_self_signed).await; + + if server_signed { + let mut js = JoinSet::new(); + for _ in 0..100 { + js.spawn(expect_target1_signature(forward_proxy_target1)); + js.spawn(expect_target2_signature(forward_proxy_target2)); + js.spawn(expect_target1_signature(forward_proxy_target1_wildcard)); + js.spawn(expect_target2_signature(forward_proxy_target2_wildcard)); + js.spawn(should_not_get_anything(forward_proxy_unknown_sni)); + } + js.join_all().await; + } else { + let mut js = JoinSet::new(); + for _ in 0..100 { + js.spawn(should_not_get_anything(forward_proxy_target1)); + js.spawn(should_not_get_anything(forward_proxy_target2)); + js.spawn(should_not_get_anything(forward_proxy_target1_wildcard)); + js.spawn(should_not_get_anything(forward_proxy_target2_wildcard)); + js.spawn(should_not_get_anything(forward_proxy_unknown_sni)); + } + js.join_all().await; + } + } +} + +#[tokio::test] +async fn test_mtls_role_reverse() { + let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca")); + + let server_cert = test_server_cert(&PathBuf::from("target/test-server")); + + let server_cert = server_cert.signed_by(UsageType::Server, &ca_key); + + let client_cert = test_client_cert(&PathBuf::from("target/test-client")); + + let client_cert = client_cert.signed_by(UsageType::Client, &ca_key); + + let target1 = start_test_target_1().await; + + let target2 = start_test_target_2().await; + + wait_for_port(target1).await; + wait_for_port(target2).await; + + let server_cert = server_cert.signed_by(UsageType::Server, &ca_key); + let client_cert = client_cert.signed_by(UsageType::Client, &ca_key); + + let proxy_port = start_reverse_proxy( + target1, + target2, + server_cert.signed_cert.as_ref().unwrap(), + &server_cert.key, + &ca_cert, + "plain", + ) + .await; + + wait_for_port(proxy_port).await; + + let forward_proxy_target1 = start_forward_proxy( + &format!("127.0.0.1:{}", proxy_port), + "target1.local", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + "plain", + ) + .await; + wait_for_port(forward_proxy_target1).await; + + expect_target1_signature(forward_proxy_target1).await; + + let proxy_port_reversed = start_reverse_proxy( + target1, + target2, + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + "plain", + ) + .await; + + wait_for_port(proxy_port_reversed).await; + + let forward_proxy_target1_reversed = start_forward_proxy( + &format!("127.0.0.1:{}", proxy_port_reversed), + "target1.local", + server_cert.signed_cert.as_ref().unwrap(), + &server_cert.key, + &ca_cert, + "plain", + ) + .await; + wait_for_port(forward_proxy_target1_reversed).await; + + should_not_get_anything(forward_proxy_target1_reversed).await; +} + +#[tokio::test] +async fn test_mtls_yamux_transport() { + setup_once(); + + let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca")); + + let server_cert = test_server_cert(&PathBuf::from("target/test-server")); + let server_cert = server_cert.signed_by(UsageType::Server, &ca_key); + let client_cert = test_client_cert(&PathBuf::from("target/test-client")); + let client_cert = client_cert.signed_by(UsageType::Client, &ca_key); + + let target1 = start_test_target_1().await; + + let target2 = start_test_target_2().await; + wait_for_port(target1).await; + wait_for_port(target2).await; + + let proxy_port = start_reverse_proxy( + target1, + target2, + server_cert.signed_cert.as_ref().unwrap(), + &server_cert.key, + &ca_cert, + "yamux", + ) + .await; + wait_for_port(proxy_port).await; + wait_for_port(proxy_port).await; + + let forward_proxy_target1 = start_forward_proxy( + &format!("127.0.0.1:{}", proxy_port), + "target1.local", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + "yamux", + ) + .await; + + let forward_proxy_target2 = start_forward_proxy( + &format!("127.0.0.1:{}", proxy_port), + "target2.local", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + "yamux", + ) + .await; + wait_for_port(forward_proxy_target1).await; + wait_for_port(forward_proxy_target2).await; + + log::info!("forward_proxy_target1: {}", forward_proxy_target1); + log::info!("forward_proxy_target2: {}", forward_proxy_target2); + + let mut js = JoinSet::new(); + + for _ in 0..128 { + js.spawn(expect_target1_signature_buffered(forward_proxy_target1)); + js.spawn(expect_target2_signature_buffered(forward_proxy_target2)); + } + + js.join_all().await; +}