refactor code structure

Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
This commit is contained in:
ゆめ 2024-09-04 22:28:59 -05:00
parent d891dff9e5
commit bd6d8625ed
No known key found for this signature in database
7 changed files with 912 additions and 812 deletions

404
src/client.rs Normal file
View file

@ -0,0 +1,404 @@
use std::{pin::Pin, sync::Arc};
use crate::{
e2e::{decrypt_file, ErrOrWrongHash},
io::prompt,
mxc_url_to_https, DumpError,
};
use futures::{future::BoxFuture, stream, StreamExt, TryStream, TryStreamExt};
use matrix_sdk::{
bytes::Bytes,
deserialized_responses::TimelineEvent,
encryption::verification::{Verification, VerificationRequestState},
matrix_auth::MatrixSession,
room::MessagesOptions,
ruma::events::key::verification::VerificationMethod,
Client, Room,
};
use matrix_sdk_crypto::SasState;
use ruma_events::room::{
message::{MessageType, RoomMessageEvent},
MediaSource,
};
use tokio::io::AsyncReadExt;
pub struct MatrixClient {
client: Client,
}
pub type PinTryStream<'s, O, E> =
Pin<Box<dyn TryStream<Ok = O, Error = E, Item = Result<O, E>> + Send + 's>>;
pub struct FileStream<'s> {
pub filename: String,
pub url: String,
pub stream: PinTryStream<'s, Bytes, DumpError>,
}
pub enum AuthMethod<'a> {
Password {
user: &'a str,
password: &'a str,
initial_device_display_name: Option<&'a str>,
},
Session(MatrixSession),
}
impl MatrixClient {
pub fn new(client: Client) -> Self {
Self { client }
}
pub fn new_arc(client: Client) -> Arc<Self> {
Arc::new(Self::new(client))
}
pub fn client(&self) -> &Client {
&self.client
}
pub async fn login(&self, auth: AuthMethod<'_>) -> Result<(), matrix_sdk::Error> {
match auth {
AuthMethod::Password {
user,
password,
initial_device_display_name,
} => {
let tmp = self.client.matrix_auth().login_username(user, &password);
if let Some(name) = initial_device_display_name {
tmp.initial_device_display_name(&name)
} else {
tmp
}
.await
.map(|_| ())
}
AuthMethod::Session(session) => {
self.client.matrix_auth().restore_session(session).await
}
}
}
pub fn try_read_attachment<'c, 'a: 'c>(
self: Arc<Self>,
client: &'c matrix_sdk::reqwest::Client,
msg: &'a RoomMessageEvent,
) -> Result<Option<BoxFuture<'c, Result<FileStream<'c>, DumpError>>>, DumpError> {
macro_rules! impl_file_like {
($msg:expr, $($variant:ident),*) => {
match $msg {
$(
MessageType::$variant(file) => {
Ok(Some(Box::pin(
async move {
let src = match &file.source {
MediaSource::Plain(s) => s,
MediaSource::Encrypted(e) => &e.url,
};
let filename = file.filename.as_deref().map(|s| s.to_string()).unwrap_or(file.body.clone());
let url = mxc_url_to_https(src.as_str(), self.client.homeserver().as_str());
let resp = client.get(&url).send().await?;
let body = resp.bytes_stream();
Ok(FileStream {
filename,
url,
stream: match &file.source {
MediaSource::Plain(_) => {
Box::pin(body.map_err(DumpError::from)) as Pin<Box<dyn TryStream<Ok = Bytes, Error = DumpError, Item = Result<Bytes, DumpError>> + Send>>
}
MediaSource::Encrypted(e) => Box::pin(decrypt_file(e.as_ref(), body).await?.map_ok(|v| Bytes::from(v)).map_err(
|e| match e {
ErrOrWrongHash::Err(e) => e.into(),
ErrOrWrongHash::WrongHash => DumpError::HashMismatch,
},
)) as Pin<Box<dyn TryStream<Ok = Bytes, Error = DumpError, Item = Result<Bytes, DumpError>> + Send>>,
}
})
})))}
)*
_ => Ok(None),
}
};
}
match msg {
RoomMessageEvent::Original(msg) => {
impl_file_like!(&msg.content.msgtype, Image, Video, Audio, File)
}
_ => Ok(None),
}
}
pub async fn setup_e2e(self: Arc<Self>) -> bool {
let client = &self.client;
log::info!("Preparing e2e machine");
client
.encryption()
.wait_for_e2ee_initialization_tasks()
.await;
log::info!("E2E machine ready");
let own_device = client
.encryption()
.get_own_device()
.await
.expect("Failed to get own device")
.expect("No own device found");
if own_device.is_cross_signed_by_owner() {
log::info!("Cross-signing keys are already set up");
return true;
}
let mut stdin = tokio::io::stdin();
let mut stdout = tokio::io::stdout();
let devices = client
.encryption()
.get_user_devices(own_device.user_id())
.await
.expect("Failed to get devices")
.devices()
.collect::<Vec<_>>();
for (i, d) in devices.iter().enumerate() {
log::info!(
"Device {}: {} ({})",
i,
d.display_name().unwrap_or("Unnamed"),
d.device_id()
);
}
let device_num = prompt(
&mut stdin,
&mut stdout,
"Enter device number to verify with: ",
)
.await
.unwrap_or_else(|e| {
log::error!("Failed to read device number: {}", e);
String::new()
})
.trim()
.parse::<usize>()
.expect("Failed to parse device number");
let device = devices.get(device_num).expect("Invalid device number");
log::info!(
"Requesting verification with {} ({})",
device.display_name().unwrap_or("Unnamed"),
device.device_id()
);
let req = match device
.request_verification_with_methods(vec![VerificationMethod::SasV1])
.await
{
Ok(req) => req,
Err(e) => {
log::error!(
"Failed to request verification for {}: {}",
device.device_id(),
e
);
return false;
}
};
let device_name = format!(
"{} ({})",
device.display_name().unwrap_or("Unnamed"),
device.device_id()
);
let device_name_clone = device_name.clone();
let mut c = req.changes();
while let Some(change) = c.next().await {
match change {
VerificationRequestState::Done => {
log::info!("Verification successful for {}", device_name_clone);
return true;
}
VerificationRequestState::Cancelled(info) => {
log::info!(
"Verification canceled for {}: {:?}",
device.device_id(),
info
);
return false;
}
VerificationRequestState::Transitioned { verification } => {
log::info!(
"Verification transitioned for {}: {:?}",
device.device_id(),
verification
);
match verification {
Verification::SasV1(v) => {
v.accept().await.expect("Failed to accept verification");
let emoji_str = v
.emoji()
.map(|emojis| {
emojis
.iter()
.map(|e| format!("{} ({})", e.symbol, e.description))
.collect::<Vec<_>>()
.join(", ")
})
.unwrap_or_else(|| "No emojis".to_string());
let decimals = v
.decimals()
.map(|(n1, n2, n3)| format!("{} {} {}", n1, n2, n3))
.unwrap_or_else(|| "No decimals".to_string());
if prompt(
&mut stdin,
&mut stdout,
&format!(
"Verification for {}:\nEmoji: {}\nDecimals: {}\n Confirm? (y/n): ",
device.device_id(),
emoji_str,
decimals
),
)
.await
.unwrap_or_else(|e| {
log::error!("Failed to read verification response: {}", e);
"n".to_string()
})
.trim() == "y" {
v.confirm().await.expect("Failed to confirm");
} else {
v.cancel().await.expect("Failed to cancel");
}
}
_ => unimplemented!(),
}
}
VerificationRequestState::Ready {
their_methods,
our_methods,
..
} => {
log::info!(
"Verification ready for {}: their methods: {:?}, our methods: {:?}",
device.device_id(),
their_methods,
our_methods
);
req.accept_with_methods(vec![VerificationMethod::SasV1])
.await
.expect("Failed to accept verification");
let sas = req
.start_sas()
.await
.expect("Failed to start SAS")
.expect("No SAS");
sas.accept().await.expect("Failed to accept SAS");
while let Some(event) = sas.changes().next().await {
log::info!("SAS event: {:?}", event);
match event {
SasState::Started { protocols } => {
log::info!("SAS started with protocols: {:?}", protocols);
sas.accept().await.expect("Failed to accept SAS");
}
SasState::Cancelled(c) => {
log::info!("SAS canceled: {:?}", c);
return false;
}
SasState::Done {
verified_devices,
verified_identities,
} => {
log::info!(
"SAS done: verified devices: {:?}, verified identities: {:?}",
verified_devices,
verified_identities
);
}
SasState::KeysExchanged { emojis, decimals } => {
println!(
"Verification for {}:\nEmoji: {}\nDecimals: [{}, {}, {}]\n Confirm? (y/n)",
device.device_id(),
emojis.map(|e| e.emojis
.iter()
.map(|e| format!("{} ({})", e.symbol, e.description))
.collect::<Vec<_>>()
.join(", "),
).unwrap_or_else(|| "No emojis".to_string()),
decimals.0, decimals.1, decimals.2
);
let mut response = String::new();
while let Ok(c) = stdin.read_u8().await {
if c == b'\n' {
break;
}
response.push(c as char);
}
if response.trim() == "y" {
sas.confirm().await.expect("Failed to confirm SAS");
} else {
sas.cancel().await.expect("Failed to cancel SAS");
}
}
SasState::Accepted { accepted_protocols } => {
log::info!("SAS accepted with protocols: {:?}", accepted_protocols);
}
SasState::Confirmed => {
log::info!("SAS confirmed! Waiting for verification to finish");
}
}
}
}
VerificationRequestState::Requested { their_methods, .. } => {
log::info!(
"Verification requested for {}: their methods: {:?}",
device.device_id(),
their_methods
);
}
VerificationRequestState::Created { our_methods } => {
log::info!(
"Verification created for {}: our methods: {:?}",
device.device_id(),
our_methods
);
}
}
}
log::info!("Verification for {} fell through", device.device_id());
false
}
pub fn room_messages(
room: &Room,
since: Option<String>,
) -> impl TryStream<Ok = Vec<TimelineEvent>, Error = matrix_sdk::Error> + '_ {
stream::try_unfold(since, move |since| async move {
let mut opt = MessagesOptions::forward().from(since.as_deref());
opt.limit = 32.try_into().expect("Failed to convert");
room.messages(opt).await.map(|r| {
if r.chunk.is_empty() {
None
} else {
Some((r.chunk, r.end))
}
})
})
}
}

View file

@ -90,7 +90,7 @@ where
} }
Some(Err(e)) => std::task::Poll::Ready(Some(Err(ErrOrWrongHash::Err(e)))), Some(Err(e)) => std::task::Poll::Ready(Some(Err(ErrOrWrongHash::Err(e)))),
None => match self.hasher.take() { None => match self.hasher.take() {
None => return std::task::Poll::Ready(None), None => std::task::Poll::Ready(None),
Some(hash) => { Some(hash) => {
if hash.finalize().as_slice() == self.expected { if hash.finalize().as_slice() == self.expected {
return std::task::Poll::Ready(None); return std::task::Poll::Ready(None);
@ -120,5 +120,5 @@ pub async fn decrypt_file<'s, E: std::error::Error + 's>(
sha256_expect, sha256_expect,
)); ));
try_decrypt(&file.key, data, &iv).await try_decrypt(&file.key, data, iv).await
} }

30
src/filter.rs Normal file
View file

@ -0,0 +1,30 @@
use ruma_client_api::{
filter::{EventFormat, FilterDefinition, RoomEventFilter, RoomFilter},
sync::sync_events::v3::Filter,
};
pub fn minimal_sync_filter() -> Filter {
let mut filter = FilterDefinition::empty();
filter.event_format = EventFormat::Client;
filter.presence = ruma_client_api::filter::Filter::empty();
let mut room_filter = RoomFilter::empty();
let mut room_event_filter = RoomEventFilter::empty();
room_event_filter.types = Some(
vec![
"m.room.encryption",
"m.room.encryption*",
"m.room.create",
"m.room.avatar",
]
.into_iter()
.map(|s| s.to_string())
.collect(),
);
room_filter.timeline = room_event_filter;
filter.room = room_filter;
Filter::FilterDefinition(filter)
}

61
src/io.rs Normal file
View file

@ -0,0 +1,61 @@
pub fn sanitize_filename(name: &str) -> String {
name.chars()
.map(|c| match c {
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' | '!' => '_',
_ => c,
})
.collect()
}
pub async fn prompt<IS: tokio::io::AsyncRead + Unpin, OS: tokio::io::AsyncWrite + Unpin>(
input: &mut IS,
output: &mut OS,
prompt: impl AsRef<str>,
) -> Result<String, tokio::io::Error> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
output.write_all(prompt.as_ref().as_bytes()).await?;
output.flush().await?;
let mut response = String::new();
match input.read_u8().await {
Ok(b'\n') => {}
Ok(byte) => response.push(byte as char),
Err(e) => return Err(e),
}
Ok(response.trim_end().to_string())
}
pub fn read_password() -> Result<String, std::io::Error> {
use crossterm::{execute, style::Print, terminal};
terminal::enable_raw_mode()?;
let mut password = String::new();
execute!(std::io::stdout(), Print("Password:"))?;
loop {
if let crossterm::event::Event::Key(event) = crossterm::event::read()? {
match event.code {
crossterm::event::KeyCode::Enter => break,
crossterm::event::KeyCode::Backspace => {
password.pop();
}
crossterm::event::KeyCode::Char(c) => {
password.push(c);
}
_ => {}
}
}
}
execute!(std::io::stdout(), Print("\n"))?;
terminal::disable_raw_mode()?;
println!();
Ok(password)
}

View file

@ -1,31 +1,78 @@
use std::{pin::Pin, sync::Arc}; use std::{
fs::OpenOptions,
os::unix::fs::OpenOptionsExt,
path::Path,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use e2e::ErrOrWrongHash; use client::{AuthMethod, FileStream};
use futures::{future::BoxFuture, stream, StreamExt, TryStream, TryStreamExt}; use matrix_sdk::{HttpError, IdParseError};
use matrix_sdk::{
bytes::Bytes,
deserialized_responses::TimelineEvent,
encryption::verification::{Verification, VerificationRequestState},
room::MessagesOptions,
ruma::events::key::verification::VerificationMethod,
Client, HttpError, IdParseError, Room,
};
use matrix_sdk_crypto::SasState;
use ruma_client_api::{
filter::{EventFormat, FilterDefinition, RoomEventFilter, RoomFilter},
sync::sync_events::v3::Filter,
};
use ruma_common::MxcUriError; use ruma_common::MxcUriError;
use ruma_events::room::{
message::{MessageType, RoomMessageEvent},
MediaSource,
};
use tokio::io::AsyncReadExt;
pub use crate::client::MatrixClient;
use futures::{TryFutureExt, TryStreamExt};
use io::{read_password, sanitize_filename};
use matrix_sdk::{config::SyncSettings, Client, ServerName};
use model::{DumpEvent, RoomMeta};
use reqwest::Url;
use ruma_client::http_client::Reqwest;
use ruma_common::serde::Raw;
use ruma_events::{AnyMessageLikeEvent, AnyTimelineEvent};
use tokio::{io::AsyncWriteExt, sync::Semaphore, task::JoinSet};
pub mod client;
pub mod e2e; pub mod e2e;
pub mod filter;
pub mod io;
pub mod model;
pub mod serdes; pub mod serdes;
fn mxc_url_to_https(mxc_url: &str, homeserver: &str) -> String { use clap::Parser;
#[derive(Clone, Debug, Parser)]
pub struct Args {
#[clap(long = "home", default_value = "matrix.org")]
pub home_server: String,
#[clap(short, long)]
pub username: Option<String>,
#[clap(long, default_value = "Matrix.org Protocol Dumper by Yumechi")]
pub device_name: Option<String>,
#[clap(long)]
pub device_id: Option<String>,
#[clap(long, default_value = "config/token.json")]
pub access_token_file: Option<String>,
#[clap(short, long, default_value = "dump")]
pub out_dir: String,
#[clap(long)]
pub filter: Vec<String>,
#[clap(long, short = 'j', default_value = "4")]
pub concurrency: usize,
#[clap(long, default_value = "config/e2e.db")]
pub e2e_db: String,
#[clap(long)]
pub password_file: Option<String>,
#[clap(
long,
help = "The timeout for the key sync in seconds",
default_value = "300"
)]
pub key_sync_timeout: u64,
}
pub fn mxc_url_to_https(mxc_url: &str, homeserver: &str) -> String {
format!( format!(
"{}_matrix/media/r0/download/{}", "{}_matrix/media/r0/download/{}",
homeserver, homeserver,
@ -63,387 +110,363 @@ pub enum DumpError {
InvalidId(#[from] IdParseError), InvalidId(#[from] IdParseError),
} }
pub struct MatrixClient { pub async fn dump_room_messages(
client: Client, room: &matrix_sdk::Room,
} out_dir: &Path,
client: Arc<MatrixClient>,
http_client: &Reqwest,
concurrency: usize,
) -> Result<(), DumpError> {
let chunk_idx = &AtomicU64::new(0);
pub fn minimal_sync_filter() -> Filter { MatrixClient::room_messages(room, None)
let mut filter = FilterDefinition::empty(); .try_for_each_concurrent(Some(concurrency), |msg| {
let room_dir = out_dir.to_owned();
let client = client.clone();
let http_client = http_client.clone();
async move {
let output = room_dir.join(format!(
"chunk-{}.json",
chunk_idx.fetch_add(1, Ordering::SeqCst)
));
filter.event_format = EventFormat::Client; let mut out = Vec::with_capacity(msg.len());
filter.presence = ruma_client_api::filter::Filter::empty();
let mut room_filter = RoomFilter::empty(); for event in msg.into_iter() {
let mut room_event_filter = RoomEventFilter::empty(); let mut fm = None;
room_event_filter.types = Some( match event.event.clone().cast::<AnyTimelineEvent>().deserialize() {
vec![ Ok(event) => {
"m.room.encryption", if let AnyTimelineEvent::MessageLike(
"m.room.encryption*", AnyMessageLikeEvent::RoomMessage(m),
"m.room.create", ) = event
"m.room.avatar", {
] match client.clone().try_read_attachment(&http_client, &m) {
.into_iter() Ok(None) => {}
.map(|s| s.to_string()) Ok(Some(fut)) => {
.collect(), match fut.await {
); Ok(FileStream {
room_filter.timeline = room_event_filter; filename,
url,
mut stream,
}) => {
let file_name = format!(
"attachment_{}_{}_{}",
m.event_id().as_str(),
sanitize_filename(
Url::parse(&url)
.unwrap()
.path_segments()
.unwrap()
.last()
.unwrap()
),
sanitize_filename(&filename),
);
filter.room = room_filter; let file = tokio::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(room_dir.join(&file_name))
.await?;
Filter::FilterDefinition(filter) let mut file = tokio::io::BufWriter::new(file);
}
impl MatrixClient { loop {
pub fn new(client: Client) -> Self { match stream.try_next().await {
Self { client } Ok(Some(chunk)) => {
} file.write_all(&chunk).await?;
}
pub fn new_arc(client: Client) -> Arc<Self> { Ok(None) => {
Arc::new(Self::new(client)) fm = Some((url, file_name));
} break;
}
pub fn client(&self) -> &Client { Err(e) => {
&self.client log::warn!(
} "Failed to get attachment data for {}: {}",
m.event_id(),
pub fn try_read_attachment<'c, 'a: 'c>( e
self: Arc<Self>, );
client: &'c matrix_sdk::reqwest::Client, }
msg: &'a RoomMessageEvent, }
) -> Result< }
Option< file.shutdown().await?;
BoxFuture< }
'c, Err(e) => {
Result< log::warn!("Failed to get attachment data for {}: {}", m.event_id(), e);
( }
String, };
String, }
Pin< Err(e) => {
Box< log::warn!("Failed to get attachment data for {}: {}", m.event_id(), e);
dyn TryStream< }
Ok = Bytes,
Error = DumpError,
Item = Result<Bytes, DumpError>,
> + Send
+ 'c,
>,
>,
),
DumpError,
>,
>,
>,
DumpError,
> {
macro_rules! impl_file_like {
($msg:expr, $($variant:ident),*) => {
match $msg {
$(
MessageType::$variant(file) => {
Ok(Some(Box::pin(
async move {
let src = match &file.source {
MediaSource::Plain(s) => s,
MediaSource::Encrypted(e) => &e.url,
};
let filename = file.filename.as_deref().map(|s| s.to_string()).unwrap_or(file.body.clone());
let url = mxc_url_to_https(src.as_str(), self.client.homeserver().as_str());
let resp = client.get(&url).send().await?;
let body = resp.bytes_stream();
Ok((filename, url, match &file.source {
MediaSource::Plain(_) => {
Box::pin(body.map_err(DumpError::from)) as Pin<Box<dyn TryStream<Ok = Bytes, Error = DumpError, Item = Result<Bytes, DumpError>> + Send>>
} }
MediaSource::Encrypted(e) => Box::pin(e2e::decrypt_file(e.as_ref(), body).await?.map_ok(|v| Bytes::from(v)).map_err( }
|e| match e { }
ErrOrWrongHash::Err(e) => e.into(), Err(e) => {
ErrOrWrongHash::WrongHash => DumpError::HashMismatch, log::warn!("Failed to deserialize event: {}", e);
}, }
)) as Pin<Box<dyn TryStream<Ok = Bytes, Error = DumpError, Item = Result<Bytes, DumpError>> + Send>>, }
}))
})))} out.push(DumpEvent {
)* event: Raw::from_json(event.event.into_json()),
_ => Ok(None), file_mapping: fm,
encryption_info: event.encryption_info.clone(),
});
} }
};
} serde_json::to_writer_pretty(
match msg { std::fs::OpenOptions::new()
RoomMessageEvent::Original(msg) => { .create(true)
impl_file_like!(&msg.content.msgtype, Image, Video, Audio, File) .write(true)
.truncate(true)
.open(output)?,
&out,
)?;
Ok(())
} }
_ => Ok(None), })
.await?;
Ok(())
}
pub async fn run(
js: &mut JoinSet<Result<(), DumpError>>,
bg_js: &mut JoinSet<Result<(), DumpError>>,
) {
let args = Args::parse();
log::info!("Starting matrix dump, args: {:?}", args);
let http_client = Reqwest::builder().https_only(true).build().unwrap();
let client = Client::builder()
.server_name(<&ServerName>::try_from(args.home_server.as_str()).unwrap())
.http_client(http_client.clone())
.sqlite_store(&args.e2e_db, None)
.build()
.await
.expect("Failed to create client");
let client = MatrixClient::new_arc(client);
let client1 = client.clone();
let access_token_file_clone = args.access_token_file.clone();
async move {
log::info!("Logging in");
if let Some(ref access_token_file) = access_token_file_clone {
if std::fs::exists(access_token_file).expect("Failed to check access token file") {
let token = serde_json::from_str(
&tokio::fs::read_to_string(access_token_file)
.await
.expect("Failed to read access token file"),
)
.expect("Failed to parse access token file");
client1
.login(AuthMethod::Session(token))
.await
.expect("Failed to login");
log::info!("Restored session");
return;
}
if let Some(ref password_file) = args.password_file {
if std::fs::exists(password_file).expect("Failed to check password file") {
log::info!("Logging in with password file");
let password = tokio::fs::read_to_string(password_file)
.await
.expect("Failed to read password file");
client1
.login(AuthMethod::Password {
user: args.username.as_ref().expect("No username provided"),
password: password.trim(),
initial_device_display_name: args.device_name.as_deref(),
})
.await
.expect("Failed to login");
return;
}
}
if let Some(ref username) = args.username {
log::info!("Logging in with password prompt");
let password = read_password().expect("Failed to read password");
client1
.login(AuthMethod::Password {
user: username,
password: &password,
initial_device_display_name: args.device_name.as_deref(),
})
.await
.expect("Failed to login");
return;
}
panic!("No login method provided");
}
}
.await;
if !client.clone().client().logged_in() {
log::error!("Failed to login");
return;
}
if let Some(s) = client.client().matrix_auth().session() {
if let Some(ref access_token_file) = args.access_token_file {
let f = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.mode(0o600)
.open(access_token_file)
.expect("Failed to open access token file");
serde_json::to_writer(&f, &s).expect("Failed to write access token file");
} }
} }
pub async fn setup_e2e(self: Arc<Self>) -> bool { {
let client = &self.client; log::info!("Starting sync, may take up to a few minutes");
log::info!("Preparing e2e machine");
client client
.encryption() .clone()
.wait_for_e2ee_initialization_tasks() .client()
.await; .sync_once(SyncSettings::default())
log::info!("E2E machine ready");
let own_device = client
.encryption()
.get_own_device()
.await .await
.expect("Failed to get own device") .expect("Failed to sync");
.expect("No own device found"); log::info!("Sync done");
if own_device.is_cross_signed_by_owner() { let client1 = client.clone();
log::info!("Cross-signing keys are already set up"); bg_js.spawn(async move {
return true; client1
.client()
.sync(SyncSettings::default())
.map_err(DumpError::from)
.await
});
}
{
log::info!("Starting E2E setup");
match client.clone().setup_e2e().await {
true => log::info!("E2E setup done"),
false => log::error!("E2E setup failed, E2E will not be decrypted"),
} }
}
let mut stdin = tokio::io::stdin(); {
log::info!("Starting room dump");
let devices = client let sem = Arc::new(Semaphore::new(args.concurrency));
.encryption()
.get_user_devices(own_device.user_id())
.await
.expect("Failed to get devices")
.devices()
.collect::<Vec<_>>();
for (i, d) in devices.iter().enumerate() { let (synced_keys_tx, synced_keys_rx) =
log::info!( tokio::sync::broadcast::channel::<matrix_sdk::ruma::OwnedRoomId>(1);
"Device {}: {} ({})",
i,
d.display_name().unwrap_or_else(|| "Unnamed"),
d.device_id()
);
}
println!("Enter the device number to verify with: "); let synced_keys_tx = Arc::new(synced_keys_tx);
let mut response = String::new();
while let Some(c) = stdin.read_u8().await.ok() {
if c == b'\n' {
break;
}
response.push(c as char);
}
let device_num = response client.client().add_event_handler(
.trim() move |ev: matrix_sdk::ruma::events::forwarded_room_key::ToDeviceForwardedRoomKeyEvent| async move {
.parse::<usize>() synced_keys_tx.send(ev.content.room_id.clone()).unwrap();
.expect("Failed to parse device number"); },
let device = devices.get(device_num).expect("Invalid device number");
log::info!(
"Requesting verification with {} ({})",
device.display_name().unwrap_or_else(|| "Unnamed"),
device.device_id()
); );
let req = match device for room in client.client().rooms() {
.request_verification_with_methods(vec![VerificationMethod::SasV1]) let mut synced_keys_rx = synced_keys_rx.resubscribe();
.await
{
Ok(req) => req,
Err(e) => {
log::error!(
"Failed to request verification for {}: {}",
device.device_id(),
e
);
return false;
}
};
let device_name = format!( let sem = sem.clone();
"{} ({})",
device.display_name().unwrap_or_else(|| "Unnamed"),
device.device_id()
);
let device_name_clone = device_name.clone(); let filter = args.filter.clone();
let out_dir = args.out_dir.clone();
let room_id = room.room_id().to_owned();
let room_id_clone = room_id.clone();
let client1 = client.clone();
let http_client = http_client.clone();
let mut c = req.changes(); js.spawn(async move {
if room.is_encrypted().await.unwrap_or(false) && !room.is_encryption_state_synced() {
while let Some(change) = c.next().await {
match change {
VerificationRequestState::Done => {
log::info!("Verification successful for {}", device_name_clone);
return true;
}
VerificationRequestState::Cancelled(info) => {
log::info!( log::info!(
"Verification canceled for {}: {:?}", "Room {} is encrypted, waiting for at most {} seconds for key sync",
device.device_id(), room_id_clone,
info args.key_sync_timeout
); );
return false;
}
VerificationRequestState::Transitioned { verification } => {
log::info!(
"Verification transitioned for {}: {:?}",
device.device_id(),
verification
);
match verification {
Verification::SasV1(v) => {
v.accept().await.expect("Failed to accept verification");
let emoji_str = v
.emoji()
.map(|emojis| {
emojis
.iter()
.map(|e| format!("{} ({})", e.symbol, e.description))
.collect::<Vec<_>>()
.join(", ")
})
.unwrap_or_else(|| "No emojis".to_string());
let decimals = v
.decimals()
.map(|(n1, n2, n3)| format!("{} {} {}", n1, n2, n3))
.unwrap_or_else(|| "No decimals".to_string());
println!( let room_id_clone1 = room_id_clone.clone();
"Verification for {}:\nEmoji: {}\nDecimals: {}\n Confirm? (y/n)", let room_clone = room.clone();
device.device_id(), tokio::select! {
emoji_str, _ = tokio::time::sleep(std::time::Duration::from_secs(args.key_sync_timeout)) => {
decimals log::warn!("Key sync timed out for room {}", room_id);
); }
_ = async move {
let mut response = String::new(); while let Ok(room_id) = synced_keys_rx.recv().await {
while let Some(c) = stdin.read_u8().await.ok() { if room_id == room_id_clone1 {
if c == b'\n' {
break; break;
} }
response.push(c as char);
} }
if !room_clone.is_encryption_state_synced() {
if response.trim() == "y" { log::warn!("Waiting for another 10 seconds for key sync to finish");
v.confirm().await.expect("Failed to confirm");
} else {
v.cancel().await.expect("Failed to cancel");
} }
} tokio::time::sleep(std::time::Duration::from_secs(10)).await;
_ => unimplemented!(), if !room_clone.is_encryption_state_synced() {
} log::warn!("Key sync timed out for room {}", room_id_clone1);
}
VerificationRequestState::Ready {
their_methods,
our_methods,
..
} => {
log::info!(
"Verification ready for {}: their methods: {:?}, our methods: {:?}",
device.device_id(),
their_methods,
our_methods
);
req.accept_with_methods(vec![VerificationMethod::SasV1])
.await
.expect("Failed to accept verification");
let sas = req
.start_sas()
.await
.expect("Failed to start SAS")
.expect("No SAS");
sas.accept().await.expect("Failed to accept SAS");
while let Some(event) = sas.changes().next().await {
log::info!("SAS event: {:?}", event);
match event {
SasState::Started { protocols } => {
log::info!("SAS started with protocols: {:?}", protocols);
sas.accept().await.expect("Failed to accept SAS");
}
SasState::Cancelled(c) => {
log::info!("SAS canceled: {:?}", c);
return false;
}
SasState::Done {
verified_devices,
verified_identities,
} => {
log::info!(
"SAS done: verified devices: {:?}, verified identities: {:?}",
verified_devices,
verified_identities
);
}
SasState::KeysExchanged { emojis, decimals } => {
println!(
"Verification for {}:\nEmoji: {}\nDecimals: {}\n Confirm? (y/n)",
device.device_id(),
emojis.map(|e| e.emojis
.iter()
.map(|e| format!("{} ({})", e.symbol, e.description))
.collect::<Vec<_>>()
.join(", "),
).unwrap_or_else(|| "No emojis".to_string()),
format!("{} {} {}", decimals.0, decimals.1, decimals.2)
);
let mut response = String::new();
while let Some(c) = stdin.read_u8().await.ok() {
if c == b'\n' {
break;
}
response.push(c as char);
}
if response.trim() == "y" {
sas.confirm().await.expect("Failed to confirm SAS");
} else {
sas.cancel().await.expect("Failed to cancel SAS");
}
}
SasState::Accepted { accepted_protocols } => {
log::info!("SAS accepted with protocols: {:?}", accepted_protocols);
}
SasState::Confirmed => {
log::info!("SAS confirmed! Waiting for verification to finish");
} }
} => {
log::info!("Key sync done for room {}", room_id_clone);
} }
} }
} }
VerificationRequestState::Requested { their_methods, .. } => {
log::info!(
"Verification requested for {}: their methods: {:?}",
device.device_id(),
their_methods
);
}
VerificationRequestState::Created { our_methods } => {
log::info!(
"Verification created for {}: our methods: {:?}",
device.device_id(),
our_methods
);
}
}
}
log::info!("Verification for {} fell through", device.device_id()); let permit = sem.clone().acquire_owned().await;
false let room_name = room.display_name().await.map(|d| d.to_string())
} .unwrap_or(room.name().unwrap_or("unknown".to_string()));
pub fn room_messages( let room_dir =
room: &Room, Path::new(&out_dir).join(sanitize_filename(&format!("{}_{}", room_id, room_name)));
since: Option<String>,
) -> impl TryStream<Ok = Vec<TimelineEvent>, Error = matrix_sdk::Error> + '_ {
stream::try_unfold(since, move |since| async move {
let mut opt = MessagesOptions::forward().from(since.as_deref());
opt.limit = 32.try_into().expect("Failed to convert");
room.messages(opt).await.map(|r| { let match_filter = if filter.is_empty() {
if r.chunk.is_empty() { true
None
} else { } else {
Some((r.chunk, r.end)) filter
.iter()
.any(|filter| room_name.contains(filter) || room_id.as_str().contains(filter))
};
tokio::fs::create_dir_all(&room_dir)
.await
.expect("Failed to create room directory");
let meta_path = room_dir.join("meta.json");
serde_json::to_writer_pretty(
OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&meta_path)
.expect("Failed to open meta file"),
&RoomMeta {
id: room_id.as_str(),
name: Some(room_name.as_str()),
state: &room.state(),
},
)?;
if !match_filter {
log::debug!("Skipping room: {} ({})", room_id, room_name);
return Ok(());
} }
})
}) log::info!("Dumping room: {} ({})", room_id, room_name);
dump_room_messages(&room, &room_dir, client1, &http_client, args.concurrency).await?;
drop(permit);
Ok(())
}); /* js.spawn */
}
drop(synced_keys_rx);
} }
} }

View file

@ -1,121 +1,4 @@
use std::{ use tokio::task::JoinSet;
fs::OpenOptions,
os::unix::fs::OpenOptionsExt,
path::Path,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use clap::Parser;
use futures::TryStreamExt;
use matrix_dump::{DumpError, MatrixClient};
use matrix_sdk::{
config::SyncSettings, deserialized_responses::EncryptionInfo, Client, RoomState, ServerName,
};
use reqwest::Url;
use ruma_client::http_client::Reqwest;
use ruma_common::serde::Raw;
use ruma_events::{AnyMessageLikeEvent, AnyTimelineEvent};
use tokio::{io::AsyncWriteExt, sync::Semaphore, task::JoinSet};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct DumpEvent {
event: Raw<AnyTimelineEvent>,
file_mapping: Option<(String, String)>,
encryption_info: Option<EncryptionInfo>,
}
#[derive(Clone, Debug, serde::Serialize)]
pub struct RoomMeta<'a> {
pub id: &'a str,
pub name: Option<&'a str>,
pub state: &'a RoomState,
}
fn sanitize_filename(name: &str) -> String {
name.chars()
.map(|c| match c {
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' | '!' => '_',
_ => c,
})
.collect()
}
#[derive(Clone, Debug, Parser)]
pub struct Args {
#[clap(long = "home", default_value = "matrix.org")]
pub home_server: String,
#[clap(short, long)]
pub username: Option<String>,
#[clap(long, default_value = "Matrix.org Protocol Dumper by Yumechi")]
pub device_name: Option<String>,
#[clap(long)]
pub device_id: Option<String>,
#[clap(long, default_value = "config/token.json")]
pub access_token_file: Option<String>,
#[clap(short, long, default_value = "dump")]
pub out_dir: String,
#[clap(long)]
pub filter: Vec<String>,
#[clap(long, short = 'j', default_value = "4")]
pub concurrency: usize,
#[clap(long, default_value = "config/e2e.db")]
pub e2e_db: String,
#[clap(long)]
pub password_file: Option<String>,
#[clap(
long,
help = "The timeout for the key sync in seconds",
default_value = "300"
)]
pub key_sync_timeout: u64,
}
fn read_password() -> Result<String, std::io::Error> {
use crossterm::{execute, style::Print, terminal};
terminal::enable_raw_mode()?;
let mut password = String::new();
execute!(std::io::stdout(), Print("Password:"))?;
loop {
match crossterm::event::read()? {
crossterm::event::Event::Key(event) => match event.code {
crossterm::event::KeyCode::Enter => break,
crossterm::event::KeyCode::Backspace => {
password.pop();
}
crossterm::event::KeyCode::Char(c) => {
password.push(c);
}
_ => {}
},
_ => {}
}
}
execute!(std::io::stdout(), Print("\n"))?;
terminal::disable_raw_mode()?;
println!();
Ok(password)
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@ -127,342 +10,24 @@ async fn main() {
let mut js = JoinSet::new(); let mut js = JoinSet::new();
let mut bg_js = JoinSet::new(); let mut bg_js = JoinSet::new();
tokio::select! { tokio::select! {
_ = run(&mut js, &mut bg_js) => {}, _ = matrix_dump::run(&mut js, &mut bg_js) => {},
_ = tokio::signal::ctrl_c() => { _ = tokio::signal::ctrl_c() => {
log::info!("Received Ctrl-C, exiting"); log::info!("Received Ctrl-C, exiting");
js.abort_all(); js.abort_all();
}, },
} }
log::info!("Waiting for tasks to finish"); log::info!("Waiting for tasks to finish, press Ctrl-C to force exit");
tokio::select! {
_ = async {
bg_js.abort_all();
while bg_js.join_next().await.is_some() {}
while js.join_next().await.is_some() {}
} => {},
_ = tokio::signal::ctrl_c() => {
log::info!("Received Ctrl-C again, force exiting");
},
}
bg_js.abort_all(); bg_js.abort_all();
while let Some(_) = bg_js.join_next().await {} js.abort_all();
js.join_all().await; js.join_all().await;
} }
async fn run(js: &mut JoinSet<Result<(), DumpError>>, bg_js: &mut JoinSet<Result<(), DumpError>>) {
let args = Args::parse();
log::info!("Starting matrix dump, args: {:?}", args);
let http_client = Reqwest::builder().https_only(true).build().unwrap();
let client = Client::builder()
.server_name(<&ServerName>::try_from(args.home_server.as_str()).unwrap())
.http_client(http_client.clone())
.sqlite_store(&args.e2e_db, None)
.build()
.await
.expect("Failed to create client");
let client = MatrixClient::new_arc(client);
let client1 = client.clone();
let access_token_file_clone = args.access_token_file.clone();
(move || async move {
if let Some(ref access_token_file) = access_token_file_clone {
if std::fs::exists(access_token_file).expect("Failed to check access token file") {
let token = serde_json::from_str(
&tokio::fs::read_to_string(access_token_file)
.await
.expect("Failed to read access token file"),
)
.expect("Failed to parse access token file");
client1
.client()
.matrix_auth()
.restore_session(token)
.await
.expect("Failed to restore session");
log::info!("Restored session");
return;
}
if let Some(ref password_file) = args.password_file {
if std::fs::exists(password_file).expect("Failed to check password file") {
log::info!("Logging in with password file");
let password = tokio::fs::read_to_string(password_file)
.await
.expect("Failed to read password file");
client1
.client()
.matrix_auth()
.login_username(
args.username.clone().expect("Username not provided"),
&password,
)
.initial_device_display_name("Matrix Protocol Dumper By Yumechi")
.await
.expect("Failed to login");
return;
}
}
if let Some(ref username) = args.username {
log::info!("Logging in with password prompt");
let password = read_password().expect("Failed to read password");
client1
.client()
.matrix_auth()
.login_username(username.clone(), &password)
.initial_device_display_name("Matrix Protocol Dumper By Yumechi")
.await
.expect("Failed to login");
return;
}
panic!("No login method provided");
}
})()
.await;
if !client.clone().client().logged_in() {
log::error!("Failed to login");
return;
}
if let Some(s) = client.client().matrix_auth().session() {
if let Some(ref access_token_file) = args.access_token_file {
let f = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.mode(0o600)
.open(access_token_file)
.expect("Failed to open access token file");
serde_json::to_writer(&f, &s).expect("Failed to write access token file");
}
}
log::info!("Starting sync");
client
.clone()
.client()
.sync_once(SyncSettings::default())
.await
.expect("Failed to sync");
log::info!("Sync done");
let client1 = client.clone();
bg_js.spawn(async move {
client1.client().sync(SyncSettings::default()).await?;
Ok(())
});
log::info!("Starting E2E setup");
match client.clone().setup_e2e().await {
true => log::info!("E2E setup done"),
false => log::error!("E2E setup failed"),
}
log::info!("Starting room dump");
let sem = Arc::new(Semaphore::new(args.concurrency));
let (synced_keys_tx, _) = tokio::sync::broadcast::channel::<matrix_sdk::ruma::OwnedRoomId>(1);
let synced_keys_tx = Arc::new(synced_keys_tx);
let synced_keys_tx1 = synced_keys_tx.clone();
client.client().add_event_handler(
|ev: matrix_sdk::ruma::events::forwarded_room_key::ToDeviceForwardedRoomKeyEvent| async move {
synced_keys_tx.send(ev.content.room_id.clone()).unwrap();
},
);
for room in client.client().rooms() {
let mut synced_keys_rx = synced_keys_tx1.subscribe();
let sem = sem.clone();
let filter = args.filter.clone();
let out_dir = args.out_dir.clone();
let room_id = room.room_id().to_owned();
let room_id_clone = room_id.clone();
let client1 = client.clone();
let http_client = http_client.clone();
js.spawn(async move {
if room.is_encrypted().await.unwrap_or(false) && !room.is_encryption_state_synced() {
log::info!(
"Room {} is encrypted, waiting for at most {} seconds for key sync",
room_id_clone,
args.key_sync_timeout
);
let room_id_clone1 = room_id_clone.clone();
let room_clone = room.clone();
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_secs(args.key_sync_timeout)) => {
log::warn!("Key sync timed out for room {}", room_id);
}
_ = async move {
while let Ok(room_id) = synced_keys_rx.recv().await {
if room_id == room_id_clone1 {
break;
}
}
if !room_clone.is_encryption_state_synced() {
log::warn!("Waiting for another 10 seconds for key sync to finish");
}
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
if !room_clone.is_encryption_state_synced() {
log::warn!("Key sync timed out for room {}", room_id_clone1);
}
} => {
log::info!("Key sync done for room {}", room_id_clone);
}
}
}
let permit = sem.clone().acquire_owned().await;
let room_name = room.display_name().await.map(|d| d.to_string())
.unwrap_or(room.name().unwrap_or("unknown".to_string()));
let room_dir =
Path::new(&out_dir).join(sanitize_filename(&format!("{}_{}", room_id, room_name)));
let match_filter = if filter.is_empty() {
true
} else {
filter
.iter()
.any(|filter| room_name.contains(filter) || room_id.as_str().contains(filter))
};
tokio::fs::create_dir_all(&room_dir)
.await
.expect("Failed to create room directory");
let meta_path = room_dir.join("meta.json");
serde_json::to_writer_pretty(
OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&meta_path)
.expect("Failed to open meta file"),
&RoomMeta {
id: room_id.as_str(),
name: Some(room_name.as_str()),
state: &room.state(),
},
)?;
if !match_filter {
log::debug!("Skipping room: {} ({})", room_id, room_name);
return Ok(());
}
log::info!("Dumping room: {} ({})", room_id, room_name);
let chunk_idx = &AtomicU64::new(0);
let client1 = client1.clone();
MatrixClient::room_messages(&room, None)
.try_for_each_concurrent(Some(args.concurrency), |msg| {
let room_dir = room_dir.clone();
let client1 = client1.clone();
let http_client = http_client.clone();
async move {
let output = room_dir.join(format!(
"chunk-{}.json",
chunk_idx.fetch_add(1, Ordering::SeqCst)
));
let mut out = Vec::with_capacity(msg.len());
for event in msg.into_iter() {
let mut fm = None;
match event.event.clone().cast::<AnyTimelineEvent>().deserialize() {
Ok(event) => match event {
AnyTimelineEvent::MessageLike(msg) => match msg {
AnyMessageLikeEvent::RoomMessage(m) => {
match client1.clone()
.try_read_attachment(&http_client, &m) {
Ok(None) => {}
Ok(Some(fut)) => {
match fut.await {
Ok((filename, url, mut byte_stream)) => {
let file_name = format!("attachment_{}_{}_{}",
m.event_id().as_str(),
sanitize_filename(Url::parse(&url).unwrap().path_segments().unwrap().last().unwrap()),
sanitize_filename(&filename),
);
let file = tokio::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(room_dir.join(&file_name))
.await?;
let mut file = tokio::io::BufWriter::new(file);
loop {
match byte_stream.try_next().await {
Ok(Some(chunk)) => {
file.write_all(&chunk).await?;
}
Ok(None) => {
fm = Some((url, file_name));
break;
}
Err(e) => {
log::warn!("Failed to get attachment data: {}", e);
}
}
}
file.shutdown().await?;
}
Err(e) => {
log::warn!("Failed to get attachment data: {}", e);
}
};
}
Err(e) => {
log::warn!("Failed to get attachment: {}", e);
}
}
}
_ => {}
}
_ => {}
}
Err(e) => {
log::warn!("Failed to deserialize event: {}", e);
}
}
out.push(DumpEvent {
event: Raw::from_json(event.event.into_json()),
file_mapping: fm,
encryption_info: event.encryption_info.clone(),
});
}
serde_json::to_writer_pretty(
std::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(output)?,
&out,
)?;
Ok(())
}
})
.await
.expect("Failed to get messages");
drop(permit);
Ok::<_, DumpError>(())
});
}
}

17
src/model.rs Normal file
View file

@ -0,0 +1,17 @@
use matrix_sdk::{deserialized_responses::EncryptionInfo, RoomState};
use ruma_common::serde::Raw;
use ruma_events::AnyTimelineEvent;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct DumpEvent {
pub event: Raw<AnyTimelineEvent>,
pub file_mapping: Option<(String, String)>,
pub encryption_info: Option<EncryptionInfo>,
}
#[derive(Clone, Debug, serde::Serialize)]
pub struct RoomMeta<'a> {
pub id: &'a str,
pub name: Option<&'a str>,
pub state: &'a RoomState,
}