404 lines
15 KiB
Rust
404 lines
15 KiB
Rust
use std::{pin::Pin, sync::Arc};
|
|
|
|
use crate::{
|
|
e2e::{decrypt_file, ErrOrWrongHash},
|
|
io::prompt,
|
|
mxc_url_to_https, DumpError,
|
|
};
|
|
use futures::{future::BoxFuture, stream, StreamExt, TryStream, TryStreamExt};
|
|
use matrix_sdk::{
|
|
bytes::Bytes,
|
|
deserialized_responses::TimelineEvent,
|
|
encryption::verification::{Verification, VerificationRequestState},
|
|
matrix_auth::MatrixSession,
|
|
room::MessagesOptions,
|
|
ruma::events::key::verification::VerificationMethod,
|
|
Client, Room,
|
|
};
|
|
use matrix_sdk_crypto::SasState;
|
|
|
|
use ruma_events::room::{
|
|
message::{MessageType, RoomMessageEvent},
|
|
MediaSource,
|
|
};
|
|
use tokio::io::AsyncReadExt;
|
|
|
|
pub struct MatrixClient {
|
|
client: Client,
|
|
}
|
|
|
|
pub type PinTryStream<'s, O, E> =
|
|
Pin<Box<dyn TryStream<Ok = O, Error = E, Item = Result<O, E>> + Send + 's>>;
|
|
|
|
pub struct FileStream<'s> {
|
|
pub filename: String,
|
|
pub url: String,
|
|
pub stream: PinTryStream<'s, Bytes, DumpError>,
|
|
}
|
|
|
|
pub enum AuthMethod<'a> {
|
|
Password {
|
|
user: &'a str,
|
|
password: &'a str,
|
|
initial_device_display_name: Option<&'a str>,
|
|
},
|
|
Session(MatrixSession),
|
|
}
|
|
|
|
impl MatrixClient {
|
|
pub fn new(client: Client) -> Self {
|
|
Self { client }
|
|
}
|
|
|
|
pub fn new_arc(client: Client) -> Arc<Self> {
|
|
Arc::new(Self::new(client))
|
|
}
|
|
|
|
pub fn client(&self) -> &Client {
|
|
&self.client
|
|
}
|
|
|
|
pub async fn login(&self, auth: AuthMethod<'_>) -> Result<(), matrix_sdk::Error> {
|
|
match auth {
|
|
AuthMethod::Password {
|
|
user,
|
|
password,
|
|
initial_device_display_name,
|
|
} => {
|
|
let tmp = self.client.matrix_auth().login_username(user, &password);
|
|
if let Some(name) = initial_device_display_name {
|
|
tmp.initial_device_display_name(&name)
|
|
} else {
|
|
tmp
|
|
}
|
|
.await
|
|
.map(|_| ())
|
|
}
|
|
AuthMethod::Session(session) => {
|
|
self.client.matrix_auth().restore_session(session).await
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn try_read_attachment<'c, 'a: 'c>(
|
|
self: Arc<Self>,
|
|
client: &'c matrix_sdk::reqwest::Client,
|
|
msg: &'a RoomMessageEvent,
|
|
) -> Result<Option<BoxFuture<'c, Result<FileStream<'c>, DumpError>>>, DumpError> {
|
|
macro_rules! impl_file_like {
|
|
($msg:expr, $($variant:ident),*) => {
|
|
match $msg {
|
|
$(
|
|
MessageType::$variant(file) => {
|
|
Ok(Some(Box::pin(
|
|
async move {
|
|
let src = match &file.source {
|
|
MediaSource::Plain(s) => s,
|
|
MediaSource::Encrypted(e) => &e.url,
|
|
};
|
|
let filename = file.filename.as_deref().map(|s| s.to_string()).unwrap_or(file.body.clone());
|
|
let url = mxc_url_to_https(src.as_str(), self.client.homeserver().as_str());
|
|
let resp = client.get(&url).send().await?;
|
|
let body = resp.bytes_stream();
|
|
Ok(FileStream {
|
|
filename,
|
|
url,
|
|
stream: match &file.source {
|
|
MediaSource::Plain(_) => {
|
|
Box::pin(body.map_err(DumpError::from)) as Pin<Box<dyn TryStream<Ok = Bytes, Error = DumpError, Item = Result<Bytes, DumpError>> + Send>>
|
|
}
|
|
MediaSource::Encrypted(e) => Box::pin(decrypt_file(e.as_ref(), body).await?.map_ok(|v| Bytes::from(v)).map_err(
|
|
|e| match e {
|
|
ErrOrWrongHash::Err(e) => e.into(),
|
|
ErrOrWrongHash::WrongHash => DumpError::HashMismatch,
|
|
},
|
|
)) as Pin<Box<dyn TryStream<Ok = Bytes, Error = DumpError, Item = Result<Bytes, DumpError>> + Send>>,
|
|
}
|
|
})
|
|
})))}
|
|
)*
|
|
_ => Ok(None),
|
|
}
|
|
};
|
|
}
|
|
match msg {
|
|
RoomMessageEvent::Original(msg) => {
|
|
impl_file_like!(&msg.content.msgtype, Image, Video, Audio, File)
|
|
}
|
|
_ => Ok(None),
|
|
}
|
|
}
|
|
|
|
pub async fn setup_e2e(self: Arc<Self>) -> bool {
|
|
let client = &self.client;
|
|
|
|
log::info!("Preparing e2e machine");
|
|
client
|
|
.encryption()
|
|
.wait_for_e2ee_initialization_tasks()
|
|
.await;
|
|
log::info!("E2E machine ready");
|
|
|
|
let own_device = client
|
|
.encryption()
|
|
.get_own_device()
|
|
.await
|
|
.expect("Failed to get own device")
|
|
.expect("No own device found");
|
|
|
|
if own_device.is_cross_signed_by_owner() {
|
|
log::info!("Cross-signing keys are already set up");
|
|
return true;
|
|
}
|
|
|
|
let mut stdin = tokio::io::stdin();
|
|
let mut stdout = tokio::io::stdout();
|
|
|
|
let devices = client
|
|
.encryption()
|
|
.get_user_devices(own_device.user_id())
|
|
.await
|
|
.expect("Failed to get devices")
|
|
.devices()
|
|
.collect::<Vec<_>>();
|
|
|
|
for (i, d) in devices.iter().enumerate() {
|
|
log::info!(
|
|
"Device {}: {} ({})",
|
|
i,
|
|
d.display_name().unwrap_or("Unnamed"),
|
|
d.device_id()
|
|
);
|
|
}
|
|
|
|
let device_num = prompt(
|
|
&mut stdin,
|
|
&mut stdout,
|
|
"Enter device number to verify with: ",
|
|
)
|
|
.await
|
|
.unwrap_or_else(|e| {
|
|
log::error!("Failed to read device number: {}", e);
|
|
String::new()
|
|
})
|
|
.trim()
|
|
.parse::<usize>()
|
|
.expect("Failed to parse device number");
|
|
|
|
let device = devices.get(device_num).expect("Invalid device number");
|
|
|
|
log::info!(
|
|
"Requesting verification with {} ({})",
|
|
device.display_name().unwrap_or("Unnamed"),
|
|
device.device_id()
|
|
);
|
|
|
|
let req = match device
|
|
.request_verification_with_methods(vec![VerificationMethod::SasV1])
|
|
.await
|
|
{
|
|
Ok(req) => req,
|
|
Err(e) => {
|
|
log::error!(
|
|
"Failed to request verification for {}: {}",
|
|
device.device_id(),
|
|
e
|
|
);
|
|
return false;
|
|
}
|
|
};
|
|
|
|
let device_name = format!(
|
|
"{} ({})",
|
|
device.display_name().unwrap_or("Unnamed"),
|
|
device.device_id()
|
|
);
|
|
|
|
let device_name_clone = device_name.clone();
|
|
|
|
let mut c = req.changes();
|
|
|
|
while let Some(change) = c.next().await {
|
|
match change {
|
|
VerificationRequestState::Done => {
|
|
log::info!("Verification successful for {}", device_name_clone);
|
|
return true;
|
|
}
|
|
VerificationRequestState::Cancelled(info) => {
|
|
log::info!(
|
|
"Verification canceled for {}: {:?}",
|
|
device.device_id(),
|
|
info
|
|
);
|
|
return false;
|
|
}
|
|
VerificationRequestState::Transitioned { verification } => {
|
|
log::info!(
|
|
"Verification transitioned for {}: {:?}",
|
|
device.device_id(),
|
|
verification
|
|
);
|
|
match verification {
|
|
Verification::SasV1(v) => {
|
|
v.accept().await.expect("Failed to accept verification");
|
|
let emoji_str = v
|
|
.emoji()
|
|
.map(|emojis| {
|
|
emojis
|
|
.iter()
|
|
.map(|e| format!("{} ({})", e.symbol, e.description))
|
|
.collect::<Vec<_>>()
|
|
.join(", ")
|
|
})
|
|
.unwrap_or_else(|| "No emojis".to_string());
|
|
let decimals = v
|
|
.decimals()
|
|
.map(|(n1, n2, n3)| format!("{} {} {}", n1, n2, n3))
|
|
.unwrap_or_else(|| "No decimals".to_string());
|
|
|
|
if prompt(
|
|
&mut stdin,
|
|
&mut stdout,
|
|
&format!(
|
|
"Verification for {}:\nEmoji: {}\nDecimals: {}\n Confirm? (y/n): ",
|
|
device.device_id(),
|
|
emoji_str,
|
|
decimals
|
|
),
|
|
)
|
|
.await
|
|
.unwrap_or_else(|e| {
|
|
log::error!("Failed to read verification response: {}", e);
|
|
"n".to_string()
|
|
})
|
|
.trim() == "y" {
|
|
v.confirm().await.expect("Failed to confirm");
|
|
} else {
|
|
v.cancel().await.expect("Failed to cancel");
|
|
}
|
|
}
|
|
_ => unimplemented!(),
|
|
}
|
|
}
|
|
VerificationRequestState::Ready {
|
|
their_methods,
|
|
our_methods,
|
|
..
|
|
} => {
|
|
log::info!(
|
|
"Verification ready for {}: their methods: {:?}, our methods: {:?}",
|
|
device.device_id(),
|
|
their_methods,
|
|
our_methods
|
|
);
|
|
|
|
req.accept_with_methods(vec![VerificationMethod::SasV1])
|
|
.await
|
|
.expect("Failed to accept verification");
|
|
|
|
let sas = req
|
|
.start_sas()
|
|
.await
|
|
.expect("Failed to start SAS")
|
|
.expect("No SAS");
|
|
|
|
sas.accept().await.expect("Failed to accept SAS");
|
|
|
|
while let Some(event) = sas.changes().next().await {
|
|
log::info!("SAS event: {:?}", event);
|
|
match event {
|
|
SasState::Started { protocols } => {
|
|
log::info!("SAS started with protocols: {:?}", protocols);
|
|
sas.accept().await.expect("Failed to accept SAS");
|
|
}
|
|
SasState::Cancelled(c) => {
|
|
log::info!("SAS canceled: {:?}", c);
|
|
return false;
|
|
}
|
|
SasState::Done {
|
|
verified_devices,
|
|
verified_identities,
|
|
} => {
|
|
log::info!(
|
|
"SAS done: verified devices: {:?}, verified identities: {:?}",
|
|
verified_devices,
|
|
verified_identities
|
|
);
|
|
}
|
|
SasState::KeysExchanged { emojis, decimals } => {
|
|
println!(
|
|
"Verification for {}:\nEmoji: {}\nDecimals: [{}, {}, {}]\n Confirm? (y/n)",
|
|
device.device_id(),
|
|
emojis.map(|e| e.emojis
|
|
.iter()
|
|
.map(|e| format!("{} ({})", e.symbol, e.description))
|
|
.collect::<Vec<_>>()
|
|
.join(", "),
|
|
).unwrap_or_else(|| "No emojis".to_string()),
|
|
decimals.0, decimals.1, decimals.2
|
|
);
|
|
|
|
let mut response = String::new();
|
|
|
|
while let Ok(c) = stdin.read_u8().await {
|
|
if c == b'\n' {
|
|
break;
|
|
}
|
|
response.push(c as char);
|
|
}
|
|
|
|
if response.trim() == "y" {
|
|
sas.confirm().await.expect("Failed to confirm SAS");
|
|
} else {
|
|
sas.cancel().await.expect("Failed to cancel SAS");
|
|
}
|
|
}
|
|
SasState::Accepted { accepted_protocols } => {
|
|
log::info!("SAS accepted with protocols: {:?}", accepted_protocols);
|
|
}
|
|
SasState::Confirmed => {
|
|
log::info!("SAS confirmed! Waiting for verification to finish");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
VerificationRequestState::Requested { their_methods, .. } => {
|
|
log::info!(
|
|
"Verification requested for {}: their methods: {:?}",
|
|
device.device_id(),
|
|
their_methods
|
|
);
|
|
}
|
|
VerificationRequestState::Created { our_methods } => {
|
|
log::info!(
|
|
"Verification created for {}: our methods: {:?}",
|
|
device.device_id(),
|
|
our_methods
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
log::info!("Verification for {} fell through", device.device_id());
|
|
|
|
false
|
|
}
|
|
|
|
pub fn room_messages(
|
|
room: &Room,
|
|
since: Option<String>,
|
|
) -> impl TryStream<Ok = Vec<TimelineEvent>, Error = matrix_sdk::Error> + '_ {
|
|
stream::try_unfold(since, move |since| async move {
|
|
let mut opt = MessagesOptions::forward().from(since.as_deref());
|
|
opt.limit = 32.try_into().expect("Failed to convert");
|
|
|
|
room.messages(opt).await.map(|r| {
|
|
if r.chunk.is_empty() {
|
|
None
|
|
} else {
|
|
Some((r.chunk, r.end))
|
|
}
|
|
})
|
|
})
|
|
}
|
|
}
|