matrix-dump/src/client.rs
eternal-flame-AD bd6d8625ed
refactor code structure
Signed-off-by: eternal-flame-AD <yume@yumechi.jp>
2024-09-04 22:28:59 -05:00

404 lines
15 KiB
Rust

use std::{pin::Pin, sync::Arc};
use crate::{
e2e::{decrypt_file, ErrOrWrongHash},
io::prompt,
mxc_url_to_https, DumpError,
};
use futures::{future::BoxFuture, stream, StreamExt, TryStream, TryStreamExt};
use matrix_sdk::{
bytes::Bytes,
deserialized_responses::TimelineEvent,
encryption::verification::{Verification, VerificationRequestState},
matrix_auth::MatrixSession,
room::MessagesOptions,
ruma::events::key::verification::VerificationMethod,
Client, Room,
};
use matrix_sdk_crypto::SasState;
use ruma_events::room::{
message::{MessageType, RoomMessageEvent},
MediaSource,
};
use tokio::io::AsyncReadExt;
pub struct MatrixClient {
client: Client,
}
pub type PinTryStream<'s, O, E> =
Pin<Box<dyn TryStream<Ok = O, Error = E, Item = Result<O, E>> + Send + 's>>;
pub struct FileStream<'s> {
pub filename: String,
pub url: String,
pub stream: PinTryStream<'s, Bytes, DumpError>,
}
pub enum AuthMethod<'a> {
Password {
user: &'a str,
password: &'a str,
initial_device_display_name: Option<&'a str>,
},
Session(MatrixSession),
}
impl MatrixClient {
pub fn new(client: Client) -> Self {
Self { client }
}
pub fn new_arc(client: Client) -> Arc<Self> {
Arc::new(Self::new(client))
}
pub fn client(&self) -> &Client {
&self.client
}
pub async fn login(&self, auth: AuthMethod<'_>) -> Result<(), matrix_sdk::Error> {
match auth {
AuthMethod::Password {
user,
password,
initial_device_display_name,
} => {
let tmp = self.client.matrix_auth().login_username(user, &password);
if let Some(name) = initial_device_display_name {
tmp.initial_device_display_name(&name)
} else {
tmp
}
.await
.map(|_| ())
}
AuthMethod::Session(session) => {
self.client.matrix_auth().restore_session(session).await
}
}
}
pub fn try_read_attachment<'c, 'a: 'c>(
self: Arc<Self>,
client: &'c matrix_sdk::reqwest::Client,
msg: &'a RoomMessageEvent,
) -> Result<Option<BoxFuture<'c, Result<FileStream<'c>, DumpError>>>, DumpError> {
macro_rules! impl_file_like {
($msg:expr, $($variant:ident),*) => {
match $msg {
$(
MessageType::$variant(file) => {
Ok(Some(Box::pin(
async move {
let src = match &file.source {
MediaSource::Plain(s) => s,
MediaSource::Encrypted(e) => &e.url,
};
let filename = file.filename.as_deref().map(|s| s.to_string()).unwrap_or(file.body.clone());
let url = mxc_url_to_https(src.as_str(), self.client.homeserver().as_str());
let resp = client.get(&url).send().await?;
let body = resp.bytes_stream();
Ok(FileStream {
filename,
url,
stream: match &file.source {
MediaSource::Plain(_) => {
Box::pin(body.map_err(DumpError::from)) as Pin<Box<dyn TryStream<Ok = Bytes, Error = DumpError, Item = Result<Bytes, DumpError>> + Send>>
}
MediaSource::Encrypted(e) => Box::pin(decrypt_file(e.as_ref(), body).await?.map_ok(|v| Bytes::from(v)).map_err(
|e| match e {
ErrOrWrongHash::Err(e) => e.into(),
ErrOrWrongHash::WrongHash => DumpError::HashMismatch,
},
)) as Pin<Box<dyn TryStream<Ok = Bytes, Error = DumpError, Item = Result<Bytes, DumpError>> + Send>>,
}
})
})))}
)*
_ => Ok(None),
}
};
}
match msg {
RoomMessageEvent::Original(msg) => {
impl_file_like!(&msg.content.msgtype, Image, Video, Audio, File)
}
_ => Ok(None),
}
}
pub async fn setup_e2e(self: Arc<Self>) -> bool {
let client = &self.client;
log::info!("Preparing e2e machine");
client
.encryption()
.wait_for_e2ee_initialization_tasks()
.await;
log::info!("E2E machine ready");
let own_device = client
.encryption()
.get_own_device()
.await
.expect("Failed to get own device")
.expect("No own device found");
if own_device.is_cross_signed_by_owner() {
log::info!("Cross-signing keys are already set up");
return true;
}
let mut stdin = tokio::io::stdin();
let mut stdout = tokio::io::stdout();
let devices = client
.encryption()
.get_user_devices(own_device.user_id())
.await
.expect("Failed to get devices")
.devices()
.collect::<Vec<_>>();
for (i, d) in devices.iter().enumerate() {
log::info!(
"Device {}: {} ({})",
i,
d.display_name().unwrap_or("Unnamed"),
d.device_id()
);
}
let device_num = prompt(
&mut stdin,
&mut stdout,
"Enter device number to verify with: ",
)
.await
.unwrap_or_else(|e| {
log::error!("Failed to read device number: {}", e);
String::new()
})
.trim()
.parse::<usize>()
.expect("Failed to parse device number");
let device = devices.get(device_num).expect("Invalid device number");
log::info!(
"Requesting verification with {} ({})",
device.display_name().unwrap_or("Unnamed"),
device.device_id()
);
let req = match device
.request_verification_with_methods(vec![VerificationMethod::SasV1])
.await
{
Ok(req) => req,
Err(e) => {
log::error!(
"Failed to request verification for {}: {}",
device.device_id(),
e
);
return false;
}
};
let device_name = format!(
"{} ({})",
device.display_name().unwrap_or("Unnamed"),
device.device_id()
);
let device_name_clone = device_name.clone();
let mut c = req.changes();
while let Some(change) = c.next().await {
match change {
VerificationRequestState::Done => {
log::info!("Verification successful for {}", device_name_clone);
return true;
}
VerificationRequestState::Cancelled(info) => {
log::info!(
"Verification canceled for {}: {:?}",
device.device_id(),
info
);
return false;
}
VerificationRequestState::Transitioned { verification } => {
log::info!(
"Verification transitioned for {}: {:?}",
device.device_id(),
verification
);
match verification {
Verification::SasV1(v) => {
v.accept().await.expect("Failed to accept verification");
let emoji_str = v
.emoji()
.map(|emojis| {
emojis
.iter()
.map(|e| format!("{} ({})", e.symbol, e.description))
.collect::<Vec<_>>()
.join(", ")
})
.unwrap_or_else(|| "No emojis".to_string());
let decimals = v
.decimals()
.map(|(n1, n2, n3)| format!("{} {} {}", n1, n2, n3))
.unwrap_or_else(|| "No decimals".to_string());
if prompt(
&mut stdin,
&mut stdout,
&format!(
"Verification for {}:\nEmoji: {}\nDecimals: {}\n Confirm? (y/n): ",
device.device_id(),
emoji_str,
decimals
),
)
.await
.unwrap_or_else(|e| {
log::error!("Failed to read verification response: {}", e);
"n".to_string()
})
.trim() == "y" {
v.confirm().await.expect("Failed to confirm");
} else {
v.cancel().await.expect("Failed to cancel");
}
}
_ => unimplemented!(),
}
}
VerificationRequestState::Ready {
their_methods,
our_methods,
..
} => {
log::info!(
"Verification ready for {}: their methods: {:?}, our methods: {:?}",
device.device_id(),
their_methods,
our_methods
);
req.accept_with_methods(vec![VerificationMethod::SasV1])
.await
.expect("Failed to accept verification");
let sas = req
.start_sas()
.await
.expect("Failed to start SAS")
.expect("No SAS");
sas.accept().await.expect("Failed to accept SAS");
while let Some(event) = sas.changes().next().await {
log::info!("SAS event: {:?}", event);
match event {
SasState::Started { protocols } => {
log::info!("SAS started with protocols: {:?}", protocols);
sas.accept().await.expect("Failed to accept SAS");
}
SasState::Cancelled(c) => {
log::info!("SAS canceled: {:?}", c);
return false;
}
SasState::Done {
verified_devices,
verified_identities,
} => {
log::info!(
"SAS done: verified devices: {:?}, verified identities: {:?}",
verified_devices,
verified_identities
);
}
SasState::KeysExchanged { emojis, decimals } => {
println!(
"Verification for {}:\nEmoji: {}\nDecimals: [{}, {}, {}]\n Confirm? (y/n)",
device.device_id(),
emojis.map(|e| e.emojis
.iter()
.map(|e| format!("{} ({})", e.symbol, e.description))
.collect::<Vec<_>>()
.join(", "),
).unwrap_or_else(|| "No emojis".to_string()),
decimals.0, decimals.1, decimals.2
);
let mut response = String::new();
while let Ok(c) = stdin.read_u8().await {
if c == b'\n' {
break;
}
response.push(c as char);
}
if response.trim() == "y" {
sas.confirm().await.expect("Failed to confirm SAS");
} else {
sas.cancel().await.expect("Failed to cancel SAS");
}
}
SasState::Accepted { accepted_protocols } => {
log::info!("SAS accepted with protocols: {:?}", accepted_protocols);
}
SasState::Confirmed => {
log::info!("SAS confirmed! Waiting for verification to finish");
}
}
}
}
VerificationRequestState::Requested { their_methods, .. } => {
log::info!(
"Verification requested for {}: their methods: {:?}",
device.device_id(),
their_methods
);
}
VerificationRequestState::Created { our_methods } => {
log::info!(
"Verification created for {}: our methods: {:?}",
device.device_id(),
our_methods
);
}
}
}
log::info!("Verification for {} fell through", device.device_id());
false
}
pub fn room_messages(
room: &Room,
since: Option<String>,
) -> impl TryStream<Ok = Vec<TimelineEvent>, Error = matrix_sdk::Error> + '_ {
stream::try_unfold(since, move |since| async move {
let mut opt = MessagesOptions::forward().from(since.as_deref());
opt.limit = 32.try_into().expect("Failed to convert");
room.messages(opt).await.map(|r| {
if r.chunk.is_empty() {
None
} else {
Some((r.chunk, r.end))
}
})
})
}
}