diff --git a/Cargo.toml b/Cargo.toml index 9b94e2fa..42535782 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "juniper_rocket", "juniper_subscriptions", "juniper_graphql_ws", + "juniper_graphql_transport_ws", "juniper_warp", "juniper_actix", "tests/codegen", diff --git a/juniper_graphql_transport_ws/CHANGELOG.md b/juniper_graphql_transport_ws/CHANGELOG.md new file mode 100644 index 00000000..79402114 --- /dev/null +++ b/juniper_graphql_transport_ws/CHANGELOG.md @@ -0,0 +1,16 @@ +`juniper_graphql_transport_ws` changelog +============================== + +All user visible changes to `juniper_graphql_transport_ws` crate will be documented in this file. This project uses [Semantic Versioning 2.0.0]. + + + + +## master + + + + +[`juniper` crate]: https://docs.rs/juniper +[`juniper_subscriptions` crate]: https://docs.rs/juniper_subscriptions +[Semantic Versioning 2.0.0]: https://semver.org diff --git a/juniper_graphql_transport_ws/Cargo.toml b/juniper_graphql_transport_ws/Cargo.toml new file mode 100644 index 00000000..11c6fae6 --- /dev/null +++ b/juniper_graphql_transport_ws/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "juniper_graphql_transport_ws" +version = "0.4.0-dev" +edition = "2021" +rust-version = "1.65" +description = "GraphQL over WebSocket Protocol implementation for `juniper` crate." +license = "BSD-2-Clause" +authors = ["Christopher Brown "] +documentation = "https://docs.rs/juniper_graphql_transport_ws" +homepage = "https://github.com/graphql-rust/juniper/tree/master/juniper_graphql_transport_ws" +repository = "https://github.com/graphql-rust/juniper" +readme = "README.md" +categories = ["asynchronous", "web-programming", "web-programming::http-server"] +keywords = ["apollo", "graphql", "graphql-ws", "subscription", "websocket"] +exclude = ["/release.toml"] + +[dependencies] +juniper = { version = "0.16.0-dev", path = "../juniper", default-features = false } +juniper_subscriptions = { version = "0.17.0-dev", path = "../juniper_subscriptions" } +serde = { version = "1.0.122", features = ["derive"], default-features = false } +tokio = { version = "1.0", features = ["macros", "rt", "time"], default-features = false } + +[dev-dependencies] +serde_json = "1.0.18" diff --git a/juniper_graphql_transport_ws/LICENSE b/juniper_graphql_transport_ws/LICENSE new file mode 100644 index 00000000..652dd95d --- /dev/null +++ b/juniper_graphql_transport_ws/LICENSE @@ -0,0 +1,25 @@ +BSD 2-Clause License + +Copyright (c) 2018-2022, Christopher Brown +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/juniper_graphql_transport_ws/README.md b/juniper_graphql_transport_ws/README.md new file mode 100644 index 00000000..972b7f26 --- /dev/null +++ b/juniper_graphql_transport_ws/README.md @@ -0,0 +1,24 @@ +`juniper_graphql_transport_ws` crate +========================== + +[![Crates.io](https://img.shields.io/crates/v/juniper_graphql_transport_ws.svg?maxAge=2592000)](https://crates.io/crates/juniper_graphql_transport_ws) +[![Documentation](https://docs.rs/juniper_graphql_transport_ws/badge.svg)](https://docs.rs/juniper_graphql_transport_ws) +[![CI](https://github.com/graphql-rust/juniper/workflows/CI/badge.svg?branch=master "CI")](https://github.com/graphql-rust/juniper/actions?query=workflow%3ACI+branch%3Amaster) +[![Rust 1.65+](https://img.shields.io/badge/rustc-1.65+-lightgray.svg "Rust 1.65+")](https://blog.rust-lang.org/2022/11/03/Rust-1.65.0.html) + +- [Changelog](https://github.com/graphql-rust/juniper/blob/master/juniper_graphql_transport_ws/CHANGELOG.md) + +This crate contains an implementation of the [graphql-transport-ws WebSocket subprotocol], as used by [Apollo]. + + + + +## License + +This project is licensed under [BSD 2-Clause License](https://github.com/graphql-rust/juniper/blob/master/juniper_graphql_transport_ws/LICENSE). + + + + +[Apollo]: https://www.apollographql.com +[graphql-transport-ws WebSocket subprotocol]: https://github.com/enisdenjo/graphql-ws/blob/fbb763a662802a6a2584b0cbeb9cf1bde38158e0/PROTOCOL.md diff --git a/juniper_graphql_transport_ws/release.toml b/juniper_graphql_transport_ws/release.toml new file mode 100644 index 00000000..ecf64dd5 --- /dev/null +++ b/juniper_graphql_transport_ws/release.toml @@ -0,0 +1,24 @@ +[[pre-release-replacements]] +file = "../juniper_actix/Cargo.toml" +exactly = 1 +search = "juniper_graphql_transport_ws = \\{ version = \"[^\"]+\"" +replace = "juniper_graphql_transport_ws = { version = \"{{version}}\"" + +[[pre-release-replacements]] +file = "../juniper_warp/Cargo.toml" +exactly = 1 +search = "juniper_graphql_transport_ws = \\{ version = \"[^\"]+\"" +replace = "juniper_graphql_transport_ws = { version = \"{{version}}\"" + +[[pre-release-replacements]] +file = "CHANGELOG.md" +max = 1 +min = 0 +search = "## master" +replace = "## [{{version}}] · {{date}}\n[{{version}}]: /../../tree/{{crate_name}}-v{{version}}/{{crate_name}}" + +[[pre-release-replacements]] +file = "README.md" +exactly = 2 +search = "graphql-rust/juniper/blob/[^/]+/" +replace = "graphql-rust/juniper/blob/{{crate_name}}-v{{version}}/" diff --git a/juniper_graphql_transport_ws/src/client_message.rs b/juniper_graphql_transport_ws/src/client_message.rs new file mode 100644 index 00000000..592327fa --- /dev/null +++ b/juniper_graphql_transport_ws/src/client_message.rs @@ -0,0 +1,155 @@ +use juniper::Variables; +use serde::Deserialize; + +use crate::utils::default_for_null; + +/// The payload for a client's "start" message. This triggers execution of a query, mutation, or +/// subscription. +#[derive(Debug, Deserialize, PartialEq)] +#[serde(bound(deserialize = "S: Deserialize<'de>"))] +#[serde(rename_all = "camelCase")] +pub struct SubscribePayload { + /// The document body. + pub query: String, + + /// The optional variables. + #[serde(default, deserialize_with = "default_for_null")] + pub variables: Variables, + + /// The optional operation name (required if the document contains multiple operations). + pub operation_name: Option, + + /// The optional extension data. + #[serde(default, deserialize_with = "default_for_null")] + pub extensions: Variables, +} + +/// ClientMessage defines the message types that clients can send. +#[derive(Debug, Deserialize, PartialEq)] +#[serde(bound(deserialize = "S: Deserialize<'de>"))] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum ClientMessage { + /// ConnectionInit is sent by the client upon connecting. + ConnectionInit { + /// Optional parameters of any type sent from the client. These are often used for + /// authentication. + #[serde(default, deserialize_with = "default_for_null")] + payload: Variables, + }, + /// Ping is used for detecting failed connections, displaying latency metrics or other types of network probing. + Ping { + /// Optional parameters of any type used to transfer additional details about the ping. + #[serde(default, deserialize_with = "default_for_null")] + payload: Variables, + }, + /// The response to the `Ping` message. + Pong { + /// Optional parameters of any type used to transfer additional details about the pong. + #[serde(default, deserialize_with = "default_for_null")] + payload: Variables, + }, + /// Requests an operation specified in the message payload. + Subscribe { + /// The id of the operation. This can be anything, but must be unique. If there are other + /// in-flight operations with the same id, the message will cause an error. + id: String, + + /// The query, variables, and operation name. + payload: SubscribePayload, + }, + /// Indicates that the client has stopped listening and wants to complete the subscription. + Complete { + /// The id of the operation to stop. + id: String, + }, +} + +#[cfg(test)] +mod test { + use juniper::{graphql_vars, DefaultScalarValue}; + + use super::*; + + #[test] + fn test_deserialization() { + type ClientMessage = super::ClientMessage; + + assert_eq!( + ClientMessage::ConnectionInit { + payload: graphql_vars! {"foo": "bar"}, + }, + serde_json::from_str(r##"{"type": "connection_init", "payload": {"foo": "bar"}}"##) + .unwrap(), + ); + + assert_eq!( + ClientMessage::ConnectionInit { + payload: graphql_vars! {}, + }, + serde_json::from_str(r##"{"type": "connection_init"}"##).unwrap(), + ); + + assert_eq!( + ClientMessage::Subscribe { + id: "foo".into(), + payload: SubscribePayload { + query: "query MyQuery { __typename }".into(), + variables: graphql_vars! {"foo": "bar"}, + operation_name: Some("MyQuery".into()), + extensions: Default::default(), + }, + }, + serde_json::from_str( + r##"{"type": "subscribe", "id": "foo", "payload": { + "query": "query MyQuery { __typename }", + "variables": { + "foo": "bar" + }, + "operationName": "MyQuery" + }}"## + ) + .unwrap(), + ); + + assert_eq!( + ClientMessage::Subscribe { + id: "foo".into(), + payload: SubscribePayload { + query: "query MyQuery { __typename }".into(), + variables: graphql_vars! {}, + operation_name: None, + extensions: Default::default(), + }, + }, + serde_json::from_str( + r##"{"type": "subscribe", "id": "foo", "payload": { + "query": "query MyQuery { __typename }" + }}"## + ) + .unwrap(), + ); + + assert_eq!( + ClientMessage::Complete { id: "foo".into() }, + serde_json::from_str(r##"{"type": "complete", "id": "foo"}"##).unwrap(), + ); + } + + #[test] + fn test_deserialization_of_null() -> serde_json::Result<()> { + let payload = r#"{"query":"query","variables":null}"#; + let payload: SubscribePayload = serde_json::from_str(payload)?; + + let expected = SubscribePayload { + query: "query".into(), + variables: graphql_vars! {}, + operation_name: None, + extensions: Default::default(), + }; + + assert_eq!(expected, payload); + + Ok(()) + } +} diff --git a/juniper_graphql_transport_ws/src/lib.rs b/juniper_graphql_transport_ws/src/lib.rs new file mode 100644 index 00000000..59b8deb3 --- /dev/null +++ b/juniper_graphql_transport_ws/src/lib.rs @@ -0,0 +1,1112 @@ +#![doc = include_str!("../README.md")] +#![deny(missing_docs, warnings)] + +mod client_message; +pub use client_message::*; + +mod server_message; +pub use server_message::*; + +mod schema; +pub use schema::*; + +mod utils; + +use std::{ + collections::HashMap, convert::Infallible, error::Error, marker::PhantomPinned, pin::Pin, + sync::Arc, time::Duration, +}; + +use juniper::{ + futures::{ + channel::oneshot, + future::{self, BoxFuture, Either, Future, FutureExt, TryFutureExt}, + stream::{self, BoxStream, SelectAll, StreamExt}, + task::{Context, Poll, Waker}, + Sink, Stream, + }, + GraphQLError, RuleError, ScalarValue, Variables, +}; + +struct ExecutionParams { + subscribe_payload: SubscribePayload, + config: Arc>, + schema: S, +} + +/// ConnectionConfig is used to configure the connection once the client sends the ConnectionInit +/// message. +pub struct ConnectionConfig { + context: CtxT, + max_in_flight_operations: usize, + keep_alive_interval: Duration, +} + +impl ConnectionConfig { + /// Constructs the configuration required for a connection to be accepted. + pub fn new(context: CtxT) -> Self { + Self { + context, + max_in_flight_operations: 0, + keep_alive_interval: Duration::from_secs(15), + } + } + + /// Specifies the maximum number of in-flight operations that a connection can have. If this + /// number is exceeded, attempting to start more will result in an error. By default, there is + /// no limit to in-flight operations. + #[must_use] + pub fn with_max_in_flight_operations(mut self, max: usize) -> Self { + self.max_in_flight_operations = max; + self + } + + /// Specifies the interval at which to send unsolicited pong messages as keep-alives. + /// Specifying a zero duration will disable keep-alives. By default, keep-alives are sent every + /// 15 seconds. + #[must_use] + pub fn with_keep_alive_interval(mut self, interval: Duration) -> Self { + self.keep_alive_interval = interval; + self + } +} + +impl Init for ConnectionConfig { + type Error = Infallible; + type Future = future::Ready>; + + fn init(self, _params: Variables) -> Self::Future { + future::ready(Ok(self)) + } +} + +/// Output provides the responses that should be sent to the client. +#[derive(Debug, PartialEq)] +pub enum Output { + /// Message is a message that should be serialized and sent to the client. + Message(ServerMessage), + /// Close indicates that the connection should be closed and provides a code and message to + /// send to the client. This is always the last message in the output stream. + Close { + /// The WebSocket code that should be sent. + code: u16, + + /// A message describing the reason for the connection closing. + message: String, + }, +} + +impl Output { + /// Converts the reaction into a one-item stream. + fn into_stream(self) -> BoxStream<'static, Self> { + stream::once(future::ready(self)).boxed() + } +} + +/// Init defines the requirements for types that can provide connection configurations when +/// ConnectionInit messages are received. Implementations are provided for `ConnectionConfig` and +/// closures that meet the requirements. +pub trait Init: Unpin + 'static { + /// The error that is returned on failure. The formatted error will be used as the contents of + /// the "message" field sent back to the client. + type Error: Error; + + /// The future configuration type. + type Future: Future, Self::Error>> + Send + 'static; + + /// Returns a future for the configuration to use. + fn init(self, params: Variables) -> Self::Future; +} + +impl Init for F +where + S: ScalarValue, + F: FnOnce(Variables) -> Fut + Unpin + 'static, + Fut: Future, E>> + Send + 'static, + E: Error, +{ + type Error = E; + type Future = Fut; + + fn init(self, params: Variables) -> Fut { + self(params) + } +} + +enum ConnectionState> { + /// PreInit is the state before a ConnectionInit message has been accepted. + PreInit { init: I, schema: S }, + /// Active is the state after a ConnectionInit message has been accepted. + Active { + config: Arc>, + stoppers: HashMap>, + schema: S, + }, + /// Terminated is the state after a ConnectionInit message has been rejected. + Terminated, +} + +impl> ConnectionState { + // Each message we receive results in a stream of zero or more reactions. For example, a + // Ping message results in a one-item stream with the Pong message reaction. + async fn handle_message( + self, + msg: ClientMessage, + ) -> (Self, BoxStream<'static, Output>) { + match self { + Self::PreInit { init, schema } => match msg { + ClientMessage::ConnectionInit { payload } => match init.init(payload).await { + Ok(config) => { + let keep_alive_interval = config.keep_alive_interval; + + let mut s = + stream::iter(vec![Output::Message(ServerMessage::ConnectionAck)]) + .boxed(); + + if keep_alive_interval > Duration::from_secs(0) { + s = s + .chain(Output::Message(ServerMessage::Pong).into_stream()) + .boxed(); + s = s + .chain(stream::unfold((), move |_| async move { + tokio::time::sleep(keep_alive_interval).await; + Some((Output::Message(ServerMessage::Pong), ())) + })) + .boxed(); + } + + ( + Self::Active { + config: Arc::new(config), + stoppers: HashMap::new(), + schema, + }, + s, + ) + } + Err(e) => ( + Self::Terminated, + stream::iter(vec![Output::Close { + code: 4403, + message: e.to_string(), + }]) + .boxed(), + ), + }, + ClientMessage::Ping { .. } => ( + Self::PreInit { init, schema }, + stream::iter(vec![Output::Message(ServerMessage::Pong)]).boxed(), + ), + ClientMessage::Subscribe { .. } => ( + Self::PreInit { init, schema }, + stream::iter(vec![Output::Close { + code: 4401, + message: "Unauthorized".to_string(), + }]) + .boxed(), + ), + _ => (Self::PreInit { init, schema }, stream::empty().boxed()), + }, + Self::Active { + config, + mut stoppers, + schema, + } => { + let reactions = match msg { + ClientMessage::Subscribe { id, payload } => { + if stoppers.contains_key(&id) { + // We already have an operation with this id. We must close the connection. + Output::Close { + code: 4409, + message: format!("Subscriber for {} already exists", id), + } + .into_stream() + } else { + // Go ahead and prune canceled stoppers before adding a new one. + stoppers.retain(|_, tx| !tx.is_canceled()); + + if config.max_in_flight_operations > 0 + && stoppers.len() >= config.max_in_flight_operations + { + // Too many in-flight operations. Just send back a validation error. + stream::iter(vec![ + Output::Message(ServerMessage::Error { + id: id.clone(), + payload: GraphQLError::ValidationError(vec![ + RuleError::new("Too many in-flight operations.", &[]), + ]) + .into(), + }), + Output::Message(ServerMessage::Complete { id }), + ]) + .boxed() + } else { + // Create a channel that we can use to cancel the operation. + let (tx, rx) = oneshot::channel::<()>(); + stoppers.insert(id.clone(), tx); + + // Create the operation stream. This stream will emit Next and Error + // messages, but will not emit Complete – that part is up to us. + let s = Self::start( + id.clone(), + ExecutionParams { + subscribe_payload: payload, + config: config.clone(), + schema: schema.clone(), + }, + ) + .into_stream() + .flatten(); + + // Combine this with our oneshot channel so that the stream ends if the + // oneshot is ever fired. + let s = stream::unfold((rx, s.boxed()), |(rx, mut s)| async move { + let next = match future::select(rx, s.next()).await { + Either::Left(_) => None, + Either::Right((r, rx)) => r.map(|r| (r, rx)), + }; + next.map(|(r, rx)| (r, (rx, s))) + }); + + // Once the stream ends, send the Complete message. + let s = s.chain( + Output::Message(ServerMessage::Complete { id }).into_stream(), + ); + + s.boxed() + } + } + } + ClientMessage::Complete { id } => { + stoppers.remove(&id); + stream::empty().boxed() + } + ClientMessage::Ping { .. } => { + stream::iter(vec![Output::Message(ServerMessage::Pong)]).boxed() + } + _ => stream::empty().boxed(), + }; + ( + Self::Active { + config, + stoppers, + schema, + }, + reactions, + ) + } + Self::Terminated => (self, stream::empty().boxed()), + } + } + + async fn start( + id: String, + params: ExecutionParams, + ) -> BoxStream<'static, Output> { + // TODO: This could be made more efficient if `juniper` exposed + // functionality to allow us to parse and validate the query, + // determine whether it's a subscription, and then execute it. + // For now, the query gets parsed and validated twice. + + let params = Arc::new(params); + + // Try to execute this as a query or mutation. + match juniper::execute( + ¶ms.subscribe_payload.query, + params.subscribe_payload.operation_name.as_deref(), + params.schema.root_node(), + ¶ms.subscribe_payload.variables, + ¶ms.config.context, + ) + .await + { + Ok((data, errors)) => { + return Output::Message(ServerMessage::Next { + id: id.clone(), + payload: NextPayload { data, errors }, + }) + .into_stream(); + } + Err(GraphQLError::IsSubscription) => {} + Err(e) => { + return Output::Message(ServerMessage::Error { + id: id.clone(), + payload: ErrorPayload::new(Box::new(params.clone()), e), + }) + .into_stream(); + } + } + + // Try to execute as a subscription. + SubscriptionStart::new(id, params.clone()).boxed() + } +} + +struct InterruptableStream { + stream: S, + rx: oneshot::Receiver<()>, +} + +impl Stream for InterruptableStream { + type Item = S::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match Pin::new(&mut self.rx).poll(cx) { + Poll::Ready(_) => return Poll::Ready(None), + Poll::Pending => {} + } + Pin::new(&mut self.stream).poll_next(cx) + } +} + +/// SubscriptionStartState is the state for a subscription operation. +enum SubscriptionStartState { + /// Init is the start before being polled for the first time. + Init { id: String }, + /// ResolvingIntoStream is the state after being polled for the first time. In this state, + /// we're parsing, validating, and getting the actual event stream. + ResolvingIntoStream { + id: String, + future: BoxFuture< + 'static, + Result, GraphQLError>, + >, + }, + /// Streaming is the state after we've successfully obtained the event stream for the + /// subscription. In this state, we're just forwarding events back to the client. + Streaming { + id: String, + stream: juniper_subscriptions::Connection<'static, S::ScalarValue>, + }, + /// Terminated is the state once we're all done. + Terminated, +} + +/// SubscriptionStart is the stream for a subscription operation. +struct SubscriptionStart { + params: Arc>, + state: SubscriptionStartState, + _marker: PhantomPinned, +} + +impl SubscriptionStart { + fn new(id: String, params: Arc>) -> Pin> { + Box::pin(Self { + params, + state: SubscriptionStartState::Init { id }, + _marker: PhantomPinned, + }) + } +} + +impl Stream for SubscriptionStart { + type Item = Output; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let (params, state) = unsafe { + // XXX: The execution parameters are referenced by state and must not be modified. + // Modifying state is fine though. + let inner = self.get_unchecked_mut(); + (&inner.params, &mut inner.state) + }; + + loop { + match state { + SubscriptionStartState::Init { id } => { + // XXX: resolve_into_stream returns a Future that references the execution + // parameters, and the returned stream also references them. We can guarantee + // that everything has the same lifetime in this self-referential struct. + let params = Arc::as_ptr(params); + *state = SubscriptionStartState::ResolvingIntoStream { + id: id.clone(), + future: unsafe { + juniper::resolve_into_stream( + &(*params).subscribe_payload.query, + (*params).subscribe_payload.operation_name.as_deref(), + (*params).schema.root_node(), + &(*params).subscribe_payload.variables, + &(*params).config.context, + ) + } + .map_ok(|(stream, errors)| { + juniper_subscriptions::Connection::from_stream(stream, errors) + }) + .boxed(), + }; + } + SubscriptionStartState::ResolvingIntoStream { + ref id, + ref mut future, + } => match future.as_mut().poll(cx) { + Poll::Ready(r) => match r { + Ok(stream) => { + *state = SubscriptionStartState::Streaming { + id: id.clone(), + stream, + } + } + Err(e) => { + return Poll::Ready(Some(Output::Message(ServerMessage::Error { + id: id.clone(), + payload: ErrorPayload::new(Box::new(params.clone()), e), + }))); + } + }, + Poll::Pending => return Poll::Pending, + }, + SubscriptionStartState::Streaming { + ref id, + ref mut stream, + } => match Pin::new(stream).poll_next(cx) { + Poll::Ready(Some(output)) => { + return Poll::Ready(Some(Output::Message(ServerMessage::Next { + id: id.clone(), + payload: NextPayload { + data: output.data, + errors: output.errors, + }, + }))); + } + Poll::Ready(None) => { + *state = SubscriptionStartState::Terminated; + return Poll::Ready(None); + } + Poll::Pending => return Poll::Pending, + }, + SubscriptionStartState::Terminated => return Poll::Ready(None), + } + } + } +} + +enum ConnectionSinkState> { + Ready { + state: ConnectionState, + }, + HandlingMessage { + #[allow(clippy::type_complexity)] + result: BoxFuture< + 'static, + ( + ConnectionState, + BoxStream<'static, Output>, + ), + >, + }, + Closed, +} + +/// Implements the graphql-ws protocol. This is a sink for `TryInto` and a stream of +/// `ServerMessage`. +pub struct Connection> { + reactions: SelectAll>>, + stream_waker: Option, + stream_terminated: bool, + sink_state: ConnectionSinkState, +} + +impl Connection +where + S: Schema, + I: Init, +{ + /// Creates a new connection, which is a sink for `TryInto` and a stream of `ServerMessage`. + /// + /// The `schema` argument should typically be an `Arc>`. + /// + /// The `init` argument is used to provide the context and additional configuration for + /// connections. This can be a `ConnectionConfig` if the context and configuration are already + /// known, or it can be a closure that gets executed asynchronously when the client sends the + /// ConnectionInit message. Using a closure allows you to perform authentication based on the + /// parameters provided by the client. + pub fn new(schema: S, init: I) -> Self { + Self { + reactions: SelectAll::new(), + stream_waker: None, + stream_terminated: false, + sink_state: ConnectionSinkState::Ready { + state: ConnectionState::PreInit { init, schema }, + }, + } + } +} + +impl Sink for Connection +where + T: TryInto>, + T::Error: Error, + S: Schema, + I: Init + Send, +{ + type Error = Infallible; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match &mut self.sink_state { + ConnectionSinkState::Ready { .. } => Poll::Ready(Ok(())), + ConnectionSinkState::HandlingMessage { ref mut result } => { + match Pin::new(result).poll(cx) { + Poll::Ready((state, reactions)) => { + self.reactions.push(reactions); + self.sink_state = ConnectionSinkState::Ready { state }; + Poll::Ready(Ok(())) + } + Poll::Pending => Poll::Pending, + } + } + ConnectionSinkState::Closed => panic!("poll_ready called after close"), + } + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + let s = self.get_mut(); + let state = &mut s.sink_state; + *state = match std::mem::replace(state, ConnectionSinkState::Closed) { + ConnectionSinkState::Ready { state } => { + match item.try_into() { + Ok(msg) => ConnectionSinkState::HandlingMessage { + result: state.handle_message(msg).boxed(), + }, + Err(e) => { + // If we weren't able to parse the message, we must close the connection. + s.reactions.push( + Output::Close { + code: 4400, + message: e.to_string(), + } + .into_stream(), + ); + ConnectionSinkState::Closed + } + } + } + _ => panic!("start_send called when not ready"), + }; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + >::poll_ready(self, cx) + } + + fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + self.sink_state = ConnectionSinkState::Closed; + if let Some(waker) = self.stream_waker.take() { + // Wake up the stream so it can close too. + waker.wake(); + } + Poll::Ready(Ok(())) + } +} + +impl Stream for Connection +where + S: Schema, + I: Init, +{ + type Item = Output; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.stream_waker = Some(cx.waker().clone()); + + if let ConnectionSinkState::Closed = self.sink_state { + return Poll::Ready(None); + } else if self.stream_terminated { + return Poll::Ready(None); + } + + // Poll the reactions for new outgoing messages. + if !self.reactions.is_empty() { + match Pin::new(&mut self.reactions).poll_next(cx) { + Poll::Ready(Some(Output::Close { code, message })) => { + self.stream_terminated = true; + return Poll::Ready(Some(Output::Close { code, message })); + } + Poll::Ready(Some(reaction)) => return Poll::Ready(Some(reaction)), + Poll::Ready(None) => { + // In rare cases, the reaction stream may terminate. For example, this will + // happen if the first message we receive does not require any reaction. Just + // recreate it in that case. + self.reactions = SelectAll::new(); + } + _ => (), + } + } + Poll::Pending + } +} + +#[cfg(test)] +mod test { + use std::{convert::Infallible, io}; + + use juniper::{ + futures::sink::SinkExt, + graphql_input_value, graphql_object, graphql_subscription, graphql_value, graphql_vars, + parser::{ParseError, Spanning}, + DefaultScalarValue, EmptyMutation, FieldError, FieldResult, RootNode, + }; + + use super::*; + + struct Context(i32); + + impl juniper::Context for Context {} + + struct Query; + + #[graphql_object(context = Context)] + impl Query { + /// context just resolves to the current context. + async fn context(context: &Context) -> i32 { + context.0 + } + } + + struct Subscription; + + #[graphql_subscription(context = Context)] + impl Subscription { + /// never never emits anything. + async fn never(_context: &Context) -> BoxStream<'static, FieldResult> { + tokio::time::sleep(Duration::from_secs(10000)) + .map(|_| unreachable!()) + .into_stream() + .boxed() + } + + /// context emits the current context once, then never emits anything else. + async fn context(context: &Context) -> BoxStream<'static, FieldResult> { + stream::once(future::ready(Ok(context.0))) + .chain( + tokio::time::sleep(Duration::from_secs(10000)) + .map(|_| unreachable!()) + .into_stream(), + ) + .boxed() + } + + /// error emits an error once, then never emits anything else. + async fn error(_context: &Context) -> BoxStream<'static, FieldResult> { + stream::once(future::ready(Err(FieldError::new( + "field error", + graphql_value!(null), + )))) + .chain( + tokio::time::sleep(Duration::from_secs(10000)) + .map(|_| unreachable!()) + .into_stream(), + ) + .boxed() + } + } + + type ClientMessage = super::ClientMessage; + type ServerMessage = super::ServerMessage; + + fn new_test_schema() -> Arc, Subscription>> { + Arc::new(RootNode::new(Query, EmptyMutation::new(), Subscription)) + } + + #[tokio::test] + async fn test_query() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: graphql_vars! {}, + }) + .await + .unwrap(); + + assert_eq!( + Output::Message(ServerMessage::ConnectionAck), + conn.next().await.unwrap() + ); + + conn.send(ClientMessage::Subscribe { + id: "foo".into(), + payload: SubscribePayload { + query: "{context}".into(), + variables: graphql_vars! {}, + operation_name: None, + extensions: Default::default(), + }, + }) + .await + .unwrap(); + + assert_eq!( + Output::Message(ServerMessage::Next { + id: "foo".into(), + payload: NextPayload { + data: graphql_value!({"context": 1}), + errors: vec![], + }, + }), + conn.next().await.unwrap() + ); + + assert_eq!( + Output::Message(ServerMessage::Complete { id: "foo".into() }), + conn.next().await.unwrap() + ); + } + + #[tokio::test] + async fn test_premature_query() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::Subscribe { + id: "foo".into(), + payload: SubscribePayload { + query: "{context}".into(), + variables: graphql_vars! {}, + operation_name: None, + extensions: Default::default(), + }, + }) + .await + .unwrap(); + + assert_eq!( + Output::Close { + code: 4401, + message: "Unauthorized".into(), + }, + conn.next().await.unwrap() + ); + + assert_eq!(None, conn.next().await); + } + + #[tokio::test] + async fn test_subscriptions() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: graphql_vars! {}, + }) + .await + .unwrap(); + + assert_eq!( + Output::Message(ServerMessage::ConnectionAck), + conn.next().await.unwrap() + ); + + conn.send(ClientMessage::Subscribe { + id: "foo".into(), + payload: SubscribePayload { + query: "subscription Foo {context}".into(), + variables: graphql_vars! {}, + operation_name: None, + extensions: Default::default(), + }, + }) + .await + .unwrap(); + + assert_eq!( + Output::Message(ServerMessage::Next { + id: "foo".into(), + payload: NextPayload { + data: graphql_value!({"context": 1}), + errors: vec![], + }, + }), + conn.next().await.unwrap() + ); + + conn.send(ClientMessage::Subscribe { + id: "bar".into(), + payload: SubscribePayload { + query: "subscription Bar {context}".into(), + variables: graphql_vars! {}, + operation_name: None, + extensions: Default::default(), + }, + }) + .await + .unwrap(); + + assert_eq!( + Output::Message(ServerMessage::Next { + id: "bar".into(), + payload: NextPayload { + data: graphql_value!({"context": 1}), + errors: vec![], + }, + }), + conn.next().await.unwrap() + ); + + conn.send(ClientMessage::Complete { id: "foo".into() }) + .await + .unwrap(); + + assert_eq!( + Output::Message(ServerMessage::Complete { id: "foo".into() }), + conn.next().await.unwrap() + ); + } + + #[tokio::test] + async fn test_init_params_ok() { + let mut conn = Connection::new(new_test_schema(), |params: Variables| async move { + assert_eq!(params.get("foo"), Some(&graphql_input_value!("bar"))); + Ok(ConnectionConfig::new(Context(1))) as Result<_, Infallible> + }); + + conn.send(ClientMessage::ConnectionInit { + payload: graphql_vars! {"foo": "bar"}, + }) + .await + .unwrap(); + + assert_eq!( + Output::Message(ServerMessage::ConnectionAck), + conn.next().await.unwrap() + ); + } + + #[tokio::test] + async fn test_init_params_error() { + let mut conn = Connection::new(new_test_schema(), |params: Variables| async move { + assert_eq!(params.get("foo"), Some(&graphql_input_value!("bar"))); + Err(io::Error::new(io::ErrorKind::Other, "init error")) + }); + + conn.send(ClientMessage::ConnectionInit { + payload: graphql_vars! {"foo": "bar"}, + }) + .await + .unwrap(); + + assert_eq!( + Output::Close { + code: 4403, + message: "init error".into(), + }, + conn.next().await.unwrap() + ); + + assert_eq!(None, conn.next().await); + } + + #[tokio::test] + async fn test_max_in_flight_operations() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)) + .with_keep_alive_interval(Duration::from_secs(0)) + .with_max_in_flight_operations(1), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: graphql_vars! {}, + }) + .await + .unwrap(); + + assert_eq!( + Output::Message(ServerMessage::ConnectionAck), + conn.next().await.unwrap() + ); + + conn.send(ClientMessage::Subscribe { + id: "foo".into(), + payload: SubscribePayload { + query: "subscription Foo {never}".into(), + variables: graphql_vars! {}, + operation_name: None, + extensions: Default::default(), + }, + }) + .await + .unwrap(); + + conn.send(ClientMessage::Subscribe { + id: "bar".into(), + payload: SubscribePayload { + query: "subscription Bar {never}".into(), + variables: graphql_vars! {}, + operation_name: None, + extensions: Default::default(), + }, + }) + .await + .unwrap(); + + match conn.next().await.unwrap() { + Output::Message(ServerMessage::Error { id, .. }) => { + assert_eq!(id, "bar"); + } + msg @ _ => panic!("expected error, got: {msg:?}"), + } + } + + #[tokio::test] + async fn test_parse_error() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: graphql_vars! {}, + }) + .await + .unwrap(); + + assert_eq!( + Output::Message(ServerMessage::ConnectionAck), + conn.next().await.unwrap() + ); + + conn.send(ClientMessage::Subscribe { + id: "foo".into(), + payload: SubscribePayload { + query: "asd".into(), + variables: graphql_vars! {}, + operation_name: None, + extensions: Default::default(), + }, + }) + .await + .unwrap(); + + match conn.next().await.unwrap() { + Output::Message(ServerMessage::Error { id, payload }) => { + assert_eq!(id, "foo"); + match payload.graphql_error() { + GraphQLError::ParseError(Spanning { + item: ParseError::UnexpectedToken(token), + .. + }) => assert_eq!(token, "asd"), + p @ _ => panic!("expected graphql parse error, got: {p:?}"), + } + } + msg @ _ => panic!("expected error, got: {msg:?}"), + } + } + + #[tokio::test] + async fn test_keep_alives() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_millis(20)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: graphql_vars! {}, + }) + .await + .unwrap(); + + assert_eq!( + Output::Message(ServerMessage::ConnectionAck), + conn.next().await.unwrap() + ); + + for _ in 0..10 { + assert_eq!( + Output::Message(ServerMessage::Pong), + conn.next().await.unwrap() + ); + } + } + + #[tokio::test] + async fn test_slow_init() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: graphql_vars! {}, + }) + .await + .unwrap(); + + // If we send the start message before the init is handled, we should still get results. + conn.send(ClientMessage::Subscribe { + id: "foo".into(), + payload: SubscribePayload { + query: "{context}".into(), + variables: graphql_vars! {}, + operation_name: None, + extensions: Default::default(), + }, + }) + .await + .unwrap(); + + assert_eq!( + Output::Message(ServerMessage::ConnectionAck), + conn.next().await.unwrap() + ); + + assert_eq!( + Output::Message(ServerMessage::Next { + id: "foo".into(), + payload: NextPayload { + data: graphql_value!({"context": 1}), + errors: vec![], + }, + }), + conn.next().await.unwrap() + ); + } + + #[tokio::test] + async fn test_subscription_field_error() { + let mut conn = Connection::new( + new_test_schema(), + ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)), + ); + + conn.send(ClientMessage::ConnectionInit { + payload: graphql_vars! {}, + }) + .await + .unwrap(); + + assert_eq!( + Output::Message(ServerMessage::ConnectionAck), + conn.next().await.unwrap() + ); + + conn.send(ClientMessage::Subscribe { + id: "foo".into(), + payload: SubscribePayload { + query: "subscription Foo {error}".into(), + variables: graphql_vars! {}, + operation_name: None, + extensions: Default::default(), + }, + }) + .await + .unwrap(); + + match conn.next().await.unwrap() { + Output::Message(ServerMessage::Next { + id, + payload: NextPayload { data, errors }, + }) => { + assert_eq!(id, "foo"); + assert_eq!(data, graphql_value!({ "error": null })); + assert_eq!(errors.len(), 1); + } + msg @ _ => panic!("expected data, got: {msg:?}"), + } + } +} diff --git a/juniper_graphql_transport_ws/src/schema.rs b/juniper_graphql_transport_ws/src/schema.rs new file mode 100644 index 00000000..68d282f0 --- /dev/null +++ b/juniper_graphql_transport_ws/src/schema.rs @@ -0,0 +1,131 @@ +use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue}; +use std::sync::Arc; + +/// Schema defines the requirements for schemas that can be used for operations. Typically this is +/// just an `Arc>` and you should not have to implement it yourself. +pub trait Schema: Unpin + Clone + Send + Sync + 'static { + /// The context type. + type Context: Unpin + Send + Sync; + + /// The scalar value type. + type ScalarValue: ScalarValue + Send + Sync; + + /// The query type info. + type QueryTypeInfo: Send + Sync; + + /// The query type. + type Query: GraphQLTypeAsync + + Send; + + /// The mutation type info. + type MutationTypeInfo: Send + Sync; + + /// The mutation type. + type Mutation: GraphQLTypeAsync< + Self::ScalarValue, + Context = Self::Context, + TypeInfo = Self::MutationTypeInfo, + > + Send; + + /// The subscription type info. + type SubscriptionTypeInfo: Send + Sync; + + /// The subscription type. + type Subscription: GraphQLSubscriptionType< + Self::ScalarValue, + Context = Self::Context, + TypeInfo = Self::SubscriptionTypeInfo, + > + Send; + + /// Returns the root node for the schema. + fn root_node( + &self, + ) -> &RootNode<'static, Self::Query, Self::Mutation, Self::Subscription, Self::ScalarValue>; +} + +/// This exists as a work-around for this issue: https://github.com/rust-lang/rust/issues/64552 +/// +/// It can be used in generators where using Arc directly would result in an error. +// TODO: Remove this once that issue is resolved. +#[doc(hidden)] +pub struct ArcSchema( + pub Arc>, +) +where + QueryT: GraphQLTypeAsync + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync, + S: ScalarValue + Send + Sync + 'static; + +impl Clone + for ArcSchema +where + QueryT: GraphQLTypeAsync + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync, + S: ScalarValue + Send + Sync + 'static, +{ + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Schema + for ArcSchema +where + QueryT: GraphQLTypeAsync + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, +{ + type Context = CtxT; + type ScalarValue = S; + type QueryTypeInfo = QueryT::TypeInfo; + type Query = QueryT; + type MutationTypeInfo = MutationT::TypeInfo; + type Mutation = MutationT; + type SubscriptionTypeInfo = SubscriptionT::TypeInfo; + type Subscription = SubscriptionT; + + fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> { + &self.0 + } +} + +impl Schema + for Arc> +where + QueryT: GraphQLTypeAsync + Send + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync + Send + 'static, + MutationT::TypeInfo: Send + Sync, + SubscriptionT: GraphQLSubscriptionType + Send + 'static, + SubscriptionT::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync, + S: ScalarValue + Send + Sync + 'static, +{ + type Context = CtxT; + type ScalarValue = S; + type QueryTypeInfo = QueryT::TypeInfo; + type Query = QueryT; + type MutationTypeInfo = MutationT::TypeInfo; + type Mutation = MutationT; + type SubscriptionTypeInfo = SubscriptionT::TypeInfo; + type Subscription = SubscriptionT; + + fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> { + self + } +} diff --git a/juniper_graphql_transport_ws/src/server_message.rs b/juniper_graphql_transport_ws/src/server_message.rs new file mode 100644 index 00000000..9a16b17a --- /dev/null +++ b/juniper_graphql_transport_ws/src/server_message.rs @@ -0,0 +1,158 @@ +use std::{any::Any, fmt, marker::PhantomPinned}; + +use juniper::{ExecutionError, GraphQLError, Value}; +use serde::{Serialize, Serializer}; + +/// Sent after execution of an operation. For queries and mutations, this is sent to the client +/// once. For subscriptions, this is sent for every event in the event stream. +#[derive(Debug, Serialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct NextPayload { + /// The result data. + pub data: Value, + + /// The errors that have occurred during execution. Note that parse and validation errors are + /// not included here. They are sent via Error messages. + #[serde(skip_serializing_if = "Vec::is_empty")] + pub errors: Vec>, +} + +/// A payload for errors that can happen before execution. Errors that happen during execution are +/// instead sent to the client via `NextPayload`. `ErrorPayload` is a wrapper for an owned +/// `GraphQLError`. +// XXX: Think carefully before deriving traits. This is self-referential (error references +// _execution_params). +pub struct ErrorPayload { + _execution_params: Option>, + error: GraphQLError, + _marker: PhantomPinned, +} + +impl ErrorPayload { + /// Creates a new [`ErrorPayload`] out of the provide `execution_params` and + /// [`GraphQLError`]. + pub(crate) fn new(execution_params: Box, error: GraphQLError) -> Self { + Self { + _execution_params: Some(execution_params), + error, + _marker: PhantomPinned, + } + } + + /// Returns the contained GraphQLError. + pub fn graphql_error(&self) -> &GraphQLError { + &self.error + } +} + +impl fmt::Debug for ErrorPayload { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.error.fmt(f) + } +} + +impl PartialEq for ErrorPayload { + fn eq(&self, other: &Self) -> bool { + self.error.eq(&other.error) + } +} + +impl Serialize for ErrorPayload { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.error.serialize(serializer) + } +} + +impl From for ErrorPayload { + fn from(error: GraphQLError) -> Self { + Self { + _execution_params: None, + error, + _marker: PhantomPinned, + } + } +} + +/// ServerMessage defines the message types that servers can send. +#[derive(Debug, Serialize, PartialEq)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "type")] +pub enum ServerMessage { + /// ConnectionAck is sent in response to a client's ConnectionInit message if the server accepted a + /// connection. + ConnectionAck, + /// The response to the `Ping` message. + Pong, + /// Data contains the result of a query, mutation, or subscription event. + Next { + /// The id of the operation that the data is for. + id: String, + + /// The data and errors that occurred during execution. + payload: NextPayload, + }, + /// Error contains an error that occurs before execution, such as validation errors. + Error { + /// The id of the operation that triggered this error. + id: String, + + /// The error(s). + payload: ErrorPayload, + }, + /// Complete indicates that no more data will be sent for the given operation. + Complete { + /// The id of the operation that has completed. + id: String, + }, +} + +#[cfg(test)] +mod test { + use juniper::{graphql_value, DefaultScalarValue}; + + use super::*; + + #[test] + fn test_serialization() { + type ServerMessage = super::ServerMessage; + + assert_eq!( + serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(), + r##"{"type":"connection_ack"}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::Pong).unwrap(), + r##"{"type":"pong"}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::Next { + id: "foo".into(), + payload: NextPayload { + data: graphql_value!(null), + errors: vec![], + }, + }) + .unwrap(), + r##"{"type":"next","id":"foo","payload":{"data":null}}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::Error { + id: "foo".into(), + payload: GraphQLError::UnknownOperationName.into(), + }) + .unwrap(), + r##"{"type":"error","id":"foo","payload":[{"message":"Unknown operation"}]}"##, + ); + + assert_eq!( + serde_json::to_string(&ServerMessage::Complete { id: "foo".into() }).unwrap(), + r##"{"type":"complete","id":"foo"}"##, + ); + } +} diff --git a/juniper_graphql_transport_ws/src/utils.rs b/juniper_graphql_transport_ws/src/utils.rs new file mode 100644 index 00000000..75106a4c --- /dev/null +++ b/juniper_graphql_transport_ws/src/utils.rs @@ -0,0 +1,9 @@ +use serde::{Deserialize, Deserializer}; + +pub(crate) fn default_for_null<'de, D, T>(deserializer: D) -> Result +where + D: Deserializer<'de>, + T: Deserialize<'de> + Default, +{ + Ok(Option::::deserialize(deserializer)?.unwrap_or_default()) +} diff --git a/juniper_graphql_ws/README.md b/juniper_graphql_ws/README.md index 90175a53..ad39b0e1 100644 --- a/juniper_graphql_ws/README.md +++ b/juniper_graphql_ws/README.md @@ -8,7 +8,7 @@ - [Changelog](https://github.com/graphql-rust/juniper/blob/master/juniper_graphql_ws/CHANGELOG.md) -This crate contains an implementation of the [GraphQL over WebSocket Protocol][1], as used by [Apollo]. +This crate contains an implementation of the [graphql-ws WebSocket subprotocol], as formerly used by [Apollo]. It has now been deprecated in favor of the protocol implemented by the [`juniper_graphql_transport_ws` crate]. @@ -21,5 +21,5 @@ This project is licensed under [BSD 2-Clause License](https://github.com/graphql [Apollo]: https://www.apollographql.com - -[1]: https://github.com/apollographql/subscriptions-transport-ws/blob/0ce7a1e1eb687fe51214483e4735f50a2f2d5c79/PROTOCOL.md +[graphql-ws WebSocket subprotocol]: https://github.com/apollographql/subscriptions-transport-ws/blob/0ce7a1e1eb687fe51214483e4735f50a2f2d5c79/PROTOCOL.md +[`juniper_graphql_transport_ws` crate]: https://docs.rs/juniper_graphql_transport_ws diff --git a/juniper_subscriptions/release.toml b/juniper_subscriptions/release.toml index a12a9065..e7e23359 100644 --- a/juniper_subscriptions/release.toml +++ b/juniper_subscriptions/release.toml @@ -10,6 +10,12 @@ exactly = 1 search = "juniper_subscriptions = \\{ version = \"[^\"]+\"" replace = "juniper_subscriptions = { version = \"{{version}}\"" +[[pre-release-replacements]] +file = "../juniper_graphql_transport_ws/Cargo.toml" +exactly = 1 +search = "juniper_subscriptions = \\{ version = \"[^\"]+\"" +replace = "juniper_subscriptions = { version = \"{{version}}\"" + [[pre-release-replacements]] file = "CHANGELOG.md" max = 1 diff --git a/juniper_warp/Cargo.toml b/juniper_warp/Cargo.toml index 3fbe45fe..a4620dbe 100644 --- a/juniper_warp/Cargo.toml +++ b/juniper_warp/Cargo.toml @@ -19,13 +19,14 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [features] -subscriptions = ["dep:juniper_graphql_ws", "warp/websocket"] +subscriptions = ["dep:juniper_graphql_ws", "dep:juniper_graphql_transport_ws", "warp/websocket"] [dependencies] anyhow = "1.0.47" futures = "0.3.22" juniper = { version = "0.16.0-dev", path = "../juniper", default-features = false } juniper_graphql_ws = { version = "0.4.0-dev", path = "../juniper_graphql_ws", optional = true } +juniper_graphql_transport_ws = { version = "0.4.0-dev", path = "../juniper_graphql_transport_ws", optional = true } serde = { version = "1.0.122", features = ["derive"] } serde_json = "1.0.18" thiserror = "1.0" diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs index 7b524d2a..d29be189 100644 --- a/juniper_warp/src/lib.rs +++ b/juniper_warp/src/lib.rs @@ -355,11 +355,20 @@ pub mod subscriptions { }, GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue, }; - use juniper_graphql_ws::{ArcSchema, ClientMessage, Connection, Init}; + use juniper_graphql_transport_ws; + use juniper_graphql_ws; struct Message(warp::ws::Message); - impl TryFrom for ClientMessage { + impl TryFrom for juniper_graphql_ws::ClientMessage { + type Error = serde_json::Error; + + fn try_from(msg: Message) -> serde_json::Result { + serde_json::from_slice(msg.0.as_bytes()) + } + } + + impl TryFrom for juniper_graphql_transport_ws::ClientMessage { type Error = serde_json::Error; fn try_from(msg: Message) -> serde_json::Result { @@ -408,6 +417,9 @@ pub mod subscriptions { /// configuration are already known, or it can be a closure that gets executed asynchronously /// when the client sends the ConnectionInit message. Using a closure allows you to perform /// authentication based on the parameters provided by the client. + /// + /// This protocol has been deprecated in favor of the `graphql-transport-ws` protocol, which is + /// provided by the `serve_graphql_transport_ws` function. pub async fn serve_graphql_ws( websocket: warp::ws::WebSocket, root_node: Arc>, @@ -422,10 +434,12 @@ pub mod subscriptions { Subscription::TypeInfo: Send + Sync, CtxT: Unpin + Send + Sync + 'static, S: ScalarValue + Send + Sync + 'static, - I: Init + Send, + I: juniper_graphql_ws::Init + Send, { let (ws_tx, ws_rx) = websocket.split(); - let (s_tx, s_rx) = Connection::new(ArcSchema(root_node), init).split(); + let (s_tx, s_rx) = + juniper_graphql_ws::Connection::new(juniper_graphql_ws::ArcSchema(root_node), init) + .split(); let ws_rx = ws_rx.map(|r| r.map(Message)); let s_rx = s_rx.map(|msg| { @@ -444,6 +458,57 @@ pub mod subscriptions { Either::Right((r, _)) => r, } } + + /// Serves the graphql-transport-ws protocol over a WebSocket connection. + /// + /// The `init` argument is used to provide the context and additional configuration for + /// connections. This can be a `juniper_graphql_transport_ws::ConnectionConfig` if the context and + /// configuration are already known, or it can be a closure that gets executed asynchronously + /// when the client sends the ConnectionInit message. Using a closure allows you to perform + /// authentication based on the parameters provided by the client. + pub async fn serve_graphql_transport_ws( + websocket: warp::ws::WebSocket, + root_node: Arc>, + init: I, + ) -> Result<(), Error> + where + Query: GraphQLTypeAsync + Send + 'static, + Query::TypeInfo: Send + Sync, + Mutation: GraphQLTypeAsync + Send + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: GraphQLSubscriptionType + Send + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, + I: juniper_graphql_transport_ws::Init + Send, + { + let (ws_tx, ws_rx) = websocket.split(); + let (s_tx, s_rx) = juniper_graphql_transport_ws::Connection::new( + juniper_graphql_transport_ws::ArcSchema(root_node), + init, + ) + .split(); + + let ws_rx = ws_rx.map(|r| r.map(Message)); + let s_rx = s_rx.map(|output| match output { + juniper_graphql_transport_ws::Output::Message(msg) => serde_json::to_string(&msg) + .map(warp::ws::Message::text) + .map_err(Error::Serde), + juniper_graphql_transport_ws::Output::Close { code, message } => { + Ok(warp::ws::Message::close_with(code, message)) + } + }); + + match future::select( + ws_rx.forward(s_tx.sink_err_into()), + s_rx.forward(ws_tx.sink_err_into()), + ) + .await + { + Either::Left((r, _)) => r.map_err(|e| e.into()), + Either::Right((r, _)) => r, + } + } } #[cfg(test)]