diff --git a/Cargo.lock b/Cargo.lock index e2f67de..99318cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -208,9 +208,9 @@ checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" dependencies = [ "brotli", "futures-core", - "futures-io", "memchr", "pin-project-lite", + "tokio", "zstd", "zstd-safe", ] @@ -372,9 +372,9 @@ checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "cc" -version = "1.1.31" +version = "1.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" +checksum = "67b9470d453346108f93a59222a9a1a5724db32d0a4727b7ab7ace4b4d822dc9" dependencies = [ "jobserver", "libc", @@ -396,6 +396,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "cipher" version = "0.4.4" @@ -715,21 +721,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" -[[package]] -name = "futures" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - [[package]] name = "futures-channel" version = "0.3.31" @@ -746,17 +737,6 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" -[[package]] -name = "futures-executor" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - [[package]] name = "futures-intrusive" version = "0.5.0" @@ -774,17 +754,6 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" -[[package]] -name = "futures-macro" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "futures-sink" version = "0.3.31" @@ -803,10 +772,8 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ - "futures-channel", "futures-core", "futures-io", - "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1209,12 +1176,6 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" -[[package]] -name = "nohash-hasher" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451" - [[package]] name = "nom" version = "7.1.3" @@ -1398,26 +1359,6 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" -[[package]] -name = "pin-project" -version = "1.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "pin-project-lite" version = "0.2.15" @@ -1519,10 +1460,11 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b" +checksum = "e346e016eacfff12233c243718197ca12f148c84e1e84268a896699b41c71780" dependencies = [ + "cfg_aliases", "libc", "once_cell", "socket2", @@ -1631,7 +1573,6 @@ dependencies = [ "async-compression", "clap", "env_logger", - "futures", "log", "openssl", "pem-rfc7468", @@ -1648,10 +1589,8 @@ dependencies = [ "time", "tokio", "tokio-rustls", - "tokio-util", "toml", "x509-parser", - "yamux", ] [[package]] @@ -2063,12 +2002,6 @@ dependencies = [ "whoami", ] -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - [[package]] name = "stringprep" version = "0.1.5" @@ -2094,9 +2027,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.86" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89275301d38033efb81a6e60e3497e734dfcc62571f2854bf4b16690398824c" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -2138,18 +2071,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d171f59dbaa811dbbb1aee1e73db92ec2b122911a48e1390dfe327a821ddede" +checksum = "3b3c6efbfc763e64eb85c11c25320f0737cb7364c4b6336db90aa9ebe27a0bbd" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b08be0f17bd307950653ce45db00cd31200d82b624b36e181337d9c7d92765b5" +checksum = "b607164372e89797d78b8e23a6d67d5d1038c1c65efd52e1389ef8b77caba2a6" dependencies = [ "proc-macro2", "quote", @@ -2240,20 +2173,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-util" -version = "0.7.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" -dependencies = [ - "bytes", - "futures-core", - "futures-io", - "futures-sink", - "pin-project-lite", - "tokio", -] - [[package]] name = "toml" version = "0.8.19" @@ -2514,21 +2433,11 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "web-time" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - [[package]] name = "webpki-roots" -version = "0.26.5" +version = "0.26.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bd24728e5af82c6c4ec1b66ac4844bdf8156257fccda846ec58b42cd0cdbe6a" +checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" dependencies = [ "rustls-pki-types", ] @@ -2760,22 +2669,6 @@ dependencies = [ "time", ] -[[package]] -name = "yamux" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31b5e376a8b012bee9c423acdbb835fc34d45001cfa3106236a624e4b738028" -dependencies = [ - "futures", - "log", - "nohash-hasher", - "parking_lot", - "pin-project", - "rand", - "static_assertions", - "web-time", -] - [[package]] name = "yasna" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index 46b31eb..e3c43c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,15 +26,12 @@ serde = { version = "1.0.214", features = ["derive"], optional = true } sqlx = { version = "0.8.2", optional = true, default-features = false, features = ["tls-none", "postgres"] } tokio = { version = "1.41.0", features = ["rt", "rt-multi-thread", "macros", "net", "io-util", "sync"], optional = true } rustls = { version = "0.23.16", optional = true } -async-compression = { version = "0.4.17", optional = true, features = ["futures-io"] } -yamux = { version = "0.13.3", optional = true } -futures = { version = "0.3.31", optional = true } -tokio-util = { version = "0.7.12", features = ["compat"], optional = true } +async-compression = { version = "0.4.17", optional = true, features = ["tokio"] } anyhow = "1.0.92" [features] -default = ["keygen", "networking", "service", "remote-crl", "setup-postgres", "yamux", "zstd", "brotli"] -asyncio = ["dep:tokio", "dep:futures", "dep:tokio-util"] +default = ["keygen", "networking", "service", "remote-crl", "setup-postgres", "zstd", "brotli"] +asyncio = ["dep:tokio"] keygen = ["dep:rcgen", "dep:pem-rfc7468", "dep:rpassword", "dep:argon2", "dep:sha2", "dep:aes-gcm", "dep:time"] networking = ["asyncio", "dep:tokio-rustls", "dep:rustls", "dep:async-compression"] test-crosscheck-openssl = ["dep:openssl"] @@ -45,7 +42,6 @@ setup-postgres = ["dep:sqlx"] stat-service = ["networking", "serde"] rustls = ["dep:rustls"] async-compression = ["dep:async-compression"] -yamux = ["dep:yamux", "networking"] zstd = ["async-compression/zstd"] brotli = ["async-compression/brotli"] diff --git a/src/bin/replikey.rs b/src/bin/replikey.rs index 694d2de..808670f 100644 --- a/src/bin/replikey.rs +++ b/src/bin/replikey.rs @@ -96,7 +96,6 @@ fn main() { print_feature!("keygen"); print_feature!("asyncio"); print_feature!("networking"); - print_feature!("yamux"); print_feature!("service"); print_feature!("remote-crl"); print_feature!("setup-postgres"); diff --git a/src/lib.rs b/src/lib.rs index df31506..c255836 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,5 @@ #[cfg(feature = "keygen")] pub mod cert; pub mod ops; -pub mod transport; pub mod fs_crypt; diff --git a/src/ops/network.rs b/src/ops/network.rs index fafb655..2848454 100644 --- a/src/ops/network.rs +++ b/src/ops/network.rs @@ -1,16 +1,9 @@ use std::{io::Cursor, net::ToSocketAddrs, sync::Arc}; use clap::Parser; -use futures::{ - io::{ - AsyncBufRead as FuturesAsyncBufRead, AsyncRead as FuturesAsyncRead, AsyncReadExt, - AsyncWrite as FuturesAsyncWrite, BufReader, - }, - AsyncWriteExt as _, -}; use tokio::{ - io::AsyncWriteExt as _, - net::{TcpSocket, TcpStream}, + io::{AsyncBufRead, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}, + net::TcpSocket, }; use tokio_rustls::{ rustls::{ @@ -21,9 +14,6 @@ use tokio_rustls::{ }, TlsAcceptor, TlsConnector, }; -use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; - -use crate::transport::{yamux::YamuxStreamManager, MultiplexedC2S, MultiplexedS2C, StreamManager}; #[derive(Debug, Parser)] pub struct NetworkCommand { @@ -92,35 +82,10 @@ pub struct ReverseProxyCommand { #[clap(long, help = "CRLs to use")] pub crl: Vec, - #[clap(long, help = "Transport to use", default_value = "plain")] - pub transport: Transport, - #[clap(long, help = "Compression to use", default_value = "none")] pub compression: Compression, } -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[derive(Debug, Clone, Copy, Default)] -pub enum Transport { - #[default] - Plain, - #[cfg(feature = "yamux")] - YAMux, -} - -impl std::str::FromStr for Transport { - type Err = String; - - fn from_str(s: &str) -> Result { - match s.to_ascii_lowercase().as_str() { - "plain" => Ok(Transport::Plain), - #[cfg(feature = "yamux")] - "yamux" => Ok(Transport::YAMux), - _ => Err(format!("Unknown transport: {}", s)), - } - } -} - #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[derive(Debug, Clone, Copy, Default)] pub enum Compression { @@ -170,56 +135,53 @@ pub struct ForwardProxyCommand { #[clap(long)] pub crl: Vec, - #[clap(long, help = "Transport to use", default_value = "plain")] - pub transport: Transport, - #[clap(long, help = "Compression to use", default_value = "none")] pub compression: Compression, } fn compressor_to<'s>( comp: Compression, - w: impl FuturesAsyncWrite + Unpin + Send + 's, -) -> Box { + w: impl AsyncWrite + Unpin + Send + 's, +) -> Box { match comp { Compression::None => Box::new(w), #[cfg(feature = "brotli")] - Compression::Brotli => Box::new(async_compression::futures::write::BrotliEncoder::new(w)), + Compression::Brotli => Box::new(async_compression::tokio::write::BrotliEncoder::new(w)), #[cfg(feature = "zstd")] - Compression::Zstd => Box::new(async_compression::futures::write::ZstdEncoder::new(w)), + Compression::Zstd => Box::new(async_compression::tokio::write::ZstdEncoder::new(w)), } } fn decompressor_from<'s>( comp: Compression, - r: impl FuturesAsyncBufRead + Unpin + Send + 's, -) -> Box { + r: impl AsyncBufRead + Unpin + Send + 's, +) -> Box { match comp { Compression::None => Box::new(BufReader::new(r)), #[cfg(feature = "brotli")] - Compression::Brotli => Box::new(async_compression::futures::bufread::BrotliDecoder::new(r)), + Compression::Brotli => Box::new(async_compression::tokio::bufread::BrotliDecoder::new(r)), #[cfg(feature = "zstd")] - Compression::Zstd => Box::new(async_compression::futures::bufread::ZstdDecoder::new(r)), + Compression::Zstd => Box::new(async_compression::tokio::bufread::ZstdDecoder::new(r)), } } async fn send_static_string( c: Compression, - w: &mut (impl FuturesAsyncWrite + Send + Unpin), + w: &mut (impl AsyncWrite + Send + Unpin), s: &str, -) -> futures::io::Result<()> { - futures::io::copy(&mut Cursor::new(s).compat(), &mut compressor_to(c, w)).await?; +) -> tokio::io::Result<()> { + tokio::io::copy(&mut Cursor::new(s), &mut compressor_to(c, w)).await?; Ok(()) } async fn copy_bidirectional_compressed( comp: Compression, - local: impl FuturesAsyncRead + FuturesAsyncWrite + Send + Unpin, - remote: impl FuturesAsyncRead + FuturesAsyncWrite + Send + Unpin, -) -> futures::io::Result<(u64, u64)> { - let (mut local_rx, mut local_tx) = local.split(); - let (remote_rx, remote_tx) = remote.split(); + local: impl AsyncRead + AsyncWrite + Send + Unpin, + remote: impl AsyncRead + AsyncWrite + Send + Unpin, +) -> tokio::io::Result<(u64, u64)> { + let (mut local_rx, mut local_tx) = tokio::io::split(local); + let (remote_rx, remote_tx) = tokio::io::split(remote); let remote_rx_buf = BufReader::new(remote_rx); @@ -227,23 +189,22 @@ async fn copy_bidirectional_compressed( let mut remote_rx_decomp = decompressor_from(comp, remote_rx_buf); let uplink = async move { - log::info!("Starting transfer uplink"); - let res = futures::io::copy(&mut local_rx, &mut remote_tx_comp).await; + let res = tokio::io::copy(&mut local_rx, &mut remote_tx_comp).await; log::info!("Finished uplink"); - let shutdown = remote_tx_comp.close().await; + let shutdown = remote_tx_comp.shutdown().await; let res = res?; shutdown?; tokio::io::Result::Ok(res) }; let downlink = async move { - log::info!("Starting transfer downlink"); - let res = futures::io::copy(&mut remote_rx_decomp, &mut local_tx).await; + let res = tokio::io::copy(&mut remote_rx_decomp, &mut local_tx).await; log::info!("Finished downlink"); - let shutdown = local_tx.close().await; + let shutdown = local_tx.shutdown().await; let res = res?; shutdown?; tokio::io::Result::Ok(res) }; + log::debug!("Starting transfer"); let res = tokio::try_join!(uplink, downlink)?; log::info!( "Finished transferring {} bytes from local to remote and {} bytes from remote to local (compressed)", @@ -349,15 +310,6 @@ pub async fn reverse_proxy(opts: ReverseProxyCommand) -> Result<(), Box { log::info!("Accepted TLS connection from: {}", addr); - let multiplexer = match opts.transport { - Transport::Plain => None, - #[cfg(feature = "yamux")] - Transport::YAMux => { - log::info!("Using YAMux transport"); - Some(YamuxStreamManager::new()) - } - }; - match tls.get_ref().1.server_name().map(|s| s.to_string()) { Some(sni) => match sni_to_target.get(&sni).cloned() { Some(target) => { @@ -366,121 +318,46 @@ pub async fn reverse_proxy(opts: ReverseProxyCommand) -> Result<(), Box { - let mut inner_listener = - m.wrap_s2c_connection(tls.compat()).await; - tokio::spawn(async move { - loop { - log::info!("Waiting for new stream (YAMux)"); - match inner_listener.accept().await { - Ok(Some(mut tls_multiplexed)) => { - let target = target.clone(); - log::info!("Accepted new stream (YAMux)"); - tokio::spawn(async move { - let target = match TcpStream::connect( - target, - ) - .await - { - Ok(s) => s, - Err(e) => { - log::error!( - "Failed to connect to target: {}", - e - ); - match tls_multiplexed - .close() - .await - { - Ok(_) => {} - Err(e) => { - log::error!( - "Failed to close multiplexed stream: {}", - e - ); - } - } - return; - } - }; - if let Err(e) = - copy_bidirectional_compressed( - opts.compression, - target.compat(), - tls_multiplexed, - ) - .await - { - log::error!( - "Failed to copy data: {}", - e - ); - } - }); - } - Ok(None) => { - log::info!("No more streams to accept"); - break; - } - Err(e) => { - log::error!( - "Failed to accept stream: {}", - e - ); - break; - } - } - } - futures::future::poll_fn(|cx| { - inner_listener.poll_close(cx) - }) - .await - .ok(); - }); + + match tokio::net::TcpStream::connect(target).await { + Ok(target) => { + if let Err(e) = copy_bidirectional_compressed( + opts.compression, + target, + tls, + ) + .await + { + log::error!("Failed to copy data: {}", e); + } } - None => match tokio::net::TcpStream::connect(target).await { - Ok(target) => { - if let Err(e) = copy_bidirectional_compressed( - opts.compression, - target.compat(), - tls.compat(), - ) + Err(e) => { + eprintln!("Failed to connect to target: {}", e); + tls.shutdown() .await - { - log::error!("Failed to copy data: {}", e); - } - } - Err(e) => { - eprintln!("Failed to connect to target: {}", e); - tls.shutdown() - .await - .expect("Failed to shutdown TLS stream"); - } - }, + .expect("Failed to shutdown TLS stream"); + } } } None => { log::warn!("Accepted connection for {:?}, but SNI {} does not match any configured SNI", tls.get_ref().1.server_name(), sni); - let mut compat = tls.compat_write(); send_static_string( opts.compression, - &mut compat, + &mut tls, format!("SNI {} does not match any configured SNI", sni) .as_str(), ) .await .expect("Failed to send static string"); - compat.close().await.expect("Failed to shutdown TLS stream"); + tls.shutdown().await.expect("Failed to shutdown TLS stream"); } }, _ => { log::error!("No SNI provided"); - let mut compat = tls.compat_write(); - send_static_string(opts.compression, &mut compat, "No SNI provided") + send_static_string(opts.compression, &mut tls, "No SNI provided") .await .expect("Failed to send static string"); - compat.close().await.expect("Failed to shutdown TLS stream"); + tls.shutdown().await.expect("Failed to shutdown TLS stream"); } } } @@ -547,172 +424,95 @@ pub async fn forward_proxy(opts: ForwardProxyCommand) -> anyhow::Result<()> { }) .expect("Failed to parse SNI"); - match opts.transport { - Transport::Plain => loop { - let (local_pt, _) = match listener.accept().await { - Ok(s) => s, + loop { + let (local_pt, _) = match listener.accept().await { + Ok(s) => s, + Err(e) => { + eprintln!("Failed to accept connection: {}", e); + continue; + } + }; + let target = opts.target.clone(); + let connector = connector.clone(); + let sni = sni.clone(); + tokio::spawn(async move { + let addrs = match target.to_socket_addrs() { + Ok(a) => a, Err(e) => { - eprintln!("Failed to accept connection: {}", e); - continue; + eprintln!("Failed to resolve target address: {}", e); + return; } }; - let target = opts.target.clone(); - let connector = connector.clone(); - let sni = sni.clone(); - tokio::spawn(async move { - let addrs = match target.to_socket_addrs() { - Ok(a) => a, + let mut last_err = None; + for addr in addrs { + log::info!("Trying to connect to: {}", addr); + let socket = match match addr { + std::net::SocketAddr::V4(_) => TcpSocket::new_v4(), + std::net::SocketAddr::V6(_) => TcpSocket::new_v6(), + } { + Ok(s) => s, Err(e) => { - eprintln!("Failed to resolve target address: {}", e); + eprintln!("Failed to create socket: {}", e); return; } }; - let mut last_err = None; - for addr in addrs { - log::info!("Trying to connect to: {}", addr); - let socket = match match addr { - std::net::SocketAddr::V4(_) => TcpSocket::new_v4(), - std::net::SocketAddr::V6(_) => TcpSocket::new_v6(), - } { - Ok(s) => s, - Err(e) => { - eprintln!("Failed to create socket: {}", e); - return; - } - }; - match socket.set_keepalive(true) { - Ok(_) => {} - Err(e) => { - log::warn!("Failed to set keepalive: {}", e); - } - } - - match socket.connect(addr).await { - Ok(target_pt) => { - let sni = sni.to_owned(); - - match connector.connect(sni, target_pt).await { - Ok(tls) => { - if let Err(e) = copy_bidirectional_compressed( - opts.compression, - local_pt.compat(), - tls.compat(), - ) - .await - { - eprintln!("Failed to copy data: {}", e); - } - } - Err(e) => { - eprintln!("Failed to connect to target: {}", e); - } - } - last_err = None; - break; - } - Err(e) => { - last_err = Some(e); - } - } - } - if let Some(e) = last_err { - log::error!("None of the target addresses worked: {}", e); - } - }); - }, - #[cfg(feature = "yamux")] - Transport::YAMux => { - let mut retries = 5; - loop { - if retries == 0 { - eprintln!("Too many retries, giving up"); - break; - } - let conn = TcpStream::connect(&opts.target).await?; - log::info!("Connected to target"); - let tls_stream = match connector.connect(sni.clone(), conn).await { - Ok(s) => s, + match socket.set_keepalive(true) { + Ok(_) => {} Err(e) => { - eprintln!("Failed to connect to target: {}", e); - retries -= 1; - continue; + log::warn!("Failed to set keepalive: {}", e); } - }; - let mut inner_listener = YamuxStreamManager::new() - .wrap_c2s_connection(tls_stream.compat()) - .await; + } - loop { - log::info!("Waiting for new TCP connection (YAMux)"); - let new_conn = tokio::select! { - _ = futures::future::poll_fn(|cx| inner_listener.poll_next_inbound(cx)) => { - None - } - conn = listener.accept() => { - match conn { - Ok((pt_stream, _)) => { - log::info!("Starting new C2S connection (YAMux)"); - match inner_listener.start_c2s().await { - Ok(s) => Some((pt_stream, s)), - Err(e) => { - log::error!("Failed to start C2S connection: {}", e); - continue; - } - } - } - Err(e) => { - log::error!("Failed to accept connection: {}", e); - continue; + match socket.connect(addr).await { + Ok(target_pt) => { + let sni = sni.to_owned(); + + match connector.connect(sni, target_pt).await { + Ok(tls) => { + if let Err(e) = + copy_bidirectional_compressed(opts.compression, local_pt, tls) + .await + { + eprintln!("Failed to copy data: {}", e); } } - } - }; - - if let Some((pt_stream, tls_multiplexed)) = new_conn { - log::info!("Obtained new connection (YAMux)"); - tokio::spawn(async move { - if let Err(e) = copy_bidirectional_compressed( - opts.compression, - pt_stream.compat(), - tls_multiplexed, - ) - .await - { - eprintln!("Failed to copy data: {}", e); + Err(e) => { + eprintln!("Failed to connect to target: {}", e); } - }); - } else { - log::warn!("No more streams to accept, next connection will be retried"); + } + last_err = None; break; } + Err(e) => { + last_err = Some(e); + } } - - futures::future::poll_fn(|cx| inner_listener.poll_close(cx)) - .await - .ok(); } - - anyhow::bail!("Too many retries"); - } - }; + if let Some(e) = last_err { + log::error!("None of the target addresses worked: {}", e); + } + }); + } } #[cfg(test)] mod tests { + use tokio::io::AsyncReadExt; + use super::*; async fn test_compression_method(algo: Compression) { let (r, w) = tokio::io::duplex(1024); let data = "Hello, world!".repeat(1024); - let mut w = compressor_to(algo, w.compat()); + let mut w = compressor_to(algo, w); - let mut r = decompressor_from(algo, BufReader::new(r.compat())); + let mut r = decompressor_from(algo, BufReader::new(r)); tokio::join!( async { w.write_all(data.as_bytes()).await.unwrap(); - w.close().await.unwrap(); + w.shutdown().await.unwrap(); }, async { let mut output = Vec::new(); diff --git a/src/ops/service.rs b/src/ops/service.rs index 62c1944..b46d5f3 100644 --- a/src/ops/service.rs +++ b/src/ops/service.rs @@ -5,7 +5,7 @@ use serde::Deserialize; use crate::ops::network::{ForwardProxyCommand, ReverseProxySpec}; -use super::network::{Compression, ReverseProxyCommand, Transport}; +use super::network::{Compression, ReverseProxyCommand}; const DEF_CONFIG_FILE: &str = "/etc/replikey.toml"; const CA_CERT: &str = "ca.pem"; @@ -49,7 +49,6 @@ pub struct ConnectionConfig { #[derive(Debug, Deserialize)] pub struct MasterConfig { listen: String, - transport: Option, compression: Option, redis: MasterServiceSpec, postgres: MasterServiceSpec, @@ -60,7 +59,6 @@ pub struct MasterConfig { #[derive(Debug, Deserialize)] pub struct SlaveConfig { target: String, - transport: Option, compression: Option, redis: SlaveServiceSpec, postgres: SlaveServiceSpec, @@ -99,7 +97,6 @@ pub fn service_replicate_master(config: String) { let master_conf = config.connection.master.as_ref().unwrap(); let cmd = ReverseProxyCommand { listen: master_conf.listen.clone(), - transport: master_conf.transport.unwrap_or_default(), compression: master_conf.compression.unwrap_or_default(), target: vec![ ReverseProxySpec { @@ -149,7 +146,6 @@ pub fn service_replicate_slave(config: String) { let slave_conf = config.connection.slave.as_ref().unwrap(); let cmd_redis = ForwardProxyCommand { listen: slave_conf.redis.listen.clone(), - transport: slave_conf.transport.unwrap_or_default(), compression: slave_conf.compression.unwrap_or_default(), sni: Some(slave_conf.redis.sni.clone()), target: slave_conf.target.clone(), @@ -160,7 +156,6 @@ pub fn service_replicate_slave(config: String) { }; let cmd_postgres = ForwardProxyCommand { listen: slave_conf.postgres.listen.clone(), - transport: slave_conf.transport.unwrap_or_default(), compression: slave_conf.compression.unwrap_or_default(), sni: Some(slave_conf.postgres.sni.clone()), target: slave_conf.target.clone(), diff --git a/src/transport/mod.rs b/src/transport/mod.rs deleted file mode 100644 index b60d7e5..0000000 --- a/src/transport/mod.rs +++ /dev/null @@ -1,203 +0,0 @@ -use std::fmt::Debug; - -use futures::{ - future::BoxFuture, - io::{AsyncRead, AsyncWrite}, - FutureExt, -}; - -#[cfg(feature = "yamux")] -pub mod yamux; - -pub trait DynStream: AsyncRead + AsyncWrite + Unpin + Send {} - -impl DynStream for T {} - -pub trait StreamManager<'c, C: 'c>: Send + Sync { - type S2CConn: MultiplexedS2C; - type C2SConn: MultiplexedC2S; - - fn wrap_c2s_connection(&'c self, stream: C) -> BoxFuture<'c, Self::C2SConn>; - fn wrap_s2c_connection(&'c self, stream: C) -> BoxFuture<'c, Self::S2CConn>; - - fn shutdown(&self) -> BoxFuture<'c, ()>; -} - -// hangs rustc for some reason.. -/* - -pub fn manager_to_dyn<'c, C: 'c>(inner: impl StreamManager<'c, C> + 'c) -> DynStreamManager<'c, C> { - DynStreamManager { - inner: Box::new(inner), - } -} - -pub struct DynStreamManager<'c, C: 'c> { - inner: Box< - dyn StreamManager< - 'c, - C, - S2CConn = Box< - dyn MultiplexedS2C> - + 'c, - >, - C2SConn = Box< - dyn MultiplexedC2S> - + 'c, - >, - > + 'c, - >, -} - -impl<'c, C> StreamManager<'c, C> for DynStreamManager<'c, C> -where - C: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'c, - >::Error: Error + 'static, - >::Error: Error + 'static, -{ - type S2CConn = DynMultiplexedS2CConnection<'c, C>; - type C2SConn = DynMultiplexedC2SConnection<'c, C>; - - fn wrap_c2s_connection(&'c self, stream: C) -> BoxFuture<'c, Self::C2SConn> { - self.inner - .wrap_c2s_connection(stream) - .map(|x| DynMultiplexedC2SConnection { - inner: Box::new(x.map(|x| { - x.map(|s| s.map(|s| Box::new(s) as _)) - .map_err(anyhow::Error::new) - }) as _), - }) - .boxed() - } - - fn wrap_s2c_connection(&'c self, stream: C) -> BoxFuture<'c, Self::S2CConn> { - self.inner - .wrap_s2c_connection(stream) - .map(|x| DynMultiplexedS2CConnection { - inner: Box::new(x.map(|x| { - x.map(|s| s.map(|s| Box::new(s) as _)) - .map_err(anyhow::Error::new) as _ - })), - }) - .boxed() - } - - fn shutdown(&self) -> BoxFuture<'c, ()> { - async move { self.inner.shutdown().await }.boxed() - } -} - -*/ -pub trait MultiplexedS2C: Send + Sync { - type Error: Debug + Send + Sync + Sized; - type Stream: DynStream; - - fn accept(&mut self) -> BoxFuture<'_, Result, Self::Error>>; - - fn map(self, map: F) -> S2CMap - where - F: Fn(Result, Self::Error>) -> Result, E> + Send + Sync, - S: DynStream, - Self: Sized, - { - S2CMap { - inner: self, - map, - _phantom: std::marker::PhantomData, - } - } -} - -pub struct S2CMap { - inner: MC, - map: F, - _phantom: std::marker::PhantomData, -} - -impl MultiplexedS2C - for S2CMap -where - MC: MultiplexedS2C, - F: Fn(Result, MC::Error>) -> Result, E> + Send + Sync, -{ - type Error = E; - type Stream = S; - - fn accept(&mut self) -> BoxFuture<'_, Result, Self::Error>> { - let inner = &mut self.inner; - let mapf = &self.map; - inner.accept().map(mapf).boxed() - } -} - -pub trait MultiplexedC2S: Send + Sync { - type Error: Debug + Send + Sync + Sized; - type Stream: DynStream; - - fn start_c2s(&mut self) -> BoxFuture<'_, Result>; - - fn map(self, map: F) -> C2SMap - where - F: Fn(Result) -> Result + Send + Sync, - S: DynStream, - Self: Sized, - { - C2SMap { - inner: self, - map, - _phantom: std::marker::PhantomData, - } - } -} - -pub struct C2SMap { - inner: MC, - map: F, - _phantom: std::marker::PhantomData, -} - -impl MultiplexedC2S - for C2SMap -where - MC: MultiplexedC2S, - F: Fn(Result) -> Result + Send + Sync, -{ - type Error = E; - type Stream = S; - - fn start_c2s(&mut self) -> BoxFuture<'_, Result> { - let inner = &mut self.inner; - let map = &self.map; - inner.start_c2s().map(map).boxed() - } -} - -pub struct DynMultiplexedS2CConnection<'c, C: 'c> { - inner: Box> + 'c>, -} - -impl<'c, C> MultiplexedS2C for DynMultiplexedS2CConnection<'c, C> { - type Error = anyhow::Error; - type Stream = Box; - fn accept(&mut self) -> BoxFuture<'_, Result, Self::Error>> { - async move { - self.inner - .accept() - .await - .map(|x| x.map(|x| Box::new(x) as _)) - } - .boxed() - } -} - -pub struct DynMultiplexedC2SConnection<'c, C: 'c> { - inner: Box> + 'c>, -} - -impl<'c, C> MultiplexedC2S for DynMultiplexedC2SConnection<'c, C> { - type Error = anyhow::Error; - type Stream = Box; - fn start_c2s(&mut self) -> BoxFuture<'_, Result> { - async move { self.inner.start_c2s().await.map(|x| Box::new(x) as _) }.boxed() - } -} diff --git a/src/transport/yamux.rs b/src/transport/yamux.rs deleted file mode 100644 index aba71a0..0000000 --- a/src/transport/yamux.rs +++ /dev/null @@ -1,80 +0,0 @@ -use futures::{ - future::BoxFuture, - io::{AsyncRead, AsyncWrite}, - FutureExt, -}; -use yamux::Config; - -pub(crate) fn yamux_config() -> yamux::Config { - let mut conf = Config::default(); - - // we are in mTLS so be generous but still prevent accidental DoS - conf.set_read_after_close(true) - .set_max_num_streams(512) - .set_max_connection_receive_window(Some((256 << 10) * 512)); - - conf -} - -pub struct YamuxStreamManager { - config: yamux::Config, -} - -impl Default for YamuxStreamManager { - fn default() -> Self { - Self::new() - } -} - -impl YamuxStreamManager { - pub fn new() -> Self { - Self { - config: yamux_config(), - } - } -} - -impl<'c, C: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'c> super::StreamManager<'c, C> - for YamuxStreamManager -{ - type S2CConn = yamux::Connection; - type C2SConn = yamux::Connection; - - fn wrap_c2s_connection(&self, stream: C) -> BoxFuture<'c, Self::C2SConn> { - let config = self.config.clone(); - async move { yamux::Connection::new(stream, config, yamux::Mode::Client) }.boxed() - } - - fn wrap_s2c_connection(&self, stream: C) -> BoxFuture<'c, Self::S2CConn> { - let config = self.config.clone(); - async move { yamux::Connection::new(stream, config, yamux::Mode::Server) }.boxed() - } - - fn shutdown(&self) -> BoxFuture<'c, ()> { - async {}.boxed() - } -} - -impl super::MultiplexedS2C - for yamux::Connection -{ - type Error = yamux::ConnectionError; - type Stream = yamux::Stream; - - fn accept(&mut self) -> BoxFuture<'_, Result, Self::Error>> { - futures::future::poll_fn(|cx| self.poll_next_inbound(cx)) - .map(|x| x.transpose()) - .boxed() - } -} - -impl super::MultiplexedC2S - for yamux::Connection -{ - type Error = yamux::ConnectionError; - type Stream = yamux::Stream; - - fn start_c2s(&mut self) -> BoxFuture<'_, Result> { - futures::future::poll_fn(|cx| self.poll_new_outbound(cx)).boxed() - } -} diff --git a/tests/mtls_integration.rs b/tests/mtls_integration.rs index ed7d76a..3c8854d 100644 --- a/tests/mtls_integration.rs +++ b/tests/mtls_integration.rs @@ -6,7 +6,6 @@ use std::{ }; use clap::Parser; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use replikey::{ cert::UsageType, ops::{ @@ -17,10 +16,9 @@ use replikey::{ use rustls::crypto::{aws_lc_rs, CryptoProvider}; use time::OffsetDateTime; use tokio::{ - io::{AsyncReadExt as _, AsyncWriteExt as _}, + io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _}, task::JoinSet, }; -use tokio_util::compat::TokioAsyncReadCompatExt; static NEXT_PORT: AtomicU32 = AtomicU32::new(12311); @@ -49,9 +47,9 @@ async fn test_stream_sequence< left: L, right: R, ) { - let (mut l_rx, mut l_tx) = left.split(); + let (mut l_rx, mut l_tx) = tokio::io::split(left); - let (mut r_rx, mut r_tx) = right.split(); + let (mut r_rx, mut r_tx) = tokio::io::split(right); let l_write = async move { let mut buf = [0u8; WRITE_SIZE]; @@ -61,9 +59,9 @@ async fn test_stream_sequence< .for_each(|(j, x)| *x = ((i + j) ^ 123) as u8); l_tx.write_all(&buf).await?; } - l_tx.close().await?; + l_tx.shutdown().await?; - Ok::<_, futures::io::Error>(()) + Ok::<_, tokio::io::Error>(()) }; let r_write = async move { let mut buf = [0u8; WRITE_SIZE]; @@ -73,7 +71,7 @@ async fn test_stream_sequence< .for_each(|(j, x)| *x = ((i + j) ^ 321) as u8); r_tx.write_all(&buf).await?; } - r_tx.close().await?; + r_tx.shutdown().await?; Ok(()) }; let l_read = async move { @@ -105,11 +103,11 @@ async fn test_stream_sequence< #[tokio::test] async fn in_memory_stream_works() { let (left, right) = tokio::io::duplex(1024); - test_stream_sequence::<1024, _, _>(left.compat(), right.compat()).await; + test_stream_sequence::<1024, _, _>(left, right).await; let (left, right) = tokio::io::duplex(1024); - test_stream_sequence::<1, _, _>(left.compat(), right.compat()).await; + test_stream_sequence::<1, _, _>(left, right).await; let (left, right) = tokio::io::duplex(1024); - test_stream_sequence::<3543, _, _>(left.compat(), right.compat()).await; + test_stream_sequence::<3543, _, _>(left, right).await; } fn test_ca(dir: &PathBuf) -> (PathBuf, PathBuf) { @@ -348,7 +346,7 @@ async fn start_reverse_proxy( cert: &PathBuf, key: &PathBuf, ca_cert: &PathBuf, - transport: &str, + compression: &str, ) -> u32 { let port = next_port(); @@ -371,8 +369,8 @@ async fn start_reverse_proxy( key.to_string_lossy().to_string(), "--ca".to_string(), ca_cert.to_string_lossy().to_string(), - "--transport".to_string(), - transport.to_string(), + "--compression".to_string(), + compression.to_string(), ]; let parsed = NetworkCommand::parse_from(&cmd); @@ -396,7 +394,7 @@ async fn start_forward_proxy( cert: &PathBuf, key: &PathBuf, ca_cert: &PathBuf, - transport: &str, + compression: &str, ) -> u32 { let port = next_port(); @@ -415,8 +413,8 @@ async fn start_forward_proxy( key.to_string_lossy().to_string(), "--ca".to_string(), ca_cert.to_string_lossy().to_string(), - "--transport".to_string(), - transport.to_string(), + "--compression".to_string(), + compression.to_string(), ]; let parsed = NetworkCommand::parse_from(&cmd); @@ -514,47 +512,6 @@ async fn expect_target1_signature(port: u32) { tx.shutdown().await.expect("failed to close"); } -async fn expect_target1_signature_buffered(port: u32) { - let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) - .await - .expect("failed to connect"); - - let (mut rx, mut tx) = tokio::io::split(stream); - - let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes(); - let expect = data.clone(); - - let mut write_size = (1..4096).cycle(); - - let (_, rx_data) = tokio::join!( - async { - while !data.is_empty() { - let n = write_size.next().unwrap().min(data.len()); - tx.write_all(&data[..n]).await.expect("failed to write"); - data = data.split_off(n); - } - tx.shutdown().await.expect("failed to close"); - }, - async { - let mut rxed = Vec::new(); - let mut buf = [0u8; 1024]; - loop { - let n = rx.read(&mut buf).await.expect("failed to read"); - if n == 0 { - break; - } - rxed.extend_from_slice(&buf[..n]); - if rxed.len() >= expect.len() { - break; - } - } - rxed - } - ); - - assert_eq!(rx_data, expect); -} - async fn expect_target2_signature(port: u32) { let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) .await @@ -600,49 +557,7 @@ async fn expect_target2_signature(port: u32) { tx.shutdown().await.expect("failed to close"); } -async fn expect_target2_signature_buffered(port: u32) { - let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) - .await - .expect("failed to connect"); - - let (mut rx, mut tx) = tokio::io::split(stream); - - let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes(); - - let expect = data.iter().map(|x| !x).collect::>(); - - let mut write_size = (1..4096).cycle(); - - let (_, rx_data) = tokio::join!( - async { - while !data.is_empty() { - let n = write_size.next().unwrap().min(data.len()); - tx.write_all(&data[..n]).await.expect("failed to write"); - data = data.split_off(n); - } - tx.shutdown().await.expect("failed to close"); - }, - async { - let mut rxed = Vec::new(); - let mut buf = [0u8; 1024]; - loop { - let n = rx.read(&mut buf).await.expect("failed to read"); - if n == 0 { - break; - } - rxed.extend_from_slice(&buf[..n]); - if rxed.len() >= expect.len() { - break; - } - } - rxed - } - ); - - assert_eq!(rx_data, expect); -} - -#[tokio::test] +#[tokio::test(flavor = "multi_thread", worker_threads = 16)] async fn test_mtls_integrated() { setup_once(); let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca")); @@ -660,172 +575,174 @@ async fn test_mtls_integrated() { expect_target1_signature(target1).await; expect_target2_signature(target2).await; - let proxy_port_self_signed = start_reverse_proxy( - target1, - target2, - &server_cert.self_signed, - &server_cert.key, - &ca_cert, - "plain", - ) - .await; - wait_for_port(proxy_port_self_signed).await; - log::info!("reverse_proxy_port_self_signed: {}", proxy_port_self_signed); - - let proxy_port = start_reverse_proxy( - target1, - target2, - server_cert.signed_cert.as_ref().unwrap(), - &server_cert.key, - &ca_cert, - "plain", - ) - .await; - wait_for_port(proxy_port).await; - log::info!("reverse_proxy_port: {}", proxy_port); - - for server_signed in [false, true] { - let forward_proxy_target1_self_signed = start_forward_proxy( - &format!( - "127.0.0.1:{}", - if server_signed { - proxy_port - } else { - proxy_port_self_signed - } - ), - "target1.local", - &client_cert.self_signed, - &client_cert.key, + for compression in ["none", "zstd", "brotli"] { + let proxy_port_self_signed = start_reverse_proxy( + target1, + target2, + &server_cert.self_signed, + &server_cert.key, &ca_cert, - "plain", + compression, ) .await; - log::info!( - "forward_proxy_target1_self_signed: {}", - forward_proxy_target1_self_signed - ); + wait_for_port(proxy_port_self_signed).await; + log::info!("reverse_proxy_port_self_signed: {}", proxy_port_self_signed); - let forward_proxy_target1 = start_forward_proxy( - &format!( - "127.0.0.1:{}", - if server_signed { - proxy_port - } else { - proxy_port_self_signed - } - ), - "target1.local", - client_cert.signed_cert.as_ref().unwrap(), - &client_cert.key, + let proxy_port = start_reverse_proxy( + target1, + target2, + server_cert.signed_cert.as_ref().unwrap(), + &server_cert.key, &ca_cert, - "plain", + compression, ) .await; - log::info!("forward_proxy_target1: {}", forward_proxy_target1); + wait_for_port(proxy_port).await; + log::info!("reverse_proxy_port: {}", proxy_port); - let forward_proxy_target2 = start_forward_proxy( - &format!( - "127.0.0.1:{}", - if server_signed { - proxy_port - } else { - proxy_port_self_signed + for server_signed in [false, true] { + let forward_proxy_target1_self_signed = start_forward_proxy( + &format!( + "127.0.0.1:{}", + if server_signed { + proxy_port + } else { + proxy_port_self_signed + } + ), + "target1.local", + &client_cert.self_signed, + &client_cert.key, + &ca_cert, + compression, + ) + .await; + log::info!( + "forward_proxy_target1_self_signed: {}", + forward_proxy_target1_self_signed + ); + + let forward_proxy_target1 = start_forward_proxy( + &format!( + "127.0.0.1:{}", + if server_signed { + proxy_port + } else { + proxy_port_self_signed + } + ), + "target1.local", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + compression, + ) + .await; + log::info!("forward_proxy_target1: {}", forward_proxy_target1); + + let forward_proxy_target2 = start_forward_proxy( + &format!( + "127.0.0.1:{}", + if server_signed { + proxy_port + } else { + proxy_port_self_signed + } + ), + "target2.local", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + compression, + ) + .await; + log::info!("forward_proxy_target2: {}", forward_proxy_target2); + + let forward_proxy_target1_wildcard = start_forward_proxy( + &format!( + "127.0.0.1:{}", + if server_signed { + proxy_port + } else { + proxy_port_self_signed + } + ), + "target1.targets.local", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + compression, + ) + .await; + + let forward_proxy_target2_wildcard: u32 = start_forward_proxy( + &format!( + "127.0.0.1:{}", + if server_signed { + proxy_port + } else { + proxy_port_self_signed + } + ), + "target2.targets.local", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + compression, + ) + .await; + log::info!( + "forward_proxy_target2_wildcard: {}", + forward_proxy_target2_wildcard + ); + + let forward_proxy_unknown_sni = start_forward_proxy( + &format!( + "127.0.0.1:{}", + if server_signed { + proxy_port + } else { + proxy_port_self_signed + } + ), + "some-other.place", + client_cert.signed_cert.as_ref().unwrap(), + &client_cert.key, + &ca_cert, + compression, + ) + .await; + log::info!("forward_proxy_unknown_sni: {}", forward_proxy_unknown_sni); + + wait_for_port(forward_proxy_target1).await; + wait_for_port(forward_proxy_target2).await; + wait_for_port(forward_proxy_target1_wildcard).await; + wait_for_port(forward_proxy_target2_wildcard).await; + wait_for_port(forward_proxy_unknown_sni).await; + log::info!("Test: self-signed cert"); + should_not_get_anything(forward_proxy_target1_self_signed).await; + + if server_signed { + let mut js = JoinSet::new(); + for _ in 0..100 { + js.spawn(expect_target1_signature(forward_proxy_target1)); + js.spawn(expect_target2_signature(forward_proxy_target2)); + js.spawn(expect_target1_signature(forward_proxy_target1_wildcard)); + js.spawn(expect_target2_signature(forward_proxy_target2_wildcard)); + js.spawn(should_not_get_anything(forward_proxy_unknown_sni)); } - ), - "target2.local", - client_cert.signed_cert.as_ref().unwrap(), - &client_cert.key, - &ca_cert, - "plain", - ) - .await; - log::info!("forward_proxy_target2: {}", forward_proxy_target2); - - let forward_proxy_target1_wildcard = start_forward_proxy( - &format!( - "127.0.0.1:{}", - if server_signed { - proxy_port - } else { - proxy_port_self_signed + js.join_all().await; + } else { + let mut js = JoinSet::new(); + for _ in 0..100 { + js.spawn(should_not_get_anything(forward_proxy_target1)); + js.spawn(should_not_get_anything(forward_proxy_target2)); + js.spawn(should_not_get_anything(forward_proxy_target1_wildcard)); + js.spawn(should_not_get_anything(forward_proxy_target2_wildcard)); + js.spawn(should_not_get_anything(forward_proxy_unknown_sni)); } - ), - "target1.targets.local", - client_cert.signed_cert.as_ref().unwrap(), - &client_cert.key, - &ca_cert, - "plain", - ) - .await; - - let forward_proxy_target2_wildcard: u32 = start_forward_proxy( - &format!( - "127.0.0.1:{}", - if server_signed { - proxy_port - } else { - proxy_port_self_signed - } - ), - "target2.targets.local", - client_cert.signed_cert.as_ref().unwrap(), - &client_cert.key, - &ca_cert, - "plain", - ) - .await; - log::info!( - "forward_proxy_target2_wildcard: {}", - forward_proxy_target2_wildcard - ); - - let forward_proxy_unknown_sni = start_forward_proxy( - &format!( - "127.0.0.1:{}", - if server_signed { - proxy_port - } else { - proxy_port_self_signed - } - ), - "some-other.place", - client_cert.signed_cert.as_ref().unwrap(), - &client_cert.key, - &ca_cert, - "plain", - ) - .await; - log::info!("forward_proxy_unknown_sni: {}", forward_proxy_unknown_sni); - - wait_for_port(forward_proxy_target1).await; - wait_for_port(forward_proxy_target2).await; - wait_for_port(forward_proxy_target1_wildcard).await; - wait_for_port(forward_proxy_target2_wildcard).await; - wait_for_port(forward_proxy_unknown_sni).await; - log::info!("Test: self-signed cert"); - should_not_get_anything(forward_proxy_target1_self_signed).await; - - if server_signed { - let mut js = JoinSet::new(); - for _ in 0..100 { - js.spawn(expect_target1_signature(forward_proxy_target1)); - js.spawn(expect_target2_signature(forward_proxy_target2)); - js.spawn(expect_target1_signature(forward_proxy_target1_wildcard)); - js.spawn(expect_target2_signature(forward_proxy_target2_wildcard)); - js.spawn(should_not_get_anything(forward_proxy_unknown_sni)); + js.join_all().await; } - js.join_all().await; - } else { - let mut js = JoinSet::new(); - for _ in 0..100 { - js.spawn(should_not_get_anything(forward_proxy_target1)); - js.spawn(should_not_get_anything(forward_proxy_target2)); - js.spawn(should_not_get_anything(forward_proxy_target1_wildcard)); - js.spawn(should_not_get_anything(forward_proxy_target2_wildcard)); - js.spawn(should_not_get_anything(forward_proxy_unknown_sni)); - } - js.join_all().await; } } } @@ -858,7 +775,7 @@ async fn test_mtls_role_reverse() { server_cert.signed_cert.as_ref().unwrap(), &server_cert.key, &ca_cert, - "plain", + "none", ) .await; @@ -870,7 +787,7 @@ async fn test_mtls_role_reverse() { client_cert.signed_cert.as_ref().unwrap(), &client_cert.key, &ca_cert, - "plain", + "none", ) .await; wait_for_port(forward_proxy_target1).await; @@ -883,7 +800,7 @@ async fn test_mtls_role_reverse() { client_cert.signed_cert.as_ref().unwrap(), &client_cert.key, &ca_cert, - "plain", + "none", ) .await; @@ -895,74 +812,10 @@ async fn test_mtls_role_reverse() { server_cert.signed_cert.as_ref().unwrap(), &server_cert.key, &ca_cert, - "plain", + "none", ) .await; wait_for_port(forward_proxy_target1_reversed).await; should_not_get_anything(forward_proxy_target1_reversed).await; } - -#[tokio::test] -async fn test_mtls_yamux_transport() { - setup_once(); - - let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca")); - - let server_cert = test_server_cert(&PathBuf::from("target/test-server")); - let server_cert = server_cert.signed_by(UsageType::Server, &ca_key); - let client_cert = test_client_cert(&PathBuf::from("target/test-client")); - let client_cert = client_cert.signed_by(UsageType::Client, &ca_key); - - let target1 = start_test_target_1().await; - - let target2 = start_test_target_2().await; - wait_for_port(target1).await; - wait_for_port(target2).await; - - let proxy_port = start_reverse_proxy( - target1, - target2, - server_cert.signed_cert.as_ref().unwrap(), - &server_cert.key, - &ca_cert, - "yamux", - ) - .await; - wait_for_port(proxy_port).await; - wait_for_port(proxy_port).await; - - let forward_proxy_target1 = start_forward_proxy( - &format!("127.0.0.1:{}", proxy_port), - "target1.local", - client_cert.signed_cert.as_ref().unwrap(), - &client_cert.key, - &ca_cert, - "yamux", - ) - .await; - - let forward_proxy_target2 = start_forward_proxy( - &format!("127.0.0.1:{}", proxy_port), - "target2.local", - client_cert.signed_cert.as_ref().unwrap(), - &client_cert.key, - &ca_cert, - "yamux", - ) - .await; - wait_for_port(forward_proxy_target1).await; - wait_for_port(forward_proxy_target2).await; - - log::info!("forward_proxy_target1: {}", forward_proxy_target1); - log::info!("forward_proxy_target2: {}", forward_proxy_target2); - - let mut js = JoinSet::new(); - - for _ in 0..128 { - js.spawn(expect_target1_signature_buffered(forward_proxy_target1)); - js.spawn(expect_target2_signature_buffered(forward_proxy_target2)); - } - - js.join_all().await; -}