From bd6d8625ed08ec4b7e312deb3d50dcfc00060bcb Mon Sep 17 00:00:00 2001 From: eternal-flame-AD Date: Wed, 4 Sep 2024 22:28:59 -0500 Subject: [PATCH] refactor code structure Signed-off-by: eternal-flame-AD --- src/client.rs | 404 ++++++++++++++++++++ src/{e2e/mod.rs => e2e.rs} | 4 +- src/filter.rs | 30 ++ src/io.rs | 61 +++ src/lib.rs | 745 +++++++++++++++++++------------------ src/main.rs | 463 +---------------------- src/model.rs | 17 + 7 files changed, 912 insertions(+), 812 deletions(-) create mode 100644 src/client.rs rename src/{e2e/mod.rs => e2e.rs} (97%) create mode 100644 src/filter.rs create mode 100644 src/io.rs create mode 100644 src/model.rs diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..1b68774 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,404 @@ +use std::{pin::Pin, sync::Arc}; + +use crate::{ + e2e::{decrypt_file, ErrOrWrongHash}, + io::prompt, + mxc_url_to_https, DumpError, +}; +use futures::{future::BoxFuture, stream, StreamExt, TryStream, TryStreamExt}; +use matrix_sdk::{ + bytes::Bytes, + deserialized_responses::TimelineEvent, + encryption::verification::{Verification, VerificationRequestState}, + matrix_auth::MatrixSession, + room::MessagesOptions, + ruma::events::key::verification::VerificationMethod, + Client, Room, +}; +use matrix_sdk_crypto::SasState; + +use ruma_events::room::{ + message::{MessageType, RoomMessageEvent}, + MediaSource, +}; +use tokio::io::AsyncReadExt; + +pub struct MatrixClient { + client: Client, +} + +pub type PinTryStream<'s, O, E> = + Pin> + Send + 's>>; + +pub struct FileStream<'s> { + pub filename: String, + pub url: String, + pub stream: PinTryStream<'s, Bytes, DumpError>, +} + +pub enum AuthMethod<'a> { + Password { + user: &'a str, + password: &'a str, + initial_device_display_name: Option<&'a str>, + }, + Session(MatrixSession), +} + +impl MatrixClient { + pub fn new(client: Client) -> Self { + Self { client } + } + + pub fn new_arc(client: Client) -> Arc { + Arc::new(Self::new(client)) + } + + pub fn client(&self) -> &Client { + &self.client + } + + pub async fn login(&self, auth: AuthMethod<'_>) -> Result<(), matrix_sdk::Error> { + match auth { + AuthMethod::Password { + user, + password, + initial_device_display_name, + } => { + let tmp = self.client.matrix_auth().login_username(user, &password); + if let Some(name) = initial_device_display_name { + tmp.initial_device_display_name(&name) + } else { + tmp + } + .await + .map(|_| ()) + } + AuthMethod::Session(session) => { + self.client.matrix_auth().restore_session(session).await + } + } + } + + pub fn try_read_attachment<'c, 'a: 'c>( + self: Arc, + client: &'c matrix_sdk::reqwest::Client, + msg: &'a RoomMessageEvent, + ) -> Result, DumpError>>>, DumpError> { + macro_rules! impl_file_like { + ($msg:expr, $($variant:ident),*) => { + match $msg { + $( + MessageType::$variant(file) => { + Ok(Some(Box::pin( + async move { + let src = match &file.source { + MediaSource::Plain(s) => s, + MediaSource::Encrypted(e) => &e.url, + }; + let filename = file.filename.as_deref().map(|s| s.to_string()).unwrap_or(file.body.clone()); + let url = mxc_url_to_https(src.as_str(), self.client.homeserver().as_str()); + let resp = client.get(&url).send().await?; + let body = resp.bytes_stream(); + Ok(FileStream { + filename, + url, + stream: match &file.source { + MediaSource::Plain(_) => { + Box::pin(body.map_err(DumpError::from)) as Pin> + Send>> + } + MediaSource::Encrypted(e) => Box::pin(decrypt_file(e.as_ref(), body).await?.map_ok(|v| Bytes::from(v)).map_err( + |e| match e { + ErrOrWrongHash::Err(e) => e.into(), + ErrOrWrongHash::WrongHash => DumpError::HashMismatch, + }, + )) as Pin> + Send>>, + } + }) + })))} + )* + _ => Ok(None), + } + }; + } + match msg { + RoomMessageEvent::Original(msg) => { + impl_file_like!(&msg.content.msgtype, Image, Video, Audio, File) + } + _ => Ok(None), + } + } + + pub async fn setup_e2e(self: Arc) -> bool { + let client = &self.client; + + log::info!("Preparing e2e machine"); + client + .encryption() + .wait_for_e2ee_initialization_tasks() + .await; + log::info!("E2E machine ready"); + + let own_device = client + .encryption() + .get_own_device() + .await + .expect("Failed to get own device") + .expect("No own device found"); + + if own_device.is_cross_signed_by_owner() { + log::info!("Cross-signing keys are already set up"); + return true; + } + + let mut stdin = tokio::io::stdin(); + let mut stdout = tokio::io::stdout(); + + let devices = client + .encryption() + .get_user_devices(own_device.user_id()) + .await + .expect("Failed to get devices") + .devices() + .collect::>(); + + for (i, d) in devices.iter().enumerate() { + log::info!( + "Device {}: {} ({})", + i, + d.display_name().unwrap_or("Unnamed"), + d.device_id() + ); + } + + let device_num = prompt( + &mut stdin, + &mut stdout, + "Enter device number to verify with: ", + ) + .await + .unwrap_or_else(|e| { + log::error!("Failed to read device number: {}", e); + String::new() + }) + .trim() + .parse::() + .expect("Failed to parse device number"); + + let device = devices.get(device_num).expect("Invalid device number"); + + log::info!( + "Requesting verification with {} ({})", + device.display_name().unwrap_or("Unnamed"), + device.device_id() + ); + + let req = match device + .request_verification_with_methods(vec![VerificationMethod::SasV1]) + .await + { + Ok(req) => req, + Err(e) => { + log::error!( + "Failed to request verification for {}: {}", + device.device_id(), + e + ); + return false; + } + }; + + let device_name = format!( + "{} ({})", + device.display_name().unwrap_or("Unnamed"), + device.device_id() + ); + + let device_name_clone = device_name.clone(); + + let mut c = req.changes(); + + while let Some(change) = c.next().await { + match change { + VerificationRequestState::Done => { + log::info!("Verification successful for {}", device_name_clone); + return true; + } + VerificationRequestState::Cancelled(info) => { + log::info!( + "Verification canceled for {}: {:?}", + device.device_id(), + info + ); + return false; + } + VerificationRequestState::Transitioned { verification } => { + log::info!( + "Verification transitioned for {}: {:?}", + device.device_id(), + verification + ); + match verification { + Verification::SasV1(v) => { + v.accept().await.expect("Failed to accept verification"); + let emoji_str = v + .emoji() + .map(|emojis| { + emojis + .iter() + .map(|e| format!("{} ({})", e.symbol, e.description)) + .collect::>() + .join(", ") + }) + .unwrap_or_else(|| "No emojis".to_string()); + let decimals = v + .decimals() + .map(|(n1, n2, n3)| format!("{} {} {}", n1, n2, n3)) + .unwrap_or_else(|| "No decimals".to_string()); + + if prompt( + &mut stdin, + &mut stdout, + &format!( + "Verification for {}:\nEmoji: {}\nDecimals: {}\n Confirm? (y/n): ", + device.device_id(), + emoji_str, + decimals + ), + ) + .await + .unwrap_or_else(|e| { + log::error!("Failed to read verification response: {}", e); + "n".to_string() + }) + .trim() == "y" { + v.confirm().await.expect("Failed to confirm"); + } else { + v.cancel().await.expect("Failed to cancel"); + } + } + _ => unimplemented!(), + } + } + VerificationRequestState::Ready { + their_methods, + our_methods, + .. + } => { + log::info!( + "Verification ready for {}: their methods: {:?}, our methods: {:?}", + device.device_id(), + their_methods, + our_methods + ); + + req.accept_with_methods(vec![VerificationMethod::SasV1]) + .await + .expect("Failed to accept verification"); + + let sas = req + .start_sas() + .await + .expect("Failed to start SAS") + .expect("No SAS"); + + sas.accept().await.expect("Failed to accept SAS"); + + while let Some(event) = sas.changes().next().await { + log::info!("SAS event: {:?}", event); + match event { + SasState::Started { protocols } => { + log::info!("SAS started with protocols: {:?}", protocols); + sas.accept().await.expect("Failed to accept SAS"); + } + SasState::Cancelled(c) => { + log::info!("SAS canceled: {:?}", c); + return false; + } + SasState::Done { + verified_devices, + verified_identities, + } => { + log::info!( + "SAS done: verified devices: {:?}, verified identities: {:?}", + verified_devices, + verified_identities + ); + } + SasState::KeysExchanged { emojis, decimals } => { + println!( + "Verification for {}:\nEmoji: {}\nDecimals: [{}, {}, {}]\n Confirm? (y/n)", + device.device_id(), + emojis.map(|e| e.emojis + .iter() + .map(|e| format!("{} ({})", e.symbol, e.description)) + .collect::>() + .join(", "), + ).unwrap_or_else(|| "No emojis".to_string()), + decimals.0, decimals.1, decimals.2 + ); + + let mut response = String::new(); + + while let Ok(c) = stdin.read_u8().await { + if c == b'\n' { + break; + } + response.push(c as char); + } + + if response.trim() == "y" { + sas.confirm().await.expect("Failed to confirm SAS"); + } else { + sas.cancel().await.expect("Failed to cancel SAS"); + } + } + SasState::Accepted { accepted_protocols } => { + log::info!("SAS accepted with protocols: {:?}", accepted_protocols); + } + SasState::Confirmed => { + log::info!("SAS confirmed! Waiting for verification to finish"); + } + } + } + } + VerificationRequestState::Requested { their_methods, .. } => { + log::info!( + "Verification requested for {}: their methods: {:?}", + device.device_id(), + their_methods + ); + } + VerificationRequestState::Created { our_methods } => { + log::info!( + "Verification created for {}: our methods: {:?}", + device.device_id(), + our_methods + ); + } + } + } + + log::info!("Verification for {} fell through", device.device_id()); + + false + } + + pub fn room_messages( + room: &Room, + since: Option, + ) -> impl TryStream, Error = matrix_sdk::Error> + '_ { + stream::try_unfold(since, move |since| async move { + let mut opt = MessagesOptions::forward().from(since.as_deref()); + opt.limit = 32.try_into().expect("Failed to convert"); + + room.messages(opt).await.map(|r| { + if r.chunk.is_empty() { + None + } else { + Some((r.chunk, r.end)) + } + }) + }) + } +} diff --git a/src/e2e/mod.rs b/src/e2e.rs similarity index 97% rename from src/e2e/mod.rs rename to src/e2e.rs index 9dd75e7..b83a69b 100644 --- a/src/e2e/mod.rs +++ b/src/e2e.rs @@ -90,7 +90,7 @@ where } Some(Err(e)) => std::task::Poll::Ready(Some(Err(ErrOrWrongHash::Err(e)))), None => match self.hasher.take() { - None => return std::task::Poll::Ready(None), + None => std::task::Poll::Ready(None), Some(hash) => { if hash.finalize().as_slice() == self.expected { return std::task::Poll::Ready(None); @@ -120,5 +120,5 @@ pub async fn decrypt_file<'s, E: std::error::Error + 's>( sha256_expect, )); - try_decrypt(&file.key, data, &iv).await + try_decrypt(&file.key, data, iv).await } diff --git a/src/filter.rs b/src/filter.rs new file mode 100644 index 0000000..878ddbc --- /dev/null +++ b/src/filter.rs @@ -0,0 +1,30 @@ +use ruma_client_api::{ + filter::{EventFormat, FilterDefinition, RoomEventFilter, RoomFilter}, + sync::sync_events::v3::Filter, +}; + +pub fn minimal_sync_filter() -> Filter { + let mut filter = FilterDefinition::empty(); + + filter.event_format = EventFormat::Client; + filter.presence = ruma_client_api::filter::Filter::empty(); + + let mut room_filter = RoomFilter::empty(); + let mut room_event_filter = RoomEventFilter::empty(); + room_event_filter.types = Some( + vec![ + "m.room.encryption", + "m.room.encryption*", + "m.room.create", + "m.room.avatar", + ] + .into_iter() + .map(|s| s.to_string()) + .collect(), + ); + room_filter.timeline = room_event_filter; + + filter.room = room_filter; + + Filter::FilterDefinition(filter) +} diff --git a/src/io.rs b/src/io.rs new file mode 100644 index 0000000..3747242 --- /dev/null +++ b/src/io.rs @@ -0,0 +1,61 @@ +pub fn sanitize_filename(name: &str) -> String { + name.chars() + .map(|c| match c { + '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' | '!' => '_', + _ => c, + }) + .collect() +} + +pub async fn prompt( + input: &mut IS, + output: &mut OS, + prompt: impl AsRef, +) -> Result { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + output.write_all(prompt.as_ref().as_bytes()).await?; + output.flush().await?; + + let mut response = String::new(); + match input.read_u8().await { + Ok(b'\n') => {} + Ok(byte) => response.push(byte as char), + Err(e) => return Err(e), + } + + Ok(response.trim_end().to_string()) +} + +pub fn read_password() -> Result { + use crossterm::{execute, style::Print, terminal}; + + terminal::enable_raw_mode()?; + + let mut password = String::new(); + + execute!(std::io::stdout(), Print("Password:"))?; + + loop { + if let crossterm::event::Event::Key(event) = crossterm::event::read()? { + match event.code { + crossterm::event::KeyCode::Enter => break, + crossterm::event::KeyCode::Backspace => { + password.pop(); + } + crossterm::event::KeyCode::Char(c) => { + password.push(c); + } + _ => {} + } + } + } + + execute!(std::io::stdout(), Print("\n"))?; + + terminal::disable_raw_mode()?; + + println!(); + + Ok(password) +} diff --git a/src/lib.rs b/src/lib.rs index dd50b78..ccf6c12 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,31 +1,78 @@ -use std::{pin::Pin, sync::Arc}; +use std::{ + fs::OpenOptions, + os::unix::fs::OpenOptionsExt, + path::Path, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, +}; -use e2e::ErrOrWrongHash; -use futures::{future::BoxFuture, stream, StreamExt, TryStream, TryStreamExt}; -use matrix_sdk::{ - bytes::Bytes, - deserialized_responses::TimelineEvent, - encryption::verification::{Verification, VerificationRequestState}, - room::MessagesOptions, - ruma::events::key::verification::VerificationMethod, - Client, HttpError, IdParseError, Room, -}; -use matrix_sdk_crypto::SasState; -use ruma_client_api::{ - filter::{EventFormat, FilterDefinition, RoomEventFilter, RoomFilter}, - sync::sync_events::v3::Filter, -}; +use client::{AuthMethod, FileStream}; +use matrix_sdk::{HttpError, IdParseError}; use ruma_common::MxcUriError; -use ruma_events::room::{ - message::{MessageType, RoomMessageEvent}, - MediaSource, -}; -use tokio::io::AsyncReadExt; +pub use crate::client::MatrixClient; +use futures::{TryFutureExt, TryStreamExt}; +use io::{read_password, sanitize_filename}; +use matrix_sdk::{config::SyncSettings, Client, ServerName}; +use model::{DumpEvent, RoomMeta}; +use reqwest::Url; +use ruma_client::http_client::Reqwest; +use ruma_common::serde::Raw; +use ruma_events::{AnyMessageLikeEvent, AnyTimelineEvent}; +use tokio::{io::AsyncWriteExt, sync::Semaphore, task::JoinSet}; + +pub mod client; pub mod e2e; +pub mod filter; +pub mod io; +pub mod model; pub mod serdes; -fn mxc_url_to_https(mxc_url: &str, homeserver: &str) -> String { +use clap::Parser; + +#[derive(Clone, Debug, Parser)] +pub struct Args { + #[clap(long = "home", default_value = "matrix.org")] + pub home_server: String, + + #[clap(short, long)] + pub username: Option, + + #[clap(long, default_value = "Matrix.org Protocol Dumper by Yumechi")] + pub device_name: Option, + + #[clap(long)] + pub device_id: Option, + + #[clap(long, default_value = "config/token.json")] + pub access_token_file: Option, + + #[clap(short, long, default_value = "dump")] + pub out_dir: String, + + #[clap(long)] + pub filter: Vec, + + #[clap(long, short = 'j', default_value = "4")] + pub concurrency: usize, + + #[clap(long, default_value = "config/e2e.db")] + pub e2e_db: String, + + #[clap(long)] + pub password_file: Option, + + #[clap( + long, + help = "The timeout for the key sync in seconds", + default_value = "300" + )] + pub key_sync_timeout: u64, +} + +pub fn mxc_url_to_https(mxc_url: &str, homeserver: &str) -> String { format!( "{}_matrix/media/r0/download/{}", homeserver, @@ -63,387 +110,363 @@ pub enum DumpError { InvalidId(#[from] IdParseError), } -pub struct MatrixClient { - client: Client, -} +pub async fn dump_room_messages( + room: &matrix_sdk::Room, + out_dir: &Path, + client: Arc, + http_client: &Reqwest, + concurrency: usize, +) -> Result<(), DumpError> { + let chunk_idx = &AtomicU64::new(0); -pub fn minimal_sync_filter() -> Filter { - let mut filter = FilterDefinition::empty(); + MatrixClient::room_messages(room, None) + .try_for_each_concurrent(Some(concurrency), |msg| { + let room_dir = out_dir.to_owned(); + let client = client.clone(); + let http_client = http_client.clone(); + async move { + let output = room_dir.join(format!( + "chunk-{}.json", + chunk_idx.fetch_add(1, Ordering::SeqCst) + )); - filter.event_format = EventFormat::Client; - filter.presence = ruma_client_api::filter::Filter::empty(); + let mut out = Vec::with_capacity(msg.len()); - let mut room_filter = RoomFilter::empty(); - let mut room_event_filter = RoomEventFilter::empty(); - room_event_filter.types = Some( - vec![ - "m.room.encryption", - "m.room.encryption*", - "m.room.create", - "m.room.avatar", - ] - .into_iter() - .map(|s| s.to_string()) - .collect(), - ); - room_filter.timeline = room_event_filter; + for event in msg.into_iter() { + let mut fm = None; + match event.event.clone().cast::().deserialize() { + Ok(event) => { + if let AnyTimelineEvent::MessageLike( + AnyMessageLikeEvent::RoomMessage(m), + ) = event + { + match client.clone().try_read_attachment(&http_client, &m) { + Ok(None) => {} + Ok(Some(fut)) => { + match fut.await { + Ok(FileStream { + filename, + url, + mut stream, + }) => { + let file_name = format!( + "attachment_{}_{}_{}", + m.event_id().as_str(), + sanitize_filename( + Url::parse(&url) + .unwrap() + .path_segments() + .unwrap() + .last() + .unwrap() + ), + sanitize_filename(&filename), + ); - filter.room = room_filter; + let file = tokio::fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(room_dir.join(&file_name)) + .await?; - Filter::FilterDefinition(filter) -} + let mut file = tokio::io::BufWriter::new(file); -impl MatrixClient { - pub fn new(client: Client) -> Self { - Self { client } - } - - pub fn new_arc(client: Client) -> Arc { - Arc::new(Self::new(client)) - } - - pub fn client(&self) -> &Client { - &self.client - } - - pub fn try_read_attachment<'c, 'a: 'c>( - self: Arc, - client: &'c matrix_sdk::reqwest::Client, - msg: &'a RoomMessageEvent, - ) -> Result< - Option< - BoxFuture< - 'c, - Result< - ( - String, - String, - Pin< - Box< - dyn TryStream< - Ok = Bytes, - Error = DumpError, - Item = Result, - > + Send - + 'c, - >, - >, - ), - DumpError, - >, - >, - >, - DumpError, - > { - macro_rules! impl_file_like { - ($msg:expr, $($variant:ident),*) => { - match $msg { - $( - MessageType::$variant(file) => { - Ok(Some(Box::pin( - async move { - let src = match &file.source { - MediaSource::Plain(s) => s, - MediaSource::Encrypted(e) => &e.url, - }; - let filename = file.filename.as_deref().map(|s| s.to_string()).unwrap_or(file.body.clone()); - let url = mxc_url_to_https(src.as_str(), self.client.homeserver().as_str()); - let resp = client.get(&url).send().await?; - let body = resp.bytes_stream(); - Ok((filename, url, match &file.source { - MediaSource::Plain(_) => { - Box::pin(body.map_err(DumpError::from)) as Pin> + Send>> + loop { + match stream.try_next().await { + Ok(Some(chunk)) => { + file.write_all(&chunk).await?; + } + Ok(None) => { + fm = Some((url, file_name)); + break; + } + Err(e) => { + log::warn!( + "Failed to get attachment data for {}: {}", + m.event_id(), + e + ); + } + } + } + file.shutdown().await?; + } + Err(e) => { + log::warn!("Failed to get attachment data for {}: {}", m.event_id(), e); + } + }; + } + Err(e) => { + log::warn!("Failed to get attachment data for {}: {}", m.event_id(), e); + } } - MediaSource::Encrypted(e) => Box::pin(e2e::decrypt_file(e.as_ref(), body).await?.map_ok(|v| Bytes::from(v)).map_err( - |e| match e { - ErrOrWrongHash::Err(e) => e.into(), - ErrOrWrongHash::WrongHash => DumpError::HashMismatch, - }, - )) as Pin> + Send>>, - })) - })))} - )* - _ => Ok(None), + } + } + Err(e) => { + log::warn!("Failed to deserialize event: {}", e); + } + } + + out.push(DumpEvent { + event: Raw::from_json(event.event.into_json()), + file_mapping: fm, + encryption_info: event.encryption_info.clone(), + }); } - }; - } - match msg { - RoomMessageEvent::Original(msg) => { - impl_file_like!(&msg.content.msgtype, Image, Video, Audio, File) + + serde_json::to_writer_pretty( + std::fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(output)?, + &out, + )?; + + Ok(()) } - _ => Ok(None), + }) + .await?; + + Ok(()) +} + +pub async fn run( + js: &mut JoinSet>, + bg_js: &mut JoinSet>, +) { + let args = Args::parse(); + log::info!("Starting matrix dump, args: {:?}", args); + let http_client = Reqwest::builder().https_only(true).build().unwrap(); + let client = Client::builder() + .server_name(<&ServerName>::try_from(args.home_server.as_str()).unwrap()) + .http_client(http_client.clone()) + .sqlite_store(&args.e2e_db, None) + .build() + .await + .expect("Failed to create client"); + + let client = MatrixClient::new_arc(client); + + let client1 = client.clone(); + let access_token_file_clone = args.access_token_file.clone(); + async move { + log::info!("Logging in"); + if let Some(ref access_token_file) = access_token_file_clone { + if std::fs::exists(access_token_file).expect("Failed to check access token file") { + let token = serde_json::from_str( + &tokio::fs::read_to_string(access_token_file) + .await + .expect("Failed to read access token file"), + ) + .expect("Failed to parse access token file"); + + client1 + .login(AuthMethod::Session(token)) + .await + .expect("Failed to login"); + + log::info!("Restored session"); + return; + } + + if let Some(ref password_file) = args.password_file { + if std::fs::exists(password_file).expect("Failed to check password file") { + log::info!("Logging in with password file"); + let password = tokio::fs::read_to_string(password_file) + .await + .expect("Failed to read password file"); + client1 + .login(AuthMethod::Password { + user: args.username.as_ref().expect("No username provided"), + password: password.trim(), + initial_device_display_name: args.device_name.as_deref(), + }) + .await + .expect("Failed to login"); + return; + } + } + + if let Some(ref username) = args.username { + log::info!("Logging in with password prompt"); + let password = read_password().expect("Failed to read password"); + client1 + .login(AuthMethod::Password { + user: username, + password: &password, + initial_device_display_name: args.device_name.as_deref(), + }) + .await + .expect("Failed to login"); + return; + } + + panic!("No login method provided"); + } + } + .await; + + if !client.clone().client().logged_in() { + log::error!("Failed to login"); + return; + } + + if let Some(s) = client.client().matrix_auth().session() { + if let Some(ref access_token_file) = args.access_token_file { + let f = OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .mode(0o600) + .open(access_token_file) + .expect("Failed to open access token file"); + + serde_json::to_writer(&f, &s).expect("Failed to write access token file"); } } - pub async fn setup_e2e(self: Arc) -> bool { - let client = &self.client; - - log::info!("Preparing e2e machine"); + { + log::info!("Starting sync, may take up to a few minutes"); client - .encryption() - .wait_for_e2ee_initialization_tasks() - .await; - log::info!("E2E machine ready"); - - let own_device = client - .encryption() - .get_own_device() + .clone() + .client() + .sync_once(SyncSettings::default()) .await - .expect("Failed to get own device") - .expect("No own device found"); + .expect("Failed to sync"); + log::info!("Sync done"); - if own_device.is_cross_signed_by_owner() { - log::info!("Cross-signing keys are already set up"); - return true; + let client1 = client.clone(); + bg_js.spawn(async move { + client1 + .client() + .sync(SyncSettings::default()) + .map_err(DumpError::from) + .await + }); + } + + { + log::info!("Starting E2E setup"); + match client.clone().setup_e2e().await { + true => log::info!("E2E setup done"), + false => log::error!("E2E setup failed, E2E will not be decrypted"), } + } - let mut stdin = tokio::io::stdin(); + { + log::info!("Starting room dump"); - let devices = client - .encryption() - .get_user_devices(own_device.user_id()) - .await - .expect("Failed to get devices") - .devices() - .collect::>(); + let sem = Arc::new(Semaphore::new(args.concurrency)); - for (i, d) in devices.iter().enumerate() { - log::info!( - "Device {}: {} ({})", - i, - d.display_name().unwrap_or_else(|| "Unnamed"), - d.device_id() - ); - } + let (synced_keys_tx, synced_keys_rx) = + tokio::sync::broadcast::channel::(1); - println!("Enter the device number to verify with: "); - let mut response = String::new(); - while let Some(c) = stdin.read_u8().await.ok() { - if c == b'\n' { - break; - } - response.push(c as char); - } + let synced_keys_tx = Arc::new(synced_keys_tx); - let device_num = response - .trim() - .parse::() - .expect("Failed to parse device number"); - - let device = devices.get(device_num).expect("Invalid device number"); - - log::info!( - "Requesting verification with {} ({})", - device.display_name().unwrap_or_else(|| "Unnamed"), - device.device_id() + client.client().add_event_handler( + move |ev: matrix_sdk::ruma::events::forwarded_room_key::ToDeviceForwardedRoomKeyEvent| async move { + synced_keys_tx.send(ev.content.room_id.clone()).unwrap(); + }, ); - let req = match device - .request_verification_with_methods(vec![VerificationMethod::SasV1]) - .await - { - Ok(req) => req, - Err(e) => { - log::error!( - "Failed to request verification for {}: {}", - device.device_id(), - e - ); - return false; - } - }; + for room in client.client().rooms() { + let mut synced_keys_rx = synced_keys_rx.resubscribe(); - let device_name = format!( - "{} ({})", - device.display_name().unwrap_or_else(|| "Unnamed"), - device.device_id() - ); + let sem = sem.clone(); - let device_name_clone = device_name.clone(); + let filter = args.filter.clone(); + let out_dir = args.out_dir.clone(); + let room_id = room.room_id().to_owned(); + let room_id_clone = room_id.clone(); + let client1 = client.clone(); + let http_client = http_client.clone(); - let mut c = req.changes(); - - while let Some(change) = c.next().await { - match change { - VerificationRequestState::Done => { - log::info!("Verification successful for {}", device_name_clone); - return true; - } - VerificationRequestState::Cancelled(info) => { + js.spawn(async move { + if room.is_encrypted().await.unwrap_or(false) && !room.is_encryption_state_synced() { log::info!( - "Verification canceled for {}: {:?}", - device.device_id(), - info + "Room {} is encrypted, waiting for at most {} seconds for key sync", + room_id_clone, + args.key_sync_timeout ); - return false; - } - VerificationRequestState::Transitioned { verification } => { - log::info!( - "Verification transitioned for {}: {:?}", - device.device_id(), - verification - ); - match verification { - Verification::SasV1(v) => { - v.accept().await.expect("Failed to accept verification"); - let emoji_str = v - .emoji() - .map(|emojis| { - emojis - .iter() - .map(|e| format!("{} ({})", e.symbol, e.description)) - .collect::>() - .join(", ") - }) - .unwrap_or_else(|| "No emojis".to_string()); - let decimals = v - .decimals() - .map(|(n1, n2, n3)| format!("{} {} {}", n1, n2, n3)) - .unwrap_or_else(|| "No decimals".to_string()); - println!( - "Verification for {}:\nEmoji: {}\nDecimals: {}\n Confirm? (y/n)", - device.device_id(), - emoji_str, - decimals - ); - - let mut response = String::new(); - while let Some(c) = stdin.read_u8().await.ok() { - if c == b'\n' { + let room_id_clone1 = room_id_clone.clone(); + let room_clone = room.clone(); + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(args.key_sync_timeout)) => { + log::warn!("Key sync timed out for room {}", room_id); + } + _ = async move { + while let Ok(room_id) = synced_keys_rx.recv().await { + if room_id == room_id_clone1 { break; } - response.push(c as char); } - - if response.trim() == "y" { - v.confirm().await.expect("Failed to confirm"); - } else { - v.cancel().await.expect("Failed to cancel"); + if !room_clone.is_encryption_state_synced() { + log::warn!("Waiting for another 10 seconds for key sync to finish"); } - } - _ => unimplemented!(), - } - } - VerificationRequestState::Ready { - their_methods, - our_methods, - .. - } => { - log::info!( - "Verification ready for {}: their methods: {:?}, our methods: {:?}", - device.device_id(), - their_methods, - our_methods - ); - - req.accept_with_methods(vec![VerificationMethod::SasV1]) - .await - .expect("Failed to accept verification"); - - let sas = req - .start_sas() - .await - .expect("Failed to start SAS") - .expect("No SAS"); - - sas.accept().await.expect("Failed to accept SAS"); - - while let Some(event) = sas.changes().next().await { - log::info!("SAS event: {:?}", event); - match event { - SasState::Started { protocols } => { - log::info!("SAS started with protocols: {:?}", protocols); - sas.accept().await.expect("Failed to accept SAS"); - } - SasState::Cancelled(c) => { - log::info!("SAS canceled: {:?}", c); - return false; - } - SasState::Done { - verified_devices, - verified_identities, - } => { - log::info!( - "SAS done: verified devices: {:?}, verified identities: {:?}", - verified_devices, - verified_identities - ); - } - SasState::KeysExchanged { emojis, decimals } => { - println!( - "Verification for {}:\nEmoji: {}\nDecimals: {}\n Confirm? (y/n)", - device.device_id(), - emojis.map(|e| e.emojis - .iter() - .map(|e| format!("{} ({})", e.symbol, e.description)) - .collect::>() - .join(", "), - ).unwrap_or_else(|| "No emojis".to_string()), - format!("{} {} {}", decimals.0, decimals.1, decimals.2) - ); - - let mut response = String::new(); - - while let Some(c) = stdin.read_u8().await.ok() { - if c == b'\n' { - break; - } - response.push(c as char); - } - - if response.trim() == "y" { - sas.confirm().await.expect("Failed to confirm SAS"); - } else { - sas.cancel().await.expect("Failed to cancel SAS"); - } - } - SasState::Accepted { accepted_protocols } => { - log::info!("SAS accepted with protocols: {:?}", accepted_protocols); - } - SasState::Confirmed => { - log::info!("SAS confirmed! Waiting for verification to finish"); + tokio::time::sleep(std::time::Duration::from_secs(10)).await; + if !room_clone.is_encryption_state_synced() { + log::warn!("Key sync timed out for room {}", room_id_clone1); } + } => { + log::info!("Key sync done for room {}", room_id_clone); } } + } - VerificationRequestState::Requested { their_methods, .. } => { - log::info!( - "Verification requested for {}: their methods: {:?}", - device.device_id(), - their_methods - ); - } - VerificationRequestState::Created { our_methods } => { - log::info!( - "Verification created for {}: our methods: {:?}", - device.device_id(), - our_methods - ); - } - } - } - log::info!("Verification for {} fell through", device.device_id()); + let permit = sem.clone().acquire_owned().await; - false - } + let room_name = room.display_name().await.map(|d| d.to_string()) + .unwrap_or(room.name().unwrap_or("unknown".to_string())); - pub fn room_messages( - room: &Room, - since: Option, - ) -> impl TryStream, Error = matrix_sdk::Error> + '_ { - stream::try_unfold(since, move |since| async move { - let mut opt = MessagesOptions::forward().from(since.as_deref()); - opt.limit = 32.try_into().expect("Failed to convert"); + let room_dir = + Path::new(&out_dir).join(sanitize_filename(&format!("{}_{}", room_id, room_name))); - room.messages(opt).await.map(|r| { - if r.chunk.is_empty() { - None + let match_filter = if filter.is_empty() { + true } else { - Some((r.chunk, r.end)) + filter + .iter() + .any(|filter| room_name.contains(filter) || room_id.as_str().contains(filter)) + }; + + tokio::fs::create_dir_all(&room_dir) + .await + .expect("Failed to create room directory"); + + let meta_path = room_dir.join("meta.json"); + + serde_json::to_writer_pretty( + OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(&meta_path) + .expect("Failed to open meta file"), + &RoomMeta { + id: room_id.as_str(), + name: Some(room_name.as_str()), + state: &room.state(), + }, + )?; + + if !match_filter { + log::debug!("Skipping room: {} ({})", room_id, room_name); + return Ok(()); } - }) - }) + + log::info!("Dumping room: {} ({})", room_id, room_name); + + dump_room_messages(&room, &room_dir, client1, &http_client, args.concurrency).await?; + + drop(permit); + + Ok(()) + }); /* js.spawn */ + } + drop(synced_keys_rx); } } diff --git a/src/main.rs b/src/main.rs index 7dda5c1..54c06ef 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,121 +1,4 @@ -use std::{ - fs::OpenOptions, - os::unix::fs::OpenOptionsExt, - path::Path, - sync::{ - atomic::{AtomicU64, Ordering}, - Arc, - }, -}; - -use clap::Parser; -use futures::TryStreamExt; -use matrix_dump::{DumpError, MatrixClient}; -use matrix_sdk::{ - config::SyncSettings, deserialized_responses::EncryptionInfo, Client, RoomState, ServerName, -}; -use reqwest::Url; -use ruma_client::http_client::Reqwest; -use ruma_common::serde::Raw; -use ruma_events::{AnyMessageLikeEvent, AnyTimelineEvent}; -use tokio::{io::AsyncWriteExt, sync::Semaphore, task::JoinSet}; - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct DumpEvent { - event: Raw, - file_mapping: Option<(String, String)>, - encryption_info: Option, -} - -#[derive(Clone, Debug, serde::Serialize)] -pub struct RoomMeta<'a> { - pub id: &'a str, - pub name: Option<&'a str>, - pub state: &'a RoomState, -} - -fn sanitize_filename(name: &str) -> String { - name.chars() - .map(|c| match c { - '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' | '!' => '_', - _ => c, - }) - .collect() -} - -#[derive(Clone, Debug, Parser)] -pub struct Args { - #[clap(long = "home", default_value = "matrix.org")] - pub home_server: String, - - #[clap(short, long)] - pub username: Option, - - #[clap(long, default_value = "Matrix.org Protocol Dumper by Yumechi")] - pub device_name: Option, - - #[clap(long)] - pub device_id: Option, - - #[clap(long, default_value = "config/token.json")] - pub access_token_file: Option, - - #[clap(short, long, default_value = "dump")] - pub out_dir: String, - - #[clap(long)] - pub filter: Vec, - - #[clap(long, short = 'j', default_value = "4")] - pub concurrency: usize, - - #[clap(long, default_value = "config/e2e.db")] - pub e2e_db: String, - - #[clap(long)] - pub password_file: Option, - - #[clap( - long, - help = "The timeout for the key sync in seconds", - default_value = "300" - )] - pub key_sync_timeout: u64, -} - -fn read_password() -> Result { - use crossterm::{execute, style::Print, terminal}; - - terminal::enable_raw_mode()?; - - let mut password = String::new(); - - execute!(std::io::stdout(), Print("Password:"))?; - - loop { - match crossterm::event::read()? { - crossterm::event::Event::Key(event) => match event.code { - crossterm::event::KeyCode::Enter => break, - crossterm::event::KeyCode::Backspace => { - password.pop(); - } - crossterm::event::KeyCode::Char(c) => { - password.push(c); - } - _ => {} - }, - _ => {} - } - } - - execute!(std::io::stdout(), Print("\n"))?; - - terminal::disable_raw_mode()?; - - println!(); - - Ok(password) -} +use tokio::task::JoinSet; #[tokio::main] async fn main() { @@ -127,342 +10,24 @@ async fn main() { let mut js = JoinSet::new(); let mut bg_js = JoinSet::new(); tokio::select! { - _ = run(&mut js, &mut bg_js) => {}, + _ = matrix_dump::run(&mut js, &mut bg_js) => {}, _ = tokio::signal::ctrl_c() => { log::info!("Received Ctrl-C, exiting"); js.abort_all(); }, } - log::info!("Waiting for tasks to finish"); + log::info!("Waiting for tasks to finish, press Ctrl-C to force exit"); + tokio::select! { + _ = async { + bg_js.abort_all(); + while bg_js.join_next().await.is_some() {} + while js.join_next().await.is_some() {} + } => {}, + _ = tokio::signal::ctrl_c() => { + log::info!("Received Ctrl-C again, force exiting"); + }, + } bg_js.abort_all(); - while let Some(_) = bg_js.join_next().await {} + js.abort_all(); js.join_all().await; } - -async fn run(js: &mut JoinSet>, bg_js: &mut JoinSet>) { - let args = Args::parse(); - log::info!("Starting matrix dump, args: {:?}", args); - let http_client = Reqwest::builder().https_only(true).build().unwrap(); - let client = Client::builder() - .server_name(<&ServerName>::try_from(args.home_server.as_str()).unwrap()) - .http_client(http_client.clone()) - .sqlite_store(&args.e2e_db, None) - .build() - .await - .expect("Failed to create client"); - - let client = MatrixClient::new_arc(client); - let client1 = client.clone(); - let access_token_file_clone = args.access_token_file.clone(); - (move || async move { - if let Some(ref access_token_file) = access_token_file_clone { - if std::fs::exists(access_token_file).expect("Failed to check access token file") { - let token = serde_json::from_str( - &tokio::fs::read_to_string(access_token_file) - .await - .expect("Failed to read access token file"), - ) - .expect("Failed to parse access token file"); - - client1 - .client() - .matrix_auth() - .restore_session(token) - .await - .expect("Failed to restore session"); - - log::info!("Restored session"); - return; - } - - if let Some(ref password_file) = args.password_file { - if std::fs::exists(password_file).expect("Failed to check password file") { - log::info!("Logging in with password file"); - let password = tokio::fs::read_to_string(password_file) - .await - .expect("Failed to read password file"); - client1 - .client() - .matrix_auth() - .login_username( - args.username.clone().expect("Username not provided"), - &password, - ) - .initial_device_display_name("Matrix Protocol Dumper By Yumechi") - .await - .expect("Failed to login"); - return; - } - } - - if let Some(ref username) = args.username { - log::info!("Logging in with password prompt"); - let password = read_password().expect("Failed to read password"); - client1 - .client() - .matrix_auth() - .login_username(username.clone(), &password) - .initial_device_display_name("Matrix Protocol Dumper By Yumechi") - .await - .expect("Failed to login"); - return; - } - - panic!("No login method provided"); - } - })() - .await; - - if !client.clone().client().logged_in() { - log::error!("Failed to login"); - return; - } - - if let Some(s) = client.client().matrix_auth().session() { - if let Some(ref access_token_file) = args.access_token_file { - let f = OpenOptions::new() - .create(true) - .write(true) - .truncate(true) - .mode(0o600) - .open(access_token_file) - .expect("Failed to open access token file"); - - serde_json::to_writer(&f, &s).expect("Failed to write access token file"); - } - } - - log::info!("Starting sync"); - client - .clone() - .client() - .sync_once(SyncSettings::default()) - .await - .expect("Failed to sync"); - log::info!("Sync done"); - - let client1 = client.clone(); - bg_js.spawn(async move { - client1.client().sync(SyncSettings::default()).await?; - - Ok(()) - }); - - log::info!("Starting E2E setup"); - match client.clone().setup_e2e().await { - true => log::info!("E2E setup done"), - false => log::error!("E2E setup failed"), - } - - log::info!("Starting room dump"); - - let sem = Arc::new(Semaphore::new(args.concurrency)); - - let (synced_keys_tx, _) = tokio::sync::broadcast::channel::(1); - - let synced_keys_tx = Arc::new(synced_keys_tx); - let synced_keys_tx1 = synced_keys_tx.clone(); - - client.client().add_event_handler( - |ev: matrix_sdk::ruma::events::forwarded_room_key::ToDeviceForwardedRoomKeyEvent| async move { - synced_keys_tx.send(ev.content.room_id.clone()).unwrap(); - }, - ); - - for room in client.client().rooms() { - let mut synced_keys_rx = synced_keys_tx1.subscribe(); - - let sem = sem.clone(); - - let filter = args.filter.clone(); - let out_dir = args.out_dir.clone(); - let room_id = room.room_id().to_owned(); - let room_id_clone = room_id.clone(); - let client1 = client.clone(); - let http_client = http_client.clone(); - js.spawn(async move { - if room.is_encrypted().await.unwrap_or(false) && !room.is_encryption_state_synced() { - log::info!( - "Room {} is encrypted, waiting for at most {} seconds for key sync", - room_id_clone, - args.key_sync_timeout - ); - - let room_id_clone1 = room_id_clone.clone(); - let room_clone = room.clone(); - tokio::select! { - _ = tokio::time::sleep(std::time::Duration::from_secs(args.key_sync_timeout)) => { - log::warn!("Key sync timed out for room {}", room_id); - } - _ = async move { - while let Ok(room_id) = synced_keys_rx.recv().await { - if room_id == room_id_clone1 { - break; - } - } - if !room_clone.is_encryption_state_synced() { - log::warn!("Waiting for another 10 seconds for key sync to finish"); - } - tokio::time::sleep(std::time::Duration::from_secs(10)).await; - if !room_clone.is_encryption_state_synced() { - log::warn!("Key sync timed out for room {}", room_id_clone1); - } - } => { - log::info!("Key sync done for room {}", room_id_clone); - } - } - - } - - let permit = sem.clone().acquire_owned().await; - - - let room_name = room.display_name().await.map(|d| d.to_string()) - .unwrap_or(room.name().unwrap_or("unknown".to_string())); - - let room_dir = - Path::new(&out_dir).join(sanitize_filename(&format!("{}_{}", room_id, room_name))); - - let match_filter = if filter.is_empty() { - true - } else { - filter - .iter() - .any(|filter| room_name.contains(filter) || room_id.as_str().contains(filter)) - }; - - tokio::fs::create_dir_all(&room_dir) - .await - .expect("Failed to create room directory"); - - let meta_path = room_dir.join("meta.json"); - - serde_json::to_writer_pretty( - OpenOptions::new() - .create(true) - .write(true) - .truncate(true) - .open(&meta_path) - .expect("Failed to open meta file"), - &RoomMeta { - id: room_id.as_str(), - name: Some(room_name.as_str()), - state: &room.state(), - }, - )?; - - if !match_filter { - log::debug!("Skipping room: {} ({})", room_id, room_name); - return Ok(()); - } - - log::info!("Dumping room: {} ({})", room_id, room_name); - - let chunk_idx = &AtomicU64::new(0); - let client1 = client1.clone(); - MatrixClient::room_messages(&room, None) - .try_for_each_concurrent(Some(args.concurrency), |msg| { - let room_dir = room_dir.clone(); - let client1 = client1.clone(); - let http_client = http_client.clone(); - async move { - let output = room_dir.join(format!( - "chunk-{}.json", - chunk_idx.fetch_add(1, Ordering::SeqCst) - )); - - let mut out = Vec::with_capacity(msg.len()); - - for event in msg.into_iter() { - let mut fm = None; - match event.event.clone().cast::().deserialize() { - Ok(event) => match event { - AnyTimelineEvent::MessageLike(msg) => match msg { - AnyMessageLikeEvent::RoomMessage(m) => { - match client1.clone() - .try_read_attachment(&http_client, &m) { - Ok(None) => {} - Ok(Some(fut)) => { - match fut.await { - Ok((filename, url, mut byte_stream)) => { - - let file_name = format!("attachment_{}_{}_{}", - m.event_id().as_str(), - sanitize_filename(Url::parse(&url).unwrap().path_segments().unwrap().last().unwrap()), - sanitize_filename(&filename), - ); - - let file = tokio::fs::OpenOptions::new() - .create(true) - .write(true) - .truncate(true) - .open(room_dir.join(&file_name)) - .await?; - - let mut file = tokio::io::BufWriter::new(file); - - loop { - match byte_stream.try_next().await { - Ok(Some(chunk)) => { - file.write_all(&chunk).await?; - } - Ok(None) => { - fm = Some((url, file_name)); - break; - } - Err(e) => { - log::warn!("Failed to get attachment data: {}", e); - } - } - } - file.shutdown().await?; - } - Err(e) => { - log::warn!("Failed to get attachment data: {}", e); - } - }; - - - } - Err(e) => { - log::warn!("Failed to get attachment: {}", e); - } - } - } - _ => {} - } - _ => {} - } - Err(e) => { - log::warn!("Failed to deserialize event: {}", e); - } - } - - out.push(DumpEvent { - event: Raw::from_json(event.event.into_json()), - file_mapping: fm, - encryption_info: event.encryption_info.clone(), - }); - } - - serde_json::to_writer_pretty( - std::fs::OpenOptions::new() - .create(true) - .write(true) - .truncate(true) - .open(output)?, - &out, - )?; - - Ok(()) - } - }) - .await - .expect("Failed to get messages"); - - drop(permit); - - Ok::<_, DumpError>(()) - }); - } -} diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..709f3ac --- /dev/null +++ b/src/model.rs @@ -0,0 +1,17 @@ +use matrix_sdk::{deserialized_responses::EncryptionInfo, RoomState}; +use ruma_common::serde::Raw; +use ruma_events::AnyTimelineEvent; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct DumpEvent { + pub event: Raw, + pub file_mapping: Option<(String, String)>, + pub encryption_info: Option, +} + +#[derive(Clone, Debug, serde::Serialize)] +pub struct RoomMeta<'a> { + pub id: &'a str, + pub name: Option<&'a str>, + pub state: &'a RoomState, +}