revert multiplexing

Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-11-04 01:34:25 -06:00
parent ef53cb079d
commit 2c3914a201
No known key found for this signature in database
9 changed files with 302 additions and 1050 deletions

147
Cargo.lock generated
View file

@ -208,9 +208,9 @@ checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857"
dependencies = [ dependencies = [
"brotli", "brotli",
"futures-core", "futures-core",
"futures-io",
"memchr", "memchr",
"pin-project-lite", "pin-project-lite",
"tokio",
"zstd", "zstd",
"zstd-safe", "zstd-safe",
] ]
@ -372,9 +372,9 @@ checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.31" version = "1.1.34"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" checksum = "67b9470d453346108f93a59222a9a1a5724db32d0a4727b7ab7ace4b4d822dc9"
dependencies = [ dependencies = [
"jobserver", "jobserver",
"libc", "libc",
@ -396,6 +396,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]] [[package]]
name = "cipher" name = "cipher"
version = "0.4.4" version = "0.4.4"
@ -715,21 +721,6 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.31" version = "0.3.31"
@ -746,17 +737,6 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]] [[package]]
name = "futures-intrusive" name = "futures-intrusive"
version = "0.5.0" version = "0.5.0"
@ -774,17 +754,6 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.31" version = "0.3.31"
@ -803,10 +772,8 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [ dependencies = [
"futures-channel",
"futures-core", "futures-core",
"futures-io", "futures-io",
"futures-macro",
"futures-sink", "futures-sink",
"futures-task", "futures-task",
"memchr", "memchr",
@ -1209,12 +1176,6 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1"
[[package]]
name = "nohash-hasher"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451"
[[package]] [[package]]
name = "nom" name = "nom"
version = "7.1.3" version = "7.1.3"
@ -1398,26 +1359,6 @@ version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "pin-project"
version = "1.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "pin-project-lite" name = "pin-project-lite"
version = "0.2.15" version = "0.2.15"
@ -1519,10 +1460,11 @@ dependencies = [
[[package]] [[package]]
name = "quinn-udp" name = "quinn-udp"
version = "0.5.5" version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b" checksum = "e346e016eacfff12233c243718197ca12f148c84e1e84268a896699b41c71780"
dependencies = [ dependencies = [
"cfg_aliases",
"libc", "libc",
"once_cell", "once_cell",
"socket2", "socket2",
@ -1631,7 +1573,6 @@ dependencies = [
"async-compression", "async-compression",
"clap", "clap",
"env_logger", "env_logger",
"futures",
"log", "log",
"openssl", "openssl",
"pem-rfc7468", "pem-rfc7468",
@ -1648,10 +1589,8 @@ dependencies = [
"time", "time",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tokio-util",
"toml", "toml",
"x509-parser", "x509-parser",
"yamux",
] ]
[[package]] [[package]]
@ -2063,12 +2002,6 @@ dependencies = [
"whoami", "whoami",
] ]
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]] [[package]]
name = "stringprep" name = "stringprep"
version = "0.1.5" version = "0.1.5"
@ -2094,9 +2027,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.86" version = "2.0.87"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e89275301d38033efb81a6e60e3497e734dfcc62571f2854bf4b16690398824c" checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -2138,18 +2071,18 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.66" version = "1.0.67"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d171f59dbaa811dbbb1aee1e73db92ec2b122911a48e1390dfe327a821ddede" checksum = "3b3c6efbfc763e64eb85c11c25320f0737cb7364c4b6336db90aa9ebe27a0bbd"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.66" version = "1.0.67"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b08be0f17bd307950653ce45db00cd31200d82b624b36e181337d9c7d92765b5" checksum = "b607164372e89797d78b8e23a6d67d5d1038c1c65efd52e1389ef8b77caba2a6"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -2240,20 +2173,6 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-util"
version = "0.7.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a"
dependencies = [
"bytes",
"futures-core",
"futures-io",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "toml" name = "toml"
version = "0.8.19" version = "0.8.19"
@ -2514,21 +2433,11 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]] [[package]]
name = "webpki-roots" name = "webpki-roots"
version = "0.26.5" version = "0.26.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bd24728e5af82c6c4ec1b66ac4844bdf8156257fccda846ec58b42cd0cdbe6a" checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958"
dependencies = [ dependencies = [
"rustls-pki-types", "rustls-pki-types",
] ]
@ -2760,22 +2669,6 @@ dependencies = [
"time", "time",
] ]
[[package]]
name = "yamux"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a31b5e376a8b012bee9c423acdbb835fc34d45001cfa3106236a624e4b738028"
dependencies = [
"futures",
"log",
"nohash-hasher",
"parking_lot",
"pin-project",
"rand",
"static_assertions",
"web-time",
]
[[package]] [[package]]
name = "yasna" name = "yasna"
version = "0.5.2" version = "0.5.2"

View file

@ -26,15 +26,12 @@ serde = { version = "1.0.214", features = ["derive"], optional = true }
sqlx = { version = "0.8.2", optional = true, default-features = false, features = ["tls-none", "postgres"] } sqlx = { version = "0.8.2", optional = true, default-features = false, features = ["tls-none", "postgres"] }
tokio = { version = "1.41.0", features = ["rt", "rt-multi-thread", "macros", "net", "io-util", "sync"], optional = true } tokio = { version = "1.41.0", features = ["rt", "rt-multi-thread", "macros", "net", "io-util", "sync"], optional = true }
rustls = { version = "0.23.16", optional = true } rustls = { version = "0.23.16", optional = true }
async-compression = { version = "0.4.17", optional = true, features = ["futures-io"] } async-compression = { version = "0.4.17", optional = true, features = ["tokio"] }
yamux = { version = "0.13.3", optional = true }
futures = { version = "0.3.31", optional = true }
tokio-util = { version = "0.7.12", features = ["compat"], optional = true }
anyhow = "1.0.92" anyhow = "1.0.92"
[features] [features]
default = ["keygen", "networking", "service", "remote-crl", "setup-postgres", "yamux", "zstd", "brotli"] default = ["keygen", "networking", "service", "remote-crl", "setup-postgres", "zstd", "brotli"]
asyncio = ["dep:tokio", "dep:futures", "dep:tokio-util"] asyncio = ["dep:tokio"]
keygen = ["dep:rcgen", "dep:pem-rfc7468", "dep:rpassword", "dep:argon2", "dep:sha2", "dep:aes-gcm", "dep:time"] keygen = ["dep:rcgen", "dep:pem-rfc7468", "dep:rpassword", "dep:argon2", "dep:sha2", "dep:aes-gcm", "dep:time"]
networking = ["asyncio", "dep:tokio-rustls", "dep:rustls", "dep:async-compression"] networking = ["asyncio", "dep:tokio-rustls", "dep:rustls", "dep:async-compression"]
test-crosscheck-openssl = ["dep:openssl"] test-crosscheck-openssl = ["dep:openssl"]
@ -45,7 +42,6 @@ setup-postgres = ["dep:sqlx"]
stat-service = ["networking", "serde"] stat-service = ["networking", "serde"]
rustls = ["dep:rustls"] rustls = ["dep:rustls"]
async-compression = ["dep:async-compression"] async-compression = ["dep:async-compression"]
yamux = ["dep:yamux", "networking"]
zstd = ["async-compression/zstd"] zstd = ["async-compression/zstd"]
brotli = ["async-compression/brotli"] brotli = ["async-compression/brotli"]

View file

@ -96,7 +96,6 @@ fn main() {
print_feature!("keygen"); print_feature!("keygen");
print_feature!("asyncio"); print_feature!("asyncio");
print_feature!("networking"); print_feature!("networking");
print_feature!("yamux");
print_feature!("service"); print_feature!("service");
print_feature!("remote-crl"); print_feature!("remote-crl");
print_feature!("setup-postgres"); print_feature!("setup-postgres");

View file

@ -3,6 +3,5 @@
#[cfg(feature = "keygen")] #[cfg(feature = "keygen")]
pub mod cert; pub mod cert;
pub mod ops; pub mod ops;
pub mod transport;
pub mod fs_crypt; pub mod fs_crypt;

View file

@ -1,16 +1,9 @@
use std::{io::Cursor, net::ToSocketAddrs, sync::Arc}; use std::{io::Cursor, net::ToSocketAddrs, sync::Arc};
use clap::Parser; use clap::Parser;
use futures::{
io::{
AsyncBufRead as FuturesAsyncBufRead, AsyncRead as FuturesAsyncRead, AsyncReadExt,
AsyncWrite as FuturesAsyncWrite, BufReader,
},
AsyncWriteExt as _,
};
use tokio::{ use tokio::{
io::AsyncWriteExt as _, io::{AsyncBufRead, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader},
net::{TcpSocket, TcpStream}, net::TcpSocket,
}; };
use tokio_rustls::{ use tokio_rustls::{
rustls::{ rustls::{
@ -21,9 +14,6 @@ use tokio_rustls::{
}, },
TlsAcceptor, TlsConnector, TlsAcceptor, TlsConnector,
}; };
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
use crate::transport::{yamux::YamuxStreamManager, MultiplexedC2S, MultiplexedS2C, StreamManager};
#[derive(Debug, Parser)] #[derive(Debug, Parser)]
pub struct NetworkCommand { pub struct NetworkCommand {
@ -92,35 +82,10 @@ pub struct ReverseProxyCommand {
#[clap(long, help = "CRLs to use")] #[clap(long, help = "CRLs to use")]
pub crl: Vec<String>, pub crl: Vec<String>,
#[clap(long, help = "Transport to use", default_value = "plain")]
pub transport: Transport,
#[clap(long, help = "Compression to use", default_value = "none")] #[clap(long, help = "Compression to use", default_value = "none")]
pub compression: Compression, pub compression: Compression,
} }
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, Default)]
pub enum Transport {
#[default]
Plain,
#[cfg(feature = "yamux")]
YAMux,
}
impl std::str::FromStr for Transport {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"plain" => Ok(Transport::Plain),
#[cfg(feature = "yamux")]
"yamux" => Ok(Transport::YAMux),
_ => Err(format!("Unknown transport: {}", s)),
}
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, Default)] #[derive(Debug, Clone, Copy, Default)]
pub enum Compression { pub enum Compression {
@ -170,56 +135,53 @@ pub struct ForwardProxyCommand {
#[clap(long)] #[clap(long)]
pub crl: Vec<String>, pub crl: Vec<String>,
#[clap(long, help = "Transport to use", default_value = "plain")]
pub transport: Transport,
#[clap(long, help = "Compression to use", default_value = "none")] #[clap(long, help = "Compression to use", default_value = "none")]
pub compression: Compression, pub compression: Compression,
} }
fn compressor_to<'s>( fn compressor_to<'s>(
comp: Compression, comp: Compression,
w: impl FuturesAsyncWrite + Unpin + Send + 's, w: impl AsyncWrite + Unpin + Send + 's,
) -> Box<dyn FuturesAsyncWrite + Unpin + Send + 's> { ) -> Box<dyn AsyncWrite + Unpin + Send + 's> {
match comp { match comp {
Compression::None => Box::new(w), Compression::None => Box::new(w),
#[cfg(feature = "brotli")] #[cfg(feature = "brotli")]
Compression::Brotli => Box::new(async_compression::futures::write::BrotliEncoder::new(w)), Compression::Brotli => Box::new(async_compression::tokio::write::BrotliEncoder::new(w)),
#[cfg(feature = "zstd")] #[cfg(feature = "zstd")]
Compression::Zstd => Box::new(async_compression::futures::write::ZstdEncoder::new(w)), Compression::Zstd => Box::new(async_compression::tokio::write::ZstdEncoder::new(w)),
} }
} }
fn decompressor_from<'s>( fn decompressor_from<'s>(
comp: Compression, comp: Compression,
r: impl FuturesAsyncBufRead + Unpin + Send + 's, r: impl AsyncBufRead + Unpin + Send + 's,
) -> Box<dyn FuturesAsyncRead + Unpin + Send + 's> { ) -> Box<dyn AsyncRead + Unpin + Send + 's> {
match comp { match comp {
Compression::None => Box::new(BufReader::new(r)), Compression::None => Box::new(BufReader::new(r)),
#[cfg(feature = "brotli")] #[cfg(feature = "brotli")]
Compression::Brotli => Box::new(async_compression::futures::bufread::BrotliDecoder::new(r)), Compression::Brotli => Box::new(async_compression::tokio::bufread::BrotliDecoder::new(r)),
#[cfg(feature = "zstd")] #[cfg(feature = "zstd")]
Compression::Zstd => Box::new(async_compression::futures::bufread::ZstdDecoder::new(r)), Compression::Zstd => Box::new(async_compression::tokio::bufread::ZstdDecoder::new(r)),
} }
} }
async fn send_static_string( async fn send_static_string(
c: Compression, c: Compression,
w: &mut (impl FuturesAsyncWrite + Send + Unpin), w: &mut (impl AsyncWrite + Send + Unpin),
s: &str, s: &str,
) -> futures::io::Result<()> { ) -> tokio::io::Result<()> {
futures::io::copy(&mut Cursor::new(s).compat(), &mut compressor_to(c, w)).await?; tokio::io::copy(&mut Cursor::new(s), &mut compressor_to(c, w)).await?;
Ok(()) Ok(())
} }
async fn copy_bidirectional_compressed( async fn copy_bidirectional_compressed(
comp: Compression, comp: Compression,
local: impl FuturesAsyncRead + FuturesAsyncWrite + Send + Unpin, local: impl AsyncRead + AsyncWrite + Send + Unpin,
remote: impl FuturesAsyncRead + FuturesAsyncWrite + Send + Unpin, remote: impl AsyncRead + AsyncWrite + Send + Unpin,
) -> futures::io::Result<(u64, u64)> { ) -> tokio::io::Result<(u64, u64)> {
let (mut local_rx, mut local_tx) = local.split(); let (mut local_rx, mut local_tx) = tokio::io::split(local);
let (remote_rx, remote_tx) = remote.split(); let (remote_rx, remote_tx) = tokio::io::split(remote);
let remote_rx_buf = BufReader::new(remote_rx); let remote_rx_buf = BufReader::new(remote_rx);
@ -227,23 +189,22 @@ async fn copy_bidirectional_compressed(
let mut remote_rx_decomp = decompressor_from(comp, remote_rx_buf); let mut remote_rx_decomp = decompressor_from(comp, remote_rx_buf);
let uplink = async move { let uplink = async move {
log::info!("Starting transfer uplink"); let res = tokio::io::copy(&mut local_rx, &mut remote_tx_comp).await;
let res = futures::io::copy(&mut local_rx, &mut remote_tx_comp).await;
log::info!("Finished uplink"); log::info!("Finished uplink");
let shutdown = remote_tx_comp.close().await; let shutdown = remote_tx_comp.shutdown().await;
let res = res?; let res = res?;
shutdown?; shutdown?;
tokio::io::Result::Ok(res) tokio::io::Result::Ok(res)
}; };
let downlink = async move { let downlink = async move {
log::info!("Starting transfer downlink"); let res = tokio::io::copy(&mut remote_rx_decomp, &mut local_tx).await;
let res = futures::io::copy(&mut remote_rx_decomp, &mut local_tx).await;
log::info!("Finished downlink"); log::info!("Finished downlink");
let shutdown = local_tx.close().await; let shutdown = local_tx.shutdown().await;
let res = res?; let res = res?;
shutdown?; shutdown?;
tokio::io::Result::Ok(res) tokio::io::Result::Ok(res)
}; };
log::debug!("Starting transfer");
let res = tokio::try_join!(uplink, downlink)?; let res = tokio::try_join!(uplink, downlink)?;
log::info!( log::info!(
"Finished transferring {} bytes from local to remote and {} bytes from remote to local (compressed)", "Finished transferring {} bytes from local to remote and {} bytes from remote to local (compressed)",
@ -349,15 +310,6 @@ pub async fn reverse_proxy(opts: ReverseProxyCommand) -> Result<(), Box<dyn std:
Ok(mut tls) => { Ok(mut tls) => {
log::info!("Accepted TLS connection from: {}", addr); log::info!("Accepted TLS connection from: {}", addr);
let multiplexer = match opts.transport {
Transport::Plain => None,
#[cfg(feature = "yamux")]
Transport::YAMux => {
log::info!("Using YAMux transport");
Some(YamuxStreamManager::new())
}
};
match tls.get_ref().1.server_name().map(|s| s.to_string()) { match tls.get_ref().1.server_name().map(|s| s.to_string()) {
Some(sni) => match sni_to_target.get(&sni).cloned() { Some(sni) => match sni_to_target.get(&sni).cloned() {
Some(target) => { Some(target) => {
@ -366,85 +318,13 @@ pub async fn reverse_proxy(opts: ReverseProxyCommand) -> Result<(), Box<dyn std:
tls.get_ref().1.server_name(), tls.get_ref().1.server_name(),
target target
); );
match multiplexer {
Some(m) => { match tokio::net::TcpStream::connect(target).await {
let mut inner_listener =
m.wrap_s2c_connection(tls.compat()).await;
tokio::spawn(async move {
loop {
log::info!("Waiting for new stream (YAMux)");
match inner_listener.accept().await {
Ok(Some(mut tls_multiplexed)) => {
let target = target.clone();
log::info!("Accepted new stream (YAMux)");
tokio::spawn(async move {
let target = match TcpStream::connect(
target,
)
.await
{
Ok(s) => s,
Err(e) => {
log::error!(
"Failed to connect to target: {}",
e
);
match tls_multiplexed
.close()
.await
{
Ok(_) => {}
Err(e) => {
log::error!(
"Failed to close multiplexed stream: {}",
e
);
}
}
return;
}
};
if let Err(e) =
copy_bidirectional_compressed(
opts.compression,
target.compat(),
tls_multiplexed,
)
.await
{
log::error!(
"Failed to copy data: {}",
e
);
}
});
}
Ok(None) => {
log::info!("No more streams to accept");
break;
}
Err(e) => {
log::error!(
"Failed to accept stream: {}",
e
);
break;
}
}
}
futures::future::poll_fn(|cx| {
inner_listener.poll_close(cx)
})
.await
.ok();
});
}
None => match tokio::net::TcpStream::connect(target).await {
Ok(target) => { Ok(target) => {
if let Err(e) = copy_bidirectional_compressed( if let Err(e) = copy_bidirectional_compressed(
opts.compression, opts.compression,
target.compat(), target,
tls.compat(), tls,
) )
.await .await
{ {
@ -457,30 +337,27 @@ pub async fn reverse_proxy(opts: ReverseProxyCommand) -> Result<(), Box<dyn std:
.await .await
.expect("Failed to shutdown TLS stream"); .expect("Failed to shutdown TLS stream");
} }
},
} }
} }
None => { None => {
log::warn!("Accepted connection for {:?}, but SNI {} does not match any configured SNI", tls.get_ref().1.server_name(), sni); log::warn!("Accepted connection for {:?}, but SNI {} does not match any configured SNI", tls.get_ref().1.server_name(), sni);
let mut compat = tls.compat_write();
send_static_string( send_static_string(
opts.compression, opts.compression,
&mut compat, &mut tls,
format!("SNI {} does not match any configured SNI", sni) format!("SNI {} does not match any configured SNI", sni)
.as_str(), .as_str(),
) )
.await .await
.expect("Failed to send static string"); .expect("Failed to send static string");
compat.close().await.expect("Failed to shutdown TLS stream"); tls.shutdown().await.expect("Failed to shutdown TLS stream");
} }
}, },
_ => { _ => {
log::error!("No SNI provided"); log::error!("No SNI provided");
let mut compat = tls.compat_write(); send_static_string(opts.compression, &mut tls, "No SNI provided")
send_static_string(opts.compression, &mut compat, "No SNI provided")
.await .await
.expect("Failed to send static string"); .expect("Failed to send static string");
compat.close().await.expect("Failed to shutdown TLS stream"); tls.shutdown().await.expect("Failed to shutdown TLS stream");
} }
} }
} }
@ -547,8 +424,7 @@ pub async fn forward_proxy(opts: ForwardProxyCommand) -> anyhow::Result<()> {
}) })
.expect("Failed to parse SNI"); .expect("Failed to parse SNI");
match opts.transport { loop {
Transport::Plain => loop {
let (local_pt, _) = match listener.accept().await { let (local_pt, _) = match listener.accept().await {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
@ -593,11 +469,8 @@ pub async fn forward_proxy(opts: ForwardProxyCommand) -> anyhow::Result<()> {
match connector.connect(sni, target_pt).await { match connector.connect(sni, target_pt).await {
Ok(tls) => { Ok(tls) => {
if let Err(e) = copy_bidirectional_compressed( if let Err(e) =
opts.compression, copy_bidirectional_compressed(opts.compression, local_pt, tls)
local_pt.compat(),
tls.compat(),
)
.await .await
{ {
eprintln!("Failed to copy data: {}", e); eprintln!("Failed to copy data: {}", e);
@ -619,100 +492,27 @@ pub async fn forward_proxy(opts: ForwardProxyCommand) -> anyhow::Result<()> {
log::error!("None of the target addresses worked: {}", e); log::error!("None of the target addresses worked: {}", e);
} }
}); });
},
#[cfg(feature = "yamux")]
Transport::YAMux => {
let mut retries = 5;
loop {
if retries == 0 {
eprintln!("Too many retries, giving up");
break;
} }
let conn = TcpStream::connect(&opts.target).await?;
log::info!("Connected to target");
let tls_stream = match connector.connect(sni.clone(), conn).await {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to connect to target: {}", e);
retries -= 1;
continue;
}
};
let mut inner_listener = YamuxStreamManager::new()
.wrap_c2s_connection(tls_stream.compat())
.await;
loop {
log::info!("Waiting for new TCP connection (YAMux)");
let new_conn = tokio::select! {
_ = futures::future::poll_fn(|cx| inner_listener.poll_next_inbound(cx)) => {
None
}
conn = listener.accept() => {
match conn {
Ok((pt_stream, _)) => {
log::info!("Starting new C2S connection (YAMux)");
match inner_listener.start_c2s().await {
Ok(s) => Some((pt_stream, s)),
Err(e) => {
log::error!("Failed to start C2S connection: {}", e);
continue;
}
}
}
Err(e) => {
log::error!("Failed to accept connection: {}", e);
continue;
}
}
}
};
if let Some((pt_stream, tls_multiplexed)) = new_conn {
log::info!("Obtained new connection (YAMux)");
tokio::spawn(async move {
if let Err(e) = copy_bidirectional_compressed(
opts.compression,
pt_stream.compat(),
tls_multiplexed,
)
.await
{
eprintln!("Failed to copy data: {}", e);
}
});
} else {
log::warn!("No more streams to accept, next connection will be retried");
break;
}
}
futures::future::poll_fn(|cx| inner_listener.poll_close(cx))
.await
.ok();
}
anyhow::bail!("Too many retries");
}
};
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use tokio::io::AsyncReadExt;
use super::*; use super::*;
async fn test_compression_method(algo: Compression) { async fn test_compression_method(algo: Compression) {
let (r, w) = tokio::io::duplex(1024); let (r, w) = tokio::io::duplex(1024);
let data = "Hello, world!".repeat(1024); let data = "Hello, world!".repeat(1024);
let mut w = compressor_to(algo, w.compat()); let mut w = compressor_to(algo, w);
let mut r = decompressor_from(algo, BufReader::new(r.compat())); let mut r = decompressor_from(algo, BufReader::new(r));
tokio::join!( tokio::join!(
async { async {
w.write_all(data.as_bytes()).await.unwrap(); w.write_all(data.as_bytes()).await.unwrap();
w.close().await.unwrap(); w.shutdown().await.unwrap();
}, },
async { async {
let mut output = Vec::new(); let mut output = Vec::new();

View file

@ -5,7 +5,7 @@ use serde::Deserialize;
use crate::ops::network::{ForwardProxyCommand, ReverseProxySpec}; use crate::ops::network::{ForwardProxyCommand, ReverseProxySpec};
use super::network::{Compression, ReverseProxyCommand, Transport}; use super::network::{Compression, ReverseProxyCommand};
const DEF_CONFIG_FILE: &str = "/etc/replikey.toml"; const DEF_CONFIG_FILE: &str = "/etc/replikey.toml";
const CA_CERT: &str = "ca.pem"; const CA_CERT: &str = "ca.pem";
@ -49,7 +49,6 @@ pub struct ConnectionConfig {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct MasterConfig { pub struct MasterConfig {
listen: String, listen: String,
transport: Option<Transport>,
compression: Option<Compression>, compression: Option<Compression>,
redis: MasterServiceSpec, redis: MasterServiceSpec,
postgres: MasterServiceSpec, postgres: MasterServiceSpec,
@ -60,7 +59,6 @@ pub struct MasterConfig {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct SlaveConfig { pub struct SlaveConfig {
target: String, target: String,
transport: Option<Transport>,
compression: Option<Compression>, compression: Option<Compression>,
redis: SlaveServiceSpec, redis: SlaveServiceSpec,
postgres: SlaveServiceSpec, postgres: SlaveServiceSpec,
@ -99,7 +97,6 @@ pub fn service_replicate_master(config: String) {
let master_conf = config.connection.master.as_ref().unwrap(); let master_conf = config.connection.master.as_ref().unwrap();
let cmd = ReverseProxyCommand { let cmd = ReverseProxyCommand {
listen: master_conf.listen.clone(), listen: master_conf.listen.clone(),
transport: master_conf.transport.unwrap_or_default(),
compression: master_conf.compression.unwrap_or_default(), compression: master_conf.compression.unwrap_or_default(),
target: vec![ target: vec![
ReverseProxySpec { ReverseProxySpec {
@ -149,7 +146,6 @@ pub fn service_replicate_slave(config: String) {
let slave_conf = config.connection.slave.as_ref().unwrap(); let slave_conf = config.connection.slave.as_ref().unwrap();
let cmd_redis = ForwardProxyCommand { let cmd_redis = ForwardProxyCommand {
listen: slave_conf.redis.listen.clone(), listen: slave_conf.redis.listen.clone(),
transport: slave_conf.transport.unwrap_or_default(),
compression: slave_conf.compression.unwrap_or_default(), compression: slave_conf.compression.unwrap_or_default(),
sni: Some(slave_conf.redis.sni.clone()), sni: Some(slave_conf.redis.sni.clone()),
target: slave_conf.target.clone(), target: slave_conf.target.clone(),
@ -160,7 +156,6 @@ pub fn service_replicate_slave(config: String) {
}; };
let cmd_postgres = ForwardProxyCommand { let cmd_postgres = ForwardProxyCommand {
listen: slave_conf.postgres.listen.clone(), listen: slave_conf.postgres.listen.clone(),
transport: slave_conf.transport.unwrap_or_default(),
compression: slave_conf.compression.unwrap_or_default(), compression: slave_conf.compression.unwrap_or_default(),
sni: Some(slave_conf.postgres.sni.clone()), sni: Some(slave_conf.postgres.sni.clone()),
target: slave_conf.target.clone(), target: slave_conf.target.clone(),

View file

@ -1,203 +0,0 @@
use std::fmt::Debug;
use futures::{
future::BoxFuture,
io::{AsyncRead, AsyncWrite},
FutureExt,
};
#[cfg(feature = "yamux")]
pub mod yamux;
pub trait DynStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> DynStream for T {}
pub trait StreamManager<'c, C: 'c>: Send + Sync {
type S2CConn: MultiplexedS2C<C>;
type C2SConn: MultiplexedC2S<C>;
fn wrap_c2s_connection(&'c self, stream: C) -> BoxFuture<'c, Self::C2SConn>;
fn wrap_s2c_connection(&'c self, stream: C) -> BoxFuture<'c, Self::S2CConn>;
fn shutdown(&self) -> BoxFuture<'c, ()>;
}
// hangs rustc for some reason..
/*
pub fn manager_to_dyn<'c, C: 'c>(inner: impl StreamManager<'c, C> + 'c) -> DynStreamManager<'c, C> {
DynStreamManager {
inner: Box::new(inner),
}
}
pub struct DynStreamManager<'c, C: 'c> {
inner: Box<
dyn StreamManager<
'c,
C,
S2CConn = Box<
dyn MultiplexedS2C<C, Error = anyhow::Error, Stream = Box<dyn DynStream + 'c>>
+ 'c,
>,
C2SConn = Box<
dyn MultiplexedC2S<C, Error = anyhow::Error, Stream = Box<dyn DynStream + 'c>>
+ 'c,
>,
> + 'c,
>,
}
impl<'c, C> StreamManager<'c, C> for DynStreamManager<'c, C>
where
C: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'c,
<Self::C2SConn as MultiplexedC2S<C>>::Error: Error + 'static,
<Self::S2CConn as MultiplexedS2C<C>>::Error: Error + 'static,
{
type S2CConn = DynMultiplexedS2CConnection<'c, C>;
type C2SConn = DynMultiplexedC2SConnection<'c, C>;
fn wrap_c2s_connection(&'c self, stream: C) -> BoxFuture<'c, Self::C2SConn> {
self.inner
.wrap_c2s_connection(stream)
.map(|x| DynMultiplexedC2SConnection {
inner: Box::new(x.map(|x| {
x.map(|s| s.map(|s| Box::new(s) as _))
.map_err(anyhow::Error::new)
}) as _),
})
.boxed()
}
fn wrap_s2c_connection(&'c self, stream: C) -> BoxFuture<'c, Self::S2CConn> {
self.inner
.wrap_s2c_connection(stream)
.map(|x| DynMultiplexedS2CConnection {
inner: Box::new(x.map(|x| {
x.map(|s| s.map(|s| Box::new(s) as _))
.map_err(anyhow::Error::new) as _
})),
})
.boxed()
}
fn shutdown(&self) -> BoxFuture<'c, ()> {
async move { self.inner.shutdown().await }.boxed()
}
}
*/
pub trait MultiplexedS2C<C>: Send + Sync {
type Error: Debug + Send + Sync + Sized;
type Stream: DynStream;
fn accept(&mut self) -> BoxFuture<'_, Result<Option<Self::Stream>, Self::Error>>;
fn map<F, S, E>(self, map: F) -> S2CMap<C, Self, F>
where
F: Fn(Result<Option<Self::Stream>, Self::Error>) -> Result<Option<S>, E> + Send + Sync,
S: DynStream,
Self: Sized,
{
S2CMap {
inner: self,
map,
_phantom: std::marker::PhantomData,
}
}
}
pub struct S2CMap<C, MC, F> {
inner: MC,
map: F,
_phantom: std::marker::PhantomData<C>,
}
impl<C: Send + Sync, E: Send + Sync + Debug, S: DynStream, MC, F: Send> MultiplexedS2C<C>
for S2CMap<C, MC, F>
where
MC: MultiplexedS2C<C>,
F: Fn(Result<Option<MC::Stream>, MC::Error>) -> Result<Option<S>, E> + Send + Sync,
{
type Error = E;
type Stream = S;
fn accept(&mut self) -> BoxFuture<'_, Result<Option<Self::Stream>, Self::Error>> {
let inner = &mut self.inner;
let mapf = &self.map;
inner.accept().map(mapf).boxed()
}
}
pub trait MultiplexedC2S<C>: Send + Sync {
type Error: Debug + Send + Sync + Sized;
type Stream: DynStream;
fn start_c2s(&mut self) -> BoxFuture<'_, Result<Self::Stream, Self::Error>>;
fn map<F, S, E>(self, map: F) -> C2SMap<C, Self, F>
where
F: Fn(Result<Self::Stream, Self::Error>) -> Result<S, E> + Send + Sync,
S: DynStream,
Self: Sized,
{
C2SMap {
inner: self,
map,
_phantom: std::marker::PhantomData,
}
}
}
pub struct C2SMap<C, MC, F> {
inner: MC,
map: F,
_phantom: std::marker::PhantomData<C>,
}
impl<C: Send + Sync, E: Send + Sync + Debug, S: DynStream, MC, F: Send> MultiplexedC2S<C>
for C2SMap<C, MC, F>
where
MC: MultiplexedC2S<C>,
F: Fn(Result<MC::Stream, MC::Error>) -> Result<S, E> + Send + Sync,
{
type Error = E;
type Stream = S;
fn start_c2s(&mut self) -> BoxFuture<'_, Result<Self::Stream, Self::Error>> {
let inner = &mut self.inner;
let map = &self.map;
inner.start_c2s().map(map).boxed()
}
}
pub struct DynMultiplexedS2CConnection<'c, C: 'c> {
inner: Box<dyn MultiplexedS2C<C, Error = anyhow::Error, Stream = Box<dyn DynStream + 'c>> + 'c>,
}
impl<'c, C> MultiplexedS2C<C> for DynMultiplexedS2CConnection<'c, C> {
type Error = anyhow::Error;
type Stream = Box<dyn DynStream + 'c>;
fn accept(&mut self) -> BoxFuture<'_, Result<Option<Self::Stream>, Self::Error>> {
async move {
self.inner
.accept()
.await
.map(|x| x.map(|x| Box::new(x) as _))
}
.boxed()
}
}
pub struct DynMultiplexedC2SConnection<'c, C: 'c> {
inner: Box<dyn MultiplexedC2S<C, Error = anyhow::Error, Stream = Box<dyn DynStream + 'c>> + 'c>,
}
impl<'c, C> MultiplexedC2S<C> for DynMultiplexedC2SConnection<'c, C> {
type Error = anyhow::Error;
type Stream = Box<dyn DynStream + 'c>;
fn start_c2s(&mut self) -> BoxFuture<'_, Result<Self::Stream, Self::Error>> {
async move { self.inner.start_c2s().await.map(|x| Box::new(x) as _) }.boxed()
}
}

View file

@ -1,80 +0,0 @@
use futures::{
future::BoxFuture,
io::{AsyncRead, AsyncWrite},
FutureExt,
};
use yamux::Config;
pub(crate) fn yamux_config() -> yamux::Config {
let mut conf = Config::default();
// we are in mTLS so be generous but still prevent accidental DoS
conf.set_read_after_close(true)
.set_max_num_streams(512)
.set_max_connection_receive_window(Some((256 << 10) * 512));
conf
}
pub struct YamuxStreamManager {
config: yamux::Config,
}
impl Default for YamuxStreamManager {
fn default() -> Self {
Self::new()
}
}
impl YamuxStreamManager {
pub fn new() -> Self {
Self {
config: yamux_config(),
}
}
}
impl<'c, C: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'c> super::StreamManager<'c, C>
for YamuxStreamManager
{
type S2CConn = yamux::Connection<C>;
type C2SConn = yamux::Connection<C>;
fn wrap_c2s_connection(&self, stream: C) -> BoxFuture<'c, Self::C2SConn> {
let config = self.config.clone();
async move { yamux::Connection::new(stream, config, yamux::Mode::Client) }.boxed()
}
fn wrap_s2c_connection(&self, stream: C) -> BoxFuture<'c, Self::S2CConn> {
let config = self.config.clone();
async move { yamux::Connection::new(stream, config, yamux::Mode::Server) }.boxed()
}
fn shutdown(&self) -> BoxFuture<'c, ()> {
async {}.boxed()
}
}
impl<C: AsyncRead + AsyncWrite + Unpin + Send + Sync> super::MultiplexedS2C<C>
for yamux::Connection<C>
{
type Error = yamux::ConnectionError;
type Stream = yamux::Stream;
fn accept(&mut self) -> BoxFuture<'_, Result<Option<Self::Stream>, Self::Error>> {
futures::future::poll_fn(|cx| self.poll_next_inbound(cx))
.map(|x| x.transpose())
.boxed()
}
}
impl<C: AsyncRead + AsyncWrite + Unpin + Send + Sync> super::MultiplexedC2S<C>
for yamux::Connection<C>
{
type Error = yamux::ConnectionError;
type Stream = yamux::Stream;
fn start_c2s(&mut self) -> BoxFuture<'_, Result<Self::Stream, Self::Error>> {
futures::future::poll_fn(|cx| self.poll_new_outbound(cx)).boxed()
}
}

View file

@ -6,7 +6,6 @@ use std::{
}; };
use clap::Parser; use clap::Parser;
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use replikey::{ use replikey::{
cert::UsageType, cert::UsageType,
ops::{ ops::{
@ -17,10 +16,9 @@ use replikey::{
use rustls::crypto::{aws_lc_rs, CryptoProvider}; use rustls::crypto::{aws_lc_rs, CryptoProvider};
use time::OffsetDateTime; use time::OffsetDateTime;
use tokio::{ use tokio::{
io::{AsyncReadExt as _, AsyncWriteExt as _}, io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _},
task::JoinSet, task::JoinSet,
}; };
use tokio_util::compat::TokioAsyncReadCompatExt;
static NEXT_PORT: AtomicU32 = AtomicU32::new(12311); static NEXT_PORT: AtomicU32 = AtomicU32::new(12311);
@ -49,9 +47,9 @@ async fn test_stream_sequence<
left: L, left: L,
right: R, right: R,
) { ) {
let (mut l_rx, mut l_tx) = left.split(); let (mut l_rx, mut l_tx) = tokio::io::split(left);
let (mut r_rx, mut r_tx) = right.split(); let (mut r_rx, mut r_tx) = tokio::io::split(right);
let l_write = async move { let l_write = async move {
let mut buf = [0u8; WRITE_SIZE]; let mut buf = [0u8; WRITE_SIZE];
@ -61,9 +59,9 @@ async fn test_stream_sequence<
.for_each(|(j, x)| *x = ((i + j) ^ 123) as u8); .for_each(|(j, x)| *x = ((i + j) ^ 123) as u8);
l_tx.write_all(&buf).await?; l_tx.write_all(&buf).await?;
} }
l_tx.close().await?; l_tx.shutdown().await?;
Ok::<_, futures::io::Error>(()) Ok::<_, tokio::io::Error>(())
}; };
let r_write = async move { let r_write = async move {
let mut buf = [0u8; WRITE_SIZE]; let mut buf = [0u8; WRITE_SIZE];
@ -73,7 +71,7 @@ async fn test_stream_sequence<
.for_each(|(j, x)| *x = ((i + j) ^ 321) as u8); .for_each(|(j, x)| *x = ((i + j) ^ 321) as u8);
r_tx.write_all(&buf).await?; r_tx.write_all(&buf).await?;
} }
r_tx.close().await?; r_tx.shutdown().await?;
Ok(()) Ok(())
}; };
let l_read = async move { let l_read = async move {
@ -105,11 +103,11 @@ async fn test_stream_sequence<
#[tokio::test] #[tokio::test]
async fn in_memory_stream_works() { async fn in_memory_stream_works() {
let (left, right) = tokio::io::duplex(1024); let (left, right) = tokio::io::duplex(1024);
test_stream_sequence::<1024, _, _>(left.compat(), right.compat()).await; test_stream_sequence::<1024, _, _>(left, right).await;
let (left, right) = tokio::io::duplex(1024); let (left, right) = tokio::io::duplex(1024);
test_stream_sequence::<1, _, _>(left.compat(), right.compat()).await; test_stream_sequence::<1, _, _>(left, right).await;
let (left, right) = tokio::io::duplex(1024); let (left, right) = tokio::io::duplex(1024);
test_stream_sequence::<3543, _, _>(left.compat(), right.compat()).await; test_stream_sequence::<3543, _, _>(left, right).await;
} }
fn test_ca(dir: &PathBuf) -> (PathBuf, PathBuf) { fn test_ca(dir: &PathBuf) -> (PathBuf, PathBuf) {
@ -348,7 +346,7 @@ async fn start_reverse_proxy(
cert: &PathBuf, cert: &PathBuf,
key: &PathBuf, key: &PathBuf,
ca_cert: &PathBuf, ca_cert: &PathBuf,
transport: &str, compression: &str,
) -> u32 { ) -> u32 {
let port = next_port(); let port = next_port();
@ -371,8 +369,8 @@ async fn start_reverse_proxy(
key.to_string_lossy().to_string(), key.to_string_lossy().to_string(),
"--ca".to_string(), "--ca".to_string(),
ca_cert.to_string_lossy().to_string(), ca_cert.to_string_lossy().to_string(),
"--transport".to_string(), "--compression".to_string(),
transport.to_string(), compression.to_string(),
]; ];
let parsed = NetworkCommand::parse_from(&cmd); let parsed = NetworkCommand::parse_from(&cmd);
@ -396,7 +394,7 @@ async fn start_forward_proxy(
cert: &PathBuf, cert: &PathBuf,
key: &PathBuf, key: &PathBuf,
ca_cert: &PathBuf, ca_cert: &PathBuf,
transport: &str, compression: &str,
) -> u32 { ) -> u32 {
let port = next_port(); let port = next_port();
@ -415,8 +413,8 @@ async fn start_forward_proxy(
key.to_string_lossy().to_string(), key.to_string_lossy().to_string(),
"--ca".to_string(), "--ca".to_string(),
ca_cert.to_string_lossy().to_string(), ca_cert.to_string_lossy().to_string(),
"--transport".to_string(), "--compression".to_string(),
transport.to_string(), compression.to_string(),
]; ];
let parsed = NetworkCommand::parse_from(&cmd); let parsed = NetworkCommand::parse_from(&cmd);
@ -514,47 +512,6 @@ async fn expect_target1_signature(port: u32) {
tx.shutdown().await.expect("failed to close"); tx.shutdown().await.expect("failed to close");
} }
async fn expect_target1_signature_buffered(port: u32) {
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.expect("failed to connect");
let (mut rx, mut tx) = tokio::io::split(stream);
let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes();
let expect = data.clone();
let mut write_size = (1..4096).cycle();
let (_, rx_data) = tokio::join!(
async {
while !data.is_empty() {
let n = write_size.next().unwrap().min(data.len());
tx.write_all(&data[..n]).await.expect("failed to write");
data = data.split_off(n);
}
tx.shutdown().await.expect("failed to close");
},
async {
let mut rxed = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = rx.read(&mut buf).await.expect("failed to read");
if n == 0 {
break;
}
rxed.extend_from_slice(&buf[..n]);
if rxed.len() >= expect.len() {
break;
}
}
rxed
}
);
assert_eq!(rx_data, expect);
}
async fn expect_target2_signature(port: u32) { async fn expect_target2_signature(port: u32) {
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await .await
@ -600,49 +557,7 @@ async fn expect_target2_signature(port: u32) {
tx.shutdown().await.expect("failed to close"); tx.shutdown().await.expect("failed to close");
} }
async fn expect_target2_signature_buffered(port: u32) { #[tokio::test(flavor = "multi_thread", worker_threads = 16)]
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.expect("failed to connect");
let (mut rx, mut tx) = tokio::io::split(stream);
let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes();
let expect = data.iter().map(|x| !x).collect::<Vec<_>>();
let mut write_size = (1..4096).cycle();
let (_, rx_data) = tokio::join!(
async {
while !data.is_empty() {
let n = write_size.next().unwrap().min(data.len());
tx.write_all(&data[..n]).await.expect("failed to write");
data = data.split_off(n);
}
tx.shutdown().await.expect("failed to close");
},
async {
let mut rxed = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = rx.read(&mut buf).await.expect("failed to read");
if n == 0 {
break;
}
rxed.extend_from_slice(&buf[..n]);
if rxed.len() >= expect.len() {
break;
}
}
rxed
}
);
assert_eq!(rx_data, expect);
}
#[tokio::test]
async fn test_mtls_integrated() { async fn test_mtls_integrated() {
setup_once(); setup_once();
let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca")); let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca"));
@ -660,13 +575,14 @@ async fn test_mtls_integrated() {
expect_target1_signature(target1).await; expect_target1_signature(target1).await;
expect_target2_signature(target2).await; expect_target2_signature(target2).await;
for compression in ["none", "zstd", "brotli"] {
let proxy_port_self_signed = start_reverse_proxy( let proxy_port_self_signed = start_reverse_proxy(
target1, target1,
target2, target2,
&server_cert.self_signed, &server_cert.self_signed,
&server_cert.key, &server_cert.key,
&ca_cert, &ca_cert,
"plain", compression,
) )
.await; .await;
wait_for_port(proxy_port_self_signed).await; wait_for_port(proxy_port_self_signed).await;
@ -678,7 +594,7 @@ async fn test_mtls_integrated() {
server_cert.signed_cert.as_ref().unwrap(), server_cert.signed_cert.as_ref().unwrap(),
&server_cert.key, &server_cert.key,
&ca_cert, &ca_cert,
"plain", compression,
) )
.await; .await;
wait_for_port(proxy_port).await; wait_for_port(proxy_port).await;
@ -698,7 +614,7 @@ async fn test_mtls_integrated() {
&client_cert.self_signed, &client_cert.self_signed,
&client_cert.key, &client_cert.key,
&ca_cert, &ca_cert,
"plain", compression,
) )
.await; .await;
log::info!( log::info!(
@ -719,7 +635,7 @@ async fn test_mtls_integrated() {
client_cert.signed_cert.as_ref().unwrap(), client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key, &client_cert.key,
&ca_cert, &ca_cert,
"plain", compression,
) )
.await; .await;
log::info!("forward_proxy_target1: {}", forward_proxy_target1); log::info!("forward_proxy_target1: {}", forward_proxy_target1);
@ -737,7 +653,7 @@ async fn test_mtls_integrated() {
client_cert.signed_cert.as_ref().unwrap(), client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key, &client_cert.key,
&ca_cert, &ca_cert,
"plain", compression,
) )
.await; .await;
log::info!("forward_proxy_target2: {}", forward_proxy_target2); log::info!("forward_proxy_target2: {}", forward_proxy_target2);
@ -755,7 +671,7 @@ async fn test_mtls_integrated() {
client_cert.signed_cert.as_ref().unwrap(), client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key, &client_cert.key,
&ca_cert, &ca_cert,
"plain", compression,
) )
.await; .await;
@ -772,7 +688,7 @@ async fn test_mtls_integrated() {
client_cert.signed_cert.as_ref().unwrap(), client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key, &client_cert.key,
&ca_cert, &ca_cert,
"plain", compression,
) )
.await; .await;
log::info!( log::info!(
@ -793,7 +709,7 @@ async fn test_mtls_integrated() {
client_cert.signed_cert.as_ref().unwrap(), client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key, &client_cert.key,
&ca_cert, &ca_cert,
"plain", compression,
) )
.await; .await;
log::info!("forward_proxy_unknown_sni: {}", forward_proxy_unknown_sni); log::info!("forward_proxy_unknown_sni: {}", forward_proxy_unknown_sni);
@ -828,6 +744,7 @@ async fn test_mtls_integrated() {
js.join_all().await; js.join_all().await;
} }
} }
}
} }
#[tokio::test] #[tokio::test]
@ -858,7 +775,7 @@ async fn test_mtls_role_reverse() {
server_cert.signed_cert.as_ref().unwrap(), server_cert.signed_cert.as_ref().unwrap(),
&server_cert.key, &server_cert.key,
&ca_cert, &ca_cert,
"plain", "none",
) )
.await; .await;
@ -870,7 +787,7 @@ async fn test_mtls_role_reverse() {
client_cert.signed_cert.as_ref().unwrap(), client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key, &client_cert.key,
&ca_cert, &ca_cert,
"plain", "none",
) )
.await; .await;
wait_for_port(forward_proxy_target1).await; wait_for_port(forward_proxy_target1).await;
@ -883,7 +800,7 @@ async fn test_mtls_role_reverse() {
client_cert.signed_cert.as_ref().unwrap(), client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key, &client_cert.key,
&ca_cert, &ca_cert,
"plain", "none",
) )
.await; .await;
@ -895,74 +812,10 @@ async fn test_mtls_role_reverse() {
server_cert.signed_cert.as_ref().unwrap(), server_cert.signed_cert.as_ref().unwrap(),
&server_cert.key, &server_cert.key,
&ca_cert, &ca_cert,
"plain", "none",
) )
.await; .await;
wait_for_port(forward_proxy_target1_reversed).await; wait_for_port(forward_proxy_target1_reversed).await;
should_not_get_anything(forward_proxy_target1_reversed).await; should_not_get_anything(forward_proxy_target1_reversed).await;
} }
#[tokio::test]
async fn test_mtls_yamux_transport() {
setup_once();
let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca"));
let server_cert = test_server_cert(&PathBuf::from("target/test-server"));
let server_cert = server_cert.signed_by(UsageType::Server, &ca_key);
let client_cert = test_client_cert(&PathBuf::from("target/test-client"));
let client_cert = client_cert.signed_by(UsageType::Client, &ca_key);
let target1 = start_test_target_1().await;
let target2 = start_test_target_2().await;
wait_for_port(target1).await;
wait_for_port(target2).await;
let proxy_port = start_reverse_proxy(
target1,
target2,
server_cert.signed_cert.as_ref().unwrap(),
&server_cert.key,
&ca_cert,
"yamux",
)
.await;
wait_for_port(proxy_port).await;
wait_for_port(proxy_port).await;
let forward_proxy_target1 = start_forward_proxy(
&format!("127.0.0.1:{}", proxy_port),
"target1.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"yamux",
)
.await;
let forward_proxy_target2 = start_forward_proxy(
&format!("127.0.0.1:{}", proxy_port),
"target2.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"yamux",
)
.await;
wait_for_port(forward_proxy_target1).await;
wait_for_port(forward_proxy_target2).await;
log::info!("forward_proxy_target1: {}", forward_proxy_target1);
log::info!("forward_proxy_target2: {}", forward_proxy_target2);
let mut js = JoinSet::new();
for _ in 0..128 {
js.spawn(expect_target1_signature_buffered(forward_proxy_target1));
js.spawn(expect_target2_signature_buffered(forward_proxy_target2));
}
js.join_all().await;
}