Minor improvements to subscriptions functionality (#591)

Co-authored-by: Christian Legnitto <LegNeato@users.noreply.github.com>
This commit is contained in:
nWacky 2020-03-31 06:43:00 +03:00 committed by GitHub
parent c91b989e2d
commit dbe2c67cb8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 134 additions and 97 deletions

View file

@ -14,7 +14,6 @@ serde_json = "1.0"
tokio = { version = "0.2", features = ["rt-core", "macros"] } tokio = { version = "0.2", features = ["rt-core", "macros"] }
warp = "0.2.1" warp = "0.2.1"
# TODO#433: get crates from GitHub juniper = { git = "https://github.com/graphql-rust/juniper" }
juniper = { path = "../../juniper" } juniper_subscriptions = { git = "https://github.com/graphql-rust/juniper" }
juniper_subscriptions = { path = "../../juniper_subscriptions"} juniper_warp = { git = "https://github.com/graphql-rust/juniper", features = ["subscriptions"] }
juniper_warp = { path = "../../juniper_warp", features = ["subscriptions"] }

View file

@ -165,7 +165,13 @@ async fn main() {
ctx: Context, ctx: Context,
coordinator: Arc<Coordinator<'static, _, _, _, _, _>>| { coordinator: Arc<Coordinator<'static, _, _, _, _, _>>| {
ws.on_upgrade(|websocket| -> Pin<Box<dyn Future<Output = ()> + Send>> { ws.on_upgrade(|websocket| -> Pin<Box<dyn Future<Output = ()> + Send>> {
graphql_subscriptions(websocket, coordinator, ctx).boxed() graphql_subscriptions(websocket, coordinator, ctx)
.map(|r| {
if let Err(e) = r {
println!("Websocket error: {}", e);
}
})
.boxed()
}) })
}, },
)) ))

View file

@ -442,6 +442,8 @@ fn playground_response(
/// Cannot be merged to `juniper_warp` yet as GraphQL over WS[1] /// Cannot be merged to `juniper_warp` yet as GraphQL over WS[1]
/// is not fully supported in current implementation. /// is not fully supported in current implementation.
/// ///
/// *Note: this implementation is in an alpha state.*
///
/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md /// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
#[cfg(feature = "subscriptions")] #[cfg(feature = "subscriptions")]
pub mod subscriptions { pub mod subscriptions {
@ -453,7 +455,7 @@ pub mod subscriptions {
}, },
}; };
use futures::{channel::mpsc, stream::StreamExt as _, Future}; use futures::{channel::mpsc, Future, StreamExt as _};
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _}; use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
use juniper_subscriptions::Coordinator; use juniper_subscriptions::Coordinator;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -467,7 +469,7 @@ pub mod subscriptions {
websocket: warp::ws::WebSocket, websocket: warp::ws::WebSocket,
coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, Context, S>>, coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, Context, S>>,
context: Context, context: Context,
) -> impl Future<Output = ()> + Send ) -> impl Future<Output = Result<(), failure::Error>> + Send
where where
S: ScalarValue + Send + Sync + 'static, S: ScalarValue + Send + Sync + 'static,
Context: Clone + Send + Sync + 'static, Context: Clone + Send + Sync + 'static,
@ -489,21 +491,34 @@ pub mod subscriptions {
); );
let context = Arc::new(context); let context = Arc::new(context);
let running = Arc::new(AtomicBool::new(false));
let got_close_signal = Arc::new(AtomicBool::new(false)); let got_close_signal = Arc::new(AtomicBool::new(false));
sink_rx.for_each(move |msg| { sink_rx.fold(Ok(()), move |_, msg| {
let msg = msg.unwrap_or_else(|e| panic!("Websocket receive error: {}", e));
if msg.is_close() {
return futures::future::ready(());
}
let coordinator = coordinator.clone(); let coordinator = coordinator.clone();
let context = context.clone(); let context = context.clone();
let running = running.clone();
let got_close_signal = got_close_signal.clone(); let got_close_signal = got_close_signal.clone();
let ws_tx = ws_tx.clone();
let msg = msg.to_str().expect("Non-text messages are not accepted"); async move {
let request: WsPayload<S> = serde_json::from_str(msg).expect("Invalid WsPayload"); let msg = match msg {
Ok(m) => m,
Err(e) => {
got_close_signal.store(true, Ordering::Relaxed);
return Err(failure::format_err!("Websocket error: {}", e));
}
};
if msg.is_close() {
return Ok(());
}
let msg = msg
.to_str()
.map_err(|_| failure::format_err!("Non-text messages are not accepted"))?;
let request: WsPayload<S> = serde_json::from_str(msg)
.map_err(|e| failure::format_err!("Invalid WsPayload: {}", e))?;
match request.type_name.as_str() { match request.type_name.as_str() {
"connection_init" => {} "connection_init" => {}
@ -511,18 +526,32 @@ pub mod subscriptions {
{ {
let closed = got_close_signal.load(Ordering::Relaxed); let closed = got_close_signal.load(Ordering::Relaxed);
if closed { if closed {
return futures::future::ready(()); return Ok(());
} }
if running.load(Ordering::Relaxed) {
return Ok(());
}
running.store(true, Ordering::Relaxed);
} }
let ws_tx = ws_tx.clone(); let ws_tx = ws_tx.clone();
if let Some(ref payload) = request.payload {
if payload.query.is_none() {
return Err(failure::format_err!("Query not found"));
}
} else {
return Err(failure::format_err!("Payload not found"));
}
tokio::task::spawn(async move { tokio::task::spawn(async move {
let payload = request.payload.expect("Could not deserialize payload"); let payload = request.payload.unwrap();
let request_id = request.id.unwrap_or("1".to_owned()); let request_id = request.id.unwrap_or("1".to_owned());
let graphql_request = GraphQLRequest::<S>::new( let graphql_request = GraphQLRequest::<S>::new(
payload.query.expect("Could not deserialize query"), payload.query.unwrap(),
None, None,
payload.variables, payload.variables,
); );
@ -531,7 +560,8 @@ pub mod subscriptions {
match coordinator.subscribe(&graphql_request, &context).await { match coordinator.subscribe(&graphql_request, &context).await {
Ok(s) => s, Ok(s) => s,
Err(err) => { Err(err) => {
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(format!( let _ =
ws_tx.unbounded_send(Some(Ok(Message::text(format!(
r#"{{"type":"error","id":"{}","payload":{}}}"#, r#"{{"type":"error","id":"{}","payload":{}}}"#,
request_id, request_id,
serde_json::ser::to_string(&err).unwrap_or( serde_json::ser::to_string(&err).unwrap_or(
@ -557,7 +587,7 @@ pub mod subscriptions {
let closed = got_close_signal.load(Ordering::Relaxed); let closed = got_close_signal.load(Ordering::Relaxed);
if !closed { if !closed {
let mut response_text = serde_json::to_string(&response) let mut response_text = serde_json::to_string(&response)
.unwrap_or("Error deserializing respone".to_owned()); .unwrap_or("Error deserializing response".to_owned());
response_text = format!( response_text = format!(
r#"{{"type":"data","id":"{}","payload":{} }}"#, r#"{{"type":"data","id":"{}","payload":{} }}"#,
@ -567,6 +597,7 @@ pub mod subscriptions {
let _ = ws_tx let _ = ws_tx
.unbounded_send(Some(Ok(Message::text(response_text)))); .unbounded_send(Some(Ok(Message::text(response_text))));
} }
async move { !closed } async move { !closed }
}) })
.for_each(|_| async {}) .for_each(|_| async {})
@ -589,7 +620,8 @@ pub mod subscriptions {
_ => {} _ => {}
} }
futures::future::ready(()) Ok(())
}
}) })
} }