Unit and integration test, configurable compression and transport

Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-11-03 22:25:34 -06:00
parent 73ebf99243
commit fbd32b2ec0
No known key found for this signature in database
11 changed files with 2038 additions and 172 deletions

160
Cargo.lock generated
View file

@ -73,6 +73,21 @@ dependencies = [
"memchr",
]
[[package]]
name = "alloc-no-stdlib"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3"
[[package]]
name = "alloc-stdlib"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece"
dependencies = [
"alloc-no-stdlib",
]
[[package]]
name = "allocator-api2"
version = "0.2.18"
@ -128,6 +143,12 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "anyhow"
version = "1.0.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74f37166d7d48a0284b99dd824694c26119c700b53bf0d1540cdb147dbdaaf13"
[[package]]
name = "argon2"
version = "0.5.3"
@ -185,10 +206,11 @@ version = "0.4.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857"
dependencies = [
"brotli",
"futures-core",
"futures-io",
"memchr",
"pin-project-lite",
"tokio",
"zstd",
"zstd-safe",
]
@ -309,6 +331,27 @@ dependencies = [
"generic-array",
]
[[package]]
name = "brotli"
version = "7.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
"brotli-decompressor",
]
[[package]]
name = "brotli-decompressor"
version = "4.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
]
[[package]]
name = "bumpalo"
version = "3.16.0"
@ -672,6 +715,21 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
@ -688,6 +746,17 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-intrusive"
version = "0.5.0"
@ -705,6 +774,17 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
@ -723,8 +803,10 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
@ -1127,6 +1209,12 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1"
[[package]]
name = "nohash-hasher"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451"
[[package]]
name = "nom"
version = "7.1.3"
@ -1310,6 +1398,26 @@ version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "pin-project"
version = "1.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "pin-project-lite"
version = "0.2.15"
@ -1518,10 +1626,12 @@ name = "replikey"
version = "0.1.0"
dependencies = [
"aes-gcm",
"anyhow",
"argon2",
"async-compression",
"clap",
"env_logger",
"futures",
"log",
"openssl",
"pem-rfc7468",
@ -1538,8 +1648,10 @@ dependencies = [
"time",
"tokio",
"tokio-rustls",
"tokio-util",
"toml",
"x509-parser",
"yamux",
]
[[package]]
@ -1951,6 +2063,12 @@ dependencies = [
"whoami",
]
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "stringprep"
version = "0.1.5"
@ -2122,6 +2240,20 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.7.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a"
dependencies = [
"bytes",
"futures-core",
"futures-io",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]]
name = "toml"
version = "0.8.19"
@ -2382,6 +2514,16 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-roots"
version = "0.26.5"
@ -2618,6 +2760,22 @@ dependencies = [
"time",
]
[[package]]
name = "yamux"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a31b5e376a8b012bee9c423acdbb835fc34d45001cfa3106236a624e4b738028"
dependencies = [
"futures",
"log",
"nohash-hasher",
"parking_lot",
"pin-project",
"rand",
"static_assertions",
"web-time",
]
[[package]]
name = "yasna"
version = "0.5.2"

View file

@ -26,11 +26,15 @@ serde = { version = "1.0.214", features = ["derive"], optional = true }
sqlx = { version = "0.8.2", optional = true, default-features = false, features = ["tls-none", "postgres"] }
tokio = { version = "1.41.0", features = ["rt", "rt-multi-thread", "macros", "net", "io-util", "sync"], optional = true }
rustls = { version = "0.23.16", optional = true }
async-compression = { version = "0.4.17", optional = true, features = ["tokio", "zstd"] }
async-compression = { version = "0.4.17", optional = true, features = ["futures-io"] }
yamux = { version = "0.13.3", optional = true }
futures = { version = "0.3.31", optional = true }
tokio-util = { version = "0.7.12", features = ["compat"], optional = true }
anyhow = "1.0.92"
[features]
default = ["keygen", "networking", "service", "remote-crl", "setup-postgres"]
asyncio = ["dep:tokio"]
default = ["keygen", "networking", "service", "remote-crl", "setup-postgres", "yamux", "zstd", "brotli"]
asyncio = ["dep:tokio", "dep:futures", "dep:tokio-util"]
keygen = ["dep:rcgen", "dep:pem-rfc7468", "dep:rpassword", "dep:argon2", "dep:sha2", "dep:aes-gcm", "dep:time"]
networking = ["asyncio", "dep:tokio-rustls", "dep:rustls", "dep:async-compression"]
test-crosscheck-openssl = ["dep:openssl"]
@ -41,6 +45,9 @@ setup-postgres = ["dep:sqlx"]
stat-service = ["networking", "serde"]
rustls = ["dep:rustls"]
async-compression = ["dep:async-compression"]
yamux = ["dep:yamux", "networking"]
zstd = ["async-compression/zstd"]
brotli = ["async-compression/brotli"]
[[bin]]
name = "replikey"

View file

@ -96,9 +96,12 @@ fn main() {
print_feature!("keygen");
print_feature!("asyncio");
print_feature!("networking");
print_feature!("yamux");
print_feature!("service");
print_feature!("remote-crl");
print_feature!("setup-postgres");
print_feature!("zstd");
print_feature!("brotli");
}
#[cfg(feature = "keygen")]
SubCommand::Cert(cert) => match cert.subcmd {

View file

@ -1,5 +1,8 @@
#![forbid(unsafe_code)]
#[cfg(feature = "keygen")]
pub mod cert;
pub mod ops;
pub mod transport;
pub mod fs_crypt;

View file

@ -185,7 +185,7 @@ pub fn dump_certificate_params(params: &CertificateParams) {
for b in serial_number.as_ref() {
print!("{:02x}", b);
}
println!("");
println!();
}
for san in &params.subject_alt_names {
match san {

View file

@ -1,9 +1,16 @@
use std::{io::Cursor, net::ToSocketAddrs, sync::Arc};
use clap::Parser;
use futures::{
io::{
AsyncBufRead as FuturesAsyncBufRead, AsyncRead as FuturesAsyncRead, AsyncReadExt,
AsyncWrite as FuturesAsyncWrite, BufReader,
},
AsyncWriteExt as _,
};
use tokio::{
io::{AsyncBufRead, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader},
net::TcpStream,
io::AsyncWriteExt as _,
net::{TcpSocket, TcpStream},
};
use tokio_rustls::{
rustls::{
@ -14,6 +21,9 @@ use tokio_rustls::{
},
TlsAcceptor, TlsConnector,
};
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
use crate::transport::{yamux::YamuxStreamManager, MultiplexedC2S, MultiplexedS2C, StreamManager};
#[derive(Debug, Parser)]
pub struct NetworkCommand {
@ -30,22 +40,45 @@ pub enum NetworkSubCommand {
ForwardProxy(ForwardProxyCommand),
}
#[derive(Debug, Clone)]
pub struct ReverseProxySpec {
pub sni: String,
pub target: String,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, thiserror::Error)]
pub enum ReverseProxySpecParseError {
#[error("Too many parts")]
TooManyParts,
#[error("No SNI provided")]
NoSni,
#[error("No target provided")]
NoTarget,
}
fn reverse_proxy_spec_parser(input: &str) -> Result<ReverseProxySpec, ReverseProxySpecParseError> {
let mut parts = input.split('/');
let sni = parts.next().ok_or(ReverseProxySpecParseError::NoSni)?;
let target = parts.next().ok_or(ReverseProxySpecParseError::NoTarget)?;
if parts.next().is_some() {
return Err(ReverseProxySpecParseError::TooManyParts);
}
Ok(ReverseProxySpec {
sni: sni.to_string(),
target: target.to_string(),
})
}
#[derive(Debug, Parser)]
pub struct ReverseProxyCommand {
#[clap(short, long)]
pub listen: String,
#[clap(long)]
pub redis_sni: String,
#[clap(long)]
pub redis_target: String,
#[clap(long)]
pub postgres_sni: String,
#[clap(long)]
pub postgres_target: String,
#[clap(long, help = "your.sni.tld/target:1234", value_parser = reverse_proxy_spec_parser)]
pub target: Vec<ReverseProxySpec>,
#[clap(long, help = "Certificate")]
pub cert: String,
@ -58,6 +91,60 @@ pub struct ReverseProxyCommand {
#[clap(long, help = "CRLs to use")]
pub crl: Vec<String>,
#[clap(long, help = "Transport to use", default_value = "plain")]
pub transport: Transport,
#[clap(long, help = "Compression to use", default_value = "none")]
pub compression: Compression,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, Default)]
pub enum Transport {
#[default]
Plain,
#[cfg(feature = "yamux")]
YAMux,
}
impl std::str::FromStr for Transport {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"plain" => Ok(Transport::Plain),
#[cfg(feature = "yamux")]
"yamux" => Ok(Transport::YAMux),
_ => Err(format!("Unknown transport: {}", s)),
}
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, Default)]
pub enum Compression {
#[default]
None,
#[cfg(feature = "brotli")]
Brotli,
#[cfg(feature = "zstd")]
Zstd,
}
impl std::str::FromStr for Compression {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"none" => Ok(Compression::None),
#[cfg(feature = "brotli")]
"brotli" => Ok(Compression::Brotli),
#[cfg(feature = "zstd")]
"zstd" => Ok(Compression::Zstd),
_ => Err(format!("Unknown compression: {}", s)),
}
}
}
#[derive(Debug, Parser)]
@ -82,48 +169,77 @@ pub struct ForwardProxyCommand {
#[clap(long)]
pub crl: Vec<String>,
#[clap(long, help = "Transport to use", default_value = "plain")]
pub transport: Transport,
#[clap(long, help = "Compression to use", default_value = "none")]
pub compression: Compression,
}
fn compressor_to(w: impl AsyncWrite + Unpin) -> impl AsyncWrite + Unpin {
async_compression::tokio::write::ZstdEncoder::new(w)
fn compressor_to<'s>(
comp: Compression,
w: impl FuturesAsyncWrite + Unpin + Send + 's,
) -> Box<dyn FuturesAsyncWrite + Unpin + Send + 's> {
match comp {
Compression::None => Box::new(w),
#[cfg(feature = "brotli")]
Compression::Brotli => Box::new(async_compression::futures::write::BrotliEncoder::new(w)),
#[cfg(feature = "zstd")]
Compression::Zstd => Box::new(async_compression::futures::write::ZstdEncoder::new(w)),
}
}
fn decompressor_from(r: impl AsyncBufRead + Unpin) -> impl AsyncRead + Unpin {
async_compression::tokio::bufread::ZstdDecoder::new(r)
fn decompressor_from<'s>(
comp: Compression,
r: impl FuturesAsyncBufRead + Unpin + Send + 's,
) -> Box<dyn FuturesAsyncRead + Unpin + Send + 's> {
match comp {
Compression::None => Box::new(BufReader::new(r)),
#[cfg(feature = "brotli")]
Compression::Brotli => Box::new(async_compression::futures::bufread::BrotliDecoder::new(r)),
#[cfg(feature = "zstd")]
Compression::Zstd => Box::new(async_compression::futures::bufread::ZstdDecoder::new(r)),
}
}
async fn send_static_string(w: &mut (impl AsyncWrite + Unpin), s: &str) -> tokio::io::Result<()> {
let mut cursor = Cursor::new(s);
tokio::io::copy(&mut cursor, &mut compressor_to(w)).await?;
async fn send_static_string(
c: Compression,
w: &mut (impl FuturesAsyncWrite + Send + Unpin),
s: &str,
) -> futures::io::Result<()> {
futures::io::copy(&mut Cursor::new(s).compat(), &mut compressor_to(c, w)).await?;
Ok(())
}
async fn copy_bidirectional_compressed(
local: impl AsyncRead + AsyncWrite + Unpin,
remote: impl AsyncRead + AsyncWrite + Unpin,
) -> tokio::io::Result<(u64, u64)> {
let (mut local_rx, mut local_tx) = tokio::io::split(local);
let (remote_rx, remote_tx) = tokio::io::split(remote);
comp: Compression,
local: impl FuturesAsyncRead + FuturesAsyncWrite + Send + Unpin,
remote: impl FuturesAsyncRead + FuturesAsyncWrite + Send + Unpin,
) -> futures::io::Result<(u64, u64)> {
let (mut local_rx, mut local_tx) = local.split();
let (remote_rx, remote_tx) = remote.split();
let remote_rx_buf = BufReader::new(remote_rx);
let mut remote_tx_comp = compressor_to(remote_tx);
let mut remote_rx_decomp = decompressor_from(remote_rx_buf);
log::info!("Starting transfer");
let mut remote_tx_comp = compressor_to(comp, remote_tx);
let mut remote_rx_decomp = decompressor_from(comp, remote_rx_buf);
let uplink = async move {
let res = tokio::io::copy(&mut local_rx, &mut remote_tx_comp).await;
let shutdown = remote_tx_comp.shutdown().await;
log::info!("Starting transfer uplink");
let res = futures::io::copy(&mut local_rx, &mut remote_tx_comp).await;
log::info!("Finished uplink");
let shutdown = remote_tx_comp.close().await;
let res = res?;
shutdown?;
tokio::io::Result::Ok(res)
};
let downlink = async move {
let res = tokio::io::copy(&mut remote_rx_decomp, &mut local_tx).await;
let shutdown = local_tx.shutdown().await;
log::info!("Starting transfer downlink");
let res = futures::io::copy(&mut remote_rx_decomp, &mut local_tx).await;
log::info!("Finished downlink");
let shutdown = local_tx.close().await;
let res = res?;
shutdown?;
tokio::io::Result::Ok(res)
@ -184,93 +300,190 @@ pub async fn reverse_proxy(opts: ReverseProxyCommand) -> Result<(), Box<dyn std:
let acceptor = TlsAcceptor::from(Arc::new(config));
let listener = tokio::net::TcpListener::bind(&opts.listen).await?;
let listen_addr = opts
.listen
.to_socket_addrs()?
.next()
.expect("Failed to resolve listen address");
let socket = match listen_addr {
std::net::SocketAddr::V4(_) => TcpSocket::new_v4()?,
std::net::SocketAddr::V6(_) => TcpSocket::new_v6()?,
};
// default options with keepalive
socket.set_reuseaddr(true)?;
socket.set_keepalive(true)?;
socket.bind(listen_addr)?;
let listener = socket.listen(1024)?;
log::info!("Listening on: {}", opts.listen);
let (redis_sni, postgres_sni, redis_target, postgres_target) = (
Arc::new(opts.redis_sni.clone()),
Arc::new(opts.postgres_sni.clone()),
opts.redis_target.clone(),
opts.postgres_target.clone(),
for spec in &opts.target {
if let Err(e) = spec.target.to_socket_addrs() {
eprintln!(
"Failed to resolve target address {}, are you sure it's correct? {}",
spec.target, e
);
return Ok(());
}
}
let sni_to_target = Arc::new(
opts.target
.into_iter()
.map(|spec| (spec.sni, spec.target))
.collect::<std::collections::HashMap<_, _>>(),
);
if let Err(e) = opts.redis_target.to_socket_addrs() {
eprintln!("Failed to resolve redis target: {}", e);
return Ok(());
}
if let Err(e) = opts.postgres_target.to_socket_addrs() {
eprintln!("Failed to resolve postgres target: {}", e);
return Ok(());
}
loop {
let (pt_stream, _) = match listener.accept().await {
let (pt_stream, addr) = match listener.accept().await {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to accept connection: {}", e);
continue;
}
};
log::info!("Accepted TCP connection from: {}", addr);
let acceptor = acceptor.clone();
let (redis_sni, postgres_sni, redis_target, postgres_target) = (
redis_sni.clone(),
postgres_sni.clone(),
redis_target.clone(),
postgres_target.clone(),
);
let sni_to_target = Arc::clone(&sni_to_target);
tokio::spawn(async move {
match acceptor.accept(pt_stream).await {
Ok(mut tls) => match tls.get_ref().1.server_name().map(|s| s.to_string()) {
Some(sni) if sni == *redis_sni => {
log::info!(
"Accepted Redis connection for {:?}",
tls.get_ref().1.server_name()
);
match tokio::net::TcpStream::connect(&redis_target).await {
Ok(redis) => {
if let Err(e) = copy_bidirectional_compressed(redis, tls).await {
eprintln!("Failed to copy data: {}", e);
Ok(mut tls) => {
log::info!("Accepted TLS connection from: {}", addr);
let multiplexer = match opts.transport {
Transport::Plain => None,
#[cfg(feature = "yamux")]
Transport::YAMux => {
log::info!("Using YAMux transport");
Some(YamuxStreamManager::new())
}
};
match tls.get_ref().1.server_name().map(|s| s.to_string()) {
Some(sni) => match sni_to_target.get(&sni).cloned() {
Some(target) => {
log::info!(
"Accepted connection for {:?}, forwarding to {}",
tls.get_ref().1.server_name(),
target
);
match multiplexer {
Some(m) => {
let mut inner_listener =
m.wrap_s2c_connection(tls.compat()).await;
tokio::spawn(async move {
loop {
log::info!("Waiting for new stream (YAMux)");
match inner_listener.accept().await {
Ok(Some(mut tls_multiplexed)) => {
let target = target.clone();
log::info!("Accepted new stream (YAMux)");
tokio::spawn(async move {
let target = match TcpStream::connect(
target,
)
.await
{
Ok(s) => s,
Err(e) => {
log::error!(
"Failed to connect to target: {}",
e
);
match tls_multiplexed
.close()
.await
{
Ok(_) => {}
Err(e) => {
log::error!(
"Failed to close multiplexed stream: {}",
e
);
}
}
return;
}
};
if let Err(e) =
copy_bidirectional_compressed(
opts.compression,
target.compat(),
tls_multiplexed,
)
.await
{
log::error!(
"Failed to copy data: {}",
e
);
}
});
}
Ok(None) => {
log::info!("No more streams to accept");
break;
}
Err(e) => {
log::error!(
"Failed to accept stream: {}",
e
);
break;
}
}
}
futures::future::poll_fn(|cx| {
inner_listener.poll_close(cx)
})
.await
.ok();
});
}
None => match tokio::net::TcpStream::connect(target).await {
Ok(target) => {
if let Err(e) = copy_bidirectional_compressed(
opts.compression,
target.compat(),
tls.compat(),
)
.await
{
log::error!("Failed to copy data: {}", e);
}
}
Err(e) => {
eprintln!("Failed to connect to target: {}", e);
tls.shutdown()
.await
.expect("Failed to shutdown TLS stream");
}
},
}
}
Err(e) => {
eprintln!("Failed to connect to redis: {}", e);
tls.shutdown().await.expect("Failed to shutdown TLS stream");
None => {
log::warn!("Accepted connection for {:?}, but SNI {} does not match any configured SNI", tls.get_ref().1.server_name(), sni);
let mut compat = tls.compat_write();
send_static_string(
opts.compression,
&mut compat,
format!("SNI {} does not match any configured SNI", sni)
.as_str(),
)
.await
.expect("Failed to send static string");
compat.close().await.expect("Failed to shutdown TLS stream");
}
},
_ => {
log::error!("No SNI provided");
let mut compat = tls.compat_write();
send_static_string(opts.compression, &mut compat, "No SNI provided")
.await
.expect("Failed to send static string");
compat.close().await.expect("Failed to shutdown TLS stream");
}
}
Some(sni) if sni == *postgres_sni => {
log::info!(
"Accepted Postgres connection for {:?}",
tls.get_ref().1.server_name()
);
match tokio::net::TcpStream::connect(&postgres_target).await {
Ok(postgres) => {
if let Err(e) = copy_bidirectional_compressed(postgres, tls).await {
eprintln!("Failed to copy data: {}", e);
}
}
Err(e) => {
eprintln!("Failed to connect to postgres: {}", e);
tls.shutdown().await.expect("Failed to shutdown TLS stream");
}
}
}
Some(sni) => {
log::warn!("Accepted connection for {:?}, but SNI {} does not match any configured SNI", tls.get_ref().1.server_name(), sni);
send_static_string(
&mut tls,
format!("SNI {} does not match any configured SNI", sni).as_str(),
)
.await
.expect("Failed to send static string");
tls.shutdown().await.expect("Failed to shutdown TLS stream");
}
_ => {
send_static_string(&mut tls, "No SNI provided")
.await
.expect("Failed to send static string");
eprintln!("No SNI provided");
tls.shutdown().await.expect("Failed to shutdown TLS stream");
}
},
}
Err(e) => {
eprintln!("Failed to accept connection: {}", e);
}
@ -279,7 +492,7 @@ pub async fn reverse_proxy(opts: ReverseProxyCommand) -> Result<(), Box<dyn std:
}
}
pub async fn forward_proxy(opts: ForwardProxyCommand) -> Result<(), Box<dyn std::error::Error>> {
pub async fn forward_proxy(opts: ForwardProxyCommand) -> anyhow::Result<()> {
let (_, ca_pem) = x509_parser::pem::parse_x509_pem(&std::fs::read(&opts.ca)?)?;
let (_, ca_cert) = x509_parser::parse_x509_certificate(&ca_pem.contents)?;
let mut cert_store = RootCertStore::empty();
@ -329,38 +542,199 @@ pub async fn forward_proxy(opts: ForwardProxyCommand) -> Result<(), Box<dyn std:
log::info!("Listening on: {}", opts.listen);
let sni = ServerName::try_from(match opts.sni {
Some(ref s) => s.as_str(),
None => opts.target.as_str(),
Some(ref s) => s.to_string(),
None => opts.target.to_string(),
})
.expect("Failed to parse SNI");
loop {
let (pt_stream, _) = match listener.accept().await {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to accept connection: {}", e);
continue;
}
};
let connector = connector.clone();
let tls_stream = match TcpStream::connect(&opts.target).await {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to connect to target: {}", e);
continue;
}
};
let sni = sni.to_owned();
tokio::spawn(async move {
match connector.connect(sni, tls_stream).await {
Ok(tls) => {
if let Err(e) = copy_bidirectional_compressed(pt_stream, tls).await {
eprintln!("Failed to copy data: {}", e);
match opts.transport {
Transport::Plain => loop {
let (local_pt, _) = match listener.accept().await {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to accept connection: {}", e);
continue;
}
};
let target = opts.target.clone();
let connector = connector.clone();
let sni = sni.clone();
tokio::spawn(async move {
let addrs = match target.to_socket_addrs() {
Ok(a) => a,
Err(e) => {
eprintln!("Failed to resolve target address: {}", e);
return;
}
};
let mut last_err = None;
for addr in addrs {
log::info!("Trying to connect to: {}", addr);
let socket = match match addr {
std::net::SocketAddr::V4(_) => TcpSocket::new_v4(),
std::net::SocketAddr::V6(_) => TcpSocket::new_v6(),
} {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to create socket: {}", e);
return;
}
};
match socket.set_keepalive(true) {
Ok(_) => {}
Err(e) => {
log::warn!("Failed to set keepalive: {}", e);
}
}
match socket.connect(addr).await {
Ok(target_pt) => {
let sni = sni.to_owned();
match connector.connect(sni, target_pt).await {
Ok(tls) => {
if let Err(e) = copy_bidirectional_compressed(
opts.compression,
local_pt.compat(),
tls.compat(),
)
.await
{
eprintln!("Failed to copy data: {}", e);
}
}
Err(e) => {
eprintln!("Failed to connect to target: {}", e);
}
}
last_err = None;
break;
}
Err(e) => {
last_err = Some(e);
}
}
}
Err(e) => {
eprintln!("Failed to connect to target: {}", e);
if let Some(e) = last_err {
log::error!("None of the target addresses worked: {}", e);
}
});
},
#[cfg(feature = "yamux")]
Transport::YAMux => {
let mut retries = 5;
loop {
if retries == 0 {
eprintln!("Too many retries, giving up");
break;
}
let conn = TcpStream::connect(&opts.target).await?;
log::info!("Connected to target");
let tls_stream = match connector.connect(sni.clone(), conn).await {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to connect to target: {}", e);
retries -= 1;
continue;
}
};
let mut inner_listener = YamuxStreamManager::new()
.wrap_c2s_connection(tls_stream.compat())
.await;
loop {
log::info!("Waiting for new TCP connection (YAMux)");
let new_conn = tokio::select! {
_ = futures::future::poll_fn(|cx| inner_listener.poll_next_inbound(cx)) => {
None
}
conn = listener.accept() => {
match conn {
Ok((pt_stream, _)) => {
log::info!("Starting new C2S connection (YAMux)");
match inner_listener.start_c2s().await {
Ok(s) => Some((pt_stream, s)),
Err(e) => {
log::error!("Failed to start C2S connection: {}", e);
continue;
}
}
}
Err(e) => {
log::error!("Failed to accept connection: {}", e);
continue;
}
}
}
};
if let Some((pt_stream, tls_multiplexed)) = new_conn {
log::info!("Obtained new connection (YAMux)");
tokio::spawn(async move {
if let Err(e) = copy_bidirectional_compressed(
opts.compression,
pt_stream.compat(),
tls_multiplexed,
)
.await
{
eprintln!("Failed to copy data: {}", e);
}
});
} else {
log::warn!("No more streams to accept, next connection will be retried");
break;
}
}
futures::future::poll_fn(|cx| inner_listener.poll_close(cx))
.await
.ok();
}
});
}
anyhow::bail!("Too many retries");
}
};
}
#[cfg(test)]
mod tests {
use super::*;
async fn test_compression_method(algo: Compression) {
let (r, w) = tokio::io::duplex(1024);
let data = "Hello, world!".repeat(1024);
let mut w = compressor_to(algo, w.compat());
let mut r = decompressor_from(algo, BufReader::new(r.compat()));
tokio::join!(
async {
w.write_all(data.as_bytes()).await.unwrap();
w.close().await.unwrap();
},
async {
let mut output = Vec::new();
r.read_to_end(&mut output).await.unwrap();
assert_eq!(output.len(), data.len());
assert_eq!(output, data.as_bytes());
}
);
}
macro_rules! test_compression {
($name:ident, $algo:expr) => {
#[tokio::test]
async fn $name() {
test_compression_method($algo).await;
}
};
}
test_compression!(test_compression_none, Compression::None);
#[cfg(feature = "brotli")]
test_compression!(test_compression_brotli, Compression::Brotli);
#[cfg(feature = "zstd")]
test_compression!(test_compression_zstd, Compression::Zstd);
}

View file

@ -141,7 +141,7 @@ pub async fn setup_postgres_pub(
opts: SetupPublicationCommand,
) -> Result<(), PostgresSetupError> {
let mut postgres = PgConnection::connect(
&&connection_string
&connection_string
.map(|s| s.to_string())
.or_else(postgres_connection_string_from_env)
.ok_or(PostgresSetupError::MissingConnection)?,
@ -196,7 +196,7 @@ pub async fn drop_postgres_pub(
opts: DropPublicationCommand,
) -> Result<(), PostgresSetupError> {
let mut postgres = PgConnection::connect(
&&connection_string
&connection_string
.map(|s| s.to_string())
.or_else(postgres_connection_string_from_env)
.ok_or(PostgresSetupError::MissingConnection)?,
@ -220,7 +220,7 @@ pub async fn add_table_to_postgres_pub(
opts: AddTableCommand,
) -> Result<(), PostgresSetupError> {
let mut postgres = PgConnection::connect(
&&connection_string
&connection_string
.map(|s| s.to_string())
.or_else(postgres_connection_string_from_env)
.ok_or(PostgresSetupError::MissingConnection)?,
@ -247,7 +247,7 @@ pub async fn drop_table_from_postgres_pub(
opts: DropTableCommand,
) -> Result<(), PostgresSetupError> {
let mut postgres = PgConnection::connect(
&&connection_string
&connection_string
.map(|s| s.to_string())
.or_else(postgres_connection_string_from_env)
.ok_or(PostgresSetupError::MissingConnection)?,
@ -274,7 +274,7 @@ pub async fn setup_postgres_sub(
opts: SetupSubscriptionCommand,
) -> Result<(), PostgresSetupError> {
let mut postgres = PgConnection::connect(
&&connection_string
&connection_string
.map(|s| s.to_string())
.or_else(postgres_connection_string_from_env)
.ok_or(PostgresSetupError::MissingConnection)?,
@ -324,7 +324,7 @@ pub async fn drop_postgres_sub(
opts: DropSubscriptionCommand,
) -> Result<(), PostgresSetupError> {
let mut postgres = PgConnection::connect(
&&connection_string
&connection_string
.map(|s| s.to_string())
.or_else(postgres_connection_string_from_env)
.ok_or(PostgresSetupError::MissingConnection)?,

View file

@ -3,7 +3,9 @@ use std::path::Path;
use clap::Parser;
use serde::Deserialize;
use super::network::ReverseProxyCommand;
use crate::ops::network::{ForwardProxyCommand, ReverseProxySpec};
use super::network::{Compression, ReverseProxyCommand, Transport};
const DEF_CONFIG_FILE: &str = "/etc/replikey.toml";
const CA_CERT: &str = "ca.pem";
@ -47,23 +49,40 @@ pub struct ConnectionConfig {
#[derive(Debug, Deserialize)]
pub struct MasterConfig {
listen: String,
redis_sni: String,
redis_target: String,
postgres_sni: String,
postgres_target: String,
transport: Option<Transport>,
compression: Option<Compression>,
redis: MasterServiceSpec,
postgres: MasterServiceSpec,
workdir: Option<String>,
crl: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct SlaveConfig {
listen: String,
redis_sni: String,
postgres_sni: String,
target: String,
transport: Option<Transport>,
compression: Option<Compression>,
redis: SlaveServiceSpec,
postgres: SlaveServiceSpec,
workdir: Option<String>,
crl: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct MasterServiceSpec {
pub target: String,
pub sni: String,
pub multiplex: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct SlaveServiceSpec {
pub target: String,
pub sni: String,
pub listen: String,
pub multiplex: Option<bool>,
}
pub fn service_replicate_master(config: String) {
let config = std::fs::read_to_string(config).unwrap();
let config: Config = toml::from_str(&config).expect("Failed to parse config");
@ -80,23 +99,38 @@ pub fn service_replicate_master(config: String) {
let master_conf = config.connection.master.as_ref().unwrap();
let cmd = ReverseProxyCommand {
listen: master_conf.listen.clone(),
redis_sni: master_conf.redis_sni.clone(),
redis_target: master_conf.redis_target.clone(),
postgres_sni: master_conf.postgres_sni.clone(),
postgres_target: master_conf.postgres_target.clone(),
transport: master_conf.transport.unwrap_or_default(),
compression: master_conf.compression.unwrap_or_default(),
target: vec![
ReverseProxySpec {
sni: master_conf.redis.sni.clone(),
target: master_conf.redis.target.clone(),
},
ReverseProxySpec {
sni: master_conf.postgres.sni.clone(),
target: master_conf.postgres.target.clone(),
},
],
cert: Path::new(SERVER_CERT).to_string_lossy().to_string(),
key: Path::new(SERVER_KEY).to_string_lossy().to_string(),
ca: Path::new(CA_CERT).to_string_lossy().to_string(),
crl: config.connection.master.as_ref().unwrap().crl.clone(),
};
tokio::runtime::Runtime::new()
let res = tokio::runtime::Runtime::new()
.unwrap()
.block_on(crate::ops::network::reverse_proxy(cmd))
.unwrap();
.block_on(crate::ops::network::reverse_proxy(cmd));
println!("Replication master started");
match res {
Ok(_) => {
log::error!("Unexpected termination of reverse proxy");
std::process::exit(1);
}
Err(e) => {
log::error!("Failed to start reverse proxy: {}", e);
std::process::exit(1);
}
}
}
pub fn service_replicate_slave(config: String) {
@ -113,23 +147,59 @@ pub fn service_replicate_slave(config: String) {
}
let slave_conf = config.connection.slave.as_ref().unwrap();
let cmd = ReverseProxyCommand {
listen: slave_conf.listen.clone(),
redis_sni: slave_conf.redis_sni.clone(),
redis_target: slave_conf.redis_sni.clone(),
postgres_sni: slave_conf.postgres_sni.clone(),
postgres_target: slave_conf.postgres_sni.clone(),
let cmd_redis = ForwardProxyCommand {
listen: slave_conf.redis.listen.clone(),
transport: slave_conf.transport.unwrap_or_default(),
compression: slave_conf.compression.unwrap_or_default(),
sni: Some(slave_conf.redis.sni.clone()),
target: slave_conf.target.clone(),
cert: Path::new(CLIENT_CERT).to_string_lossy().to_string(),
key: Path::new(CLIENT_KEY).to_string_lossy().to_string(),
ca: Path::new(CA_CERT).to_string_lossy().to_string(),
crl: config.connection.slave.as_ref().unwrap().crl.clone(),
};
let cmd_postgres = ForwardProxyCommand {
listen: slave_conf.postgres.listen.clone(),
transport: slave_conf.transport.unwrap_or_default(),
compression: slave_conf.compression.unwrap_or_default(),
sni: Some(slave_conf.postgres.sni.clone()),
target: slave_conf.target.clone(),
cert: Path::new(CLIENT_CERT).to_string_lossy().to_string(),
key: Path::new(CLIENT_KEY).to_string_lossy().to_string(),
ca: Path::new(CA_CERT).to_string_lossy().to_string(),
crl: config.connection.slave.as_ref().unwrap().crl.clone(),
};
tokio::runtime::Runtime::new()
.unwrap()
.block_on(crate::ops::network::reverse_proxy(cmd))
.unwrap();
enum ProxyId {
Redis,
Postgres,
}
println!("Replication slave started");
let res = tokio::runtime::Runtime::new()
.unwrap()
.block_on(async move {
tokio::select! {
r = crate::ops::network::forward_proxy(cmd_redis) => {
(ProxyId::Redis, r)
},
r = crate::ops::network::forward_proxy(cmd_postgres) => {
(ProxyId::Postgres, r)
},
}
});
match res {
(ProxyId::Redis, Ok(_)) => {
log::error!("Unexpected termination of redis proxy");
}
(ProxyId::Postgres, Ok(_)) => {
log::error!("Unexpected termination of postgres proxy");
}
(ProxyId::Redis, Err(e)) => {
log::error!("Failed to start redis proxy: {}", e);
}
(ProxyId::Postgres, Err(e)) => {
log::error!("Failed to start postgres proxy: {}", e);
}
}
}

203
src/transport/mod.rs Normal file
View file

@ -0,0 +1,203 @@
use std::fmt::Debug;
use futures::{
future::BoxFuture,
io::{AsyncRead, AsyncWrite},
FutureExt,
};
#[cfg(feature = "yamux")]
pub mod yamux;
pub trait DynStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> DynStream for T {}
pub trait StreamManager<'c, C: 'c>: Send + Sync {
type S2CConn: MultiplexedS2C<C>;
type C2SConn: MultiplexedC2S<C>;
fn wrap_c2s_connection(&'c self, stream: C) -> BoxFuture<'c, Self::C2SConn>;
fn wrap_s2c_connection(&'c self, stream: C) -> BoxFuture<'c, Self::S2CConn>;
fn shutdown(&self) -> BoxFuture<'c, ()>;
}
// hangs rustc for some reason..
/*
pub fn manager_to_dyn<'c, C: 'c>(inner: impl StreamManager<'c, C> + 'c) -> DynStreamManager<'c, C> {
DynStreamManager {
inner: Box::new(inner),
}
}
pub struct DynStreamManager<'c, C: 'c> {
inner: Box<
dyn StreamManager<
'c,
C,
S2CConn = Box<
dyn MultiplexedS2C<C, Error = anyhow::Error, Stream = Box<dyn DynStream + 'c>>
+ 'c,
>,
C2SConn = Box<
dyn MultiplexedC2S<C, Error = anyhow::Error, Stream = Box<dyn DynStream + 'c>>
+ 'c,
>,
> + 'c,
>,
}
impl<'c, C> StreamManager<'c, C> for DynStreamManager<'c, C>
where
C: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'c,
<Self::C2SConn as MultiplexedC2S<C>>::Error: Error + 'static,
<Self::S2CConn as MultiplexedS2C<C>>::Error: Error + 'static,
{
type S2CConn = DynMultiplexedS2CConnection<'c, C>;
type C2SConn = DynMultiplexedC2SConnection<'c, C>;
fn wrap_c2s_connection(&'c self, stream: C) -> BoxFuture<'c, Self::C2SConn> {
self.inner
.wrap_c2s_connection(stream)
.map(|x| DynMultiplexedC2SConnection {
inner: Box::new(x.map(|x| {
x.map(|s| s.map(|s| Box::new(s) as _))
.map_err(anyhow::Error::new)
}) as _),
})
.boxed()
}
fn wrap_s2c_connection(&'c self, stream: C) -> BoxFuture<'c, Self::S2CConn> {
self.inner
.wrap_s2c_connection(stream)
.map(|x| DynMultiplexedS2CConnection {
inner: Box::new(x.map(|x| {
x.map(|s| s.map(|s| Box::new(s) as _))
.map_err(anyhow::Error::new) as _
})),
})
.boxed()
}
fn shutdown(&self) -> BoxFuture<'c, ()> {
async move { self.inner.shutdown().await }.boxed()
}
}
*/
pub trait MultiplexedS2C<C>: Send + Sync {
type Error: Debug + Send + Sync + Sized;
type Stream: DynStream;
fn accept(&mut self) -> BoxFuture<'_, Result<Option<Self::Stream>, Self::Error>>;
fn map<F, S, E>(self, map: F) -> S2CMap<C, Self, F>
where
F: Fn(Result<Option<Self::Stream>, Self::Error>) -> Result<Option<S>, E> + Send + Sync,
S: DynStream,
Self: Sized,
{
S2CMap {
inner: self,
map,
_phantom: std::marker::PhantomData,
}
}
}
pub struct S2CMap<C, MC, F> {
inner: MC,
map: F,
_phantom: std::marker::PhantomData<C>,
}
impl<C: Send + Sync, E: Send + Sync + Debug, S: DynStream, MC, F: Send> MultiplexedS2C<C>
for S2CMap<C, MC, F>
where
MC: MultiplexedS2C<C>,
F: Fn(Result<Option<MC::Stream>, MC::Error>) -> Result<Option<S>, E> + Send + Sync,
{
type Error = E;
type Stream = S;
fn accept(&mut self) -> BoxFuture<'_, Result<Option<Self::Stream>, Self::Error>> {
let inner = &mut self.inner;
let mapf = &self.map;
inner.accept().map(mapf).boxed()
}
}
pub trait MultiplexedC2S<C>: Send + Sync {
type Error: Debug + Send + Sync + Sized;
type Stream: DynStream;
fn start_c2s(&mut self) -> BoxFuture<'_, Result<Self::Stream, Self::Error>>;
fn map<F, S, E>(self, map: F) -> C2SMap<C, Self, F>
where
F: Fn(Result<Self::Stream, Self::Error>) -> Result<S, E> + Send + Sync,
S: DynStream,
Self: Sized,
{
C2SMap {
inner: self,
map,
_phantom: std::marker::PhantomData,
}
}
}
pub struct C2SMap<C, MC, F> {
inner: MC,
map: F,
_phantom: std::marker::PhantomData<C>,
}
impl<C: Send + Sync, E: Send + Sync + Debug, S: DynStream, MC, F: Send> MultiplexedC2S<C>
for C2SMap<C, MC, F>
where
MC: MultiplexedC2S<C>,
F: Fn(Result<MC::Stream, MC::Error>) -> Result<S, E> + Send + Sync,
{
type Error = E;
type Stream = S;
fn start_c2s(&mut self) -> BoxFuture<'_, Result<Self::Stream, Self::Error>> {
let inner = &mut self.inner;
let map = &self.map;
inner.start_c2s().map(map).boxed()
}
}
pub struct DynMultiplexedS2CConnection<'c, C: 'c> {
inner: Box<dyn MultiplexedS2C<C, Error = anyhow::Error, Stream = Box<dyn DynStream + 'c>> + 'c>,
}
impl<'c, C> MultiplexedS2C<C> for DynMultiplexedS2CConnection<'c, C> {
type Error = anyhow::Error;
type Stream = Box<dyn DynStream + 'c>;
fn accept(&mut self) -> BoxFuture<'_, Result<Option<Self::Stream>, Self::Error>> {
async move {
self.inner
.accept()
.await
.map(|x| x.map(|x| Box::new(x) as _))
}
.boxed()
}
}
pub struct DynMultiplexedC2SConnection<'c, C: 'c> {
inner: Box<dyn MultiplexedC2S<C, Error = anyhow::Error, Stream = Box<dyn DynStream + 'c>> + 'c>,
}
impl<'c, C> MultiplexedC2S<C> for DynMultiplexedC2SConnection<'c, C> {
type Error = anyhow::Error;
type Stream = Box<dyn DynStream + 'c>;
fn start_c2s(&mut self) -> BoxFuture<'_, Result<Self::Stream, Self::Error>> {
async move { self.inner.start_c2s().await.map(|x| Box::new(x) as _) }.boxed()
}
}

80
src/transport/yamux.rs Normal file
View file

@ -0,0 +1,80 @@
use futures::{
future::BoxFuture,
io::{AsyncRead, AsyncWrite},
FutureExt,
};
use yamux::Config;
pub(crate) fn yamux_config() -> yamux::Config {
let mut conf = Config::default();
// we are in mTLS so be generous but still prevent accidental DoS
conf.set_read_after_close(true)
.set_max_num_streams(512)
.set_max_connection_receive_window(Some((256 << 10) * 512));
conf
}
pub struct YamuxStreamManager {
config: yamux::Config,
}
impl Default for YamuxStreamManager {
fn default() -> Self {
Self::new()
}
}
impl YamuxStreamManager {
pub fn new() -> Self {
Self {
config: yamux_config(),
}
}
}
impl<'c, C: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'c> super::StreamManager<'c, C>
for YamuxStreamManager
{
type S2CConn = yamux::Connection<C>;
type C2SConn = yamux::Connection<C>;
fn wrap_c2s_connection(&self, stream: C) -> BoxFuture<'c, Self::C2SConn> {
let config = self.config.clone();
async move { yamux::Connection::new(stream, config, yamux::Mode::Client) }.boxed()
}
fn wrap_s2c_connection(&self, stream: C) -> BoxFuture<'c, Self::S2CConn> {
let config = self.config.clone();
async move { yamux::Connection::new(stream, config, yamux::Mode::Server) }.boxed()
}
fn shutdown(&self) -> BoxFuture<'c, ()> {
async {}.boxed()
}
}
impl<C: AsyncRead + AsyncWrite + Unpin + Send + Sync> super::MultiplexedS2C<C>
for yamux::Connection<C>
{
type Error = yamux::ConnectionError;
type Stream = yamux::Stream;
fn accept(&mut self) -> BoxFuture<'_, Result<Option<Self::Stream>, Self::Error>> {
futures::future::poll_fn(|cx| self.poll_next_inbound(cx))
.map(|x| x.transpose())
.boxed()
}
}
impl<C: AsyncRead + AsyncWrite + Unpin + Send + Sync> super::MultiplexedC2S<C>
for yamux::Connection<C>
{
type Error = yamux::ConnectionError;
type Stream = yamux::Stream;
fn start_c2s(&mut self) -> BoxFuture<'_, Result<Self::Stream, Self::Error>> {
futures::future::poll_fn(|cx| self.poll_new_outbound(cx)).boxed()
}
}

968
tests/mtls_integration.rs Normal file
View file

@ -0,0 +1,968 @@
#![allow(clippy::all)]
use std::{
path::PathBuf,
sync::{atomic::AtomicU32, Once},
};
use clap::Parser;
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use replikey::{
cert::UsageType,
ops::{
cert::{CertCommand, CertSubCommand},
network::{NetworkCommand, NetworkSubCommand},
},
};
use rustls::crypto::{aws_lc_rs, CryptoProvider};
use time::OffsetDateTime;
use tokio::{
io::{AsyncReadExt as _, AsyncWriteExt as _},
task::JoinSet,
};
use tokio_util::compat::TokioAsyncReadCompatExt;
static NEXT_PORT: AtomicU32 = AtomicU32::new(12311);
static TEST_SETUP: Once = Once::new();
fn setup_once() {
TEST_SETUP.call_once(|| {
if std::env::var("RUST_LOG").is_err() {
std::env::set_var("RUST_LOG", "info");
}
env_logger::init();
CryptoProvider::install_default(aws_lc_rs::default_provider())
.expect("Failed to install crypto provider");
});
}
fn next_port() -> u32 {
NEXT_PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
}
async fn test_stream_sequence<
const WRITE_SIZE: usize,
L: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
R: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
>(
left: L,
right: R,
) {
let (mut l_rx, mut l_tx) = left.split();
let (mut r_rx, mut r_tx) = right.split();
let l_write = async move {
let mut buf = [0u8; WRITE_SIZE];
for i in 0..100 {
buf.iter_mut()
.enumerate()
.for_each(|(j, x)| *x = ((i + j) ^ 123) as u8);
l_tx.write_all(&buf).await?;
}
l_tx.close().await?;
Ok::<_, futures::io::Error>(())
};
let r_write = async move {
let mut buf = [0u8; WRITE_SIZE];
for i in 0..100 {
buf.iter_mut()
.enumerate()
.for_each(|(j, x)| *x = ((i + j) ^ 321) as u8);
r_tx.write_all(&buf).await?;
}
r_tx.close().await?;
Ok(())
};
let l_read = async move {
let mut buf = [0u8; WRITE_SIZE];
for i in 0..100 {
l_rx.read_exact(&mut buf).await?;
for j in 0..WRITE_SIZE {
assert_eq!(buf[j], ((i + j) ^ 321) as u8);
}
}
assert_eq!(l_rx.read(&mut buf).await?, 0, "expected eof");
Ok(())
};
let r_read = async move {
let mut buf = [0u8; WRITE_SIZE];
for i in 0..100 {
r_rx.read_exact(&mut buf).await?;
for j in 0..WRITE_SIZE {
assert_eq!(buf[j], ((i + j) ^ 123) as u8);
}
}
assert_eq!(r_rx.read(&mut buf).await?, 0, "expected eof");
Ok(())
};
tokio::try_join!(l_write, r_write, l_read, r_read).expect("stream sequence failed");
}
#[tokio::test]
async fn in_memory_stream_works() {
let (left, right) = tokio::io::duplex(1024);
test_stream_sequence::<1024, _, _>(left.compat(), right.compat()).await;
let (left, right) = tokio::io::duplex(1024);
test_stream_sequence::<1, _, _>(left.compat(), right.compat()).await;
let (left, right) = tokio::io::duplex(1024);
test_stream_sequence::<3543, _, _>(left.compat(), right.compat()).await;
}
fn test_ca(dir: &PathBuf) -> (PathBuf, PathBuf) {
let id = OffsetDateTime::now_utc().unix_timestamp_nanos();
let path = dir.join(format!("ca-{}/", id));
std::fs::create_dir_all(&path).expect("failed to create ca dir");
let cmd = [
"replikey".to_string(),
"create-ca".to_string(),
"--valid-days".to_string(),
"365".to_string(),
"--dn-common-name".to_string(),
"replikey-test-ca".to_string(),
"--output".to_string(),
path.to_string_lossy().to_string(),
];
let parsed = CertCommand::parse_from(&cmd);
if let CertSubCommand::CreateCa(opts) = parsed.subcmd {
replikey::ops::cert::create_ca(opts, false);
} else {
panic!("failed to parse create-ca");
}
let pem = path.join("ca.pem");
let key = path.join("ca.key");
assert!(pem.exists());
assert!(key.exists());
(pem, key)
}
struct PeerCert {
key: PathBuf,
csr: PathBuf,
self_signed: PathBuf,
signed_cert: Option<PathBuf>,
}
impl PeerCert {
fn signed_by(&self, usage: UsageType, ca_key: &PathBuf) -> PeerCert {
let id = OffsetDateTime::now_utc().unix_timestamp_nanos();
let path = ca_key.parent().unwrap().join(format!("signed-{}/", id));
std::fs::create_dir_all(&path).expect("failed to create signed dir");
let cmd = [
"replikey".to_string(),
"sign-server-csr".to_string(),
"--valid-days".to_string(),
"365".to_string(),
"--ca-dir".to_string(),
ca_key.parent().unwrap().to_string_lossy().to_string(),
"--input-csr".to_string(),
self.csr.to_string_lossy().to_string(),
"-d".to_string(),
"target1.local".to_string(),
"-d".to_string(),
"target2.local".to_string(),
"-d".to_string(),
"*.targets.local".to_string(),
"--output".to_string(),
format!("{}/server-signed.pem", path.to_string_lossy()),
];
let parsed = CertCommand::parse_from(&cmd);
if let CertSubCommand::SignServerCSR(opts) = parsed.subcmd {
replikey::ops::cert::sign_csr(opts, usage, false);
} else {
panic!("failed to parse sign-server-csr");
}
PeerCert {
key: self.key.clone(),
csr: self.csr.clone(),
self_signed: self.self_signed.clone(),
signed_cert: Some(path.join("server-signed.pem")),
}
}
}
fn test_server_cert(dir: &PathBuf) -> PeerCert {
let id = OffsetDateTime::now_utc().unix_timestamp_nanos();
let path = dir.join(format!("server-{}/", id));
std::fs::create_dir_all(&path).expect("failed to create server dir");
let cmd = [
"replikey".to_string(),
"create-server".to_string(),
"--valid-days".to_string(),
"365".to_string(),
"--dn-common-name".to_string(),
"replikey-test-server".to_string(),
"-d".to_string(),
"target1.local".to_string(),
"-d".to_string(),
"target2.local".to_string(),
"-d".to_string(),
"*.targets.local".to_string(),
"--output".to_string(),
path.to_string_lossy().to_string(),
];
let parsed = CertCommand::parse_from(&cmd);
if let CertSubCommand::CreateServer(opts) = parsed.subcmd {
replikey::ops::cert::create_server(opts);
} else {
panic!("failed to parse create-server");
}
let pem = path.join("server.pem");
let key = path.join("server.key");
let csr = path.join("server.csr");
assert!(pem.exists());
assert!(key.exists());
PeerCert {
key,
csr,
self_signed: pem,
signed_cert: None,
}
}
fn test_client_cert(dir: &PathBuf) -> PeerCert {
let id = OffsetDateTime::now_utc().unix_timestamp_nanos();
let path = dir.join(format!("client-{}/", id));
std::fs::create_dir_all(&path).expect("failed to create client dir");
let cmd = [
"replikey".to_string(),
"create-client".to_string(),
"--valid-days".to_string(),
"365".to_string(),
"--dn-common-name".to_string(),
"replikey-test-client".to_string(),
"--output".to_string(),
path.to_string_lossy().to_string(),
];
let parsed = CertCommand::parse_from(&cmd);
if let CertSubCommand::CreateClient(opts) = parsed.subcmd {
replikey::ops::cert::create_client(opts);
} else {
panic!("failed to parse create-client");
}
let pem = path.join("client.pem");
let key = path.join("client.key");
let csr = path.join("client.csr");
assert!(pem.exists());
assert!(key.exists());
PeerCert {
key,
csr,
self_signed: pem,
signed_cert: None,
}
}
async fn start_test_target_1() -> u32 {
let port = next_port();
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("failed to bind listener");
tokio::spawn(async move {
loop {
let (mut stream, _) = listener.accept().await.expect("failed to accept");
log::info!("accepted connection on target1");
tokio::spawn(async move {
let (mut rx, mut tx) = stream.split();
let n = tokio::io::copy(&mut rx, &mut tx)
.await
.expect("failed to copy");
log::info!("copied {} bytes", n);
tx.shutdown().await.expect("failed to close");
});
}
});
port
}
async fn start_test_target_2() -> u32 {
let port = next_port();
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("failed to bind listener");
tokio::spawn(async move {
loop {
let (mut stream, _) = listener.accept().await.expect("failed to accept");
log::info!("accepted connection on target2");
tokio::spawn(async move {
let (mut rx, mut tx) = stream.split();
let mut buf = [0u8; 1024];
let mut total = 0;
loop {
let n = rx.read(&mut buf).await.expect("failed to read");
total += n;
if n == 0 {
break;
}
buf[..n].iter_mut().for_each(|x| *x = !*x);
tx.write_all(&buf[..n]).await.expect("failed to write");
}
log::info!("copied {} bytes", total);
tx.shutdown().await.expect("failed to close");
});
}
});
port
}
async fn start_reverse_proxy(
target1: u32,
target2: u32,
cert: &PathBuf,
key: &PathBuf,
ca_cert: &PathBuf,
transport: &str,
) -> u32 {
let port = next_port();
let cmd = [
"replikey".to_string(),
"reverse-proxy".to_string(),
"--listen".to_string(),
format!("127.0.0.1:{}", port),
"--target".to_string(),
format!("target1.local/127.0.0.1:{}", target1),
"--target".to_string(),
format!("target2.local/127.0.0.1:{}", target2),
"--target".to_string(),
format!("target1.targets.local/127.0.0.1:{}", target1),
"--target".to_string(),
format!("target2.targets.local/127.0.0.1:{}", target2),
"--cert".to_string(),
cert.to_string_lossy().to_string(),
"--key".to_string(),
key.to_string_lossy().to_string(),
"--ca".to_string(),
ca_cert.to_string_lossy().to_string(),
"--transport".to_string(),
transport.to_string(),
];
let parsed = NetworkCommand::parse_from(&cmd);
if let NetworkSubCommand::ReverseProxy(opts) = parsed.subcmd {
tokio::spawn(async move {
replikey::ops::network::reverse_proxy(opts)
.await
.expect("failed to run reverse proxy");
});
port
} else {
panic!("failed to parse reverse-proxy");
}
}
async fn start_forward_proxy(
addr: &str,
sni: &str,
cert: &PathBuf,
key: &PathBuf,
ca_cert: &PathBuf,
transport: &str,
) -> u32 {
let port = next_port();
let cmd = [
"replikey".to_string(),
"forward-proxy".to_string(),
"--listen".to_string(),
format!("127.0.0.1:{}", port),
"--target".to_string(),
addr.to_string(),
"--sni".to_string(),
sni.to_string(),
"--cert".to_string(),
cert.to_string_lossy().to_string(),
"--key".to_string(),
key.to_string_lossy().to_string(),
"--ca".to_string(),
ca_cert.to_string_lossy().to_string(),
"--transport".to_string(),
transport.to_string(),
];
let parsed = NetworkCommand::parse_from(&cmd);
if let NetworkSubCommand::ForwardProxy(opts) = parsed.subcmd {
tokio::spawn(async move {
replikey::ops::network::forward_proxy(opts)
.await
.expect("failed to run forward proxy");
});
port
} else {
panic!("failed to parse forward-proxy");
}
}
async fn wait_for_port(port: u32) {
loop {
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_secs(5)) => {
panic!("timeout waiting for port");
}
res = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) => {
if res.is_ok() {
break;
}
}
}
}
}
async fn should_not_get_anything(port: u32) {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.expect("failed to connect");
let res = stream.read(&mut [0u8; 1024]).await;
match res {
Ok(n) => {
assert_eq!(n, 0);
}
Err(e) => {
assert_eq!(e.kind(), std::io::ErrorKind::ConnectionReset);
}
}
}
async fn expect_target1_signature(port: u32) {
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.expect("failed to connect");
let (mut rx, mut tx) = tokio::io::split(stream);
for batch in 0..8 {
log::debug!("expect_target1_signature batch {}", batch);
let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes();
let expect = data.clone();
let mut write_size = (1..4096).cycle();
let (_, rx_data) = tokio::join!(
async {
while !data.is_empty() {
let n = write_size.next().unwrap().min(data.len());
tx.write_all(&data[..n]).await.expect("failed to write");
log::trace!("wrote {} bytes for target1", n);
data = data.split_off(n);
}
tx.flush().await.expect("failed to flush");
log::info!("wrote {} bytes for target1", expect.len());
},
async {
let mut rxed = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = rx.read(&mut buf).await.expect("failed to read");
log::trace!("read {} bytes for target1", n);
if n == 0 {
break;
}
rxed.extend_from_slice(&buf[..n]);
if rxed.len() >= expect.len() {
break;
}
}
rxed
}
);
assert_eq!(rx_data, expect);
}
tx.shutdown().await.expect("failed to close");
}
async fn expect_target1_signature_buffered(port: u32) {
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.expect("failed to connect");
let (mut rx, mut tx) = tokio::io::split(stream);
let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes();
let expect = data.clone();
let mut write_size = (1..4096).cycle();
let (_, rx_data) = tokio::join!(
async {
while !data.is_empty() {
let n = write_size.next().unwrap().min(data.len());
tx.write_all(&data[..n]).await.expect("failed to write");
data = data.split_off(n);
}
tx.shutdown().await.expect("failed to close");
},
async {
let mut rxed = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = rx.read(&mut buf).await.expect("failed to read");
if n == 0 {
break;
}
rxed.extend_from_slice(&buf[..n]);
if rxed.len() >= expect.len() {
break;
}
}
rxed
}
);
assert_eq!(rx_data, expect);
}
async fn expect_target2_signature(port: u32) {
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.expect("failed to connect");
let (mut rx, mut tx) = tokio::io::split(stream);
for _batch in 0..8 {
let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes();
let expect = data.iter().map(|x| !x).collect::<Vec<_>>();
let mut write_size = (1..4096).cycle();
let (_, rx_data) = tokio::join!(
async {
while !data.is_empty() {
let n = write_size.next().unwrap().min(data.len());
tx.write_all(&data[..n]).await.expect("failed to write");
data = data.split_off(n);
}
tx.flush().await.expect("failed to flush");
},
async {
let mut rxed = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = rx.read(&mut buf).await.expect("failed to read");
if n == 0 {
break;
}
rxed.extend_from_slice(&buf[..n]);
if rxed.len() >= expect.len() {
break;
}
}
rxed
}
);
assert_eq!(rx_data, expect);
}
tx.shutdown().await.expect("failed to close");
}
async fn expect_target2_signature_buffered(port: u32) {
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.expect("failed to connect");
let (mut rx, mut tx) = tokio::io::split(stream);
let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes();
let expect = data.iter().map(|x| !x).collect::<Vec<_>>();
let mut write_size = (1..4096).cycle();
let (_, rx_data) = tokio::join!(
async {
while !data.is_empty() {
let n = write_size.next().unwrap().min(data.len());
tx.write_all(&data[..n]).await.expect("failed to write");
data = data.split_off(n);
}
tx.shutdown().await.expect("failed to close");
},
async {
let mut rxed = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = rx.read(&mut buf).await.expect("failed to read");
if n == 0 {
break;
}
rxed.extend_from_slice(&buf[..n]);
if rxed.len() >= expect.len() {
break;
}
}
rxed
}
);
assert_eq!(rx_data, expect);
}
#[tokio::test]
async fn test_mtls_integrated() {
setup_once();
let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca"));
let server_cert = test_server_cert(&PathBuf::from("target/test-server"));
let client_cert = test_client_cert(&PathBuf::from("target/test-client"));
let server_cert = server_cert.signed_by(UsageType::Server, &ca_key);
let client_cert = client_cert.signed_by(UsageType::Client, &ca_key);
let target1 = start_test_target_1().await;
let target2 = start_test_target_2().await;
log::info!("target1: {}, target2: {}", target1, target2);
wait_for_port(target1).await;
wait_for_port(target2).await;
expect_target1_signature(target1).await;
expect_target2_signature(target2).await;
let proxy_port_self_signed = start_reverse_proxy(
target1,
target2,
&server_cert.self_signed,
&server_cert.key,
&ca_cert,
"plain",
)
.await;
wait_for_port(proxy_port_self_signed).await;
log::info!("reverse_proxy_port_self_signed: {}", proxy_port_self_signed);
let proxy_port = start_reverse_proxy(
target1,
target2,
server_cert.signed_cert.as_ref().unwrap(),
&server_cert.key,
&ca_cert,
"plain",
)
.await;
wait_for_port(proxy_port).await;
log::info!("reverse_proxy_port: {}", proxy_port);
for server_signed in [false, true] {
let forward_proxy_target1_self_signed = start_forward_proxy(
&format!(
"127.0.0.1:{}",
if server_signed {
proxy_port
} else {
proxy_port_self_signed
}
),
"target1.local",
&client_cert.self_signed,
&client_cert.key,
&ca_cert,
"plain",
)
.await;
log::info!(
"forward_proxy_target1_self_signed: {}",
forward_proxy_target1_self_signed
);
let forward_proxy_target1 = start_forward_proxy(
&format!(
"127.0.0.1:{}",
if server_signed {
proxy_port
} else {
proxy_port_self_signed
}
),
"target1.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"plain",
)
.await;
log::info!("forward_proxy_target1: {}", forward_proxy_target1);
let forward_proxy_target2 = start_forward_proxy(
&format!(
"127.0.0.1:{}",
if server_signed {
proxy_port
} else {
proxy_port_self_signed
}
),
"target2.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"plain",
)
.await;
log::info!("forward_proxy_target2: {}", forward_proxy_target2);
let forward_proxy_target1_wildcard = start_forward_proxy(
&format!(
"127.0.0.1:{}",
if server_signed {
proxy_port
} else {
proxy_port_self_signed
}
),
"target1.targets.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"plain",
)
.await;
let forward_proxy_target2_wildcard: u32 = start_forward_proxy(
&format!(
"127.0.0.1:{}",
if server_signed {
proxy_port
} else {
proxy_port_self_signed
}
),
"target2.targets.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"plain",
)
.await;
log::info!(
"forward_proxy_target2_wildcard: {}",
forward_proxy_target2_wildcard
);
let forward_proxy_unknown_sni = start_forward_proxy(
&format!(
"127.0.0.1:{}",
if server_signed {
proxy_port
} else {
proxy_port_self_signed
}
),
"some-other.place",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"plain",
)
.await;
log::info!("forward_proxy_unknown_sni: {}", forward_proxy_unknown_sni);
wait_for_port(forward_proxy_target1).await;
wait_for_port(forward_proxy_target2).await;
wait_for_port(forward_proxy_target1_wildcard).await;
wait_for_port(forward_proxy_target2_wildcard).await;
wait_for_port(forward_proxy_unknown_sni).await;
log::info!("Test: self-signed cert");
should_not_get_anything(forward_proxy_target1_self_signed).await;
if server_signed {
let mut js = JoinSet::new();
for _ in 0..100 {
js.spawn(expect_target1_signature(forward_proxy_target1));
js.spawn(expect_target2_signature(forward_proxy_target2));
js.spawn(expect_target1_signature(forward_proxy_target1_wildcard));
js.spawn(expect_target2_signature(forward_proxy_target2_wildcard));
js.spawn(should_not_get_anything(forward_proxy_unknown_sni));
}
js.join_all().await;
} else {
let mut js = JoinSet::new();
for _ in 0..100 {
js.spawn(should_not_get_anything(forward_proxy_target1));
js.spawn(should_not_get_anything(forward_proxy_target2));
js.spawn(should_not_get_anything(forward_proxy_target1_wildcard));
js.spawn(should_not_get_anything(forward_proxy_target2_wildcard));
js.spawn(should_not_get_anything(forward_proxy_unknown_sni));
}
js.join_all().await;
}
}
}
#[tokio::test]
async fn test_mtls_role_reverse() {
let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca"));
let server_cert = test_server_cert(&PathBuf::from("target/test-server"));
let server_cert = server_cert.signed_by(UsageType::Server, &ca_key);
let client_cert = test_client_cert(&PathBuf::from("target/test-client"));
let client_cert = client_cert.signed_by(UsageType::Client, &ca_key);
let target1 = start_test_target_1().await;
let target2 = start_test_target_2().await;
wait_for_port(target1).await;
wait_for_port(target2).await;
let server_cert = server_cert.signed_by(UsageType::Server, &ca_key);
let client_cert = client_cert.signed_by(UsageType::Client, &ca_key);
let proxy_port = start_reverse_proxy(
target1,
target2,
server_cert.signed_cert.as_ref().unwrap(),
&server_cert.key,
&ca_cert,
"plain",
)
.await;
wait_for_port(proxy_port).await;
let forward_proxy_target1 = start_forward_proxy(
&format!("127.0.0.1:{}", proxy_port),
"target1.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"plain",
)
.await;
wait_for_port(forward_proxy_target1).await;
expect_target1_signature(forward_proxy_target1).await;
let proxy_port_reversed = start_reverse_proxy(
target1,
target2,
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"plain",
)
.await;
wait_for_port(proxy_port_reversed).await;
let forward_proxy_target1_reversed = start_forward_proxy(
&format!("127.0.0.1:{}", proxy_port_reversed),
"target1.local",
server_cert.signed_cert.as_ref().unwrap(),
&server_cert.key,
&ca_cert,
"plain",
)
.await;
wait_for_port(forward_proxy_target1_reversed).await;
should_not_get_anything(forward_proxy_target1_reversed).await;
}
#[tokio::test]
async fn test_mtls_yamux_transport() {
setup_once();
let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca"));
let server_cert = test_server_cert(&PathBuf::from("target/test-server"));
let server_cert = server_cert.signed_by(UsageType::Server, &ca_key);
let client_cert = test_client_cert(&PathBuf::from("target/test-client"));
let client_cert = client_cert.signed_by(UsageType::Client, &ca_key);
let target1 = start_test_target_1().await;
let target2 = start_test_target_2().await;
wait_for_port(target1).await;
wait_for_port(target2).await;
let proxy_port = start_reverse_proxy(
target1,
target2,
server_cert.signed_cert.as_ref().unwrap(),
&server_cert.key,
&ca_cert,
"yamux",
)
.await;
wait_for_port(proxy_port).await;
wait_for_port(proxy_port).await;
let forward_proxy_target1 = start_forward_proxy(
&format!("127.0.0.1:{}", proxy_port),
"target1.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"yamux",
)
.await;
let forward_proxy_target2 = start_forward_proxy(
&format!("127.0.0.1:{}", proxy_port),
"target2.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"yamux",
)
.await;
wait_for_port(forward_proxy_target1).await;
wait_for_port(forward_proxy_target2).await;
log::info!("forward_proxy_target1: {}", forward_proxy_target1);
log::info!("forward_proxy_target2: {}", forward_proxy_target2);
let mut js = JoinSet::new();
for _ in 0..128 {
js.spawn(expect_target1_signature_buffered(forward_proxy_target1));
js.spawn(expect_target2_signature_buffered(forward_proxy_target2));
}
js.join_all().await;
}