From 84c9720b535c37dfc1d8bc6b142bf2f63e5fd166 Mon Sep 17 00:00:00 2001 From: Chris <ccbrown112@gmail.com> Date: Wed, 29 Jul 2020 04:23:44 -0400 Subject: [PATCH] GraphQL-WS crate and Warp subscriptions update (#721) * update pre-existing juniper_warp::subscriptions * initial draft * finish up, update example * polish + timing test * fix pre-existing bug * rebase updates * address comments * add release.toml * makefile and initial changelog * add new Cargo.toml to juniper/release.toml --- Cargo.toml | 1 + examples/warp_subscriptions/Cargo.toml | 6 +- examples/warp_subscriptions/src/main.rs | 38 +- juniper/release.toml | 2 + juniper_graphql_ws/CHANGELOG.md | 3 + juniper_graphql_ws/Cargo.toml | 19 + juniper_graphql_ws/Makefile.toml | 20 + juniper_graphql_ws/release.toml | 8 + juniper_graphql_ws/src/client_message.rs | 131 +++ juniper_graphql_ws/src/lib.rs | 1073 ++++++++++++++++++++++ juniper_graphql_ws/src/schema.rs | 131 +++ juniper_graphql_ws/src/server_message.rs | 191 ++++ juniper_subscriptions/src/lib.rs | 16 +- juniper_warp/Cargo.toml | 4 +- juniper_warp/src/lib.rs | 291 ++---- 15 files changed, 1696 insertions(+), 238 deletions(-) create mode 100644 juniper_graphql_ws/CHANGELOG.md create mode 100644 juniper_graphql_ws/Cargo.toml create mode 100644 juniper_graphql_ws/Makefile.toml create mode 100644 juniper_graphql_ws/release.toml create mode 100644 juniper_graphql_ws/src/client_message.rs create mode 100644 juniper_graphql_ws/src/lib.rs create mode 100644 juniper_graphql_ws/src/schema.rs create mode 100644 juniper_graphql_ws/src/server_message.rs diff --git a/Cargo.toml b/Cargo.toml index 79429a10..d37670be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ members = [ "juniper_rocket", "juniper_rocket_async", "juniper_subscriptions", + "juniper_graphql_ws", "juniper_warp", "juniper_actix", ] diff --git a/examples/warp_subscriptions/Cargo.toml b/examples/warp_subscriptions/Cargo.toml index 7fc8fb4a..5c69129e 100644 --- a/examples/warp_subscriptions/Cargo.toml +++ b/examples/warp_subscriptions/Cargo.toml @@ -13,6 +13,6 @@ serde_json = "1.0" tokio = { version = "0.2", features = ["rt-core", "macros"] } warp = "0.2.1" -juniper = { git = "https://github.com/graphql-rust/juniper" } -juniper_subscriptions = { git = "https://github.com/graphql-rust/juniper" } -juniper_warp = { git = "https://github.com/graphql-rust/juniper", features = ["subscriptions"] } +juniper = { path = "../../juniper" } +juniper_graphql_ws = { path = "../../juniper_graphql_ws" } +juniper_warp = { path = "../../juniper_warp", features = ["subscriptions"] } diff --git a/examples/warp_subscriptions/src/main.rs b/examples/warp_subscriptions/src/main.rs index f0f9f737..0d4f31a6 100644 --- a/examples/warp_subscriptions/src/main.rs +++ b/examples/warp_subscriptions/src/main.rs @@ -2,10 +2,10 @@ use std::{env, pin::Pin, sync::Arc, time::Duration}; -use futures::{Future, FutureExt as _, Stream}; +use futures::{FutureExt as _, Stream}; use juniper::{DefaultScalarValue, EmptyMutation, FieldError, RootNode}; -use juniper_subscriptions::Coordinator; -use juniper_warp::{playground_filter, subscriptions::graphql_subscriptions}; +use juniper_graphql_ws::ConnectionConfig; +use juniper_warp::{playground_filter, subscriptions::serve_graphql_ws}; use warp::{http::Response, Filter}; #[derive(Clone)] @@ -151,30 +151,24 @@ async fn main() { let qm_state = warp::any().map(move || Context {}); let qm_graphql_filter = juniper_warp::make_graphql_filter(qm_schema, qm_state.boxed()); - let sub_state = warp::any().map(move || Context {}); - let coordinator = Arc::new(juniper_subscriptions::Coordinator::new(schema())); + let root_node = Arc::new(schema()); log::info!("Listening on 127.0.0.1:8080"); let routes = (warp::path("subscriptions") .and(warp::ws()) - .and(sub_state.clone()) - .and(warp::any().map(move || Arc::clone(&coordinator))) - .map( - |ws: warp::ws::Ws, - ctx: Context, - coordinator: Arc<Coordinator<'static, _, _, _, _, _>>| { - ws.on_upgrade(|websocket| -> Pin<Box<dyn Future<Output = ()> + Send>> { - graphql_subscriptions(websocket, coordinator, ctx) - .map(|r| { - if let Err(e) = r { - println!("Websocket error: {}", e); - } - }) - .boxed() - }) - }, - )) + .map(move |ws: warp::ws::Ws| { + let root_node = root_node.clone(); + ws.on_upgrade(move |websocket| async move { + serve_graphql_ws(websocket, root_node, ConnectionConfig::new(Context {})) + .map(|r| { + if let Err(e) = r { + println!("Websocket error: {}", e); + } + }) + .await + }) + })) .map(|reply| { // TODO#584: remove this workaround warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws") diff --git a/juniper/release.toml b/juniper/release.toml index 72391149..ab15f4e5 100644 --- a/juniper/release.toml +++ b/juniper/release.toml @@ -30,6 +30,8 @@ pre-release-replacements = [ {file="../juniper_warp/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""}, # Subscriptions {file="../juniper_subscriptions/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""}, + # GraphQL-WS + {file="../juniper_graphql_ws/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""}, # Actix-Web {file="../juniper_actix/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""}, {file="../juniper_actix/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""}, diff --git a/juniper_graphql_ws/CHANGELOG.md b/juniper_graphql_ws/CHANGELOG.md new file mode 100644 index 00000000..05232472 --- /dev/null +++ b/juniper_graphql_ws/CHANGELOG.md @@ -0,0 +1,3 @@ +# master + +- Initial Release diff --git a/juniper_graphql_ws/Cargo.toml b/juniper_graphql_ws/Cargo.toml new file mode 100644 index 00000000..8bf19ee7 --- /dev/null +++ b/juniper_graphql_ws/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "juniper_graphql_ws" +version = "0.1.0" +authors = ["Christopher Brown <ccbrown112@gmail.com>"] +license = "BSD-2-Clause" +description = "Graphql-ws protocol implementation for Juniper" +documentation = "https://docs.rs/juniper_graphql_ws" +repository = "https://github.com/graphql-rust/juniper" +keywords = ["graphql-ws", "juniper", "graphql", "apollo"] +edition = "2018" + +[dependencies] +juniper = { version = "0.14.2", path = "../juniper", default-features = false } +juniper_subscriptions = { path = "../juniper_subscriptions" } +serde = { version = "1.0.8", features = ["derive"] } +tokio = { version = "0.2", features = ["macros", "rt-core", "time"] } + +[dev-dependencies] +serde_json = { version = "1.0.2" } diff --git a/juniper_graphql_ws/Makefile.toml b/juniper_graphql_ws/Makefile.toml new file mode 100644 index 00000000..ba858470 --- /dev/null +++ b/juniper_graphql_ws/Makefile.toml @@ -0,0 +1,20 @@ +[env] +CARGO_MAKE_CARGO_ALL_FEATURES = "" + +[tasks.build-verbose] +condition = { rust_version = { min = "1.29.0" } } + +[tasks.build-verbose.windows] +condition = { rust_version = { min = "1.29.0" }, env = { "TARGET" = "x86_64-pc-windows-msvc" } } + +[tasks.test-verbose] +condition = { rust_version = { min = "1.29.0" } } + +[tasks.test-verbose.windows] +condition = { rust_version = { min = "1.29.0" }, env = { "TARGET" = "x86_64-pc-windows-msvc" } } + +[tasks.ci-coverage-flow] +condition = { rust_version = { min = "1.29.0" } } + +[tasks.ci-coverage-flow.windows] +disabled = true diff --git a/juniper_graphql_ws/release.toml b/juniper_graphql_ws/release.toml new file mode 100644 index 00000000..98e70594 --- /dev/null +++ b/juniper_graphql_ws/release.toml @@ -0,0 +1,8 @@ +no-dev-version = true +pre-release-commit-message = "Release {{crate_name}} {{version}}" +pro-release-commit-message = "Bump {{crate_name}} version to {{next_version}}" +tag-message = "Release {{crate_name}} {{version}}" +upload-doc = false +pre-release-replacements = [ + {file="src/lib.rs", search="docs.rs/juniper_graphql_ws/[a-z0-9\\.-]+", replace="docs.rs/juniper_graphql_ws/{{version}}"}, +] diff --git a/juniper_graphql_ws/src/client_message.rs b/juniper_graphql_ws/src/client_message.rs new file mode 100644 index 00000000..1e20caef --- /dev/null +++ b/juniper_graphql_ws/src/client_message.rs @@ -0,0 +1,131 @@ +use juniper::{ScalarValue, Variables}; + +/// The payload for a client's "start" message. This triggers execution of a query, mutation, or +/// subscription. +#[derive(Debug, Deserialize, PartialEq)] +#[serde(bound(deserialize = "S: ScalarValue"))] +#[serde(rename_all = "camelCase")] +pub struct StartPayload<S: ScalarValue> { + /// The document body. + pub query: String, + + /// The optional variables. + #[serde(default)] + pub variables: Variables<S>, + + /// The optional operation name (required if the document contains multiple operations). + pub operation_name: Option<String>, +} + +/// ClientMessage defines the message types that clients can send. +#[derive(Debug, Deserialize, PartialEq)] +#[serde(bound(deserialize = "S: ScalarValue"))] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum ClientMessage<S: ScalarValue> { + /// ConnectionInit is sent by the client upon connecting. + ConnectionInit { + /// Optional parameters of any type sent from the client. These are often used for + /// authentication. + #[serde(default)] + payload: Variables<S>, + }, + /// Start messages are used to execute a GraphQL operation. + Start { + /// The id of the operation. This can be anything, but must be unique. If there are other + /// in-flight operations with the same id, the message will be ignored or cause an error. + id: String, + + /// The query, variables, and operation name. + payload: StartPayload<S>, + }, + /// Stop messages are used to unsubscribe from a subscription. + Stop { + /// The id of the operation to stop. + id: String, + }, + /// ConnectionTerminate is used to terminate the connection. + ConnectionTerminate, +} + +#[cfg(test)] +mod test { + use super::*; + use juniper::{DefaultScalarValue, InputValue}; + + #[test] + fn test_deserialization() { + type ClientMessage = super::ClientMessage<DefaultScalarValue>; + + assert_eq!( + ClientMessage::ConnectionInit { + payload: [("foo".to_string(), InputValue::scalar("bar"))] + .iter() + .cloned() + .collect(), + }, + serde_json::from_str(r##"{"type": "connection_init", "payload": {"foo": "bar"}}"##) + .unwrap(), + ); + + assert_eq!( + ClientMessage::ConnectionInit { + payload: Variables::default(), + }, + serde_json::from_str(r##"{"type": "connection_init"}"##).unwrap(), + ); + + assert_eq!( + ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "query MyQuery { __typename }".to_string(), + variables: [("foo".to_string(), InputValue::scalar("bar"))] + .iter() + .cloned() + .collect(), + operation_name: Some("MyQuery".to_string()), + }, + }, + serde_json::from_str( + r##"{"type": "start", "id": "foo", "payload": { + "query": "query MyQuery { __typename }", + "variables": { + "foo": "bar" + }, + "operationName": "MyQuery" + }}"## + ) + .unwrap(), + ); + + assert_eq!( + ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "query MyQuery { __typename }".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }, + serde_json::from_str( + r##"{"type": "start", "id": "foo", "payload": { + "query": "query MyQuery { __typename }" + }}"## + ) + .unwrap(), + ); + + assert_eq!( + ClientMessage::Stop { + id: "foo".to_string() + }, + serde_json::from_str(r##"{"type": "stop", "id": "foo"}"##).unwrap(), + ); + + assert_eq!( + ClientMessage::ConnectionTerminate, + serde_json::from_str(r##"{"type": "connection_terminate"}"##).unwrap(), + ); + } +} diff --git a/juniper_graphql_ws/src/lib.rs b/juniper_graphql_ws/src/lib.rs new file mode 100644 index 00000000..7cf73eec --- /dev/null +++ b/juniper_graphql_ws/src/lib.rs @@ -0,0 +1,1073 @@ +/*! + +# juniper_graphql_ws + +This crate contains an implementation of the [graphql-ws protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/263844b5c1a850c1e29814564eb62cb587e5eaaf/PROTOCOL.md), as used by Apollo. + +*/ + +#![deny(missing_docs)] +#![deny(warnings)] + +#[macro_use] +extern crate serde; + +mod client_message; +pub use client_message::*; + +mod server_message; +pub use server_message::*; + +mod schema; +pub use schema::*; + +use juniper::{ + futures::{ + channel::oneshot, + future::{self, BoxFuture, Either, Future, FutureExt, TryFutureExt}, + stream::{self, BoxStream, SelectAll, StreamExt}, + task::{Context, Poll, Waker}, + Sink, Stream, + }, + GraphQLError, RuleError, ScalarValue, Variables, +}; +use std::{ + collections::HashMap, + convert::{Infallible, TryInto}, + error::Error, + marker::PhantomPinned, + pin::Pin, + sync::Arc, + time::Duration, +}; + +struct ExecutionParams<S: Schema> { + start_payload: StartPayload<S::ScalarValue>, + config: Arc<ConnectionConfig<S::Context>>, + schema: S, +} + +/// ConnectionConfig is used to configure the connection once the client sends the ConnectionInit +/// message. +pub struct ConnectionConfig<CtxT> { + context: CtxT, + max_in_flight_operations: usize, + keep_alive_interval: Duration, +} + +impl<CtxT> ConnectionConfig<CtxT> { + /// Constructs the configuration required for a connection to be accepted. + pub fn new(context: CtxT) -> Self { + Self { + context, + max_in_flight_operations: 0, + keep_alive_interval: Duration::from_secs(30), + } + } + + /// Specifies the maximum number of in-flight operations that a connection can have. If this + /// number is exceeded, attempting to start more will result in an error. By default, there is + /// no limit to in-flight operations. + pub fn with_max_in_flight_operations(mut self, max: usize) -> Self { + self.max_in_flight_operations = max; + self + } + + /// Specifies the interval at which to send keep-alives. Specifying a zero duration will + /// disable keep-alives. By default, keep-alives are sent every + /// 30 seconds. + pub fn with_keep_alive_interval(mut self, interval: Duration) -> Self { + self.keep_alive_interval = interval; + self + } +} + +impl<S: ScalarValue, CtxT: Unpin + Send + 'static> Init<S, CtxT> for ConnectionConfig<CtxT> { + type Error = Infallible; + type Future = future::Ready<Result<Self, Self::Error>>; + + fn init(self, _params: Variables<S>) -> Self::Future { + future::ready(Ok(self)) + } +} + +enum Reaction<S: Schema> { + ServerMessage(ServerMessage<S::ScalarValue>), + EndStream, +} + +impl<S: Schema> Reaction<S> { + /// Converts the reaction into a one-item stream. + fn to_stream(self) -> BoxStream<'static, Self> { + stream::once(future::ready(self)).boxed() + } +} + +/// Init defines the requirements for types that can provide connection configurations when +/// ConnectionInit messages are received. Implementations are provided for `ConnectionConfig` and +/// closures that meet the requirements. +pub trait Init<S: ScalarValue, CtxT>: Unpin + 'static { + /// The error that is returned on failure. The formatted error will be used as the contents of + /// the "message" field sent back to the client. + type Error: Error; + + /// The future configuration type. + type Future: Future<Output = Result<ConnectionConfig<CtxT>, Self::Error>> + Send + 'static; + + /// Returns a future for the configuration to use. + fn init(self, params: Variables<S>) -> Self::Future; +} + +impl<F, S, CtxT, Fut, E> Init<S, CtxT> for F +where + S: ScalarValue, + F: FnOnce(Variables<S>) -> Fut + Unpin + 'static, + Fut: Future<Output = Result<ConnectionConfig<CtxT>, E>> + Send + 'static, + E: Error, +{ + type Error = E; + type Future = Fut; + + fn init(self, params: Variables<S>) -> Fut { + self(params) + } +} + +enum ConnectionState<S: Schema, I: Init<S::ScalarValue, S::Context>> { + /// PreInit is the state before a ConnectionInit message has been accepted. + PreInit { init: I, schema: S }, + /// Active is the state after a ConnectionInit message has been accepted. + Active { + config: Arc<ConnectionConfig<S::Context>>, + stoppers: HashMap<String, oneshot::Sender<()>>, + schema: S, + }, + /// Terminated is the state after a ConnectionInit message has been rejected. + Terminated, +} + +impl<S: Schema, I: Init<S::ScalarValue, S::Context>> ConnectionState<S, I> { + // Each message we receive results in a stream of zero or more reactions. For example, a + // ConnectionTerminate message results in a one-item stream with the EndStream reaction. + async fn handle_message( + self, + msg: ClientMessage<S::ScalarValue>, + ) -> (Self, BoxStream<'static, Reaction<S>>) { + if let ClientMessage::ConnectionTerminate = msg { + return (self, Reaction::EndStream.to_stream()); + } + + match self { + Self::PreInit { init, schema } => match msg { + ClientMessage::ConnectionInit { payload } => match init.init(payload).await { + Ok(config) => { + let keep_alive_interval = config.keep_alive_interval; + + let mut s = stream::iter(vec![Reaction::ServerMessage( + ServerMessage::ConnectionAck, + )]) + .boxed(); + + if keep_alive_interval > Duration::from_secs(0) { + s = s + .chain( + Reaction::ServerMessage(ServerMessage::ConnectionKeepAlive) + .to_stream(), + ) + .boxed(); + s = s + .chain(stream::unfold((), move |_| async move { + tokio::time::delay_for(keep_alive_interval).await; + Some(( + Reaction::ServerMessage(ServerMessage::ConnectionKeepAlive), + (), + )) + })) + .boxed(); + } + + ( + Self::Active { + config: Arc::new(config), + stoppers: HashMap::new(), + schema, + }, + s, + ) + } + Err(e) => ( + Self::Terminated, + stream::iter(vec![ + Reaction::ServerMessage(ServerMessage::ConnectionError { + payload: ConnectionErrorPayload { + message: e.to_string(), + }, + }), + Reaction::EndStream, + ]) + .boxed(), + ), + }, + _ => (Self::PreInit { init, schema }, stream::empty().boxed()), + }, + Self::Active { + config, + mut stoppers, + schema, + } => { + let reactions = match msg { + ClientMessage::Start { id, payload } => { + if stoppers.contains_key(&id) { + // We already have an operation with this id, so we can't start a new + // one. + stream::empty().boxed() + } else { + // Go ahead and prune canceled stoppers before adding a new one. + stoppers.retain(|_, tx| !tx.is_canceled()); + + if config.max_in_flight_operations > 0 + && stoppers.len() >= config.max_in_flight_operations + { + // Too many in-flight operations. Just send back a validation error. + stream::iter(vec![ + Reaction::ServerMessage(ServerMessage::Error { + id: id.clone(), + payload: GraphQLError::ValidationError(vec![ + RuleError::new("Too many in-flight operations.", &[]), + ]) + .into(), + }), + Reaction::ServerMessage(ServerMessage::Complete { id }), + ]) + .boxed() + } else { + // Create a channel that we can use to cancel the operation. + let (tx, rx) = oneshot::channel::<()>(); + stoppers.insert(id.clone(), tx); + + // Create the operation stream. This stream will emit Data and Error + // messages, but will not emit Complete – that part is up to us. + let s = Self::start( + id.clone(), + ExecutionParams { + start_payload: payload, + config: config.clone(), + schema: schema.clone(), + }, + ) + .into_stream() + .flatten(); + + // Combine this with our oneshot channel so that the stream ends if the + // oneshot is ever fired. + let s = stream::unfold((rx, s.boxed()), |(rx, mut s)| async move { + let next = match future::select(rx, s.next()).await { + Either::Left(_) => None, + Either::Right((r, rx)) => r.map(|r| (r, rx)), + }; + next.map(|(r, rx)| (r, (rx, s))) + }); + + // Once the stream ends, send the Complete message. + let s = s.chain( + Reaction::ServerMessage(ServerMessage::Complete { id }) + .to_stream(), + ); + + s.boxed() + } + } + } + ClientMessage::Stop { id } => { + stoppers.remove(&id); + stream::empty().boxed() + } + _ => stream::empty().boxed(), + }; + ( + Self::Active { + config, + stoppers, + schema, + }, + reactions, + ) + } + Self::Terminated => (self, stream::empty().boxed()), + } + } + + async fn start(id: String, params: ExecutionParams<S>) -> BoxStream<'static, Reaction<S>> { + // TODO: This could be made more efficient if juniper exposed functionality to allow us to + // parse and validate the query, determine whether it's a subscription, and then execute + // it. For now, the query gets parsed and validated twice. + + let params = Arc::new(params); + + // Try to execute this as a query or mutation. + match juniper::execute( + ¶ms.start_payload.query, + params + .start_payload + .operation_name + .as_ref() + .map(|s| s.as_str()), + params.schema.root_node(), + ¶ms.start_payload.variables, + ¶ms.config.context, + ) + .await + { + Ok((data, errors)) => { + return Reaction::ServerMessage(ServerMessage::Data { + id: id.clone(), + payload: DataPayload { data, errors }, + }) + .to_stream(); + } + Err(GraphQLError::IsSubscription) => {} + Err(e) => { + return Reaction::ServerMessage(ServerMessage::Error { + id: id.clone(), + // e only references data owned by params. The new ErrorPayload will continue to keep that data alive. + payload: unsafe { ErrorPayload::new_unchecked(Box::new(params.clone()), e) }, + }) + .to_stream(); + } + } + + // Try to execute as a subscription. + SubscriptionStart::new(id, params.clone()).boxed() + } +} + +struct InterruptableStream<S> { + stream: S, + rx: oneshot::Receiver<()>, +} + +impl<S: Stream + Unpin> Stream for InterruptableStream<S> { + type Item = S::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { + match Pin::new(&mut self.rx).poll(cx) { + Poll::Ready(_) => return Poll::Ready(None), + Poll::Pending => {} + } + Pin::new(&mut self.stream).poll_next(cx) + } +} + +/// SubscriptionStartState is the state for a subscription operation. +enum SubscriptionStartState<S: Schema> { + /// Init is the start before being polled for the first time. + Init { id: String }, + /// ResolvingIntoStream is the state after being polled for the first time. In this state, + /// we're parsing, validating, and getting the actual event stream. + ResolvingIntoStream { + id: String, + future: BoxFuture< + 'static, + Result< + juniper_subscriptions::Connection<'static, S::ScalarValue>, + GraphQLError<'static>, + >, + >, + }, + /// Streaming is the state after we've successfully obtained the event stream for the + /// subscription. In this state, we're just forwarding events back to the client. + Streaming { + id: String, + stream: juniper_subscriptions::Connection<'static, S::ScalarValue>, + }, + /// Terminated is the state once we're all done. + Terminated, +} + +/// SubscriptionStart is the stream for a subscription operation. +struct SubscriptionStart<S: Schema> { + params: Arc<ExecutionParams<S>>, + state: SubscriptionStartState<S>, + _marker: PhantomPinned, +} + +impl<S: Schema> SubscriptionStart<S> { + fn new(id: String, params: Arc<ExecutionParams<S>>) -> Pin<Box<Self>> { + Box::pin(Self { + params, + state: SubscriptionStartState::Init { id }, + _marker: PhantomPinned, + }) + } +} + +impl<S: Schema> Stream for SubscriptionStart<S> { + type Item = Reaction<S>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { + let (params, state) = unsafe { + // XXX: The execution parameters are referenced by state and must not be modified. + // Modifying state is fine though. + let inner = self.get_unchecked_mut(); + (&inner.params, &mut inner.state) + }; + + loop { + match state { + SubscriptionStartState::Init { id } => { + // XXX: resolve_into_stream returns a Future that references the execution + // parameters, and the returned stream also references them. We can guarantee + // that everything has the same lifetime in this self-referential struct. + let params = Arc::as_ptr(params); + *state = SubscriptionStartState::ResolvingIntoStream { + id: id.clone(), + future: unsafe { + juniper::resolve_into_stream( + &(*params).start_payload.query, + (*params) + .start_payload + .operation_name + .as_ref() + .map(|s| s.as_str()), + (*params).schema.root_node(), + &(*params).start_payload.variables, + &(*params).config.context, + ) + } + .map_ok(|(stream, errors)| { + juniper_subscriptions::Connection::from_stream(stream, errors) + }) + .boxed(), + }; + } + SubscriptionStartState::ResolvingIntoStream { + ref id, + ref mut future, + } => match future.as_mut().poll(cx) { + Poll::Ready(r) => match r { + Ok(stream) => { + *state = SubscriptionStartState::Streaming { + id: id.clone(), + stream, + } + } + Err(e) => { + return Poll::Ready(Some(Reaction::ServerMessage( + ServerMessage::Error { + id: id.clone(), + // e only references data owned by params. The new ErrorPayload will continue to keep that data alive. + payload: unsafe { + ErrorPayload::new_unchecked(Box::new(params.clone()), e) + }, + }, + ))); + } + }, + Poll::Pending => return Poll::Pending, + }, + SubscriptionStartState::Streaming { + ref id, + ref mut stream, + } => match Pin::new(stream).poll_next(cx) { + Poll::Ready(Some(output)) => { + return Poll::Ready(Some(Reaction::ServerMessage(ServerMessage::Data { + id: id.clone(), + payload: DataPayload { + data: output.data, + errors: output.errors, + }, + }))); + } + Poll::Ready(None) => { + *state = SubscriptionStartState::Terminated; + return Poll::Ready(None); + } + Poll::Pending => return Poll::Pending, + }, + SubscriptionStartState::Terminated => return Poll::Ready(None), + } + } + } +} + +enum ConnectionSinkState<S: Schema, I: Init<S::ScalarValue, S::Context>> { + Ready { + state: ConnectionState<S, I>, + }, + HandlingMessage { + result: BoxFuture<'static, (ConnectionState<S, I>, BoxStream<'static, Reaction<S>>)>, + }, + Closed, +} + +/// Implements the graphql-ws protocol. This is a sink for `TryInto<ClientMessage>` and a stream of +/// `ServerMessage`. +pub struct Connection<S: Schema, I: Init<S::ScalarValue, S::Context>> { + reactions: SelectAll<BoxStream<'static, Reaction<S>>>, + stream_waker: Option<Waker>, + sink_state: ConnectionSinkState<S, I>, +} + +impl<S, I> Connection<S, I> +where + S: Schema, + I: Init<S::ScalarValue, S::Context>, +{ + /// Creates a new connection, which is a sink for `TryInto<ClientMessage>` and a stream of `ServerMessage`. + /// + /// The `schema` argument should typically be an `Arc<RootNode<...>>`. + /// + /// The `init` argument is used to provide the context and additional configuration for + /// connections. This can be a `ConnectionConfig` if the context and configuration are already + /// known, or it can be a closure that gets executed asynchronously when the client sends the + /// ConnectionInit message. Using a closure allows you to perform authentication based on the + /// parameters provided by the client. + pub fn new(schema: S, init: I) -> Self { + Self { + reactions: SelectAll::new(), + stream_waker: None, + sink_state: ConnectionSinkState::Ready { + state: ConnectionState::PreInit { init, schema }, + }, + } + } +} + +impl<S, I, T> Sink<T> for Connection<S, I> +where + T: TryInto<ClientMessage<S::ScalarValue>>, + T::Error: Error, + S: Schema, + I: Init<S::ScalarValue, S::Context> + Send, +{ + type Error = Infallible; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> { + match &mut self.sink_state { + ConnectionSinkState::Ready { .. } => Poll::Ready(Ok(())), + ConnectionSinkState::HandlingMessage { ref mut result } => { + match Pin::new(result).poll(cx) { + Poll::Ready((state, reactions)) => { + self.reactions.push(reactions); + self.sink_state = ConnectionSinkState::Ready { state }; + Poll::Ready(Ok(())) + } + Poll::Pending => Poll::Pending, + } + } + ConnectionSinkState::Closed => panic!("poll_ready called after close"), + } + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + let s = self.get_mut(); + let state = &mut s.sink_state; + *state = match std::mem::replace(state, ConnectionSinkState::Closed) { + ConnectionSinkState::Ready { state } => { + match item.try_into() { + Ok(msg) => ConnectionSinkState::HandlingMessage { + result: state.handle_message(msg).boxed(), + }, + Err(e) => { + // If we weren't able to parse the message, send back an error. + s.reactions.push( + Reaction::ServerMessage(ServerMessage::ConnectionError { + payload: ConnectionErrorPayload { + message: e.to_string(), + }, + }) + .to_stream(), + ); + ConnectionSinkState::Ready { state } + } + } + } + _ => panic!("start_send called when not ready"), + }; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> { + <Self as Sink<T>>::poll_ready(self, cx) + } + + fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> { + self.sink_state = ConnectionSinkState::Closed; + if let Some(waker) = self.stream_waker.take() { + // Wake up the stream so it can close too. + waker.wake(); + } + Poll::Ready(Ok(())) + } +} + +impl<S, I> Stream for Connection<S, I> +where + S: Schema, + I: Init<S::ScalarValue, S::Context>, +{ + type Item = ServerMessage<S::ScalarValue>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { + self.stream_waker = Some(cx.waker().clone()); + + if let ConnectionSinkState::Closed = self.sink_state { + return Poll::Ready(None); + } + + // Poll the reactions for new outgoing messages. + loop { + if !self.reactions.is_empty() { + match Pin::new(&mut self.reactions).poll_next(cx) { + Poll::Ready(Some(reaction)) => match reaction { + Reaction::ServerMessage(msg) => return Poll::Ready(Some(msg)), + Reaction::EndStream => return Poll::Ready(None), + }, + Poll::Ready(None) => { + // In rare cases, the reaction stream may terminate. For example, this will + // happen if the first message we receive does not require any reaction. Just + // recreate it in that case. + self.reactions = SelectAll::new(); + return Poll::Pending; + } + Poll::Pending => return Poll::Pending, + } + } else { + return Poll::Pending; + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use juniper::{ + futures::sink::SinkExt, + parser::{ParseError, Spanning, Token}, + DefaultScalarValue, EmptyMutation, FieldError, FieldResult, InputValue, RootNode, Value, + }; + use std::{convert::Infallible, io}; + + struct Context(i32); + + struct Query; + + #[juniper::graphql_object(Context=Context)] + impl Query { + /// context just resolves to the current context. + async fn context(context: &Context) -> i32 { + context.0 + } + } + + struct Subscription; + + #[juniper::graphql_subscription(Context=Context)] + impl Subscription { + /// never never emits anything. + async fn never(context: &Context) -> BoxStream<'static, FieldResult<i32>> { + tokio::time::delay_for(Duration::from_secs(10000)) + .map(|_| unreachable!()) + .into_stream() + .boxed() + } + + /// context emits the current context once, then never emits anything else. + async fn context(context: &Context) -> BoxStream<'static, FieldResult<i32>> { + stream::once(future::ready(Ok(context.0))) + .chain( + tokio::time::delay_for(Duration::from_secs(10000)) + .map(|_| unreachable!()) + .into_stream(), + ) + .boxed() + } + + /// error emits an error once, then never emits anything else. + async fn error(context: &Context) -> BoxStream<'static, FieldResult<i32>> { + stream::once(future::ready(Err(FieldError::new( + "field error", + Value::null(), + )))) + .chain( + tokio::time::delay_for(Duration::from_secs(10000)) + .map(|_| unreachable!()) + .into_stream(), + ) + .boxed() + } + } + + type ClientMessage = super::ClientMessage<DefaultScalarValue>; + type ServerMessage = super::ServerMessage<DefaultScalarValue>; + + fn new_test_schema() -> Arc<RootNode<'static, Query, EmptyMutation<Context>, Subscription>> { + Arc::new(RootNode::new(Query, EmptyMutation::new(), Subscription)) + } + + #[tokio::test] + async fn test_query() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + + conn.send(ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "{context}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + assert_eq!( + ServerMessage::Data { + id: "foo".to_string(), + payload: DataPayload { + data: Value::Object( + [("context", Value::Scalar(DefaultScalarValue::Int(1)))] + .iter() + .cloned() + .collect() + ), + errors: vec![], + }, + }, + conn.next().await.unwrap() + ); + + assert_eq!( + ServerMessage::Complete { + id: "foo".to_string(), + }, + conn.next().await.unwrap() + ); + } + + #[tokio::test] + async fn test_subscriptions() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + + conn.send(ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "subscription Foo {context}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + assert_eq!( + ServerMessage::Data { + id: "foo".to_string(), + payload: DataPayload { + data: Value::Object([("context", Value::scalar(1))].iter().cloned().collect()), + errors: vec![], + }, + }, + conn.next().await.unwrap() + ); + + conn.send(ClientMessage::Start { + id: "bar".to_string(), + payload: StartPayload { + query: "subscription Bar {context}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + assert_eq!( + ServerMessage::Data { + id: "bar".to_string(), + payload: DataPayload { + data: Value::Object([("context", Value::scalar(1))].iter().cloned().collect()), + errors: vec![], + }, + }, + conn.next().await.unwrap() + ); + + conn.send(ClientMessage::Stop { + id: "foo".to_string(), + }) + .await + .unwrap(); + + assert_eq!( + ServerMessage::Complete { + id: "foo".to_string(), + }, + conn.next().await.unwrap() + ); + } + + #[tokio::test] + async fn test_init_params_ok() { + let mut conn = Connection::new(new_test_schema(), |params: Variables| async move { + assert_eq!(params.get("foo"), Some(&InputValue::scalar("bar"))); + Ok(ConnectionConfig::new(Context(1))) as Result<_, Infallible> + }); + + conn.send(ClientMessage::ConnectionInit { + payload: [("foo".to_string(), InputValue::scalar("bar".to_string()))] + .iter() + .cloned() + .collect(), + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + } + + #[tokio::test] + async fn test_init_params_error() { + let mut conn = Connection::new(new_test_schema(), |params: Variables| async move { + assert_eq!(params.get("foo"), Some(&InputValue::scalar("bar"))); + Err(io::Error::new(io::ErrorKind::Other, "init error")) + }); + + conn.send(ClientMessage::ConnectionInit { + payload: [("foo".to_string(), InputValue::scalar("bar".to_string()))] + .iter() + .cloned() + .collect(), + }) + .await + .unwrap(); + + assert_eq!( + ServerMessage::ConnectionError { + payload: ConnectionErrorPayload { + message: "init error".to_string(), + }, + }, + conn.next().await.unwrap() + ); + } + + #[tokio::test] + async fn test_max_in_flight_operations() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)) + .with_keep_alive_interval(Duration::from_secs(0)) + .with_max_in_flight_operations(1), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + + conn.send(ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "subscription Foo {never}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + conn.send(ClientMessage::Start { + id: "bar".to_string(), + payload: StartPayload { + query: "subscription Bar {never}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + match conn.next().await.unwrap() { + ServerMessage::Error { id, .. } => { + assert_eq!(id, "bar"); + } + msg @ _ => panic!("expected error, got: {:?}", msg), + } + } + + #[tokio::test] + async fn test_parse_error() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + + conn.send(ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "asd".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + match conn.next().await.unwrap() { + ServerMessage::Error { id, payload } => { + assert_eq!(id, "foo"); + match payload.graphql_error() { + GraphQLError::ParseError(Spanning { + item: ParseError::UnexpectedToken(Token::Name("asd")), + .. + }) => {} + p @ _ => panic!("expected graphql parse error, got: {:?}", p), + } + } + msg @ _ => panic!("expected error, got: {:?}", msg), + } + } + + #[tokio::test] + async fn test_keep_alives() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_millis(20)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + + for _ in 0..10 { + assert_eq!( + ServerMessage::ConnectionKeepAlive, + conn.next().await.unwrap() + ); + } + } + + #[tokio::test] + async fn test_slow_init() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .await + .unwrap(); + + // If we send the start message before the init is handled, we should still get results. + conn.send(ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "{context}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + + assert_eq!( + ServerMessage::Data { + id: "foo".to_string(), + payload: DataPayload { + data: Value::Object( + [("context", Value::Scalar(DefaultScalarValue::Int(1)))] + .iter() + .cloned() + .collect() + ), + errors: vec![], + }, + }, + conn.next().await.unwrap() + ); + } + + #[tokio::test] + async fn test_subscription_field_error() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: Variables::default(), + }) + .await + .unwrap(); + + assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap()); + + conn.send(ClientMessage::Start { + id: "foo".to_string(), + payload: StartPayload { + query: "subscription Foo {error}".to_string(), + variables: Variables::default(), + operation_name: None, + }, + }) + .await + .unwrap(); + + match conn.next().await.unwrap() { + ServerMessage::Data { + id, + payload: DataPayload { data, errors }, + } => { + assert_eq!(id, "foo"); + assert_eq!( + data, + Value::Object([("error", Value::null())].iter().cloned().collect()) + ); + assert_eq!(errors.len(), 1); + } + msg @ _ => panic!("expected data, got: {:?}", msg), + } + } +} diff --git a/juniper_graphql_ws/src/schema.rs b/juniper_graphql_ws/src/schema.rs new file mode 100644 index 00000000..68d282f0 --- /dev/null +++ b/juniper_graphql_ws/src/schema.rs @@ -0,0 +1,131 @@ +use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue}; +use std::sync::Arc; + +/// Schema defines the requirements for schemas that can be used for operations. Typically this is +/// just an `Arc<RootNode<...>>` and you should not have to implement it yourself. +pub trait Schema: Unpin + Clone + Send + Sync + 'static { + /// The context type. + type Context: Unpin + Send + Sync; + + /// The scalar value type. + type ScalarValue: ScalarValue + Send + Sync; + + /// The query type info. + type QueryTypeInfo: Send + Sync; + + /// The query type. + type Query: GraphQLTypeAsync<Self::ScalarValue, Context = Self::Context, TypeInfo = Self::QueryTypeInfo> + + Send; + + /// The mutation type info. + type MutationTypeInfo: Send + Sync; + + /// The mutation type. + type Mutation: GraphQLTypeAsync< + Self::ScalarValue, + Context = Self::Context, + TypeInfo = Self::MutationTypeInfo, + > + Send; + + /// The subscription type info. + type SubscriptionTypeInfo: Send + Sync; + + /// The subscription type. + type Subscription: GraphQLSubscriptionType< + Self::ScalarValue, + Context = Self::Context, + TypeInfo = Self::SubscriptionTypeInfo, + > + Send; + + /// Returns the root node for the schema. + fn root_node( + &self, + ) -> &RootNode<'static, Self::Query, Self::Mutation, Self::Subscription, Self::ScalarValue>; +} + +/// This exists as a work-around for this issue: https://github.com/rust-lang/rust/issues/64552 +/// +/// It can be used in generators where using Arc directly would result in an error. +// TODO: Remove this once that issue is resolved. +#[doc(hidden)] +pub struct ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>( + pub Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>, +) +where + QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync, + S: ScalarValue + Send + Sync + 'static; + +impl<QueryT, MutationT, SubscriptionT, CtxT, S> Clone + for ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S> +where + QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync, + S: ScalarValue + Send + Sync + 'static, +{ + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl<QueryT, MutationT, SubscriptionT, CtxT, S> Schema + for ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S> +where + QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, +{ + type Context = CtxT; + type ScalarValue = S; + type QueryTypeInfo = QueryT::TypeInfo; + type Query = QueryT; + type MutationTypeInfo = MutationT::TypeInfo; + type Mutation = MutationT; + type SubscriptionTypeInfo = SubscriptionT::TypeInfo; + type Subscription = SubscriptionT; + + fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> { + &self.0 + } +} + +impl<QueryT, MutationT, SubscriptionT, CtxT, S> Schema + for Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>> +where + QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync, + S: ScalarValue + Send + Sync + 'static, +{ + type Context = CtxT; + type ScalarValue = S; + type QueryTypeInfo = QueryT::TypeInfo; + type Query = QueryT; + type MutationTypeInfo = MutationT::TypeInfo; + type Mutation = MutationT; + type SubscriptionTypeInfo = SubscriptionT::TypeInfo; + type Subscription = SubscriptionT; + + fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> { + self + } +} diff --git a/juniper_graphql_ws/src/server_message.rs b/juniper_graphql_ws/src/server_message.rs new file mode 100644 index 00000000..3c353164 --- /dev/null +++ b/juniper_graphql_ws/src/server_message.rs @@ -0,0 +1,191 @@ +use juniper::{ExecutionError, GraphQLError, ScalarValue, Value}; +use serde::{Serialize, Serializer}; +use std::{any::Any, fmt, marker::PhantomPinned}; + +/// The payload for errors that are not associated with a GraphQL operation. +#[derive(Debug, Serialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct ConnectionErrorPayload { + /// The error message. + pub message: String, +} + +/// Sent after execution of an operation. For queries and mutations, this is sent to the client +/// once. For subscriptions, this is sent for every event in the event stream. +#[derive(Debug, Serialize, PartialEq)] +#[serde(bound(serialize = "S: ScalarValue"))] +#[serde(rename_all = "camelCase")] +pub struct DataPayload<S> { + /// The result data. + pub data: Value<S>, + + /// The errors that have occurred during execution. Note that parse and validation errors are + /// not included here. They are sent via Error messages. + pub errors: Vec<ExecutionError<S>>, +} + +/// A payload for errors that can happen before execution. Errors that happen during execution are +/// instead sent to the client via `DataPayload`. `ErrorPayload` is a wrapper for an owned +/// `GraphQLError`. +// XXX: Think carefully before deriving traits. This is self-referential (error references +// _execution_params). +pub struct ErrorPayload { + _execution_params: Option<Box<dyn Any + Send>>, + error: GraphQLError<'static>, + _marker: PhantomPinned, +} + +impl ErrorPayload { + /// For this to be okay, the caller must guarantee that the error can only reference data from + /// execution_params and that execution_params has not been modified or moved. + pub(crate) unsafe fn new_unchecked<'a>( + execution_params: Box<dyn Any + Send>, + error: GraphQLError<'a>, + ) -> Self { + Self { + _execution_params: Some(execution_params), + error: std::mem::transmute(error), + _marker: PhantomPinned, + } + } + + /// Returns the contained GraphQLError. + pub fn graphql_error<'a>(&'a self) -> &GraphQLError<'a> { + &self.error + } +} + +impl fmt::Debug for ErrorPayload { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.error.fmt(f) + } +} + +impl PartialEq for ErrorPayload { + fn eq(&self, other: &Self) -> bool { + self.error.eq(&other.error) + } +} + +impl Serialize for ErrorPayload { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + self.error.serialize(serializer) + } +} + +impl From<GraphQLError<'static>> for ErrorPayload { + fn from(error: GraphQLError<'static>) -> Self { + Self { + _execution_params: None, + error, + _marker: PhantomPinned, + } + } +} + +/// ServerMessage defines the message types that servers can send. +#[derive(Debug, Serialize, PartialEq)] +#[serde(bound(serialize = "S: ScalarValue"))] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum ServerMessage<S: ScalarValue> { + /// ConnectionError is used for errors that are not associated with a GraphQL operation. For + /// example, this will be used when: + /// + /// * The server is unable to parse a client's message. + /// * The client's initialization parameters are rejected. + ConnectionError { + /// The error that occurred. + payload: ConnectionErrorPayload, + }, + /// ConnectionAck is sent in response to a client's ConnectionInit message if the server accepted a + /// connection. + ConnectionAck, + /// Data contains the result of a query, mutation, or subscription event. + Data { + /// The id of the operation that the data is for. + id: String, + + /// The data and errors that occurred during execution. + payload: DataPayload<S>, + }, + /// Error contains an error that occurs before execution, such as validation errors. + Error { + /// The id of the operation that triggered this error. + id: String, + + /// The error(s). + payload: ErrorPayload, + }, + /// Complete indicates that no more data will be sent for the given operation. + Complete { + /// The id of the operation that has completed. + id: String, + }, + /// ConnectionKeepAlive is sent periodically after accepting a connection. + #[serde(rename = "ka")] + ConnectionKeepAlive, +} + +#[cfg(test)] +mod test { + use super::*; + use juniper::DefaultScalarValue; + + #[test] + fn test_serialization() { + type ServerMessage = super::ServerMessage<DefaultScalarValue>; + + assert_eq!( + serde_json::to_string(&ServerMessage::ConnectionError { + payload: ConnectionErrorPayload { + message: "foo".to_string(), + }, + }) + .unwrap(), + r##"{"type":"connection_error","payload":{"message":"foo"}}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(), + r##"{"type":"connection_ack"}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::Data { + id: "foo".to_string(), + payload: DataPayload { + data: Value::null(), + errors: vec![], + }, + }) + .unwrap(), + r##"{"type":"data","id":"foo","payload":{"data":null,"errors":[]}}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::Error { + id: "foo".to_string(), + payload: GraphQLError::UnknownOperationName.into(), + }) + .unwrap(), + r##"{"type":"error","id":"foo","payload":[{"message":"Unknown operation"}]}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::Complete { + id: "foo".to_string(), + }) + .unwrap(), + r##"{"type":"complete","id":"foo"}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::ConnectionKeepAlive).unwrap(), + r##"{"type":"ka"}"##, + ); + } +} diff --git a/juniper_subscriptions/src/lib.rs b/juniper_subscriptions/src/lib.rs index 0e78279c..3418c055 100644 --- a/juniper_subscriptions/src/lib.rs +++ b/juniper_subscriptions/src/lib.rs @@ -222,19 +222,25 @@ where } if filled_count == obj_len { + let mut errors = vec![]; filled_count = 0; let new_vec = (0..obj_len).map(|_| None).collect::<Vec<_>>(); let ready_vec = std::mem::replace(&mut ready_vec, new_vec); let ready_vec_iterator = ready_vec.into_iter().map(|el| { let (name, val) = el.unwrap(); - if let Ok(value) = val { - (name, value) - } else { - (name, Value::Null) + match val { + Ok(value) => (name, value), + Err(e) => { + errors.push(e); + (name, Value::Null) + } } }); let obj = Object::from_iter(ready_vec_iterator); - Poll::Ready(Some(ExecutionOutput::from_data(Value::Object(obj)))) + Poll::Ready(Some(ExecutionOutput { + data: Value::Object(obj), + errors, + })) } else { Poll::Pending } diff --git a/juniper_warp/Cargo.toml b/juniper_warp/Cargo.toml index cf14ae32..f2fcb5b5 100644 --- a/juniper_warp/Cargo.toml +++ b/juniper_warp/Cargo.toml @@ -9,7 +9,7 @@ repository = "https://github.com/graphql-rust/juniper" edition = "2018" [features] -subscriptions = ["juniper_subscriptions"] +subscriptions = ["juniper_graphql_ws"] [dependencies] bytes = "0.5" @@ -17,7 +17,7 @@ anyhow = "1.0" thiserror = "1.0" futures = "0.3.1" juniper = { version = "0.14.2", path = "../juniper", default-features = false } -juniper_subscriptions = { path = "../juniper_subscriptions", optional = true } +juniper_graphql_ws = { path = "../juniper_graphql_ws", optional = true } serde = { version = "1.0.75", features = ["derive"] } serde_json = "1.0.24" tokio = { version = "0.2", features = ["blocking", "rt-core"] } diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs index 291ba011..06898304 100644 --- a/juniper_warp/src/lib.rs +++ b/juniper_warp/src/lib.rs @@ -393,224 +393,103 @@ fn playground_response( /// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md #[cfg(feature = "subscriptions")] pub mod subscriptions { - use std::{ - collections::HashMap, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, + use juniper::{ + futures::{ + future::{self, Either}, + sink::SinkExt, + stream::StreamExt, }, + GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue, }; + use juniper_graphql_ws::{ArcSchema, ClientMessage, Connection, Init}; + use std::{convert::Infallible, fmt, sync::Arc}; - use anyhow::anyhow; - use futures::{channel::mpsc, Future, StreamExt as _, TryFutureExt as _, TryStreamExt as _}; - use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _}; - use juniper_subscriptions::Coordinator; - use serde::{Deserialize, Serialize}; - use warp::ws::Message; + struct Message(warp::ws::Message); - /// Listen to incoming messages and do one of the following: - /// - execute subscription and return values from stream - /// - stop stream and close ws connection - #[allow(dead_code)] - pub fn graphql_subscriptions<Query, Mutation, Subscription, CtxT, S>( - websocket: warp::ws::WebSocket, - coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, CtxT, S>>, - context: CtxT, - ) -> impl Future<Output = Result<(), anyhow::Error>> + Send - where - Query: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, - Query::TypeInfo: Send + Sync, - Mutation: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, - Mutation::TypeInfo: Send + Sync, - Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, - Subscription::TypeInfo: Send + Sync, - CtxT: Send + Sync + 'static, - S: ScalarValue + Send + Sync + 'static, - { - let (sink_tx, sink_rx) = websocket.split(); - let (ws_tx, ws_rx) = mpsc::unbounded(); - tokio::task::spawn( - ws_rx - .take_while(|v: &Option<_>| futures::future::ready(v.is_some())) - .map(|x| x.unwrap()) - .forward(sink_tx), - ); + impl<S: ScalarValue> std::convert::TryFrom<Message> for ClientMessage<S> { + type Error = serde_json::Error; - let context = Arc::new(context); - let got_close_signal = Arc::new(AtomicBool::new(false)); - let got_close_signal2 = got_close_signal.clone(); - - struct SubscriptionState { - should_stop: AtomicBool, + fn try_from(msg: Message) -> serde_json::Result<Self> { + serde_json::from_slice(msg.0.as_bytes()) } - let subscription_states = HashMap::<String, Arc<SubscriptionState>>::new(); - - sink_rx - .map_err(move |e| { - got_close_signal2.store(true, Ordering::Relaxed); - anyhow!("Websocket error: {}", e) - }) - .try_fold(subscription_states, move |mut subscription_states, msg| { - let coordinator = coordinator.clone(); - let context = context.clone(); - let got_close_signal = got_close_signal.clone(); - let ws_tx = ws_tx.clone(); - - async move { - if msg.is_close() { - return Ok(subscription_states); - } - - let msg = msg - .to_str() - .map_err(|_| anyhow!("Non-text messages are not accepted"))?; - let request: WsPayload<S> = serde_json::from_str(msg) - .map_err(|e| anyhow!("Invalid WsPayload: {}", e))?; - - match request.type_name.as_str() { - "connection_init" => {} - "start" => { - if got_close_signal.load(Ordering::Relaxed) { - return Ok(subscription_states); - } - - let request_id = request.id.clone().unwrap_or("1".to_owned()); - - if let Some(existing) = subscription_states.get(&request_id) { - existing.should_stop.store(true, Ordering::Relaxed); - } - let state = Arc::new(SubscriptionState { - should_stop: AtomicBool::new(false), - }); - subscription_states.insert(request_id.clone(), state.clone()); - - let ws_tx = ws_tx.clone(); - - if let Some(ref payload) = request.payload { - if payload.query.is_none() { - return Err(anyhow!("Query not found")); - } - } else { - return Err(anyhow!("Payload not found")); - } - - tokio::task::spawn(async move { - let payload = request.payload.unwrap(); - - let graphql_request = GraphQLRequest::<S>::new( - payload.query.unwrap(), - None, - payload.variables, - ); - - let values_stream = match coordinator - .subscribe(&graphql_request, &context) - .await - { - Ok(s) => s, - Err(err) => { - let _ = - ws_tx.unbounded_send(Some(Ok(Message::text(format!( - r#"{{"type":"error","id":"{}","payload":{}}}"#, - request_id, - serde_json::ser::to_string(&err).unwrap_or( - "Error deserializing GraphQLError".to_owned() - ) - ))))); - - let close_message = format!( - r#"{{"type":"complete","id":"{}","payload":null}}"#, - request_id - ); - let _ = ws_tx - .unbounded_send(Some(Ok(Message::text(close_message)))); - // close channel - let _ = ws_tx.unbounded_send(None); - return; - } - }; - - values_stream - .take_while(move |response| { - let request_id = request_id.clone(); - let should_stop = state.should_stop.load(Ordering::Relaxed) - || got_close_signal.load(Ordering::Relaxed); - if !should_stop { - let mut response_text = serde_json::to_string( - &response, - ) - .unwrap_or("Error deserializing response".to_owned()); - - response_text = format!( - r#"{{"type":"data","id":"{}","payload":{} }}"#, - request_id, response_text - ); - - let _ = ws_tx.unbounded_send(Some(Ok(Message::text( - response_text, - )))); - } - - async move { !should_stop } - }) - .for_each(|_| async {}) - .await; - }); - } - "stop" => { - let request_id = request.id.unwrap_or("1".to_owned()); - if let Some(existing) = subscription_states.get(&request_id) { - existing.should_stop.store(true, Ordering::Relaxed); - subscription_states.remove(&request_id); - } - - let close_message = format!( - r#"{{"type":"complete","id":"{}","payload":null}}"#, - request_id - ); - let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message)))); - - // close channel - let _ = ws_tx.unbounded_send(None); - } - _ => {} - } - - Ok(subscription_states) - } - }) - .map_ok(|_| ()) } - #[derive(Deserialize)] - #[serde(bound = "GraphQLPayload<S>: Deserialize<'de>")] - struct WsPayload<S> + /// Errors that can happen while serving a connection. + #[derive(Debug)] + pub enum Error { + /// Errors that can happen in Warp while serving a connection. + Warp(warp::Error), + + /// Errors that can happen while serializing outgoing messages. Note that errors that occur + /// while deserializing internal messages are handled internally by the protocol. + Serde(serde_json::Error), + } + + impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Warp(e) => write!(f, "warp error: {}", e), + Self::Serde(e) => write!(f, "serde error: {}", e), + } + } + } + + impl std::error::Error for Error {} + + impl From<warp::Error> for Error { + fn from(err: warp::Error) -> Self { + Self::Warp(err) + } + } + + impl From<Infallible> for Error { + fn from(_err: Infallible) -> Self { + unreachable!() + } + } + + /// Serves the graphql-ws protocol over a WebSocket connection. + /// + /// The `init` argument is used to provide the context and additional configuration for + /// connections. This can be a `juniper_graphql_ws::ConnectionConfig` if the context and + /// configuration are already known, or it can be a closure that gets executed asynchronously + /// when the client sends the ConnectionInit message. Using a closure allows you to perform + /// authentication based on the parameters provided by the client. + pub async fn serve_graphql_ws<Query, Mutation, Subscription, CtxT, S, I>( + websocket: warp::ws::WebSocket, + root_node: Arc<RootNode<'static, Query, Mutation, Subscription, S>>, + init: I, + ) -> Result<(), Error> where - S: ScalarValue + Send + Sync, + Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + Query::TypeInfo: Send + Sync, + Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, + I: Init<S, CtxT> + Send, { - id: Option<String>, - #[serde(rename(deserialize = "type"))] - type_name: String, - payload: Option<GraphQLPayload<S>>, - } + let (ws_tx, ws_rx) = websocket.split(); + let (s_tx, s_rx) = Connection::new(ArcSchema(root_node), init).split(); - #[derive(Debug, Deserialize)] - #[serde(bound = "InputValue<S>: Deserialize<'de>")] - struct GraphQLPayload<S> - where - S: ScalarValue + Send + Sync, - { - variables: Option<InputValue<S>>, - extensions: Option<HashMap<String, String>>, - #[serde(rename(deserialize = "operationName"))] - operaton_name: Option<String>, - query: Option<String>, - } + let ws_rx = ws_rx.map(|r| r.map(|msg| Message(msg))); + let s_rx = s_rx.map(|msg| { + serde_json::to_string(&msg) + .map(|t| warp::ws::Message::text(t)) + .map_err(|e| Error::Serde(e)) + }); - #[derive(Serialize)] - struct Output { - data: String, - variables: String, + match future::select( + ws_rx.forward(s_tx.sink_err_into()), + s_rx.forward(ws_tx.sink_err_into()), + ) + .await + { + Either::Left((r, _)) => r.map_err(|e| e.into()), + Either::Right((r, _)) => r, + } } }