fix: endless poll on an errored ws stream (#683)

* fix: endless poll on an errored ws stream
This commit is contained in:
imspace 2020-06-25 12:23:13 +08:00 committed by GitHub
parent a08ce0760d
commit 37a37d462f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -420,7 +420,7 @@ pub mod subscriptions {
}, },
}; };
use futures::{channel::mpsc, Future, StreamExt as _}; use futures::{channel::mpsc, Future, StreamExt as _, TryStreamExt as _};
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _}; use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
use juniper_subscriptions::Coordinator; use juniper_subscriptions::Coordinator;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -458,71 +458,71 @@ pub mod subscriptions {
let context = Arc::new(context); let context = Arc::new(context);
let running = Arc::new(AtomicBool::new(false)); let running = Arc::new(AtomicBool::new(false));
let got_close_signal = Arc::new(AtomicBool::new(false)); let got_close_signal = Arc::new(AtomicBool::new(false));
let got_close_signal2 = got_close_signal.clone();
sink_rx.fold(Ok(()), move |_, msg| { sink_rx
let coordinator = coordinator.clone(); .map_err(move |e| {
let context = context.clone(); got_close_signal2.store(true, Ordering::Relaxed);
let running = running.clone(); failure::format_err!("Websocket error: {}", e)
let got_close_signal = got_close_signal.clone(); })
let ws_tx = ws_tx.clone(); .try_fold((), move |_, msg| {
let coordinator = coordinator.clone();
let context = context.clone();
let running = running.clone();
let got_close_signal = got_close_signal.clone();
let ws_tx = ws_tx.clone();
async move { async move {
let msg = match msg { if msg.is_close() {
Ok(m) => m, return Ok(());
Err(e) => {
got_close_signal.store(true, Ordering::Relaxed);
return Err(failure::format_err!("Websocket error: {}", e));
} }
};
if msg.is_close() { let msg = msg
return Ok(()); .to_str()
} .map_err(|_| failure::format_err!("Non-text messages are not accepted"))?;
let request: WsPayload<S> = serde_json::from_str(msg)
.map_err(|e| failure::format_err!("Invalid WsPayload: {}", e))?;
let msg = msg match request.type_name.as_str() {
.to_str() "connection_init" => {}
.map_err(|_| failure::format_err!("Non-text messages are not accepted"))?; "start" => {
let request: WsPayload<S> = serde_json::from_str(msg) {
.map_err(|e| failure::format_err!("Invalid WsPayload: {}", e))?; let closed = got_close_signal.load(Ordering::Relaxed);
if closed {
return Ok(());
}
match request.type_name.as_str() { if running.load(Ordering::Relaxed) {
"connection_init" => {} return Ok(());
"start" => { }
{ running.store(true, Ordering::Relaxed);
let closed = got_close_signal.load(Ordering::Relaxed);
if closed {
return Ok(());
} }
if running.load(Ordering::Relaxed) { let ws_tx = ws_tx.clone();
return Ok(());
if let Some(ref payload) = request.payload {
if payload.query.is_none() {
return Err(failure::format_err!("Query not found"));
}
} else {
return Err(failure::format_err!("Payload not found"));
} }
running.store(true, Ordering::Relaxed);
}
let ws_tx = ws_tx.clone(); tokio::task::spawn(async move {
let payload = request.payload.unwrap();
if let Some(ref payload) = request.payload { let request_id = request.id.unwrap_or("1".to_owned());
if payload.query.is_none() {
return Err(failure::format_err!("Query not found"));
}
} else {
return Err(failure::format_err!("Payload not found"));
}
tokio::task::spawn(async move { let graphql_request = GraphQLRequest::<S>::new(
let payload = request.payload.unwrap(); payload.query.unwrap(),
None,
payload.variables,
);
let request_id = request.id.unwrap_or("1".to_owned()); let values_stream = match coordinator
.subscribe(&graphql_request, &context)
let graphql_request = GraphQLRequest::<S>::new( .await
payload.query.unwrap(), {
None,
payload.variables,
);
let values_stream =
match coordinator.subscribe(&graphql_request, &context).await {
Ok(s) => s, Ok(s) => s,
Err(err) => { Err(err) => {
let _ = let _ =
@ -546,48 +546,51 @@ pub mod subscriptions {
} }
}; };
values_stream values_stream
.take_while(move |response| { .take_while(move |response| {
let request_id = request_id.clone(); let request_id = request_id.clone();
let closed = got_close_signal.load(Ordering::Relaxed); let closed = got_close_signal.load(Ordering::Relaxed);
if !closed { if !closed {
let mut response_text = serde_json::to_string(&response) let mut response_text = serde_json::to_string(
&response,
)
.unwrap_or("Error deserializing response".to_owned()); .unwrap_or("Error deserializing response".to_owned());
response_text = format!( response_text = format!(
r#"{{"type":"data","id":"{}","payload":{} }}"#, r#"{{"type":"data","id":"{}","payload":{} }}"#,
request_id, response_text request_id, response_text
); );
let _ = ws_tx let _ = ws_tx.unbounded_send(Some(Ok(Message::text(
.unbounded_send(Some(Ok(Message::text(response_text)))); response_text,
} ))));
}
async move { !closed } async move { !closed }
}) })
.for_each(|_| async {}) .for_each(|_| async {})
.await; .await;
}); });
}
"stop" => {
got_close_signal.store(true, Ordering::Relaxed);
let request_id = request.id.unwrap_or("1".to_owned());
let close_message = format!(
r#"{{"type":"complete","id":"{}","payload":null}}"#,
request_id
);
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message))));
// close channel
let _ = ws_tx.unbounded_send(None);
}
_ => {}
} }
"stop" => {
got_close_signal.store(true, Ordering::Relaxed);
let request_id = request.id.unwrap_or("1".to_owned()); Ok(())
let close_message = format!(
r#"{{"type":"complete","id":"{}","payload":null}}"#,
request_id
);
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message))));
// close channel
let _ = ws_tx.unbounded_send(None);
}
_ => {}
} }
})
Ok(())
}
})
} }
#[derive(Deserialize)] #[derive(Deserialize)]