#![allow(clippy::all)] use std::{ path::PathBuf, sync::{atomic::AtomicU32, Once}, }; use clap::Parser; use replikey::{ cert::UsageType, ops::{ cert::{CertCommand, CertSubCommand}, network::{NetworkCommand, NetworkSubCommand}, }, }; use rustls::crypto::{aws_lc_rs, CryptoProvider}; use time::OffsetDateTime; use tokio::{ io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _}, task::JoinSet, }; static NEXT_PORT: AtomicU32 = AtomicU32::new(12311); static TEST_SETUP: Once = Once::new(); fn setup_once() { TEST_SETUP.call_once(|| { if std::env::var("RUST_LOG").is_err() { std::env::set_var("RUST_LOG", "info"); } env_logger::init(); CryptoProvider::install_default(aws_lc_rs::default_provider()) .expect("Failed to install crypto provider"); }); } fn next_port() -> u32 { NEXT_PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst) } async fn test_stream_sequence< const WRITE_SIZE: usize, L: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, R: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, >( left: L, right: R, ) { let (mut l_rx, mut l_tx) = tokio::io::split(left); let (mut r_rx, mut r_tx) = tokio::io::split(right); let l_write = async move { let mut buf = [0u8; WRITE_SIZE]; for i in 0..100 { buf.iter_mut() .enumerate() .for_each(|(j, x)| *x = ((i + j) ^ 123) as u8); l_tx.write_all(&buf).await?; } l_tx.shutdown().await?; Ok::<_, tokio::io::Error>(()) }; let r_write = async move { let mut buf = [0u8; WRITE_SIZE]; for i in 0..100 { buf.iter_mut() .enumerate() .for_each(|(j, x)| *x = ((i + j) ^ 321) as u8); r_tx.write_all(&buf).await?; } r_tx.shutdown().await?; Ok(()) }; let l_read = async move { let mut buf = [0u8; WRITE_SIZE]; for i in 0..100 { l_rx.read_exact(&mut buf).await?; for j in 0..WRITE_SIZE { assert_eq!(buf[j], ((i + j) ^ 321) as u8); } } assert_eq!(l_rx.read(&mut buf).await?, 0, "expected eof"); Ok(()) }; let r_read = async move { let mut buf = [0u8; WRITE_SIZE]; for i in 0..100 { r_rx.read_exact(&mut buf).await?; for j in 0..WRITE_SIZE { assert_eq!(buf[j], ((i + j) ^ 123) as u8); } } assert_eq!(r_rx.read(&mut buf).await?, 0, "expected eof"); Ok(()) }; tokio::try_join!(l_write, r_write, l_read, r_read).expect("stream sequence failed"); } #[tokio::test] async fn in_memory_stream_works() { let (left, right) = tokio::io::duplex(1024); test_stream_sequence::<1024, _, _>(left, right).await; let (left, right) = tokio::io::duplex(1024); test_stream_sequence::<1, _, _>(left, right).await; let (left, right) = tokio::io::duplex(1024); test_stream_sequence::<3543, _, _>(left, right).await; } fn test_ca(dir: &PathBuf) -> (PathBuf, PathBuf) { let id = OffsetDateTime::now_utc().unix_timestamp_nanos(); let path = dir.join(format!("ca-{}/", id)); std::fs::create_dir_all(&path).expect("failed to create ca dir"); let cmd = [ "replikey".to_string(), "create-ca".to_string(), "--valid-days".to_string(), "365".to_string(), "--dn-common-name".to_string(), "replikey-test-ca".to_string(), "--output".to_string(), path.to_string_lossy().to_string(), ]; let parsed = CertCommand::parse_from(&cmd); if let CertSubCommand::CreateCa(opts) = parsed.subcmd { replikey::ops::cert::create_ca(opts, false); } else { panic!("failed to parse create-ca"); } let pem = path.join("ca.pem"); let key = path.join("ca.key"); assert!(pem.exists()); assert!(key.exists()); (pem, key) } struct PeerCert { key: PathBuf, csr: PathBuf, self_signed: PathBuf, signed_cert: Option, } impl PeerCert { fn signed_by(&self, usage: UsageType, ca_key: &PathBuf) -> PeerCert { let id = OffsetDateTime::now_utc().unix_timestamp_nanos(); let path = ca_key.parent().unwrap().join(format!("signed-{}/", id)); std::fs::create_dir_all(&path).expect("failed to create signed dir"); let cmd = [ "replikey".to_string(), "sign-server-csr".to_string(), "--valid-days".to_string(), "365".to_string(), "--ca-dir".to_string(), ca_key.parent().unwrap().to_string_lossy().to_string(), "--input-csr".to_string(), self.csr.to_string_lossy().to_string(), "-d".to_string(), "target1.local".to_string(), "-d".to_string(), "target2.local".to_string(), "-d".to_string(), "*.targets.local".to_string(), "--output".to_string(), format!("{}/server-signed.pem", path.to_string_lossy()), ]; let parsed = CertCommand::parse_from(&cmd); if let CertSubCommand::SignServerCSR(opts) = parsed.subcmd { replikey::ops::cert::sign_csr(opts, usage, false); } else { panic!("failed to parse sign-server-csr"); } PeerCert { key: self.key.clone(), csr: self.csr.clone(), self_signed: self.self_signed.clone(), signed_cert: Some(path.join("server-signed.pem")), } } } fn test_server_cert(dir: &PathBuf) -> PeerCert { let id = OffsetDateTime::now_utc().unix_timestamp_nanos(); let path = dir.join(format!("server-{}/", id)); std::fs::create_dir_all(&path).expect("failed to create server dir"); let cmd = [ "replikey".to_string(), "create-server".to_string(), "--valid-days".to_string(), "365".to_string(), "--dn-common-name".to_string(), "replikey-test-server".to_string(), "-d".to_string(), "target1.local".to_string(), "-d".to_string(), "target2.local".to_string(), "-d".to_string(), "*.targets.local".to_string(), "--output".to_string(), path.to_string_lossy().to_string(), ]; let parsed = CertCommand::parse_from(&cmd); if let CertSubCommand::CreateServer(opts) = parsed.subcmd { replikey::ops::cert::create_server(opts); } else { panic!("failed to parse create-server"); } let pem = path.join("server.pem"); let key = path.join("server.key"); let csr = path.join("server.csr"); assert!(pem.exists()); assert!(key.exists()); PeerCert { key, csr, self_signed: pem, signed_cert: None, } } fn test_client_cert(dir: &PathBuf) -> PeerCert { let id = OffsetDateTime::now_utc().unix_timestamp_nanos(); let path = dir.join(format!("client-{}/", id)); std::fs::create_dir_all(&path).expect("failed to create client dir"); let cmd = [ "replikey".to_string(), "create-client".to_string(), "--valid-days".to_string(), "365".to_string(), "--dn-common-name".to_string(), "replikey-test-client".to_string(), "--output".to_string(), path.to_string_lossy().to_string(), ]; let parsed = CertCommand::parse_from(&cmd); if let CertSubCommand::CreateClient(opts) = parsed.subcmd { replikey::ops::cert::create_client(opts); } else { panic!("failed to parse create-client"); } let pem = path.join("client.pem"); let key = path.join("client.key"); let csr = path.join("client.csr"); assert!(pem.exists()); assert!(key.exists()); PeerCert { key, csr, self_signed: pem, signed_cert: None, } } async fn start_test_target_1() -> u32 { let port = next_port(); let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .expect("failed to bind listener"); tokio::spawn(async move { loop { let (mut stream, _) = listener.accept().await.expect("failed to accept"); log::info!("accepted connection on target1"); tokio::spawn(async move { let (mut rx, mut tx) = stream.split(); let n = tokio::io::copy(&mut rx, &mut tx) .await .expect("failed to copy"); log::info!("copied {} bytes", n); tx.shutdown().await.expect("failed to close"); }); } }); port } async fn start_test_target_2() -> u32 { let port = next_port(); let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await .expect("failed to bind listener"); tokio::spawn(async move { loop { let (mut stream, _) = listener.accept().await.expect("failed to accept"); log::info!("accepted connection on target2"); tokio::spawn(async move { let (mut rx, mut tx) = stream.split(); let mut buf = [0u8; 1024]; let mut total = 0; loop { let n = rx.read(&mut buf).await.expect("failed to read"); total += n; if n == 0 { break; } buf[..n].iter_mut().for_each(|x| *x = !*x); tx.write_all(&buf[..n]).await.expect("failed to write"); } log::info!("copied {} bytes", total); tx.shutdown().await.expect("failed to close"); }); } }); port } async fn start_reverse_proxy( target1: u32, target2: u32, cert: &PathBuf, key: &PathBuf, ca_cert: &PathBuf, compression: &str, ) -> u32 { let port = next_port(); let cmd = [ "replikey".to_string(), "reverse-proxy".to_string(), "--listen".to_string(), format!("127.0.0.1:{}", port), "--target".to_string(), format!("target1.local/127.0.0.1:{}", target1), "--target".to_string(), format!("target2.local/127.0.0.1:{}", target2), "--target".to_string(), format!("target1.targets.local/127.0.0.1:{}", target1), "--target".to_string(), format!("target2.targets.local/127.0.0.1:{}", target2), "--cert".to_string(), cert.to_string_lossy().to_string(), "--key".to_string(), key.to_string_lossy().to_string(), "--ca".to_string(), ca_cert.to_string_lossy().to_string(), "--compression".to_string(), compression.to_string(), ]; let parsed = NetworkCommand::parse_from(&cmd); if let NetworkSubCommand::ReverseProxy(opts) = parsed.subcmd { tokio::spawn(async move { replikey::ops::network::reverse_proxy(opts) .await .expect("failed to run reverse proxy"); }); port } else { panic!("failed to parse reverse-proxy"); } } async fn start_forward_proxy( addr: &str, sni: &str, cert: &PathBuf, key: &PathBuf, ca_cert: &PathBuf, compression: &str, ) -> u32 { let port = next_port(); let cmd = [ "replikey".to_string(), "forward-proxy".to_string(), "--listen".to_string(), format!("127.0.0.1:{}", port), "--target".to_string(), addr.to_string(), "--sni".to_string(), sni.to_string(), "--cert".to_string(), cert.to_string_lossy().to_string(), "--key".to_string(), key.to_string_lossy().to_string(), "--ca".to_string(), ca_cert.to_string_lossy().to_string(), "--compression".to_string(), compression.to_string(), ]; let parsed = NetworkCommand::parse_from(&cmd); if let NetworkSubCommand::ForwardProxy(opts) = parsed.subcmd { tokio::spawn(async move { replikey::ops::network::forward_proxy(opts) .await .expect("failed to run forward proxy"); }); port } else { panic!("failed to parse forward-proxy"); } } async fn wait_for_port(port: u32) { loop { tokio::select! { _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => { panic!("timeout waiting for port"); } res = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) => { if res.is_ok() { break; } } } } } async fn should_not_get_anything(port: u32) { let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) .await .expect("failed to connect"); let res = stream.read(&mut [0u8; 1024]).await; match res { Ok(n) => { assert_eq!(n, 0); } Err(e) => { assert_eq!(e.kind(), std::io::ErrorKind::ConnectionReset); } } } async fn expect_target1_signature(port: u32) { let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) .await .expect("failed to connect"); let (mut rx, mut tx) = tokio::io::split(stream); for batch in 0..8 { log::debug!("expect_target1_signature batch {}", batch); let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes(); let expect = data.clone(); let mut write_size = (1..4096).cycle(); let (_, rx_data) = tokio::join!( async { while !data.is_empty() { let n = write_size.next().unwrap().min(data.len()); tx.write_all(&data[..n]).await.expect("failed to write"); log::trace!("wrote {} bytes for target1", n); data = data.split_off(n); } tx.flush().await.expect("failed to flush"); log::info!("wrote {} bytes for target1", expect.len()); }, async { let mut rxed = Vec::new(); let mut buf = [0u8; 1024]; loop { let n = rx.read(&mut buf).await.expect("failed to read"); log::trace!("read {} bytes for target1", n); if n == 0 { break; } rxed.extend_from_slice(&buf[..n]); if rxed.len() >= expect.len() { break; } } rxed } ); assert_eq!(rx_data, expect); } tx.shutdown().await.expect("failed to close"); } async fn expect_target2_signature(port: u32) { let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) .await .expect("failed to connect"); let (mut rx, mut tx) = tokio::io::split(stream); for _batch in 0..8 { let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes(); let expect = data.iter().map(|x| !x).collect::>(); let mut write_size = (1..4096).cycle(); let (_, rx_data) = tokio::join!( async { while !data.is_empty() { let n = write_size.next().unwrap().min(data.len()); tx.write_all(&data[..n]).await.expect("failed to write"); data = data.split_off(n); } tx.flush().await.expect("failed to flush"); }, async { let mut rxed = Vec::new(); let mut buf = [0u8; 1024]; loop { let n = rx.read(&mut buf).await.expect("failed to read"); if n == 0 { break; } rxed.extend_from_slice(&buf[..n]); if rxed.len() >= expect.len() { break; } } rxed } ); assert_eq!(rx_data, expect); } tx.shutdown().await.expect("failed to close"); } #[tokio::test(flavor = "multi_thread", worker_threads = 16)] async fn test_mtls_integrated() { setup_once(); let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca")); let server_cert = test_server_cert(&PathBuf::from("target/test-server")); let client_cert = test_client_cert(&PathBuf::from("target/test-client")); let server_cert = server_cert.signed_by(UsageType::Server, &ca_key); let client_cert = client_cert.signed_by(UsageType::Client, &ca_key); let target1 = start_test_target_1().await; let target2 = start_test_target_2().await; log::info!("target1: {}, target2: {}", target1, target2); wait_for_port(target1).await; wait_for_port(target2).await; expect_target1_signature(target1).await; expect_target2_signature(target2).await; for compression in ["none", "zstd", "brotli"] { let proxy_port_self_signed = start_reverse_proxy( target1, target2, &server_cert.self_signed, &server_cert.key, &ca_cert, compression, ) .await; wait_for_port(proxy_port_self_signed).await; log::info!("reverse_proxy_port_self_signed: {}", proxy_port_self_signed); let proxy_port = start_reverse_proxy( target1, target2, server_cert.signed_cert.as_ref().unwrap(), &server_cert.key, &ca_cert, compression, ) .await; wait_for_port(proxy_port).await; log::info!("reverse_proxy_port: {}", proxy_port); for server_signed in [false, true] { let forward_proxy_target1_self_signed = start_forward_proxy( &format!( "127.0.0.1:{}", if server_signed { proxy_port } else { proxy_port_self_signed } ), "target1.local", &client_cert.self_signed, &client_cert.key, &ca_cert, compression, ) .await; log::info!( "forward_proxy_target1_self_signed: {}", forward_proxy_target1_self_signed ); let forward_proxy_target1 = start_forward_proxy( &format!( "127.0.0.1:{}", if server_signed { proxy_port } else { proxy_port_self_signed } ), "target1.local", client_cert.signed_cert.as_ref().unwrap(), &client_cert.key, &ca_cert, compression, ) .await; log::info!("forward_proxy_target1: {}", forward_proxy_target1); let forward_proxy_target2 = start_forward_proxy( &format!( "127.0.0.1:{}", if server_signed { proxy_port } else { proxy_port_self_signed } ), "target2.local", client_cert.signed_cert.as_ref().unwrap(), &client_cert.key, &ca_cert, compression, ) .await; log::info!("forward_proxy_target2: {}", forward_proxy_target2); let forward_proxy_target1_wildcard = start_forward_proxy( &format!( "127.0.0.1:{}", if server_signed { proxy_port } else { proxy_port_self_signed } ), "target1.targets.local", client_cert.signed_cert.as_ref().unwrap(), &client_cert.key, &ca_cert, compression, ) .await; let forward_proxy_target2_wildcard: u32 = start_forward_proxy( &format!( "127.0.0.1:{}", if server_signed { proxy_port } else { proxy_port_self_signed } ), "target2.targets.local", client_cert.signed_cert.as_ref().unwrap(), &client_cert.key, &ca_cert, compression, ) .await; log::info!( "forward_proxy_target2_wildcard: {}", forward_proxy_target2_wildcard ); let forward_proxy_unknown_sni = start_forward_proxy( &format!( "127.0.0.1:{}", if server_signed { proxy_port } else { proxy_port_self_signed } ), "some-other.place", client_cert.signed_cert.as_ref().unwrap(), &client_cert.key, &ca_cert, compression, ) .await; log::info!("forward_proxy_unknown_sni: {}", forward_proxy_unknown_sni); wait_for_port(forward_proxy_target1).await; wait_for_port(forward_proxy_target2).await; wait_for_port(forward_proxy_target1_wildcard).await; wait_for_port(forward_proxy_target2_wildcard).await; wait_for_port(forward_proxy_unknown_sni).await; log::info!("Test: self-signed cert"); should_not_get_anything(forward_proxy_target1_self_signed).await; if server_signed { let mut js = JoinSet::new(); for _ in 0..100 { js.spawn(expect_target1_signature(forward_proxy_target1)); js.spawn(expect_target2_signature(forward_proxy_target2)); js.spawn(expect_target1_signature(forward_proxy_target1_wildcard)); js.spawn(expect_target2_signature(forward_proxy_target2_wildcard)); js.spawn(should_not_get_anything(forward_proxy_unknown_sni)); } js.join_all().await; } else { let mut js = JoinSet::new(); for _ in 0..100 { js.spawn(should_not_get_anything(forward_proxy_target1)); js.spawn(should_not_get_anything(forward_proxy_target2)); js.spawn(should_not_get_anything(forward_proxy_target1_wildcard)); js.spawn(should_not_get_anything(forward_proxy_target2_wildcard)); js.spawn(should_not_get_anything(forward_proxy_unknown_sni)); } js.join_all().await; } } } } #[tokio::test] async fn test_mtls_role_reverse() { let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca")); let server_cert = test_server_cert(&PathBuf::from("target/test-server")); let server_cert = server_cert.signed_by(UsageType::Server, &ca_key); let client_cert = test_client_cert(&PathBuf::from("target/test-client")); let client_cert = client_cert.signed_by(UsageType::Client, &ca_key); let target1 = start_test_target_1().await; let target2 = start_test_target_2().await; wait_for_port(target1).await; wait_for_port(target2).await; let server_cert = server_cert.signed_by(UsageType::Server, &ca_key); let client_cert = client_cert.signed_by(UsageType::Client, &ca_key); let proxy_port = start_reverse_proxy( target1, target2, server_cert.signed_cert.as_ref().unwrap(), &server_cert.key, &ca_cert, "none", ) .await; wait_for_port(proxy_port).await; let forward_proxy_target1 = start_forward_proxy( &format!("127.0.0.1:{}", proxy_port), "target1.local", client_cert.signed_cert.as_ref().unwrap(), &client_cert.key, &ca_cert, "none", ) .await; wait_for_port(forward_proxy_target1).await; expect_target1_signature(forward_proxy_target1).await; let proxy_port_reversed = start_reverse_proxy( target1, target2, client_cert.signed_cert.as_ref().unwrap(), &client_cert.key, &ca_cert, "none", ) .await; wait_for_port(proxy_port_reversed).await; let forward_proxy_target1_reversed = start_forward_proxy( &format!("127.0.0.1:{}", proxy_port_reversed), "target1.local", server_cert.signed_cert.as_ref().unwrap(), &server_cert.key, &ca_cert, "none", ) .await; wait_for_port(forward_proxy_target1_reversed).await; should_not_get_anything(forward_proxy_target1_reversed).await; }