From 84c9720b535c37dfc1d8bc6b142bf2f63e5fd166 Mon Sep 17 00:00:00 2001
From: Chris <ccbrown112@gmail.com>
Date: Wed, 29 Jul 2020 04:23:44 -0400
Subject: [PATCH] GraphQL-WS crate and Warp subscriptions update (#721)

* update pre-existing juniper_warp::subscriptions

* initial draft

* finish up, update example

* polish + timing test

* fix pre-existing bug

* rebase updates

* address comments

* add release.toml

* makefile and initial changelog

* add new Cargo.toml to juniper/release.toml
---
 Cargo.toml                               |    1 +
 examples/warp_subscriptions/Cargo.toml   |    6 +-
 examples/warp_subscriptions/src/main.rs  |   38 +-
 juniper/release.toml                     |    2 +
 juniper_graphql_ws/CHANGELOG.md          |    3 +
 juniper_graphql_ws/Cargo.toml            |   19 +
 juniper_graphql_ws/Makefile.toml         |   20 +
 juniper_graphql_ws/release.toml          |    8 +
 juniper_graphql_ws/src/client_message.rs |  131 +++
 juniper_graphql_ws/src/lib.rs            | 1073 ++++++++++++++++++++++
 juniper_graphql_ws/src/schema.rs         |  131 +++
 juniper_graphql_ws/src/server_message.rs |  191 ++++
 juniper_subscriptions/src/lib.rs         |   16 +-
 juniper_warp/Cargo.toml                  |    4 +-
 juniper_warp/src/lib.rs                  |  291 ++----
 15 files changed, 1696 insertions(+), 238 deletions(-)
 create mode 100644 juniper_graphql_ws/CHANGELOG.md
 create mode 100644 juniper_graphql_ws/Cargo.toml
 create mode 100644 juniper_graphql_ws/Makefile.toml
 create mode 100644 juniper_graphql_ws/release.toml
 create mode 100644 juniper_graphql_ws/src/client_message.rs
 create mode 100644 juniper_graphql_ws/src/lib.rs
 create mode 100644 juniper_graphql_ws/src/schema.rs
 create mode 100644 juniper_graphql_ws/src/server_message.rs

diff --git a/Cargo.toml b/Cargo.toml
index 79429a10..d37670be 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -15,6 +15,7 @@ members = [
   "juniper_rocket",
   "juniper_rocket_async",
   "juniper_subscriptions",
+  "juniper_graphql_ws",
   "juniper_warp",
   "juniper_actix",
 ]
diff --git a/examples/warp_subscriptions/Cargo.toml b/examples/warp_subscriptions/Cargo.toml
index 7fc8fb4a..5c69129e 100644
--- a/examples/warp_subscriptions/Cargo.toml
+++ b/examples/warp_subscriptions/Cargo.toml
@@ -13,6 +13,6 @@ serde_json = "1.0"
 tokio = { version = "0.2", features = ["rt-core", "macros"] }
 warp = "0.2.1"
 
-juniper = { git = "https://github.com/graphql-rust/juniper" }
-juniper_subscriptions = { git = "https://github.com/graphql-rust/juniper" }
-juniper_warp = { git = "https://github.com/graphql-rust/juniper", features = ["subscriptions"] }
+juniper = { path = "../../juniper" }
+juniper_graphql_ws = { path = "../../juniper_graphql_ws" }
+juniper_warp = { path = "../../juniper_warp", features = ["subscriptions"] }
diff --git a/examples/warp_subscriptions/src/main.rs b/examples/warp_subscriptions/src/main.rs
index f0f9f737..0d4f31a6 100644
--- a/examples/warp_subscriptions/src/main.rs
+++ b/examples/warp_subscriptions/src/main.rs
@@ -2,10 +2,10 @@
 
 use std::{env, pin::Pin, sync::Arc, time::Duration};
 
-use futures::{Future, FutureExt as _, Stream};
+use futures::{FutureExt as _, Stream};
 use juniper::{DefaultScalarValue, EmptyMutation, FieldError, RootNode};
-use juniper_subscriptions::Coordinator;
-use juniper_warp::{playground_filter, subscriptions::graphql_subscriptions};
+use juniper_graphql_ws::ConnectionConfig;
+use juniper_warp::{playground_filter, subscriptions::serve_graphql_ws};
 use warp::{http::Response, Filter};
 
 #[derive(Clone)]
@@ -151,30 +151,24 @@ async fn main() {
     let qm_state = warp::any().map(move || Context {});
     let qm_graphql_filter = juniper_warp::make_graphql_filter(qm_schema, qm_state.boxed());
 
-    let sub_state = warp::any().map(move || Context {});
-    let coordinator = Arc::new(juniper_subscriptions::Coordinator::new(schema()));
+    let root_node = Arc::new(schema());
 
     log::info!("Listening on 127.0.0.1:8080");
 
     let routes = (warp::path("subscriptions")
         .and(warp::ws())
-        .and(sub_state.clone())
-        .and(warp::any().map(move || Arc::clone(&coordinator)))
-        .map(
-            |ws: warp::ws::Ws,
-             ctx: Context,
-             coordinator: Arc<Coordinator<'static, _, _, _, _, _>>| {
-                ws.on_upgrade(|websocket| -> Pin<Box<dyn Future<Output = ()> + Send>> {
-                    graphql_subscriptions(websocket, coordinator, ctx)
-                        .map(|r| {
-                            if let Err(e) = r {
-                                println!("Websocket error: {}", e);
-                            }
-                        })
-                        .boxed()
-                })
-            },
-        ))
+        .map(move |ws: warp::ws::Ws| {
+            let root_node = root_node.clone();
+            ws.on_upgrade(move |websocket| async move {
+                serve_graphql_ws(websocket, root_node, ConnectionConfig::new(Context {}))
+                    .map(|r| {
+                        if let Err(e) = r {
+                            println!("Websocket error: {}", e);
+                        }
+                    })
+                    .await
+            })
+        }))
     .map(|reply| {
         // TODO#584: remove this workaround
         warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws")
diff --git a/juniper/release.toml b/juniper/release.toml
index 72391149..ab15f4e5 100644
--- a/juniper/release.toml
+++ b/juniper/release.toml
@@ -30,6 +30,8 @@ pre-release-replacements = [
   {file="../juniper_warp/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""},
   # Subscriptions
   {file="../juniper_subscriptions/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
+  # GraphQL-WS
+  {file="../juniper_graphql_ws/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
   # Actix-Web
   {file="../juniper_actix/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
   {file="../juniper_actix/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""},
diff --git a/juniper_graphql_ws/CHANGELOG.md b/juniper_graphql_ws/CHANGELOG.md
new file mode 100644
index 00000000..05232472
--- /dev/null
+++ b/juniper_graphql_ws/CHANGELOG.md
@@ -0,0 +1,3 @@
+# master
+
+- Initial Release
diff --git a/juniper_graphql_ws/Cargo.toml b/juniper_graphql_ws/Cargo.toml
new file mode 100644
index 00000000..8bf19ee7
--- /dev/null
+++ b/juniper_graphql_ws/Cargo.toml
@@ -0,0 +1,19 @@
+[package]
+name = "juniper_graphql_ws"
+version = "0.1.0"
+authors = ["Christopher Brown <ccbrown112@gmail.com>"]
+license = "BSD-2-Clause"
+description = "Graphql-ws protocol implementation for Juniper"
+documentation = "https://docs.rs/juniper_graphql_ws"
+repository = "https://github.com/graphql-rust/juniper"
+keywords = ["graphql-ws", "juniper", "graphql", "apollo"]
+edition = "2018"
+
+[dependencies]
+juniper = { version = "0.14.2", path = "../juniper", default-features = false }
+juniper_subscriptions = { path = "../juniper_subscriptions" }
+serde = { version = "1.0.8", features = ["derive"] }
+tokio = { version = "0.2", features = ["macros", "rt-core", "time"] }
+
+[dev-dependencies]
+serde_json = { version = "1.0.2" }
diff --git a/juniper_graphql_ws/Makefile.toml b/juniper_graphql_ws/Makefile.toml
new file mode 100644
index 00000000..ba858470
--- /dev/null
+++ b/juniper_graphql_ws/Makefile.toml
@@ -0,0 +1,20 @@
+[env]
+CARGO_MAKE_CARGO_ALL_FEATURES = ""
+
+[tasks.build-verbose]
+condition = { rust_version = { min = "1.29.0" } }
+
+[tasks.build-verbose.windows]
+condition = { rust_version = { min = "1.29.0" }, env = { "TARGET" = "x86_64-pc-windows-msvc" } }
+
+[tasks.test-verbose]
+condition = { rust_version = { min = "1.29.0" } }
+
+[tasks.test-verbose.windows]
+condition = { rust_version = { min = "1.29.0" }, env = { "TARGET" = "x86_64-pc-windows-msvc" } }
+
+[tasks.ci-coverage-flow]
+condition = { rust_version = { min = "1.29.0" } }
+
+[tasks.ci-coverage-flow.windows]
+disabled = true
diff --git a/juniper_graphql_ws/release.toml b/juniper_graphql_ws/release.toml
new file mode 100644
index 00000000..98e70594
--- /dev/null
+++ b/juniper_graphql_ws/release.toml
@@ -0,0 +1,8 @@
+no-dev-version = true
+pre-release-commit-message = "Release {{crate_name}} {{version}}"
+pro-release-commit-message = "Bump {{crate_name}} version to {{next_version}}"
+tag-message = "Release {{crate_name}} {{version}}"
+upload-doc = false
+pre-release-replacements = [
+  {file="src/lib.rs", search="docs.rs/juniper_graphql_ws/[a-z0-9\\.-]+", replace="docs.rs/juniper_graphql_ws/{{version}}"},
+]
diff --git a/juniper_graphql_ws/src/client_message.rs b/juniper_graphql_ws/src/client_message.rs
new file mode 100644
index 00000000..1e20caef
--- /dev/null
+++ b/juniper_graphql_ws/src/client_message.rs
@@ -0,0 +1,131 @@
+use juniper::{ScalarValue, Variables};
+
+/// The payload for a client's "start" message. This triggers execution of a query, mutation, or
+/// subscription.
+#[derive(Debug, Deserialize, PartialEq)]
+#[serde(bound(deserialize = "S: ScalarValue"))]
+#[serde(rename_all = "camelCase")]
+pub struct StartPayload<S: ScalarValue> {
+    /// The document body.
+    pub query: String,
+
+    /// The optional variables.
+    #[serde(default)]
+    pub variables: Variables<S>,
+
+    /// The optional operation name (required if the document contains multiple operations).
+    pub operation_name: Option<String>,
+}
+
+/// ClientMessage defines the message types that clients can send.
+#[derive(Debug, Deserialize, PartialEq)]
+#[serde(bound(deserialize = "S: ScalarValue"))]
+#[serde(rename_all = "snake_case")]
+#[serde(tag = "type")]
+pub enum ClientMessage<S: ScalarValue> {
+    /// ConnectionInit is sent by the client upon connecting.
+    ConnectionInit {
+        /// Optional parameters of any type sent from the client. These are often used for
+        /// authentication.
+        #[serde(default)]
+        payload: Variables<S>,
+    },
+    /// Start messages are used to execute a GraphQL operation.
+    Start {
+        /// The id of the operation. This can be anything, but must be unique. If there are other
+        /// in-flight operations with the same id, the message will be ignored or cause an error.
+        id: String,
+
+        /// The query, variables, and operation name.
+        payload: StartPayload<S>,
+    },
+    /// Stop messages are used to unsubscribe from a subscription.
+    Stop {
+        /// The id of the operation to stop.
+        id: String,
+    },
+    /// ConnectionTerminate is used to terminate the connection.
+    ConnectionTerminate,
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+    use juniper::{DefaultScalarValue, InputValue};
+
+    #[test]
+    fn test_deserialization() {
+        type ClientMessage = super::ClientMessage<DefaultScalarValue>;
+
+        assert_eq!(
+            ClientMessage::ConnectionInit {
+                payload: [("foo".to_string(), InputValue::scalar("bar"))]
+                    .iter()
+                    .cloned()
+                    .collect(),
+            },
+            serde_json::from_str(r##"{"type": "connection_init", "payload": {"foo": "bar"}}"##)
+                .unwrap(),
+        );
+
+        assert_eq!(
+            ClientMessage::ConnectionInit {
+                payload: Variables::default(),
+            },
+            serde_json::from_str(r##"{"type": "connection_init"}"##).unwrap(),
+        );
+
+        assert_eq!(
+            ClientMessage::Start {
+                id: "foo".to_string(),
+                payload: StartPayload {
+                    query: "query MyQuery { __typename }".to_string(),
+                    variables: [("foo".to_string(), InputValue::scalar("bar"))]
+                        .iter()
+                        .cloned()
+                        .collect(),
+                    operation_name: Some("MyQuery".to_string()),
+                },
+            },
+            serde_json::from_str(
+                r##"{"type": "start", "id": "foo", "payload": {
+                "query": "query MyQuery { __typename }",
+                "variables": {
+                    "foo": "bar"
+                },
+                "operationName": "MyQuery"
+            }}"##
+            )
+            .unwrap(),
+        );
+
+        assert_eq!(
+            ClientMessage::Start {
+                id: "foo".to_string(),
+                payload: StartPayload {
+                    query: "query MyQuery { __typename }".to_string(),
+                    variables: Variables::default(),
+                    operation_name: None,
+                },
+            },
+            serde_json::from_str(
+                r##"{"type": "start", "id": "foo", "payload": {
+                "query": "query MyQuery { __typename }"
+            }}"##
+            )
+            .unwrap(),
+        );
+
+        assert_eq!(
+            ClientMessage::Stop {
+                id: "foo".to_string()
+            },
+            serde_json::from_str(r##"{"type": "stop", "id": "foo"}"##).unwrap(),
+        );
+
+        assert_eq!(
+            ClientMessage::ConnectionTerminate,
+            serde_json::from_str(r##"{"type": "connection_terminate"}"##).unwrap(),
+        );
+    }
+}
diff --git a/juniper_graphql_ws/src/lib.rs b/juniper_graphql_ws/src/lib.rs
new file mode 100644
index 00000000..7cf73eec
--- /dev/null
+++ b/juniper_graphql_ws/src/lib.rs
@@ -0,0 +1,1073 @@
+/*!
+
+# juniper_graphql_ws
+
+This crate contains an implementation of the [graphql-ws protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/263844b5c1a850c1e29814564eb62cb587e5eaaf/PROTOCOL.md), as used by Apollo.
+
+*/
+
+#![deny(missing_docs)]
+#![deny(warnings)]
+
+#[macro_use]
+extern crate serde;
+
+mod client_message;
+pub use client_message::*;
+
+mod server_message;
+pub use server_message::*;
+
+mod schema;
+pub use schema::*;
+
+use juniper::{
+    futures::{
+        channel::oneshot,
+        future::{self, BoxFuture, Either, Future, FutureExt, TryFutureExt},
+        stream::{self, BoxStream, SelectAll, StreamExt},
+        task::{Context, Poll, Waker},
+        Sink, Stream,
+    },
+    GraphQLError, RuleError, ScalarValue, Variables,
+};
+use std::{
+    collections::HashMap,
+    convert::{Infallible, TryInto},
+    error::Error,
+    marker::PhantomPinned,
+    pin::Pin,
+    sync::Arc,
+    time::Duration,
+};
+
+struct ExecutionParams<S: Schema> {
+    start_payload: StartPayload<S::ScalarValue>,
+    config: Arc<ConnectionConfig<S::Context>>,
+    schema: S,
+}
+
+/// ConnectionConfig is used to configure the connection once the client sends the ConnectionInit
+/// message.
+pub struct ConnectionConfig<CtxT> {
+    context: CtxT,
+    max_in_flight_operations: usize,
+    keep_alive_interval: Duration,
+}
+
+impl<CtxT> ConnectionConfig<CtxT> {
+    /// Constructs the configuration required for a connection to be accepted.
+    pub fn new(context: CtxT) -> Self {
+        Self {
+            context,
+            max_in_flight_operations: 0,
+            keep_alive_interval: Duration::from_secs(30),
+        }
+    }
+
+    /// Specifies the maximum number of in-flight operations that a connection can have. If this
+    /// number is exceeded, attempting to start more will result in an error. By default, there is
+    /// no limit to in-flight operations.
+    pub fn with_max_in_flight_operations(mut self, max: usize) -> Self {
+        self.max_in_flight_operations = max;
+        self
+    }
+
+    /// Specifies the interval at which to send keep-alives. Specifying a zero duration will
+    /// disable keep-alives. By default, keep-alives are sent every
+    /// 30 seconds.
+    pub fn with_keep_alive_interval(mut self, interval: Duration) -> Self {
+        self.keep_alive_interval = interval;
+        self
+    }
+}
+
+impl<S: ScalarValue, CtxT: Unpin + Send + 'static> Init<S, CtxT> for ConnectionConfig<CtxT> {
+    type Error = Infallible;
+    type Future = future::Ready<Result<Self, Self::Error>>;
+
+    fn init(self, _params: Variables<S>) -> Self::Future {
+        future::ready(Ok(self))
+    }
+}
+
+enum Reaction<S: Schema> {
+    ServerMessage(ServerMessage<S::ScalarValue>),
+    EndStream,
+}
+
+impl<S: Schema> Reaction<S> {
+    /// Converts the reaction into a one-item stream.
+    fn to_stream(self) -> BoxStream<'static, Self> {
+        stream::once(future::ready(self)).boxed()
+    }
+}
+
+/// Init defines the requirements for types that can provide connection configurations when
+/// ConnectionInit messages are received. Implementations are provided for `ConnectionConfig` and
+/// closures that meet the requirements.
+pub trait Init<S: ScalarValue, CtxT>: Unpin + 'static {
+    /// The error that is returned on failure. The formatted error will be used as the contents of
+    /// the "message" field sent back to the client.
+    type Error: Error;
+
+    /// The future configuration type.
+    type Future: Future<Output = Result<ConnectionConfig<CtxT>, Self::Error>> + Send + 'static;
+
+    /// Returns a future for the configuration to use.
+    fn init(self, params: Variables<S>) -> Self::Future;
+}
+
+impl<F, S, CtxT, Fut, E> Init<S, CtxT> for F
+where
+    S: ScalarValue,
+    F: FnOnce(Variables<S>) -> Fut + Unpin + 'static,
+    Fut: Future<Output = Result<ConnectionConfig<CtxT>, E>> + Send + 'static,
+    E: Error,
+{
+    type Error = E;
+    type Future = Fut;
+
+    fn init(self, params: Variables<S>) -> Fut {
+        self(params)
+    }
+}
+
+enum ConnectionState<S: Schema, I: Init<S::ScalarValue, S::Context>> {
+    /// PreInit is the state before a ConnectionInit message has been accepted.
+    PreInit { init: I, schema: S },
+    /// Active is the state after a ConnectionInit message has been accepted.
+    Active {
+        config: Arc<ConnectionConfig<S::Context>>,
+        stoppers: HashMap<String, oneshot::Sender<()>>,
+        schema: S,
+    },
+    /// Terminated is the state after a ConnectionInit message has been rejected.
+    Terminated,
+}
+
+impl<S: Schema, I: Init<S::ScalarValue, S::Context>> ConnectionState<S, I> {
+    // Each message we receive results in a stream of zero or more reactions. For example, a
+    // ConnectionTerminate message results in a one-item stream with the EndStream reaction.
+    async fn handle_message(
+        self,
+        msg: ClientMessage<S::ScalarValue>,
+    ) -> (Self, BoxStream<'static, Reaction<S>>) {
+        if let ClientMessage::ConnectionTerminate = msg {
+            return (self, Reaction::EndStream.to_stream());
+        }
+
+        match self {
+            Self::PreInit { init, schema } => match msg {
+                ClientMessage::ConnectionInit { payload } => match init.init(payload).await {
+                    Ok(config) => {
+                        let keep_alive_interval = config.keep_alive_interval;
+
+                        let mut s = stream::iter(vec![Reaction::ServerMessage(
+                            ServerMessage::ConnectionAck,
+                        )])
+                        .boxed();
+
+                        if keep_alive_interval > Duration::from_secs(0) {
+                            s = s
+                                .chain(
+                                    Reaction::ServerMessage(ServerMessage::ConnectionKeepAlive)
+                                        .to_stream(),
+                                )
+                                .boxed();
+                            s = s
+                                .chain(stream::unfold((), move |_| async move {
+                                    tokio::time::delay_for(keep_alive_interval).await;
+                                    Some((
+                                        Reaction::ServerMessage(ServerMessage::ConnectionKeepAlive),
+                                        (),
+                                    ))
+                                }))
+                                .boxed();
+                        }
+
+                        (
+                            Self::Active {
+                                config: Arc::new(config),
+                                stoppers: HashMap::new(),
+                                schema,
+                            },
+                            s,
+                        )
+                    }
+                    Err(e) => (
+                        Self::Terminated,
+                        stream::iter(vec![
+                            Reaction::ServerMessage(ServerMessage::ConnectionError {
+                                payload: ConnectionErrorPayload {
+                                    message: e.to_string(),
+                                },
+                            }),
+                            Reaction::EndStream,
+                        ])
+                        .boxed(),
+                    ),
+                },
+                _ => (Self::PreInit { init, schema }, stream::empty().boxed()),
+            },
+            Self::Active {
+                config,
+                mut stoppers,
+                schema,
+            } => {
+                let reactions = match msg {
+                    ClientMessage::Start { id, payload } => {
+                        if stoppers.contains_key(&id) {
+                            // We already have an operation with this id, so we can't start a new
+                            // one.
+                            stream::empty().boxed()
+                        } else {
+                            // Go ahead and prune canceled stoppers before adding a new one.
+                            stoppers.retain(|_, tx| !tx.is_canceled());
+
+                            if config.max_in_flight_operations > 0
+                                && stoppers.len() >= config.max_in_flight_operations
+                            {
+                                // Too many in-flight operations. Just send back a validation error.
+                                stream::iter(vec![
+                                    Reaction::ServerMessage(ServerMessage::Error {
+                                        id: id.clone(),
+                                        payload: GraphQLError::ValidationError(vec![
+                                            RuleError::new("Too many in-flight operations.", &[]),
+                                        ])
+                                        .into(),
+                                    }),
+                                    Reaction::ServerMessage(ServerMessage::Complete { id }),
+                                ])
+                                .boxed()
+                            } else {
+                                // Create a channel that we can use to cancel the operation.
+                                let (tx, rx) = oneshot::channel::<()>();
+                                stoppers.insert(id.clone(), tx);
+
+                                // Create the operation stream. This stream will emit Data and Error
+                                // messages, but will not emit Complete – that part is up to us.
+                                let s = Self::start(
+                                    id.clone(),
+                                    ExecutionParams {
+                                        start_payload: payload,
+                                        config: config.clone(),
+                                        schema: schema.clone(),
+                                    },
+                                )
+                                .into_stream()
+                                .flatten();
+
+                                // Combine this with our oneshot channel so that the stream ends if the
+                                // oneshot is ever fired.
+                                let s = stream::unfold((rx, s.boxed()), |(rx, mut s)| async move {
+                                    let next = match future::select(rx, s.next()).await {
+                                        Either::Left(_) => None,
+                                        Either::Right((r, rx)) => r.map(|r| (r, rx)),
+                                    };
+                                    next.map(|(r, rx)| (r, (rx, s)))
+                                });
+
+                                // Once the stream ends, send the Complete message.
+                                let s = s.chain(
+                                    Reaction::ServerMessage(ServerMessage::Complete { id })
+                                        .to_stream(),
+                                );
+
+                                s.boxed()
+                            }
+                        }
+                    }
+                    ClientMessage::Stop { id } => {
+                        stoppers.remove(&id);
+                        stream::empty().boxed()
+                    }
+                    _ => stream::empty().boxed(),
+                };
+                (
+                    Self::Active {
+                        config,
+                        stoppers,
+                        schema,
+                    },
+                    reactions,
+                )
+            }
+            Self::Terminated => (self, stream::empty().boxed()),
+        }
+    }
+
+    async fn start(id: String, params: ExecutionParams<S>) -> BoxStream<'static, Reaction<S>> {
+        // TODO: This could be made more efficient if juniper exposed functionality to allow us to
+        // parse and validate the query, determine whether it's a subscription, and then execute
+        // it. For now, the query gets parsed and validated twice.
+
+        let params = Arc::new(params);
+
+        // Try to execute this as a query or mutation.
+        match juniper::execute(
+            &params.start_payload.query,
+            params
+                .start_payload
+                .operation_name
+                .as_ref()
+                .map(|s| s.as_str()),
+            params.schema.root_node(),
+            &params.start_payload.variables,
+            &params.config.context,
+        )
+        .await
+        {
+            Ok((data, errors)) => {
+                return Reaction::ServerMessage(ServerMessage::Data {
+                    id: id.clone(),
+                    payload: DataPayload { data, errors },
+                })
+                .to_stream();
+            }
+            Err(GraphQLError::IsSubscription) => {}
+            Err(e) => {
+                return Reaction::ServerMessage(ServerMessage::Error {
+                    id: id.clone(),
+                    // e only references data owned by params. The new ErrorPayload will continue to keep that data alive.
+                    payload: unsafe { ErrorPayload::new_unchecked(Box::new(params.clone()), e) },
+                })
+                .to_stream();
+            }
+        }
+
+        // Try to execute as a subscription.
+        SubscriptionStart::new(id, params.clone()).boxed()
+    }
+}
+
+struct InterruptableStream<S> {
+    stream: S,
+    rx: oneshot::Receiver<()>,
+}
+
+impl<S: Stream + Unpin> Stream for InterruptableStream<S> {
+    type Item = S::Item;
+
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
+        match Pin::new(&mut self.rx).poll(cx) {
+            Poll::Ready(_) => return Poll::Ready(None),
+            Poll::Pending => {}
+        }
+        Pin::new(&mut self.stream).poll_next(cx)
+    }
+}
+
+/// SubscriptionStartState is the state for a subscription operation.
+enum SubscriptionStartState<S: Schema> {
+    /// Init is the start before being polled for the first time.
+    Init { id: String },
+    /// ResolvingIntoStream is the state after being polled for the first time. In this state,
+    /// we're parsing, validating, and getting the actual event stream.
+    ResolvingIntoStream {
+        id: String,
+        future: BoxFuture<
+            'static,
+            Result<
+                juniper_subscriptions::Connection<'static, S::ScalarValue>,
+                GraphQLError<'static>,
+            >,
+        >,
+    },
+    /// Streaming is the state after we've successfully obtained the event stream for the
+    /// subscription. In this state, we're just forwarding events back to the client.
+    Streaming {
+        id: String,
+        stream: juniper_subscriptions::Connection<'static, S::ScalarValue>,
+    },
+    /// Terminated is the state once we're all done.
+    Terminated,
+}
+
+/// SubscriptionStart is the stream for a subscription operation.
+struct SubscriptionStart<S: Schema> {
+    params: Arc<ExecutionParams<S>>,
+    state: SubscriptionStartState<S>,
+    _marker: PhantomPinned,
+}
+
+impl<S: Schema> SubscriptionStart<S> {
+    fn new(id: String, params: Arc<ExecutionParams<S>>) -> Pin<Box<Self>> {
+        Box::pin(Self {
+            params,
+            state: SubscriptionStartState::Init { id },
+            _marker: PhantomPinned,
+        })
+    }
+}
+
+impl<S: Schema> Stream for SubscriptionStart<S> {
+    type Item = Reaction<S>;
+
+    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
+        let (params, state) = unsafe {
+            // XXX: The execution parameters are referenced by state and must not be modified.
+            // Modifying state is fine though.
+            let inner = self.get_unchecked_mut();
+            (&inner.params, &mut inner.state)
+        };
+
+        loop {
+            match state {
+                SubscriptionStartState::Init { id } => {
+                    // XXX: resolve_into_stream returns a Future that references the execution
+                    // parameters, and the returned stream also references them. We can guarantee
+                    // that everything has the same lifetime in this self-referential struct.
+                    let params = Arc::as_ptr(params);
+                    *state = SubscriptionStartState::ResolvingIntoStream {
+                        id: id.clone(),
+                        future: unsafe {
+                            juniper::resolve_into_stream(
+                                &(*params).start_payload.query,
+                                (*params)
+                                    .start_payload
+                                    .operation_name
+                                    .as_ref()
+                                    .map(|s| s.as_str()),
+                                (*params).schema.root_node(),
+                                &(*params).start_payload.variables,
+                                &(*params).config.context,
+                            )
+                        }
+                        .map_ok(|(stream, errors)| {
+                            juniper_subscriptions::Connection::from_stream(stream, errors)
+                        })
+                        .boxed(),
+                    };
+                }
+                SubscriptionStartState::ResolvingIntoStream {
+                    ref id,
+                    ref mut future,
+                } => match future.as_mut().poll(cx) {
+                    Poll::Ready(r) => match r {
+                        Ok(stream) => {
+                            *state = SubscriptionStartState::Streaming {
+                                id: id.clone(),
+                                stream,
+                            }
+                        }
+                        Err(e) => {
+                            return Poll::Ready(Some(Reaction::ServerMessage(
+                                ServerMessage::Error {
+                                    id: id.clone(),
+                                    // e only references data owned by params. The new ErrorPayload will continue to keep that data alive.
+                                    payload: unsafe {
+                                        ErrorPayload::new_unchecked(Box::new(params.clone()), e)
+                                    },
+                                },
+                            )));
+                        }
+                    },
+                    Poll::Pending => return Poll::Pending,
+                },
+                SubscriptionStartState::Streaming {
+                    ref id,
+                    ref mut stream,
+                } => match Pin::new(stream).poll_next(cx) {
+                    Poll::Ready(Some(output)) => {
+                        return Poll::Ready(Some(Reaction::ServerMessage(ServerMessage::Data {
+                            id: id.clone(),
+                            payload: DataPayload {
+                                data: output.data,
+                                errors: output.errors,
+                            },
+                        })));
+                    }
+                    Poll::Ready(None) => {
+                        *state = SubscriptionStartState::Terminated;
+                        return Poll::Ready(None);
+                    }
+                    Poll::Pending => return Poll::Pending,
+                },
+                SubscriptionStartState::Terminated => return Poll::Ready(None),
+            }
+        }
+    }
+}
+
+enum ConnectionSinkState<S: Schema, I: Init<S::ScalarValue, S::Context>> {
+    Ready {
+        state: ConnectionState<S, I>,
+    },
+    HandlingMessage {
+        result: BoxFuture<'static, (ConnectionState<S, I>, BoxStream<'static, Reaction<S>>)>,
+    },
+    Closed,
+}
+
+/// Implements the graphql-ws protocol. This is a sink for `TryInto<ClientMessage>` and a stream of
+/// `ServerMessage`.
+pub struct Connection<S: Schema, I: Init<S::ScalarValue, S::Context>> {
+    reactions: SelectAll<BoxStream<'static, Reaction<S>>>,
+    stream_waker: Option<Waker>,
+    sink_state: ConnectionSinkState<S, I>,
+}
+
+impl<S, I> Connection<S, I>
+where
+    S: Schema,
+    I: Init<S::ScalarValue, S::Context>,
+{
+    /// Creates a new connection, which is a sink for `TryInto<ClientMessage>` and a stream of `ServerMessage`.
+    ///
+    /// The `schema` argument should typically be an `Arc<RootNode<...>>`.
+    ///
+    /// The `init` argument is used to provide the context and additional configuration for
+    /// connections. This can be a `ConnectionConfig` if the context and configuration are already
+    /// known, or it can be a closure that gets executed asynchronously when the client sends the
+    /// ConnectionInit message. Using a closure allows you to perform authentication based on the
+    /// parameters provided by the client.
+    pub fn new(schema: S, init: I) -> Self {
+        Self {
+            reactions: SelectAll::new(),
+            stream_waker: None,
+            sink_state: ConnectionSinkState::Ready {
+                state: ConnectionState::PreInit { init, schema },
+            },
+        }
+    }
+}
+
+impl<S, I, T> Sink<T> for Connection<S, I>
+where
+    T: TryInto<ClientMessage<S::ScalarValue>>,
+    T::Error: Error,
+    S: Schema,
+    I: Init<S::ScalarValue, S::Context> + Send,
+{
+    type Error = Infallible;
+
+    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+        match &mut self.sink_state {
+            ConnectionSinkState::Ready { .. } => Poll::Ready(Ok(())),
+            ConnectionSinkState::HandlingMessage { ref mut result } => {
+                match Pin::new(result).poll(cx) {
+                    Poll::Ready((state, reactions)) => {
+                        self.reactions.push(reactions);
+                        self.sink_state = ConnectionSinkState::Ready { state };
+                        Poll::Ready(Ok(()))
+                    }
+                    Poll::Pending => Poll::Pending,
+                }
+            }
+            ConnectionSinkState::Closed => panic!("poll_ready called after close"),
+        }
+    }
+
+    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
+        let s = self.get_mut();
+        let state = &mut s.sink_state;
+        *state = match std::mem::replace(state, ConnectionSinkState::Closed) {
+            ConnectionSinkState::Ready { state } => {
+                match item.try_into() {
+                    Ok(msg) => ConnectionSinkState::HandlingMessage {
+                        result: state.handle_message(msg).boxed(),
+                    },
+                    Err(e) => {
+                        // If we weren't able to parse the message, send back an error.
+                        s.reactions.push(
+                            Reaction::ServerMessage(ServerMessage::ConnectionError {
+                                payload: ConnectionErrorPayload {
+                                    message: e.to_string(),
+                                },
+                            })
+                            .to_stream(),
+                        );
+                        ConnectionSinkState::Ready { state }
+                    }
+                }
+            }
+            _ => panic!("start_send called when not ready"),
+        };
+        Ok(())
+    }
+
+    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+        <Self as Sink<T>>::poll_ready(self, cx)
+    }
+
+    fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+        self.sink_state = ConnectionSinkState::Closed;
+        if let Some(waker) = self.stream_waker.take() {
+            // Wake up the stream so it can close too.
+            waker.wake();
+        }
+        Poll::Ready(Ok(()))
+    }
+}
+
+impl<S, I> Stream for Connection<S, I>
+where
+    S: Schema,
+    I: Init<S::ScalarValue, S::Context>,
+{
+    type Item = ServerMessage<S::ScalarValue>;
+
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
+        self.stream_waker = Some(cx.waker().clone());
+
+        if let ConnectionSinkState::Closed = self.sink_state {
+            return Poll::Ready(None);
+        }
+
+        // Poll the reactions for new outgoing messages.
+        loop {
+            if !self.reactions.is_empty() {
+                match Pin::new(&mut self.reactions).poll_next(cx) {
+                    Poll::Ready(Some(reaction)) => match reaction {
+                        Reaction::ServerMessage(msg) => return Poll::Ready(Some(msg)),
+                        Reaction::EndStream => return Poll::Ready(None),
+                    },
+                    Poll::Ready(None) => {
+                        // In rare cases, the reaction stream may terminate. For example, this will
+                        // happen if the first message we receive does not require any reaction. Just
+                        // recreate it in that case.
+                        self.reactions = SelectAll::new();
+                        return Poll::Pending;
+                    }
+                    Poll::Pending => return Poll::Pending,
+                }
+            } else {
+                return Poll::Pending;
+            }
+        }
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+    use juniper::{
+        futures::sink::SinkExt,
+        parser::{ParseError, Spanning, Token},
+        DefaultScalarValue, EmptyMutation, FieldError, FieldResult, InputValue, RootNode, Value,
+    };
+    use std::{convert::Infallible, io};
+
+    struct Context(i32);
+
+    struct Query;
+
+    #[juniper::graphql_object(Context=Context)]
+    impl Query {
+        /// context just resolves to the current context.
+        async fn context(context: &Context) -> i32 {
+            context.0
+        }
+    }
+
+    struct Subscription;
+
+    #[juniper::graphql_subscription(Context=Context)]
+    impl Subscription {
+        /// never never emits anything.
+        async fn never(context: &Context) -> BoxStream<'static, FieldResult<i32>> {
+            tokio::time::delay_for(Duration::from_secs(10000))
+                .map(|_| unreachable!())
+                .into_stream()
+                .boxed()
+        }
+
+        /// context emits the current context once, then never emits anything else.
+        async fn context(context: &Context) -> BoxStream<'static, FieldResult<i32>> {
+            stream::once(future::ready(Ok(context.0)))
+                .chain(
+                    tokio::time::delay_for(Duration::from_secs(10000))
+                        .map(|_| unreachable!())
+                        .into_stream(),
+                )
+                .boxed()
+        }
+
+        /// error emits an error once, then never emits anything else.
+        async fn error(context: &Context) -> BoxStream<'static, FieldResult<i32>> {
+            stream::once(future::ready(Err(FieldError::new(
+                "field error",
+                Value::null(),
+            ))))
+            .chain(
+                tokio::time::delay_for(Duration::from_secs(10000))
+                    .map(|_| unreachable!())
+                    .into_stream(),
+            )
+            .boxed()
+        }
+    }
+
+    type ClientMessage = super::ClientMessage<DefaultScalarValue>;
+    type ServerMessage = super::ServerMessage<DefaultScalarValue>;
+
+    fn new_test_schema() -> Arc<RootNode<'static, Query, EmptyMutation<Context>, Subscription>> {
+        Arc::new(RootNode::new(Query, EmptyMutation::new(), Subscription))
+    }
+
+    #[tokio::test]
+    async fn test_query() {
+        let mut conn = Connection::new(
+            new_test_schema(),
+            ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)),
+        );
+
+        conn.send(ClientMessage::ConnectionInit {
+            payload: Variables::default(),
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
+
+        conn.send(ClientMessage::Start {
+            id: "foo".to_string(),
+            payload: StartPayload {
+                query: "{context}".to_string(),
+                variables: Variables::default(),
+                operation_name: None,
+            },
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(
+            ServerMessage::Data {
+                id: "foo".to_string(),
+                payload: DataPayload {
+                    data: Value::Object(
+                        [("context", Value::Scalar(DefaultScalarValue::Int(1)))]
+                            .iter()
+                            .cloned()
+                            .collect()
+                    ),
+                    errors: vec![],
+                },
+            },
+            conn.next().await.unwrap()
+        );
+
+        assert_eq!(
+            ServerMessage::Complete {
+                id: "foo".to_string(),
+            },
+            conn.next().await.unwrap()
+        );
+    }
+
+    #[tokio::test]
+    async fn test_subscriptions() {
+        let mut conn = Connection::new(
+            new_test_schema(),
+            ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)),
+        );
+
+        conn.send(ClientMessage::ConnectionInit {
+            payload: Variables::default(),
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
+
+        conn.send(ClientMessage::Start {
+            id: "foo".to_string(),
+            payload: StartPayload {
+                query: "subscription Foo {context}".to_string(),
+                variables: Variables::default(),
+                operation_name: None,
+            },
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(
+            ServerMessage::Data {
+                id: "foo".to_string(),
+                payload: DataPayload {
+                    data: Value::Object([("context", Value::scalar(1))].iter().cloned().collect()),
+                    errors: vec![],
+                },
+            },
+            conn.next().await.unwrap()
+        );
+
+        conn.send(ClientMessage::Start {
+            id: "bar".to_string(),
+            payload: StartPayload {
+                query: "subscription Bar {context}".to_string(),
+                variables: Variables::default(),
+                operation_name: None,
+            },
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(
+            ServerMessage::Data {
+                id: "bar".to_string(),
+                payload: DataPayload {
+                    data: Value::Object([("context", Value::scalar(1))].iter().cloned().collect()),
+                    errors: vec![],
+                },
+            },
+            conn.next().await.unwrap()
+        );
+
+        conn.send(ClientMessage::Stop {
+            id: "foo".to_string(),
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(
+            ServerMessage::Complete {
+                id: "foo".to_string(),
+            },
+            conn.next().await.unwrap()
+        );
+    }
+
+    #[tokio::test]
+    async fn test_init_params_ok() {
+        let mut conn = Connection::new(new_test_schema(), |params: Variables| async move {
+            assert_eq!(params.get("foo"), Some(&InputValue::scalar("bar")));
+            Ok(ConnectionConfig::new(Context(1))) as Result<_, Infallible>
+        });
+
+        conn.send(ClientMessage::ConnectionInit {
+            payload: [("foo".to_string(), InputValue::scalar("bar".to_string()))]
+                .iter()
+                .cloned()
+                .collect(),
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
+    }
+
+    #[tokio::test]
+    async fn test_init_params_error() {
+        let mut conn = Connection::new(new_test_schema(), |params: Variables| async move {
+            assert_eq!(params.get("foo"), Some(&InputValue::scalar("bar")));
+            Err(io::Error::new(io::ErrorKind::Other, "init error"))
+        });
+
+        conn.send(ClientMessage::ConnectionInit {
+            payload: [("foo".to_string(), InputValue::scalar("bar".to_string()))]
+                .iter()
+                .cloned()
+                .collect(),
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(
+            ServerMessage::ConnectionError {
+                payload: ConnectionErrorPayload {
+                    message: "init error".to_string(),
+                },
+            },
+            conn.next().await.unwrap()
+        );
+    }
+
+    #[tokio::test]
+    async fn test_max_in_flight_operations() {
+        let mut conn = Connection::new(
+            new_test_schema(),
+            ConnectionConfig::new(Context(1))
+                .with_keep_alive_interval(Duration::from_secs(0))
+                .with_max_in_flight_operations(1),
+        );
+
+        conn.send(ClientMessage::ConnectionInit {
+            payload: Variables::default(),
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
+
+        conn.send(ClientMessage::Start {
+            id: "foo".to_string(),
+            payload: StartPayload {
+                query: "subscription Foo {never}".to_string(),
+                variables: Variables::default(),
+                operation_name: None,
+            },
+        })
+        .await
+        .unwrap();
+
+        conn.send(ClientMessage::Start {
+            id: "bar".to_string(),
+            payload: StartPayload {
+                query: "subscription Bar {never}".to_string(),
+                variables: Variables::default(),
+                operation_name: None,
+            },
+        })
+        .await
+        .unwrap();
+
+        match conn.next().await.unwrap() {
+            ServerMessage::Error { id, .. } => {
+                assert_eq!(id, "bar");
+            }
+            msg @ _ => panic!("expected error, got: {:?}", msg),
+        }
+    }
+
+    #[tokio::test]
+    async fn test_parse_error() {
+        let mut conn = Connection::new(
+            new_test_schema(),
+            ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)),
+        );
+
+        conn.send(ClientMessage::ConnectionInit {
+            payload: Variables::default(),
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
+
+        conn.send(ClientMessage::Start {
+            id: "foo".to_string(),
+            payload: StartPayload {
+                query: "asd".to_string(),
+                variables: Variables::default(),
+                operation_name: None,
+            },
+        })
+        .await
+        .unwrap();
+
+        match conn.next().await.unwrap() {
+            ServerMessage::Error { id, payload } => {
+                assert_eq!(id, "foo");
+                match payload.graphql_error() {
+                    GraphQLError::ParseError(Spanning {
+                        item: ParseError::UnexpectedToken(Token::Name("asd")),
+                        ..
+                    }) => {}
+                    p @ _ => panic!("expected graphql parse error, got: {:?}", p),
+                }
+            }
+            msg @ _ => panic!("expected error, got: {:?}", msg),
+        }
+    }
+
+    #[tokio::test]
+    async fn test_keep_alives() {
+        let mut conn = Connection::new(
+            new_test_schema(),
+            ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_millis(20)),
+        );
+
+        conn.send(ClientMessage::ConnectionInit {
+            payload: Variables::default(),
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
+
+        for _ in 0..10 {
+            assert_eq!(
+                ServerMessage::ConnectionKeepAlive,
+                conn.next().await.unwrap()
+            );
+        }
+    }
+
+    #[tokio::test]
+    async fn test_slow_init() {
+        let mut conn = Connection::new(
+            new_test_schema(),
+            ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)),
+        );
+
+        conn.send(ClientMessage::ConnectionInit {
+            payload: Variables::default(),
+        })
+        .await
+        .unwrap();
+
+        // If we send the start message before the init is handled, we should still get results.
+        conn.send(ClientMessage::Start {
+            id: "foo".to_string(),
+            payload: StartPayload {
+                query: "{context}".to_string(),
+                variables: Variables::default(),
+                operation_name: None,
+            },
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
+
+        assert_eq!(
+            ServerMessage::Data {
+                id: "foo".to_string(),
+                payload: DataPayload {
+                    data: Value::Object(
+                        [("context", Value::Scalar(DefaultScalarValue::Int(1)))]
+                            .iter()
+                            .cloned()
+                            .collect()
+                    ),
+                    errors: vec![],
+                },
+            },
+            conn.next().await.unwrap()
+        );
+    }
+
+    #[tokio::test]
+    async fn test_subscription_field_error() {
+        let mut conn = Connection::new(
+            new_test_schema(),
+            ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_secs(0)),
+        );
+
+        conn.send(ClientMessage::ConnectionInit {
+            payload: Variables::default(),
+        })
+        .await
+        .unwrap();
+
+        assert_eq!(ServerMessage::ConnectionAck, conn.next().await.unwrap());
+
+        conn.send(ClientMessage::Start {
+            id: "foo".to_string(),
+            payload: StartPayload {
+                query: "subscription Foo {error}".to_string(),
+                variables: Variables::default(),
+                operation_name: None,
+            },
+        })
+        .await
+        .unwrap();
+
+        match conn.next().await.unwrap() {
+            ServerMessage::Data {
+                id,
+                payload: DataPayload { data, errors },
+            } => {
+                assert_eq!(id, "foo");
+                assert_eq!(
+                    data,
+                    Value::Object([("error", Value::null())].iter().cloned().collect())
+                );
+                assert_eq!(errors.len(), 1);
+            }
+            msg @ _ => panic!("expected data, got: {:?}", msg),
+        }
+    }
+}
diff --git a/juniper_graphql_ws/src/schema.rs b/juniper_graphql_ws/src/schema.rs
new file mode 100644
index 00000000..68d282f0
--- /dev/null
+++ b/juniper_graphql_ws/src/schema.rs
@@ -0,0 +1,131 @@
+use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue};
+use std::sync::Arc;
+
+/// Schema defines the requirements for schemas that can be used for operations. Typically this is
+/// just an `Arc<RootNode<...>>` and you should not have to implement it yourself.
+pub trait Schema: Unpin + Clone + Send + Sync + 'static {
+    /// The context type.
+    type Context: Unpin + Send + Sync;
+
+    /// The scalar value type.
+    type ScalarValue: ScalarValue + Send + Sync;
+
+    /// The query type info.
+    type QueryTypeInfo: Send + Sync;
+
+    /// The query type.
+    type Query: GraphQLTypeAsync<Self::ScalarValue, Context = Self::Context, TypeInfo = Self::QueryTypeInfo>
+        + Send;
+
+    /// The mutation type info.
+    type MutationTypeInfo: Send + Sync;
+
+    /// The mutation type.
+    type Mutation: GraphQLTypeAsync<
+            Self::ScalarValue,
+            Context = Self::Context,
+            TypeInfo = Self::MutationTypeInfo,
+        > + Send;
+
+    /// The subscription type info.
+    type SubscriptionTypeInfo: Send + Sync;
+
+    /// The subscription type.
+    type Subscription: GraphQLSubscriptionType<
+            Self::ScalarValue,
+            Context = Self::Context,
+            TypeInfo = Self::SubscriptionTypeInfo,
+        > + Send;
+
+    /// Returns the root node for the schema.
+    fn root_node(
+        &self,
+    ) -> &RootNode<'static, Self::Query, Self::Mutation, Self::Subscription, Self::ScalarValue>;
+}
+
+/// This exists as a work-around for this issue: https://github.com/rust-lang/rust/issues/64552
+///
+/// It can be used in generators where using Arc directly would result in an error.
+// TODO: Remove this once that issue is resolved.
+#[doc(hidden)]
+pub struct ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>(
+    pub Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>,
+)
+where
+    QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+    QueryT::TypeInfo: Send + Sync,
+    MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+    MutationT::TypeInfo: Send + Sync,
+    SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
+    SubscriptionT::TypeInfo: Send + Sync,
+    CtxT: Unpin + Send + Sync,
+    S: ScalarValue + Send + Sync + 'static;
+
+impl<QueryT, MutationT, SubscriptionT, CtxT, S> Clone
+    for ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>
+where
+    QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+    QueryT::TypeInfo: Send + Sync,
+    MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+    MutationT::TypeInfo: Send + Sync,
+    SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
+    SubscriptionT::TypeInfo: Send + Sync,
+    CtxT: Unpin + Send + Sync,
+    S: ScalarValue + Send + Sync + 'static,
+{
+    fn clone(&self) -> Self {
+        Self(self.0.clone())
+    }
+}
+
+impl<QueryT, MutationT, SubscriptionT, CtxT, S> Schema
+    for ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>
+where
+    QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+    QueryT::TypeInfo: Send + Sync,
+    MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+    MutationT::TypeInfo: Send + Sync,
+    SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
+    SubscriptionT::TypeInfo: Send + Sync,
+    CtxT: Unpin + Send + Sync + 'static,
+    S: ScalarValue + Send + Sync + 'static,
+{
+    type Context = CtxT;
+    type ScalarValue = S;
+    type QueryTypeInfo = QueryT::TypeInfo;
+    type Query = QueryT;
+    type MutationTypeInfo = MutationT::TypeInfo;
+    type Mutation = MutationT;
+    type SubscriptionTypeInfo = SubscriptionT::TypeInfo;
+    type Subscription = SubscriptionT;
+
+    fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> {
+        &self.0
+    }
+}
+
+impl<QueryT, MutationT, SubscriptionT, CtxT, S> Schema
+    for Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>
+where
+    QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+    QueryT::TypeInfo: Send + Sync,
+    MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+    MutationT::TypeInfo: Send + Sync,
+    SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
+    SubscriptionT::TypeInfo: Send + Sync,
+    CtxT: Unpin + Send + Sync,
+    S: ScalarValue + Send + Sync + 'static,
+{
+    type Context = CtxT;
+    type ScalarValue = S;
+    type QueryTypeInfo = QueryT::TypeInfo;
+    type Query = QueryT;
+    type MutationTypeInfo = MutationT::TypeInfo;
+    type Mutation = MutationT;
+    type SubscriptionTypeInfo = SubscriptionT::TypeInfo;
+    type Subscription = SubscriptionT;
+
+    fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> {
+        self
+    }
+}
diff --git a/juniper_graphql_ws/src/server_message.rs b/juniper_graphql_ws/src/server_message.rs
new file mode 100644
index 00000000..3c353164
--- /dev/null
+++ b/juniper_graphql_ws/src/server_message.rs
@@ -0,0 +1,191 @@
+use juniper::{ExecutionError, GraphQLError, ScalarValue, Value};
+use serde::{Serialize, Serializer};
+use std::{any::Any, fmt, marker::PhantomPinned};
+
+/// The payload for errors that are not associated with a GraphQL operation.
+#[derive(Debug, Serialize, PartialEq)]
+#[serde(rename_all = "camelCase")]
+pub struct ConnectionErrorPayload {
+    /// The error message.
+    pub message: String,
+}
+
+/// Sent after execution of an operation. For queries and mutations, this is sent to the client
+/// once. For subscriptions, this is sent for every event in the event stream.
+#[derive(Debug, Serialize, PartialEq)]
+#[serde(bound(serialize = "S: ScalarValue"))]
+#[serde(rename_all = "camelCase")]
+pub struct DataPayload<S> {
+    /// The result data.
+    pub data: Value<S>,
+
+    /// The errors that have occurred during execution. Note that parse and validation errors are
+    /// not included here. They are sent via Error messages.
+    pub errors: Vec<ExecutionError<S>>,
+}
+
+/// A payload for errors that can happen before execution. Errors that happen during execution are
+/// instead sent to the client via `DataPayload`. `ErrorPayload` is a wrapper for an owned
+/// `GraphQLError`.
+// XXX: Think carefully before deriving traits. This is self-referential (error references
+// _execution_params).
+pub struct ErrorPayload {
+    _execution_params: Option<Box<dyn Any + Send>>,
+    error: GraphQLError<'static>,
+    _marker: PhantomPinned,
+}
+
+impl ErrorPayload {
+    /// For this to be okay, the caller must guarantee that the error can only reference data from
+    /// execution_params and that execution_params has not been modified or moved.
+    pub(crate) unsafe fn new_unchecked<'a>(
+        execution_params: Box<dyn Any + Send>,
+        error: GraphQLError<'a>,
+    ) -> Self {
+        Self {
+            _execution_params: Some(execution_params),
+            error: std::mem::transmute(error),
+            _marker: PhantomPinned,
+        }
+    }
+
+    /// Returns the contained GraphQLError.
+    pub fn graphql_error<'a>(&'a self) -> &GraphQLError<'a> {
+        &self.error
+    }
+}
+
+impl fmt::Debug for ErrorPayload {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        self.error.fmt(f)
+    }
+}
+
+impl PartialEq for ErrorPayload {
+    fn eq(&self, other: &Self) -> bool {
+        self.error.eq(&other.error)
+    }
+}
+
+impl Serialize for ErrorPayload {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        self.error.serialize(serializer)
+    }
+}
+
+impl From<GraphQLError<'static>> for ErrorPayload {
+    fn from(error: GraphQLError<'static>) -> Self {
+        Self {
+            _execution_params: None,
+            error,
+            _marker: PhantomPinned,
+        }
+    }
+}
+
+/// ServerMessage defines the message types that servers can send.
+#[derive(Debug, Serialize, PartialEq)]
+#[serde(bound(serialize = "S: ScalarValue"))]
+#[serde(rename_all = "snake_case")]
+#[serde(tag = "type")]
+pub enum ServerMessage<S: ScalarValue> {
+    /// ConnectionError is used for errors that are not associated with a GraphQL operation. For
+    /// example, this will be used when:
+    ///
+    ///   * The server is unable to parse a client's message.
+    ///   * The client's initialization parameters are rejected.
+    ConnectionError {
+        /// The error that occurred.
+        payload: ConnectionErrorPayload,
+    },
+    /// ConnectionAck is sent in response to a client's ConnectionInit message if the server accepted a
+    /// connection.
+    ConnectionAck,
+    /// Data contains the result of a query, mutation, or subscription event.
+    Data {
+        /// The id of the operation that the data is for.
+        id: String,
+
+        /// The data and errors that occurred during execution.
+        payload: DataPayload<S>,
+    },
+    /// Error contains an error that occurs before execution, such as validation errors.
+    Error {
+        /// The id of the operation that triggered this error.
+        id: String,
+
+        /// The error(s).
+        payload: ErrorPayload,
+    },
+    /// Complete indicates that no more data will be sent for the given operation.
+    Complete {
+        /// The id of the operation that has completed.
+        id: String,
+    },
+    /// ConnectionKeepAlive is sent periodically after accepting a connection.
+    #[serde(rename = "ka")]
+    ConnectionKeepAlive,
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+    use juniper::DefaultScalarValue;
+
+    #[test]
+    fn test_serialization() {
+        type ServerMessage = super::ServerMessage<DefaultScalarValue>;
+
+        assert_eq!(
+            serde_json::to_string(&ServerMessage::ConnectionError {
+                payload: ConnectionErrorPayload {
+                    message: "foo".to_string(),
+                },
+            })
+            .unwrap(),
+            r##"{"type":"connection_error","payload":{"message":"foo"}}"##,
+        );
+
+        assert_eq!(
+            serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
+            r##"{"type":"connection_ack"}"##,
+        );
+
+        assert_eq!(
+            serde_json::to_string(&ServerMessage::Data {
+                id: "foo".to_string(),
+                payload: DataPayload {
+                    data: Value::null(),
+                    errors: vec![],
+                },
+            })
+            .unwrap(),
+            r##"{"type":"data","id":"foo","payload":{"data":null,"errors":[]}}"##,
+        );
+
+        assert_eq!(
+            serde_json::to_string(&ServerMessage::Error {
+                id: "foo".to_string(),
+                payload: GraphQLError::UnknownOperationName.into(),
+            })
+            .unwrap(),
+            r##"{"type":"error","id":"foo","payload":[{"message":"Unknown operation"}]}"##,
+        );
+
+        assert_eq!(
+            serde_json::to_string(&ServerMessage::Complete {
+                id: "foo".to_string(),
+            })
+            .unwrap(),
+            r##"{"type":"complete","id":"foo"}"##,
+        );
+
+        assert_eq!(
+            serde_json::to_string(&ServerMessage::ConnectionKeepAlive).unwrap(),
+            r##"{"type":"ka"}"##,
+        );
+    }
+}
diff --git a/juniper_subscriptions/src/lib.rs b/juniper_subscriptions/src/lib.rs
index 0e78279c..3418c055 100644
--- a/juniper_subscriptions/src/lib.rs
+++ b/juniper_subscriptions/src/lib.rs
@@ -222,19 +222,25 @@ where
                 }
 
                 if filled_count == obj_len {
+                    let mut errors = vec![];
                     filled_count = 0;
                     let new_vec = (0..obj_len).map(|_| None).collect::<Vec<_>>();
                     let ready_vec = std::mem::replace(&mut ready_vec, new_vec);
                     let ready_vec_iterator = ready_vec.into_iter().map(|el| {
                         let (name, val) = el.unwrap();
-                        if let Ok(value) = val {
-                            (name, value)
-                        } else {
-                            (name, Value::Null)
+                        match val {
+                            Ok(value) => (name, value),
+                            Err(e) => {
+                                errors.push(e);
+                                (name, Value::Null)
+                            }
                         }
                     });
                     let obj = Object::from_iter(ready_vec_iterator);
-                    Poll::Ready(Some(ExecutionOutput::from_data(Value::Object(obj))))
+                    Poll::Ready(Some(ExecutionOutput {
+                        data: Value::Object(obj),
+                        errors,
+                    }))
                 } else {
                     Poll::Pending
                 }
diff --git a/juniper_warp/Cargo.toml b/juniper_warp/Cargo.toml
index cf14ae32..f2fcb5b5 100644
--- a/juniper_warp/Cargo.toml
+++ b/juniper_warp/Cargo.toml
@@ -9,7 +9,7 @@ repository = "https://github.com/graphql-rust/juniper"
 edition = "2018"
 
 [features]
-subscriptions = ["juniper_subscriptions"]
+subscriptions = ["juniper_graphql_ws"]
 
 [dependencies]
 bytes = "0.5"
@@ -17,7 +17,7 @@ anyhow = "1.0"
 thiserror = "1.0"
 futures = "0.3.1"
 juniper = { version = "0.14.2", path = "../juniper", default-features = false  }
-juniper_subscriptions = { path = "../juniper_subscriptions", optional = true }
+juniper_graphql_ws = { path = "../juniper_graphql_ws", optional = true }
 serde = { version = "1.0.75", features = ["derive"] }
 serde_json = "1.0.24"
 tokio = { version = "0.2", features = ["blocking", "rt-core"] }
diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs
index 291ba011..06898304 100644
--- a/juniper_warp/src/lib.rs
+++ b/juniper_warp/src/lib.rs
@@ -393,224 +393,103 @@ fn playground_response(
 /// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
 #[cfg(feature = "subscriptions")]
 pub mod subscriptions {
-    use std::{
-        collections::HashMap,
-        sync::{
-            atomic::{AtomicBool, Ordering},
-            Arc,
+    use juniper::{
+        futures::{
+            future::{self, Either},
+            sink::SinkExt,
+            stream::StreamExt,
         },
+        GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue,
     };
+    use juniper_graphql_ws::{ArcSchema, ClientMessage, Connection, Init};
+    use std::{convert::Infallible, fmt, sync::Arc};
 
-    use anyhow::anyhow;
-    use futures::{channel::mpsc, Future, StreamExt as _, TryFutureExt as _, TryStreamExt as _};
-    use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
-    use juniper_subscriptions::Coordinator;
-    use serde::{Deserialize, Serialize};
-    use warp::ws::Message;
+    struct Message(warp::ws::Message);
 
-    /// Listen to incoming messages and do one of the following:
-    ///  - execute subscription and return values from stream
-    ///  - stop stream and close ws connection
-    #[allow(dead_code)]
-    pub fn graphql_subscriptions<Query, Mutation, Subscription, CtxT, S>(
-        websocket: warp::ws::WebSocket,
-        coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, CtxT, S>>,
-        context: CtxT,
-    ) -> impl Future<Output = Result<(), anyhow::Error>> + Send
-    where
-        Query: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
-        Query::TypeInfo: Send + Sync,
-        Mutation: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
-        Mutation::TypeInfo: Send + Sync,
-        Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
-        Subscription::TypeInfo: Send + Sync,
-        CtxT: Send + Sync + 'static,
-        S: ScalarValue + Send + Sync + 'static,
-    {
-        let (sink_tx, sink_rx) = websocket.split();
-        let (ws_tx, ws_rx) = mpsc::unbounded();
-        tokio::task::spawn(
-            ws_rx
-                .take_while(|v: &Option<_>| futures::future::ready(v.is_some()))
-                .map(|x| x.unwrap())
-                .forward(sink_tx),
-        );
+    impl<S: ScalarValue> std::convert::TryFrom<Message> for ClientMessage<S> {
+        type Error = serde_json::Error;
 
-        let context = Arc::new(context);
-        let got_close_signal = Arc::new(AtomicBool::new(false));
-        let got_close_signal2 = got_close_signal.clone();
-
-        struct SubscriptionState {
-            should_stop: AtomicBool,
+        fn try_from(msg: Message) -> serde_json::Result<Self> {
+            serde_json::from_slice(msg.0.as_bytes())
         }
-        let subscription_states = HashMap::<String, Arc<SubscriptionState>>::new();
-
-        sink_rx
-            .map_err(move |e| {
-                got_close_signal2.store(true, Ordering::Relaxed);
-                anyhow!("Websocket error: {}", e)
-            })
-            .try_fold(subscription_states, move |mut subscription_states, msg| {
-                let coordinator = coordinator.clone();
-                let context = context.clone();
-                let got_close_signal = got_close_signal.clone();
-                let ws_tx = ws_tx.clone();
-
-                async move {
-                    if msg.is_close() {
-                        return Ok(subscription_states);
-                    }
-
-                    let msg = msg
-                        .to_str()
-                        .map_err(|_| anyhow!("Non-text messages are not accepted"))?;
-                    let request: WsPayload<S> = serde_json::from_str(msg)
-                        .map_err(|e| anyhow!("Invalid WsPayload: {}", e))?;
-
-                    match request.type_name.as_str() {
-                        "connection_init" => {}
-                        "start" => {
-                            if got_close_signal.load(Ordering::Relaxed) {
-                                return Ok(subscription_states);
-                            }
-
-                            let request_id = request.id.clone().unwrap_or("1".to_owned());
-
-                            if let Some(existing) = subscription_states.get(&request_id) {
-                                existing.should_stop.store(true, Ordering::Relaxed);
-                            }
-                            let state = Arc::new(SubscriptionState {
-                                should_stop: AtomicBool::new(false),
-                            });
-                            subscription_states.insert(request_id.clone(), state.clone());
-
-                            let ws_tx = ws_tx.clone();
-
-                            if let Some(ref payload) = request.payload {
-                                if payload.query.is_none() {
-                                    return Err(anyhow!("Query not found"));
-                                }
-                            } else {
-                                return Err(anyhow!("Payload not found"));
-                            }
-
-                            tokio::task::spawn(async move {
-                                let payload = request.payload.unwrap();
-
-                                let graphql_request = GraphQLRequest::<S>::new(
-                                    payload.query.unwrap(),
-                                    None,
-                                    payload.variables,
-                                );
-
-                                let values_stream = match coordinator
-                                    .subscribe(&graphql_request, &context)
-                                    .await
-                                {
-                                    Ok(s) => s,
-                                    Err(err) => {
-                                        let _ =
-                                            ws_tx.unbounded_send(Some(Ok(Message::text(format!(
-                                                r#"{{"type":"error","id":"{}","payload":{}}}"#,
-                                                request_id,
-                                                serde_json::ser::to_string(&err).unwrap_or(
-                                                    "Error deserializing GraphQLError".to_owned()
-                                                )
-                                            )))));
-
-                                        let close_message = format!(
-                                            r#"{{"type":"complete","id":"{}","payload":null}}"#,
-                                            request_id
-                                        );
-                                        let _ = ws_tx
-                                            .unbounded_send(Some(Ok(Message::text(close_message))));
-                                        // close channel
-                                        let _ = ws_tx.unbounded_send(None);
-                                        return;
-                                    }
-                                };
-
-                                values_stream
-                                    .take_while(move |response| {
-                                        let request_id = request_id.clone();
-                                        let should_stop = state.should_stop.load(Ordering::Relaxed)
-                                            || got_close_signal.load(Ordering::Relaxed);
-                                        if !should_stop {
-                                            let mut response_text = serde_json::to_string(
-                                                &response,
-                                            )
-                                            .unwrap_or("Error deserializing response".to_owned());
-
-                                            response_text = format!(
-                                                r#"{{"type":"data","id":"{}","payload":{} }}"#,
-                                                request_id, response_text
-                                            );
-
-                                            let _ = ws_tx.unbounded_send(Some(Ok(Message::text(
-                                                response_text,
-                                            ))));
-                                        }
-
-                                        async move { !should_stop }
-                                    })
-                                    .for_each(|_| async {})
-                                    .await;
-                            });
-                        }
-                        "stop" => {
-                            let request_id = request.id.unwrap_or("1".to_owned());
-                            if let Some(existing) = subscription_states.get(&request_id) {
-                                existing.should_stop.store(true, Ordering::Relaxed);
-                                subscription_states.remove(&request_id);
-                            }
-
-                            let close_message = format!(
-                                r#"{{"type":"complete","id":"{}","payload":null}}"#,
-                                request_id
-                            );
-                            let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message))));
-
-                            // close channel
-                            let _ = ws_tx.unbounded_send(None);
-                        }
-                        _ => {}
-                    }
-
-                    Ok(subscription_states)
-                }
-            })
-            .map_ok(|_| ())
     }
 
-    #[derive(Deserialize)]
-    #[serde(bound = "GraphQLPayload<S>: Deserialize<'de>")]
-    struct WsPayload<S>
+    /// Errors that can happen while serving a connection.
+    #[derive(Debug)]
+    pub enum Error {
+        /// Errors that can happen in Warp while serving a connection.
+        Warp(warp::Error),
+
+        /// Errors that can happen while serializing outgoing messages. Note that errors that occur
+        /// while deserializing internal messages are handled internally by the protocol.
+        Serde(serde_json::Error),
+    }
+
+    impl fmt::Display for Error {
+        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+            match self {
+                Self::Warp(e) => write!(f, "warp error: {}", e),
+                Self::Serde(e) => write!(f, "serde error: {}", e),
+            }
+        }
+    }
+
+    impl std::error::Error for Error {}
+
+    impl From<warp::Error> for Error {
+        fn from(err: warp::Error) -> Self {
+            Self::Warp(err)
+        }
+    }
+
+    impl From<Infallible> for Error {
+        fn from(_err: Infallible) -> Self {
+            unreachable!()
+        }
+    }
+
+    /// Serves the graphql-ws protocol over a WebSocket connection.
+    ///
+    /// The `init` argument is used to provide the context and additional configuration for
+    /// connections. This can be a `juniper_graphql_ws::ConnectionConfig` if the context and
+    /// configuration are already known, or it can be a closure that gets executed asynchronously
+    /// when the client sends the ConnectionInit message. Using a closure allows you to perform
+    /// authentication based on the parameters provided by the client.
+    pub async fn serve_graphql_ws<Query, Mutation, Subscription, CtxT, S, I>(
+        websocket: warp::ws::WebSocket,
+        root_node: Arc<RootNode<'static, Query, Mutation, Subscription, S>>,
+        init: I,
+    ) -> Result<(), Error>
     where
-        S: ScalarValue + Send + Sync,
+        Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+        Query::TypeInfo: Send + Sync,
+        Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
+        Mutation::TypeInfo: Send + Sync,
+        Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
+        Subscription::TypeInfo: Send + Sync,
+        CtxT: Unpin + Send + Sync + 'static,
+        S: ScalarValue + Send + Sync + 'static,
+        I: Init<S, CtxT> + Send,
     {
-        id: Option<String>,
-        #[serde(rename(deserialize = "type"))]
-        type_name: String,
-        payload: Option<GraphQLPayload<S>>,
-    }
+        let (ws_tx, ws_rx) = websocket.split();
+        let (s_tx, s_rx) = Connection::new(ArcSchema(root_node), init).split();
 
-    #[derive(Debug, Deserialize)]
-    #[serde(bound = "InputValue<S>: Deserialize<'de>")]
-    struct GraphQLPayload<S>
-    where
-        S: ScalarValue + Send + Sync,
-    {
-        variables: Option<InputValue<S>>,
-        extensions: Option<HashMap<String, String>>,
-        #[serde(rename(deserialize = "operationName"))]
-        operaton_name: Option<String>,
-        query: Option<String>,
-    }
+        let ws_rx = ws_rx.map(|r| r.map(|msg| Message(msg)));
+        let s_rx = s_rx.map(|msg| {
+            serde_json::to_string(&msg)
+                .map(|t| warp::ws::Message::text(t))
+                .map_err(|e| Error::Serde(e))
+        });
 
-    #[derive(Serialize)]
-    struct Output {
-        data: String,
-        variables: String,
+        match future::select(
+            ws_rx.forward(s_tx.sink_err_into()),
+            s_rx.forward(ws_tx.sink_err_into()),
+        )
+        .await
+        {
+            Either::Left((r, _)) => r.map_err(|e| e.into()),
+            Either::Right((r, _)) => r,
+        }
     }
 }