replikey/tests/mtls_integration.rs

822 lines
25 KiB
Rust
Raw Normal View History

#![allow(clippy::all)]
use std::{
path::PathBuf,
sync::{atomic::AtomicU32, Once},
};
use clap::Parser;
use replikey::{
cert::UsageType,
ops::{
cert::{CertCommand, CertSubCommand},
network::{NetworkCommand, NetworkSubCommand},
},
};
use rustls::crypto::{aws_lc_rs, CryptoProvider};
use time::OffsetDateTime;
use tokio::{
io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _},
task::JoinSet,
};
static NEXT_PORT: AtomicU32 = AtomicU32::new(12311);
static TEST_SETUP: Once = Once::new();
fn setup_once() {
TEST_SETUP.call_once(|| {
if std::env::var("RUST_LOG").is_err() {
std::env::set_var("RUST_LOG", "info");
}
env_logger::init();
CryptoProvider::install_default(aws_lc_rs::default_provider())
.expect("Failed to install crypto provider");
});
}
fn next_port() -> u32 {
NEXT_PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
}
async fn test_stream_sequence<
const WRITE_SIZE: usize,
L: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
R: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
>(
left: L,
right: R,
) {
let (mut l_rx, mut l_tx) = tokio::io::split(left);
let (mut r_rx, mut r_tx) = tokio::io::split(right);
let l_write = async move {
let mut buf = [0u8; WRITE_SIZE];
for i in 0..100 {
buf.iter_mut()
.enumerate()
.for_each(|(j, x)| *x = ((i + j) ^ 123) as u8);
l_tx.write_all(&buf).await?;
}
l_tx.shutdown().await?;
Ok::<_, tokio::io::Error>(())
};
let r_write = async move {
let mut buf = [0u8; WRITE_SIZE];
for i in 0..100 {
buf.iter_mut()
.enumerate()
.for_each(|(j, x)| *x = ((i + j) ^ 321) as u8);
r_tx.write_all(&buf).await?;
}
r_tx.shutdown().await?;
Ok(())
};
let l_read = async move {
let mut buf = [0u8; WRITE_SIZE];
for i in 0..100 {
l_rx.read_exact(&mut buf).await?;
for j in 0..WRITE_SIZE {
assert_eq!(buf[j], ((i + j) ^ 321) as u8);
}
}
assert_eq!(l_rx.read(&mut buf).await?, 0, "expected eof");
Ok(())
};
let r_read = async move {
let mut buf = [0u8; WRITE_SIZE];
for i in 0..100 {
r_rx.read_exact(&mut buf).await?;
for j in 0..WRITE_SIZE {
assert_eq!(buf[j], ((i + j) ^ 123) as u8);
}
}
assert_eq!(r_rx.read(&mut buf).await?, 0, "expected eof");
Ok(())
};
tokio::try_join!(l_write, r_write, l_read, r_read).expect("stream sequence failed");
}
#[tokio::test]
async fn in_memory_stream_works() {
let (left, right) = tokio::io::duplex(1024);
test_stream_sequence::<1024, _, _>(left, right).await;
let (left, right) = tokio::io::duplex(1024);
test_stream_sequence::<1, _, _>(left, right).await;
let (left, right) = tokio::io::duplex(1024);
test_stream_sequence::<3543, _, _>(left, right).await;
}
fn test_ca(dir: &PathBuf) -> (PathBuf, PathBuf) {
let id = OffsetDateTime::now_utc().unix_timestamp_nanos();
let path = dir.join(format!("ca-{}/", id));
std::fs::create_dir_all(&path).expect("failed to create ca dir");
let cmd = [
"replikey".to_string(),
"create-ca".to_string(),
"--valid-days".to_string(),
"365".to_string(),
"--dn-common-name".to_string(),
"replikey-test-ca".to_string(),
"--output".to_string(),
path.to_string_lossy().to_string(),
];
let parsed = CertCommand::parse_from(&cmd);
if let CertSubCommand::CreateCa(opts) = parsed.subcmd {
replikey::ops::cert::create_ca(opts, false);
} else {
panic!("failed to parse create-ca");
}
let pem = path.join("ca.pem");
let key = path.join("ca.key");
assert!(pem.exists());
assert!(key.exists());
(pem, key)
}
struct PeerCert {
key: PathBuf,
csr: PathBuf,
self_signed: PathBuf,
signed_cert: Option<PathBuf>,
}
impl PeerCert {
fn signed_by(&self, usage: UsageType, ca_key: &PathBuf) -> PeerCert {
let id = OffsetDateTime::now_utc().unix_timestamp_nanos();
let path = ca_key.parent().unwrap().join(format!("signed-{}/", id));
std::fs::create_dir_all(&path).expect("failed to create signed dir");
let cmd = [
"replikey".to_string(),
"sign-server-csr".to_string(),
"--valid-days".to_string(),
"365".to_string(),
"--ca-dir".to_string(),
ca_key.parent().unwrap().to_string_lossy().to_string(),
"--input-csr".to_string(),
self.csr.to_string_lossy().to_string(),
"-d".to_string(),
"target1.local".to_string(),
"-d".to_string(),
"target2.local".to_string(),
"-d".to_string(),
"*.targets.local".to_string(),
"--output".to_string(),
format!("{}/server-signed.pem", path.to_string_lossy()),
];
let parsed = CertCommand::parse_from(&cmd);
if let CertSubCommand::SignServerCSR(opts) = parsed.subcmd {
replikey::ops::cert::sign_csr(opts, usage, false);
} else {
panic!("failed to parse sign-server-csr");
}
PeerCert {
key: self.key.clone(),
csr: self.csr.clone(),
self_signed: self.self_signed.clone(),
signed_cert: Some(path.join("server-signed.pem")),
}
}
}
fn test_server_cert(dir: &PathBuf) -> PeerCert {
let id = OffsetDateTime::now_utc().unix_timestamp_nanos();
let path = dir.join(format!("server-{}/", id));
std::fs::create_dir_all(&path).expect("failed to create server dir");
let cmd = [
"replikey".to_string(),
"create-server".to_string(),
"--valid-days".to_string(),
"365".to_string(),
"--dn-common-name".to_string(),
"replikey-test-server".to_string(),
"-d".to_string(),
"target1.local".to_string(),
"-d".to_string(),
"target2.local".to_string(),
"-d".to_string(),
"*.targets.local".to_string(),
"--output".to_string(),
path.to_string_lossy().to_string(),
];
let parsed = CertCommand::parse_from(&cmd);
if let CertSubCommand::CreateServer(opts) = parsed.subcmd {
replikey::ops::cert::create_server(opts);
} else {
panic!("failed to parse create-server");
}
let pem = path.join("server.pem");
let key = path.join("server.key");
let csr = path.join("server.csr");
assert!(pem.exists());
assert!(key.exists());
PeerCert {
key,
csr,
self_signed: pem,
signed_cert: None,
}
}
fn test_client_cert(dir: &PathBuf) -> PeerCert {
let id = OffsetDateTime::now_utc().unix_timestamp_nanos();
let path = dir.join(format!("client-{}/", id));
std::fs::create_dir_all(&path).expect("failed to create client dir");
let cmd = [
"replikey".to_string(),
"create-client".to_string(),
"--valid-days".to_string(),
"365".to_string(),
"--dn-common-name".to_string(),
"replikey-test-client".to_string(),
"--output".to_string(),
path.to_string_lossy().to_string(),
];
let parsed = CertCommand::parse_from(&cmd);
if let CertSubCommand::CreateClient(opts) = parsed.subcmd {
replikey::ops::cert::create_client(opts);
} else {
panic!("failed to parse create-client");
}
let pem = path.join("client.pem");
let key = path.join("client.key");
let csr = path.join("client.csr");
assert!(pem.exists());
assert!(key.exists());
PeerCert {
key,
csr,
self_signed: pem,
signed_cert: None,
}
}
async fn start_test_target_1() -> u32 {
let port = next_port();
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("failed to bind listener");
tokio::spawn(async move {
loop {
let (mut stream, _) = listener.accept().await.expect("failed to accept");
log::info!("accepted connection on target1");
tokio::spawn(async move {
let (mut rx, mut tx) = stream.split();
let n = tokio::io::copy(&mut rx, &mut tx)
.await
.expect("failed to copy");
log::info!("copied {} bytes", n);
tx.shutdown().await.expect("failed to close");
});
}
});
port
}
async fn start_test_target_2() -> u32 {
let port = next_port();
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.expect("failed to bind listener");
tokio::spawn(async move {
loop {
let (mut stream, _) = listener.accept().await.expect("failed to accept");
log::info!("accepted connection on target2");
tokio::spawn(async move {
let (mut rx, mut tx) = stream.split();
let mut buf = [0u8; 1024];
let mut total = 0;
loop {
let n = rx.read(&mut buf).await.expect("failed to read");
total += n;
if n == 0 {
break;
}
buf[..n].iter_mut().for_each(|x| *x = !*x);
tx.write_all(&buf[..n]).await.expect("failed to write");
}
log::info!("copied {} bytes", total);
tx.shutdown().await.expect("failed to close");
});
}
});
port
}
async fn start_reverse_proxy(
target1: u32,
target2: u32,
cert: &PathBuf,
key: &PathBuf,
ca_cert: &PathBuf,
compression: &str,
) -> u32 {
let port = next_port();
let cmd = [
"replikey".to_string(),
"reverse-proxy".to_string(),
"--listen".to_string(),
format!("127.0.0.1:{}", port),
"--target".to_string(),
format!("target1.local/127.0.0.1:{}", target1),
"--target".to_string(),
format!("target2.local/127.0.0.1:{}", target2),
"--target".to_string(),
format!("target1.targets.local/127.0.0.1:{}", target1),
"--target".to_string(),
format!("target2.targets.local/127.0.0.1:{}", target2),
"--cert".to_string(),
cert.to_string_lossy().to_string(),
"--key".to_string(),
key.to_string_lossy().to_string(),
"--ca".to_string(),
ca_cert.to_string_lossy().to_string(),
"--compression".to_string(),
compression.to_string(),
];
let parsed = NetworkCommand::parse_from(&cmd);
if let NetworkSubCommand::ReverseProxy(opts) = parsed.subcmd {
tokio::spawn(async move {
replikey::ops::network::reverse_proxy(opts)
.await
.expect("failed to run reverse proxy");
});
port
} else {
panic!("failed to parse reverse-proxy");
}
}
async fn start_forward_proxy(
addr: &str,
sni: &str,
cert: &PathBuf,
key: &PathBuf,
ca_cert: &PathBuf,
compression: &str,
) -> u32 {
let port = next_port();
let cmd = [
"replikey".to_string(),
"forward-proxy".to_string(),
"--listen".to_string(),
format!("127.0.0.1:{}", port),
"--target".to_string(),
addr.to_string(),
"--sni".to_string(),
sni.to_string(),
"--cert".to_string(),
cert.to_string_lossy().to_string(),
"--key".to_string(),
key.to_string_lossy().to_string(),
"--ca".to_string(),
ca_cert.to_string_lossy().to_string(),
"--compression".to_string(),
compression.to_string(),
];
let parsed = NetworkCommand::parse_from(&cmd);
if let NetworkSubCommand::ForwardProxy(opts) = parsed.subcmd {
tokio::spawn(async move {
replikey::ops::network::forward_proxy(opts)
.await
.expect("failed to run forward proxy");
});
port
} else {
panic!("failed to parse forward-proxy");
}
}
async fn wait_for_port(port: u32) {
loop {
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_secs(5)) => {
panic!("timeout waiting for port");
}
res = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) => {
if res.is_ok() {
break;
}
}
}
}
}
async fn should_not_get_anything(port: u32) {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.expect("failed to connect");
let res = stream.read(&mut [0u8; 1024]).await;
match res {
Ok(n) => {
assert_eq!(n, 0);
}
Err(e) => {
assert_eq!(e.kind(), std::io::ErrorKind::ConnectionReset);
}
}
}
async fn expect_target1_signature(port: u32) {
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.expect("failed to connect");
let (mut rx, mut tx) = tokio::io::split(stream);
for batch in 0..8 {
log::debug!("expect_target1_signature batch {}", batch);
let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes();
let expect = data.clone();
let mut write_size = (1..4096).cycle();
let (_, rx_data) = tokio::join!(
async {
while !data.is_empty() {
let n = write_size.next().unwrap().min(data.len());
tx.write_all(&data[..n]).await.expect("failed to write");
log::trace!("wrote {} bytes for target1", n);
data = data.split_off(n);
}
tx.flush().await.expect("failed to flush");
log::info!("wrote {} bytes for target1", expect.len());
},
async {
let mut rxed = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = rx.read(&mut buf).await.expect("failed to read");
log::trace!("read {} bytes for target1", n);
if n == 0 {
break;
}
rxed.extend_from_slice(&buf[..n]);
if rxed.len() >= expect.len() {
break;
}
}
rxed
}
);
assert_eq!(rx_data, expect);
}
tx.shutdown().await.expect("failed to close");
}
async fn expect_target2_signature(port: u32) {
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await
.expect("failed to connect");
let (mut rx, mut tx) = tokio::io::split(stream);
for _batch in 0..8 {
let mut data = "GET / HTTP/1.0\r\n\r\n".repeat(1024).into_bytes();
let expect = data.iter().map(|x| !x).collect::<Vec<_>>();
let mut write_size = (1..4096).cycle();
let (_, rx_data) = tokio::join!(
async {
while !data.is_empty() {
let n = write_size.next().unwrap().min(data.len());
tx.write_all(&data[..n]).await.expect("failed to write");
data = data.split_off(n);
}
tx.flush().await.expect("failed to flush");
},
async {
let mut rxed = Vec::new();
let mut buf = [0u8; 1024];
loop {
let n = rx.read(&mut buf).await.expect("failed to read");
if n == 0 {
break;
}
rxed.extend_from_slice(&buf[..n]);
if rxed.len() >= expect.len() {
break;
}
}
rxed
}
);
assert_eq!(rx_data, expect);
}
tx.shutdown().await.expect("failed to close");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 16)]
async fn test_mtls_integrated() {
setup_once();
let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca"));
let server_cert = test_server_cert(&PathBuf::from("target/test-server"));
let client_cert = test_client_cert(&PathBuf::from("target/test-client"));
let server_cert = server_cert.signed_by(UsageType::Server, &ca_key);
let client_cert = client_cert.signed_by(UsageType::Client, &ca_key);
let target1 = start_test_target_1().await;
let target2 = start_test_target_2().await;
log::info!("target1: {}, target2: {}", target1, target2);
wait_for_port(target1).await;
wait_for_port(target2).await;
expect_target1_signature(target1).await;
expect_target2_signature(target2).await;
for compression in ["none", "zstd", "brotli"] {
let proxy_port_self_signed = start_reverse_proxy(
target1,
target2,
&server_cert.self_signed,
&server_cert.key,
&ca_cert,
compression,
)
.await;
wait_for_port(proxy_port_self_signed).await;
log::info!("reverse_proxy_port_self_signed: {}", proxy_port_self_signed);
let proxy_port = start_reverse_proxy(
target1,
target2,
server_cert.signed_cert.as_ref().unwrap(),
&server_cert.key,
&ca_cert,
compression,
)
.await;
wait_for_port(proxy_port).await;
log::info!("reverse_proxy_port: {}", proxy_port);
for server_signed in [false, true] {
let forward_proxy_target1_self_signed = start_forward_proxy(
&format!(
"127.0.0.1:{}",
if server_signed {
proxy_port
} else {
proxy_port_self_signed
}
),
"target1.local",
&client_cert.self_signed,
&client_cert.key,
&ca_cert,
compression,
)
.await;
log::info!(
"forward_proxy_target1_self_signed: {}",
forward_proxy_target1_self_signed
);
let forward_proxy_target1 = start_forward_proxy(
&format!(
"127.0.0.1:{}",
if server_signed {
proxy_port
} else {
proxy_port_self_signed
}
),
"target1.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
compression,
)
.await;
log::info!("forward_proxy_target1: {}", forward_proxy_target1);
let forward_proxy_target2 = start_forward_proxy(
&format!(
"127.0.0.1:{}",
if server_signed {
proxy_port
} else {
proxy_port_self_signed
}
),
"target2.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
compression,
)
.await;
log::info!("forward_proxy_target2: {}", forward_proxy_target2);
let forward_proxy_target1_wildcard = start_forward_proxy(
&format!(
"127.0.0.1:{}",
if server_signed {
proxy_port
} else {
proxy_port_self_signed
}
),
"target1.targets.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
compression,
)
.await;
let forward_proxy_target2_wildcard: u32 = start_forward_proxy(
&format!(
"127.0.0.1:{}",
if server_signed {
proxy_port
} else {
proxy_port_self_signed
}
),
"target2.targets.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
compression,
)
.await;
log::info!(
"forward_proxy_target2_wildcard: {}",
forward_proxy_target2_wildcard
);
let forward_proxy_unknown_sni = start_forward_proxy(
&format!(
"127.0.0.1:{}",
if server_signed {
proxy_port
} else {
proxy_port_self_signed
}
),
"some-other.place",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
compression,
)
.await;
log::info!("forward_proxy_unknown_sni: {}", forward_proxy_unknown_sni);
wait_for_port(forward_proxy_target1).await;
wait_for_port(forward_proxy_target2).await;
wait_for_port(forward_proxy_target1_wildcard).await;
wait_for_port(forward_proxy_target2_wildcard).await;
wait_for_port(forward_proxy_unknown_sni).await;
log::info!("Test: self-signed cert");
should_not_get_anything(forward_proxy_target1_self_signed).await;
if server_signed {
let mut js = JoinSet::new();
for _ in 0..100 {
js.spawn(expect_target1_signature(forward_proxy_target1));
js.spawn(expect_target2_signature(forward_proxy_target2));
js.spawn(expect_target1_signature(forward_proxy_target1_wildcard));
js.spawn(expect_target2_signature(forward_proxy_target2_wildcard));
js.spawn(should_not_get_anything(forward_proxy_unknown_sni));
}
js.join_all().await;
} else {
let mut js = JoinSet::new();
for _ in 0..100 {
js.spawn(should_not_get_anything(forward_proxy_target1));
js.spawn(should_not_get_anything(forward_proxy_target2));
js.spawn(should_not_get_anything(forward_proxy_target1_wildcard));
js.spawn(should_not_get_anything(forward_proxy_target2_wildcard));
js.spawn(should_not_get_anything(forward_proxy_unknown_sni));
}
js.join_all().await;
}
}
}
}
#[tokio::test]
async fn test_mtls_role_reverse() {
let (ca_cert, ca_key) = test_ca(&PathBuf::from("target/test-ca"));
let server_cert = test_server_cert(&PathBuf::from("target/test-server"));
let server_cert = server_cert.signed_by(UsageType::Server, &ca_key);
let client_cert = test_client_cert(&PathBuf::from("target/test-client"));
let client_cert = client_cert.signed_by(UsageType::Client, &ca_key);
let target1 = start_test_target_1().await;
let target2 = start_test_target_2().await;
wait_for_port(target1).await;
wait_for_port(target2).await;
let server_cert = server_cert.signed_by(UsageType::Server, &ca_key);
let client_cert = client_cert.signed_by(UsageType::Client, &ca_key);
let proxy_port = start_reverse_proxy(
target1,
target2,
server_cert.signed_cert.as_ref().unwrap(),
&server_cert.key,
&ca_cert,
"none",
)
.await;
wait_for_port(proxy_port).await;
let forward_proxy_target1 = start_forward_proxy(
&format!("127.0.0.1:{}", proxy_port),
"target1.local",
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"none",
)
.await;
wait_for_port(forward_proxy_target1).await;
expect_target1_signature(forward_proxy_target1).await;
let proxy_port_reversed = start_reverse_proxy(
target1,
target2,
client_cert.signed_cert.as_ref().unwrap(),
&client_cert.key,
&ca_cert,
"none",
)
.await;
wait_for_port(proxy_port_reversed).await;
let forward_proxy_target1_reversed = start_forward_proxy(
&format!("127.0.0.1:{}", proxy_port_reversed),
"target1.local",
server_cert.signed_cert.as_ref().unwrap(),
&server_cert.key,
&ca_cert,
"none",
)
.await;
wait_for_port(forward_proxy_target1_reversed).await;
should_not_get_anything(forward_proxy_target1_reversed).await;
}