matrix-dump/src/lib.rs
eternal-flame-AD bd6d8625ed
refactor code structure
Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
2024-09-04 22:28:59 -05:00

472 lines
17 KiB
Rust

use std::{
fs::OpenOptions,
os::unix::fs::OpenOptionsExt,
path::Path,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use client::{AuthMethod, FileStream};
use matrix_sdk::{HttpError, IdParseError};
use ruma_common::MxcUriError;
pub use crate::client::MatrixClient;
use futures::{TryFutureExt, TryStreamExt};
use io::{read_password, sanitize_filename};
use matrix_sdk::{config::SyncSettings, Client, ServerName};
use model::{DumpEvent, RoomMeta};
use reqwest::Url;
use ruma_client::http_client::Reqwest;
use ruma_common::serde::Raw;
use ruma_events::{AnyMessageLikeEvent, AnyTimelineEvent};
use tokio::{io::AsyncWriteExt, sync::Semaphore, task::JoinSet};
pub mod client;
pub mod e2e;
pub mod filter;
pub mod io;
pub mod model;
pub mod serdes;
use clap::Parser;
#[derive(Clone, Debug, Parser)]
pub struct Args {
#[clap(long = "home", default_value = "matrix.org")]
pub home_server: String,
#[clap(short, long)]
pub username: Option<String>,
#[clap(long, default_value = "Matrix.org Protocol Dumper by Yumechi")]
pub device_name: Option<String>,
#[clap(long)]
pub device_id: Option<String>,
#[clap(long, default_value = "config/token.json")]
pub access_token_file: Option<String>,
#[clap(short, long, default_value = "dump")]
pub out_dir: String,
#[clap(long)]
pub filter: Vec<String>,
#[clap(long, short = 'j', default_value = "4")]
pub concurrency: usize,
#[clap(long, default_value = "config/e2e.db")]
pub e2e_db: String,
#[clap(long)]
pub password_file: Option<String>,
#[clap(
long,
help = "The timeout for the key sync in seconds",
default_value = "300"
)]
pub key_sync_timeout: u64,
}
pub fn mxc_url_to_https(mxc_url: &str, homeserver: &str) -> String {
format!(
"{}_matrix/media/r0/download/{}",
homeserver,
mxc_url.trim_start_matches("mxc://")
)
}
#[derive(Debug, thiserror::Error)]
pub enum DumpError {
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("Matrix SDK error: {0}")]
Matrix(#[from] matrix_sdk::Error),
#[error("IO error: {0}")]
Tokio(#[from] tokio::io::Error),
#[error("MxcUri error: {0}")]
MxcUri(#[from] MxcUriError),
#[error("HTTP error: {0}")]
HttpError(#[from] HttpError),
#[error("Unable to decrypt media: {0}")]
Decrypt(#[from] e2e::DecryptError),
#[error("Failed to verify media hash")]
HashMismatch,
#[error("Reqwest error: {0}")]
Reqwest(#[from] matrix_sdk::reqwest::Error),
#[error("Invalid ID: {0}")]
InvalidId(#[from] IdParseError),
}
pub async fn dump_room_messages(
room: &matrix_sdk::Room,
out_dir: &Path,
client: Arc<MatrixClient>,
http_client: &Reqwest,
concurrency: usize,
) -> Result<(), DumpError> {
let chunk_idx = &AtomicU64::new(0);
MatrixClient::room_messages(room, None)
.try_for_each_concurrent(Some(concurrency), |msg| {
let room_dir = out_dir.to_owned();
let client = client.clone();
let http_client = http_client.clone();
async move {
let output = room_dir.join(format!(
"chunk-{}.json",
chunk_idx.fetch_add(1, Ordering::SeqCst)
));
let mut out = Vec::with_capacity(msg.len());
for event in msg.into_iter() {
let mut fm = None;
match event.event.clone().cast::<AnyTimelineEvent>().deserialize() {
Ok(event) => {
if let AnyTimelineEvent::MessageLike(
AnyMessageLikeEvent::RoomMessage(m),
) = event
{
match client.clone().try_read_attachment(&http_client, &m) {
Ok(None) => {}
Ok(Some(fut)) => {
match fut.await {
Ok(FileStream {
filename,
url,
mut stream,
}) => {
let file_name = format!(
"attachment_{}_{}_{}",
m.event_id().as_str(),
sanitize_filename(
Url::parse(&url)
.unwrap()
.path_segments()
.unwrap()
.last()
.unwrap()
),
sanitize_filename(&filename),
);
let file = tokio::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(room_dir.join(&file_name))
.await?;
let mut file = tokio::io::BufWriter::new(file);
loop {
match stream.try_next().await {
Ok(Some(chunk)) => {
file.write_all(&chunk).await?;
}
Ok(None) => {
fm = Some((url, file_name));
break;
}
Err(e) => {
log::warn!(
"Failed to get attachment data for {}: {}",
m.event_id(),
e
);
}
}
}
file.shutdown().await?;
}
Err(e) => {
log::warn!("Failed to get attachment data for {}: {}", m.event_id(), e);
}
};
}
Err(e) => {
log::warn!("Failed to get attachment data for {}: {}", m.event_id(), e);
}
}
}
}
Err(e) => {
log::warn!("Failed to deserialize event: {}", e);
}
}
out.push(DumpEvent {
event: Raw::from_json(event.event.into_json()),
file_mapping: fm,
encryption_info: event.encryption_info.clone(),
});
}
serde_json::to_writer_pretty(
std::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(output)?,
&out,
)?;
Ok(())
}
})
.await?;
Ok(())
}
pub async fn run(
js: &mut JoinSet<Result<(), DumpError>>,
bg_js: &mut JoinSet<Result<(), DumpError>>,
) {
let args = Args::parse();
log::info!("Starting matrix dump, args: {:?}", args);
let http_client = Reqwest::builder().https_only(true).build().unwrap();
let client = Client::builder()
.server_name(<&ServerName>::try_from(args.home_server.as_str()).unwrap())
.http_client(http_client.clone())
.sqlite_store(&args.e2e_db, None)
.build()
.await
.expect("Failed to create client");
let client = MatrixClient::new_arc(client);
let client1 = client.clone();
let access_token_file_clone = args.access_token_file.clone();
async move {
log::info!("Logging in");
if let Some(ref access_token_file) = access_token_file_clone {
if std::fs::exists(access_token_file).expect("Failed to check access token file") {
let token = serde_json::from_str(
&tokio::fs::read_to_string(access_token_file)
.await
.expect("Failed to read access token file"),
)
.expect("Failed to parse access token file");
client1
.login(AuthMethod::Session(token))
.await
.expect("Failed to login");
log::info!("Restored session");
return;
}
if let Some(ref password_file) = args.password_file {
if std::fs::exists(password_file).expect("Failed to check password file") {
log::info!("Logging in with password file");
let password = tokio::fs::read_to_string(password_file)
.await
.expect("Failed to read password file");
client1
.login(AuthMethod::Password {
user: args.username.as_ref().expect("No username provided"),
password: password.trim(),
initial_device_display_name: args.device_name.as_deref(),
})
.await
.expect("Failed to login");
return;
}
}
if let Some(ref username) = args.username {
log::info!("Logging in with password prompt");
let password = read_password().expect("Failed to read password");
client1
.login(AuthMethod::Password {
user: username,
password: &password,
initial_device_display_name: args.device_name.as_deref(),
})
.await
.expect("Failed to login");
return;
}
panic!("No login method provided");
}
}
.await;
if !client.clone().client().logged_in() {
log::error!("Failed to login");
return;
}
if let Some(s) = client.client().matrix_auth().session() {
if let Some(ref access_token_file) = args.access_token_file {
let f = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.mode(0o600)
.open(access_token_file)
.expect("Failed to open access token file");
serde_json::to_writer(&f, &s).expect("Failed to write access token file");
}
}
{
log::info!("Starting sync, may take up to a few minutes");
client
.clone()
.client()
.sync_once(SyncSettings::default())
.await
.expect("Failed to sync");
log::info!("Sync done");
let client1 = client.clone();
bg_js.spawn(async move {
client1
.client()
.sync(SyncSettings::default())
.map_err(DumpError::from)
.await
});
}
{
log::info!("Starting E2E setup");
match client.clone().setup_e2e().await {
true => log::info!("E2E setup done"),
false => log::error!("E2E setup failed, E2E will not be decrypted"),
}
}
{
log::info!("Starting room dump");
let sem = Arc::new(Semaphore::new(args.concurrency));
let (synced_keys_tx, synced_keys_rx) =
tokio::sync::broadcast::channel::<matrix_sdk::ruma::OwnedRoomId>(1);
let synced_keys_tx = Arc::new(synced_keys_tx);
client.client().add_event_handler(
move |ev: matrix_sdk::ruma::events::forwarded_room_key::ToDeviceForwardedRoomKeyEvent| async move {
synced_keys_tx.send(ev.content.room_id.clone()).unwrap();
},
);
for room in client.client().rooms() {
let mut synced_keys_rx = synced_keys_rx.resubscribe();
let sem = sem.clone();
let filter = args.filter.clone();
let out_dir = args.out_dir.clone();
let room_id = room.room_id().to_owned();
let room_id_clone = room_id.clone();
let client1 = client.clone();
let http_client = http_client.clone();
js.spawn(async move {
if room.is_encrypted().await.unwrap_or(false) && !room.is_encryption_state_synced() {
log::info!(
"Room {} is encrypted, waiting for at most {} seconds for key sync",
room_id_clone,
args.key_sync_timeout
);
let room_id_clone1 = room_id_clone.clone();
let room_clone = room.clone();
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_secs(args.key_sync_timeout)) => {
log::warn!("Key sync timed out for room {}", room_id);
}
_ = async move {
while let Ok(room_id) = synced_keys_rx.recv().await {
if room_id == room_id_clone1 {
break;
}
}
if !room_clone.is_encryption_state_synced() {
log::warn!("Waiting for another 10 seconds for key sync to finish");
}
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
if !room_clone.is_encryption_state_synced() {
log::warn!("Key sync timed out for room {}", room_id_clone1);
}
} => {
log::info!("Key sync done for room {}", room_id_clone);
}
}
}
let permit = sem.clone().acquire_owned().await;
let room_name = room.display_name().await.map(|d| d.to_string())
.unwrap_or(room.name().unwrap_or("unknown".to_string()));
let room_dir =
Path::new(&out_dir).join(sanitize_filename(&format!("{}_{}", room_id, room_name)));
let match_filter = if filter.is_empty() {
true
} else {
filter
.iter()
.any(|filter| room_name.contains(filter) || room_id.as_str().contains(filter))
};
tokio::fs::create_dir_all(&room_dir)
.await
.expect("Failed to create room directory");
let meta_path = room_dir.join("meta.json");
serde_json::to_writer_pretty(
OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&meta_path)
.expect("Failed to open meta file"),
&RoomMeta {
id: room_id.as_str(),
name: Some(room_name.as_str()),
state: &room.state(),
},
)?;
if !match_filter {
log::debug!("Skipping room: {} ({})", room_id, room_name);
return Ok(());
}
log::info!("Dumping room: {} ({})", room_id, room_name);
dump_room_messages(&room, &room_dir, client1, &http_client, args.concurrency).await?;
drop(permit);
Ok(())
}); /* js.spawn */
}
drop(synced_keys_rx);
}
}