Minor improvements to subscriptions functionality (#591)
Co-authored-by: Christian Legnitto <LegNeato@users.noreply.github.com>
This commit is contained in:
parent
c91b989e2d
commit
dbe2c67cb8
3 changed files with 134 additions and 97 deletions
|
@ -14,7 +14,6 @@ serde_json = "1.0"
|
||||||
tokio = { version = "0.2", features = ["rt-core", "macros"] }
|
tokio = { version = "0.2", features = ["rt-core", "macros"] }
|
||||||
warp = "0.2.1"
|
warp = "0.2.1"
|
||||||
|
|
||||||
# TODO#433: get crates from GitHub
|
juniper = { git = "https://github.com/graphql-rust/juniper" }
|
||||||
juniper = { path = "../../juniper" }
|
juniper_subscriptions = { git = "https://github.com/graphql-rust/juniper" }
|
||||||
juniper_subscriptions = { path = "../../juniper_subscriptions"}
|
juniper_warp = { git = "https://github.com/graphql-rust/juniper", features = ["subscriptions"] }
|
||||||
juniper_warp = { path = "../../juniper_warp", features = ["subscriptions"] }
|
|
||||||
|
|
|
@ -165,7 +165,13 @@ async fn main() {
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
coordinator: Arc<Coordinator<'static, _, _, _, _, _>>| {
|
coordinator: Arc<Coordinator<'static, _, _, _, _, _>>| {
|
||||||
ws.on_upgrade(|websocket| -> Pin<Box<dyn Future<Output = ()> + Send>> {
|
ws.on_upgrade(|websocket| -> Pin<Box<dyn Future<Output = ()> + Send>> {
|
||||||
graphql_subscriptions(websocket, coordinator, ctx).boxed()
|
graphql_subscriptions(websocket, coordinator, ctx)
|
||||||
|
.map(|r| {
|
||||||
|
if let Err(e) = r {
|
||||||
|
println!("Websocket error: {}", e);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
))
|
))
|
||||||
|
|
|
@ -442,6 +442,8 @@ fn playground_response(
|
||||||
/// Cannot be merged to `juniper_warp` yet as GraphQL over WS[1]
|
/// Cannot be merged to `juniper_warp` yet as GraphQL over WS[1]
|
||||||
/// is not fully supported in current implementation.
|
/// is not fully supported in current implementation.
|
||||||
///
|
///
|
||||||
|
/// *Note: this implementation is in an alpha state.*
|
||||||
|
///
|
||||||
/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
|
/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
|
||||||
#[cfg(feature = "subscriptions")]
|
#[cfg(feature = "subscriptions")]
|
||||||
pub mod subscriptions {
|
pub mod subscriptions {
|
||||||
|
@ -453,7 +455,7 @@ pub mod subscriptions {
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use futures::{channel::mpsc, stream::StreamExt as _, Future};
|
use futures::{channel::mpsc, Future, StreamExt as _};
|
||||||
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
|
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
|
||||||
use juniper_subscriptions::Coordinator;
|
use juniper_subscriptions::Coordinator;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -467,7 +469,7 @@ pub mod subscriptions {
|
||||||
websocket: warp::ws::WebSocket,
|
websocket: warp::ws::WebSocket,
|
||||||
coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, Context, S>>,
|
coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, Context, S>>,
|
||||||
context: Context,
|
context: Context,
|
||||||
) -> impl Future<Output = ()> + Send
|
) -> impl Future<Output = Result<(), failure::Error>> + Send
|
||||||
where
|
where
|
||||||
S: ScalarValue + Send + Sync + 'static,
|
S: ScalarValue + Send + Sync + 'static,
|
||||||
Context: Clone + Send + Sync + 'static,
|
Context: Clone + Send + Sync + 'static,
|
||||||
|
@ -489,21 +491,34 @@ pub mod subscriptions {
|
||||||
);
|
);
|
||||||
|
|
||||||
let context = Arc::new(context);
|
let context = Arc::new(context);
|
||||||
|
let running = Arc::new(AtomicBool::new(false));
|
||||||
let got_close_signal = Arc::new(AtomicBool::new(false));
|
let got_close_signal = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
sink_rx.for_each(move |msg| {
|
sink_rx.fold(Ok(()), move |_, msg| {
|
||||||
let msg = msg.unwrap_or_else(|e| panic!("Websocket receive error: {}", e));
|
|
||||||
|
|
||||||
if msg.is_close() {
|
|
||||||
return futures::future::ready(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let coordinator = coordinator.clone();
|
let coordinator = coordinator.clone();
|
||||||
let context = context.clone();
|
let context = context.clone();
|
||||||
|
let running = running.clone();
|
||||||
let got_close_signal = got_close_signal.clone();
|
let got_close_signal = got_close_signal.clone();
|
||||||
|
let ws_tx = ws_tx.clone();
|
||||||
|
|
||||||
let msg = msg.to_str().expect("Non-text messages are not accepted");
|
async move {
|
||||||
let request: WsPayload<S> = serde_json::from_str(msg).expect("Invalid WsPayload");
|
let msg = match msg {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(e) => {
|
||||||
|
got_close_signal.store(true, Ordering::Relaxed);
|
||||||
|
return Err(failure::format_err!("Websocket error: {}", e));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if msg.is_close() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let msg = msg
|
||||||
|
.to_str()
|
||||||
|
.map_err(|_| failure::format_err!("Non-text messages are not accepted"))?;
|
||||||
|
let request: WsPayload<S> = serde_json::from_str(msg)
|
||||||
|
.map_err(|e| failure::format_err!("Invalid WsPayload: {}", e))?;
|
||||||
|
|
||||||
match request.type_name.as_str() {
|
match request.type_name.as_str() {
|
||||||
"connection_init" => {}
|
"connection_init" => {}
|
||||||
|
@ -511,18 +526,32 @@ pub mod subscriptions {
|
||||||
{
|
{
|
||||||
let closed = got_close_signal.load(Ordering::Relaxed);
|
let closed = got_close_signal.load(Ordering::Relaxed);
|
||||||
if closed {
|
if closed {
|
||||||
return futures::future::ready(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if running.load(Ordering::Relaxed) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
running.store(true, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
let ws_tx = ws_tx.clone();
|
let ws_tx = ws_tx.clone();
|
||||||
|
|
||||||
|
if let Some(ref payload) = request.payload {
|
||||||
|
if payload.query.is_none() {
|
||||||
|
return Err(failure::format_err!("Query not found"));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(failure::format_err!("Payload not found"));
|
||||||
|
}
|
||||||
|
|
||||||
tokio::task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
let payload = request.payload.expect("Could not deserialize payload");
|
let payload = request.payload.unwrap();
|
||||||
|
|
||||||
let request_id = request.id.unwrap_or("1".to_owned());
|
let request_id = request.id.unwrap_or("1".to_owned());
|
||||||
|
|
||||||
let graphql_request = GraphQLRequest::<S>::new(
|
let graphql_request = GraphQLRequest::<S>::new(
|
||||||
payload.query.expect("Could not deserialize query"),
|
payload.query.unwrap(),
|
||||||
None,
|
None,
|
||||||
payload.variables,
|
payload.variables,
|
||||||
);
|
);
|
||||||
|
@ -531,7 +560,8 @@ pub mod subscriptions {
|
||||||
match coordinator.subscribe(&graphql_request, &context).await {
|
match coordinator.subscribe(&graphql_request, &context).await {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(format!(
|
let _ =
|
||||||
|
ws_tx.unbounded_send(Some(Ok(Message::text(format!(
|
||||||
r#"{{"type":"error","id":"{}","payload":{}}}"#,
|
r#"{{"type":"error","id":"{}","payload":{}}}"#,
|
||||||
request_id,
|
request_id,
|
||||||
serde_json::ser::to_string(&err).unwrap_or(
|
serde_json::ser::to_string(&err).unwrap_or(
|
||||||
|
@ -557,7 +587,7 @@ pub mod subscriptions {
|
||||||
let closed = got_close_signal.load(Ordering::Relaxed);
|
let closed = got_close_signal.load(Ordering::Relaxed);
|
||||||
if !closed {
|
if !closed {
|
||||||
let mut response_text = serde_json::to_string(&response)
|
let mut response_text = serde_json::to_string(&response)
|
||||||
.unwrap_or("Error deserializing respone".to_owned());
|
.unwrap_or("Error deserializing response".to_owned());
|
||||||
|
|
||||||
response_text = format!(
|
response_text = format!(
|
||||||
r#"{{"type":"data","id":"{}","payload":{} }}"#,
|
r#"{{"type":"data","id":"{}","payload":{} }}"#,
|
||||||
|
@ -567,6 +597,7 @@ pub mod subscriptions {
|
||||||
let _ = ws_tx
|
let _ = ws_tx
|
||||||
.unbounded_send(Some(Ok(Message::text(response_text))));
|
.unbounded_send(Some(Ok(Message::text(response_text))));
|
||||||
}
|
}
|
||||||
|
|
||||||
async move { !closed }
|
async move { !closed }
|
||||||
})
|
})
|
||||||
.for_each(|_| async {})
|
.for_each(|_| async {})
|
||||||
|
@ -589,7 +620,8 @@ pub mod subscriptions {
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
futures::future::ready(())
|
Ok(())
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue