GraphQL-WS crate and Warp subscriptions update (#721)

* update pre-existing juniper_warp::subscriptions

* initial draft

* finish up, update example

* polish + timing test

* fix pre-existing bug

* rebase updates

* address comments

* add release.toml

* makefile and initial changelog

* add new Cargo.toml to juniper/release.toml
This commit is contained in:
Chris 2020-07-29 04:23:44 -04:00 committed by GitHub
parent dc309b83b7
commit 84c9720b53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 1696 additions and 238 deletions

View file

@ -15,6 +15,7 @@ members = [
"juniper_rocket",
"juniper_rocket_async",
"juniper_subscriptions",
"juniper_graphql_ws",
"juniper_warp",
"juniper_actix",
]

View file

@ -13,6 +13,6 @@ serde_json = "1.0"
tokio = { version = "0.2", features = ["rt-core", "macros"] }
warp = "0.2.1"
juniper = { git = "https://github.com/graphql-rust/juniper" }
juniper_subscriptions = { git = "https://github.com/graphql-rust/juniper" }
juniper_warp = { git = "https://github.com/graphql-rust/juniper", features = ["subscriptions"] }
juniper = { path = "../../juniper" }
juniper_graphql_ws = { path = "../../juniper_graphql_ws" }
juniper_warp = { path = "../../juniper_warp", features = ["subscriptions"] }

View file

@ -2,10 +2,10 @@
use std::{env, pin::Pin, sync::Arc, time::Duration};
use futures::{Future, FutureExt as _, Stream};
use futures::{FutureExt as _, Stream};
use juniper::{DefaultScalarValue, EmptyMutation, FieldError, RootNode};
use juniper_subscriptions::Coordinator;
use juniper_warp::{playground_filter, subscriptions::graphql_subscriptions};
use juniper_graphql_ws::ConnectionConfig;
use juniper_warp::{playground_filter, subscriptions::serve_graphql_ws};
use warp::{http::Response, Filter};
#[derive(Clone)]
@ -151,30 +151,24 @@ async fn main() {
let qm_state = warp::any().map(move || Context {});
let qm_graphql_filter = juniper_warp::make_graphql_filter(qm_schema, qm_state.boxed());
let sub_state = warp::any().map(move || Context {});
let coordinator = Arc::new(juniper_subscriptions::Coordinator::new(schema()));
let root_node = Arc::new(schema());
log::info!("Listening on 127.0.0.1:8080");
let routes = (warp::path("subscriptions")
.and(warp::ws())
.and(sub_state.clone())
.and(warp::any().map(move || Arc::clone(&coordinator)))
.map(
|ws: warp::ws::Ws,
ctx: Context,
coordinator: Arc<Coordinator<'static, _, _, _, _, _>>| {
ws.on_upgrade(|websocket| -> Pin<Box<dyn Future<Output = ()> + Send>> {
graphql_subscriptions(websocket, coordinator, ctx)
.map(|r| {
if let Err(e) = r {
println!("Websocket error: {}", e);
}
})
.boxed()
})
},
))
.map(move |ws: warp::ws::Ws| {
let root_node = root_node.clone();
ws.on_upgrade(move |websocket| async move {
serve_graphql_ws(websocket, root_node, ConnectionConfig::new(Context {}))
.map(|r| {
if let Err(e) = r {
println!("Websocket error: {}", e);
}
})
.await
})
}))
.map(|reply| {
// TODO#584: remove this workaround
warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws")

View file

@ -30,6 +30,8 @@ pre-release-replacements = [
{file="../juniper_warp/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""},
# Subscriptions
{file="../juniper_subscriptions/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
# GraphQL-WS
{file="../juniper_graphql_ws/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
# Actix-Web
{file="../juniper_actix/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
{file="../juniper_actix/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""},

View file

@ -0,0 +1,3 @@
# master
- Initial Release

View file

@ -0,0 +1,19 @@
[package]
name = "juniper_graphql_ws"
version = "0.1.0"
authors = ["Christopher Brown <ccbrown112@gmail.com>"]
license = "BSD-2-Clause"
description = "Graphql-ws protocol implementation for Juniper"
documentation = "https://docs.rs/juniper_graphql_ws"
repository = "https://github.com/graphql-rust/juniper"
keywords = ["graphql-ws", "juniper", "graphql", "apollo"]
edition = "2018"
[dependencies]
juniper = { version = "0.14.2", path = "../juniper", default-features = false }
juniper_subscriptions = { path = "../juniper_subscriptions" }
serde = { version = "1.0.8", features = ["derive"] }
tokio = { version = "0.2", features = ["macros", "rt-core", "time"] }
[dev-dependencies]
serde_json = { version = "1.0.2" }

View file

@ -0,0 +1,20 @@
[env]
CARGO_MAKE_CARGO_ALL_FEATURES = ""
[tasks.build-verbose]
condition = { rust_version = { min = "1.29.0" } }
[tasks.build-verbose.windows]
condition = { rust_version = { min = "1.29.0" }, env = { "TARGET" = "x86_64-pc-windows-msvc" } }
[tasks.test-verbose]
condition = { rust_version = { min = "1.29.0" } }
[tasks.test-verbose.windows]
condition = { rust_version = { min = "1.29.0" }, env = { "TARGET" = "x86_64-pc-windows-msvc" } }
[tasks.ci-coverage-flow]
condition = { rust_version = { min = "1.29.0" } }
[tasks.ci-coverage-flow.windows]
disabled = true

View file

@ -0,0 +1,8 @@
no-dev-version = true
pre-release-commit-message = "Release {{crate_name}} {{version}}"
pro-release-commit-message = "Bump {{crate_name}} version to {{next_version}}"
tag-message = "Release {{crate_name}} {{version}}"
upload-doc = false
pre-release-replacements = [
{file="src/lib.rs", search="docs.rs/juniper_graphql_ws/[a-z0-9\\.-]+", replace="docs.rs/juniper_graphql_ws/{{version}}"},
]

View file

@ -0,0 +1,131 @@
use juniper::{ScalarValue, Variables};
/// The payload for a client's "start" message. This triggers execution of a query, mutation, or
/// subscription.
#[derive(Debug, Deserialize, PartialEq)]
#[serde(bound(deserialize = "S: ScalarValue"))]
#[serde(rename_all = "camelCase")]
pub struct StartPayload<S: ScalarValue> {
/// The document body.
pub query: String,
/// The optional variables.
#[serde(default)]
pub variables: Variables<S>,
/// The optional operation name (required if the document contains multiple operations).
pub operation_name: Option<String>,
}
/// ClientMessage defines the message types that clients can send.
#[derive(Debug, Deserialize, PartialEq)]
#[serde(bound(deserialize = "S: ScalarValue"))]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum ClientMessage<S: ScalarValue> {
/// ConnectionInit is sent by the client upon connecting.
ConnectionInit {
/// Optional parameters of any type sent from the client. These are often used for
/// authentication.
#[serde(default)]
payload: Variables<S>,
},
/// Start messages are used to execute a GraphQL operation.
Start {
/// The id of the operation. This can be anything, but must be unique. If there are other
/// in-flight operations with the same id, the message will be ignored or cause an error.
id: String,
/// The query, variables, and operation name.
payload: StartPayload<S>,
},
/// Stop messages are used to unsubscribe from a subscription.
Stop {
/// The id of the operation to stop.
id: String,
},
/// ConnectionTerminate is used to terminate the connection.
ConnectionTerminate,
}
#[cfg(test)]
mod test {
use super::*;
use juniper::{DefaultScalarValue, InputValue};
#[test]
fn test_deserialization() {
type ClientMessage = super::ClientMessage<DefaultScalarValue>;
assert_eq!(
ClientMessage::ConnectionInit {
payload: [("foo".to_string(), InputValue::scalar("bar"))]
.iter()
.cloned()
.collect(),
},
serde_json::from_str(r##"{"type": "connection_init", "payload": {"foo": "bar"}}"##)
.unwrap(),
);
assert_eq!(
ClientMessage::ConnectionInit {
payload: Variables::default(),
},
serde_json::from_str(r##"{"type": "connection_init"}"##).unwrap(),
);
assert_eq!(
ClientMessage::Start {
id: "foo".to_string(),
payload: StartPayload {
query: "query MyQuery { __typename }".to_string(),
variables: [("foo".to_string(), InputValue::scalar("bar"))]
.iter()
.cloned()
.collect(),
operation_name: Some("MyQuery".to_string()),
},
},
serde_json::from_str(
r##"{"type": "start", "id": "foo", "payload": {
"query": "query MyQuery { __typename }",
"variables": {
"foo": "bar"
},
"operationName": "MyQuery"
}}"##
)
.unwrap(),
);
assert_eq!(
ClientMessage::Start {
id: "foo".to_string(),
payload: StartPayload {
query: "query MyQuery { __typename }".to_string(),
variables: Variables::default(),
operation_name: None,
},
},
serde_json::from_str(
r##"{"type": "start", "id": "foo", "payload": {
"query": "query MyQuery { __typename }"
}}"##
)
.unwrap(),
);
assert_eq!(
ClientMessage::Stop {
id: "foo".to_string()
},
serde_json::from_str(r##"{"type": "stop", "id": "foo"}"##).unwrap(),
);
assert_eq!(
ClientMessage::ConnectionTerminate,
serde_json::from_str(r##"{"type": "connection_terminate"}"##).unwrap(),
);
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,131 @@
use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue};
use std::sync::Arc;
/// Schema defines the requirements for schemas that can be used for operations. Typically this is
/// just an `Arc<RootNode<...>>` and you should not have to implement it yourself.
pub trait Schema: Unpin + Clone + Send + Sync + 'static {
/// The context type.
type Context: Unpin + Send + Sync;
/// The scalar value type.
type ScalarValue: ScalarValue + Send + Sync;
/// The query type info.
type QueryTypeInfo: Send + Sync;
/// The query type.
type Query: GraphQLTypeAsync<Self::ScalarValue, Context = Self::Context, TypeInfo = Self::QueryTypeInfo>
+ Send;
/// The mutation type info.
type MutationTypeInfo: Send + Sync;
/// The mutation type.
type Mutation: GraphQLTypeAsync<
Self::ScalarValue,
Context = Self::Context,
TypeInfo = Self::MutationTypeInfo,
> + Send;
/// The subscription type info.
type SubscriptionTypeInfo: Send + Sync;
/// The subscription type.
type Subscription: GraphQLSubscriptionType<
Self::ScalarValue,
Context = Self::Context,
TypeInfo = Self::SubscriptionTypeInfo,
> + Send;
/// Returns the root node for the schema.
fn root_node(
&self,
) -> &RootNode<'static, Self::Query, Self::Mutation, Self::Subscription, Self::ScalarValue>;
}
/// This exists as a work-around for this issue: https://github.com/rust-lang/rust/issues/64552
///
/// It can be used in generators where using Arc directly would result in an error.
// TODO: Remove this once that issue is resolved.
#[doc(hidden)]
pub struct ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>(
pub Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>,
)
where
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
QueryT::TypeInfo: Send + Sync,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
MutationT::TypeInfo: Send + Sync,
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
SubscriptionT::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync,
S: ScalarValue + Send + Sync + 'static;
impl<QueryT, MutationT, SubscriptionT, CtxT, S> Clone
for ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>
where
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
QueryT::TypeInfo: Send + Sync,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
MutationT::TypeInfo: Send + Sync,
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
SubscriptionT::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync,
S: ScalarValue + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<QueryT, MutationT, SubscriptionT, CtxT, S> Schema
for ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>
where
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
QueryT::TypeInfo: Send + Sync,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
MutationT::TypeInfo: Send + Sync,
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
SubscriptionT::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
{
type Context = CtxT;
type ScalarValue = S;
type QueryTypeInfo = QueryT::TypeInfo;
type Query = QueryT;
type MutationTypeInfo = MutationT::TypeInfo;
type Mutation = MutationT;
type SubscriptionTypeInfo = SubscriptionT::TypeInfo;
type Subscription = SubscriptionT;
fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> {
&self.0
}
}
impl<QueryT, MutationT, SubscriptionT, CtxT, S> Schema
for Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>
where
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
QueryT::TypeInfo: Send + Sync,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
MutationT::TypeInfo: Send + Sync,
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
SubscriptionT::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync,
S: ScalarValue + Send + Sync + 'static,
{
type Context = CtxT;
type ScalarValue = S;
type QueryTypeInfo = QueryT::TypeInfo;
type Query = QueryT;
type MutationTypeInfo = MutationT::TypeInfo;
type Mutation = MutationT;
type SubscriptionTypeInfo = SubscriptionT::TypeInfo;
type Subscription = SubscriptionT;
fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> {
self
}
}

View file

@ -0,0 +1,191 @@
use juniper::{ExecutionError, GraphQLError, ScalarValue, Value};
use serde::{Serialize, Serializer};
use std::{any::Any, fmt, marker::PhantomPinned};
/// The payload for errors that are not associated with a GraphQL operation.
#[derive(Debug, Serialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ConnectionErrorPayload {
/// The error message.
pub message: String,
}
/// Sent after execution of an operation. For queries and mutations, this is sent to the client
/// once. For subscriptions, this is sent for every event in the event stream.
#[derive(Debug, Serialize, PartialEq)]
#[serde(bound(serialize = "S: ScalarValue"))]
#[serde(rename_all = "camelCase")]
pub struct DataPayload<S> {
/// The result data.
pub data: Value<S>,
/// The errors that have occurred during execution. Note that parse and validation errors are
/// not included here. They are sent via Error messages.
pub errors: Vec<ExecutionError<S>>,
}
/// A payload for errors that can happen before execution. Errors that happen during execution are
/// instead sent to the client via `DataPayload`. `ErrorPayload` is a wrapper for an owned
/// `GraphQLError`.
// XXX: Think carefully before deriving traits. This is self-referential (error references
// _execution_params).
pub struct ErrorPayload {
_execution_params: Option<Box<dyn Any + Send>>,
error: GraphQLError<'static>,
_marker: PhantomPinned,
}
impl ErrorPayload {
/// For this to be okay, the caller must guarantee that the error can only reference data from
/// execution_params and that execution_params has not been modified or moved.
pub(crate) unsafe fn new_unchecked<'a>(
execution_params: Box<dyn Any + Send>,
error: GraphQLError<'a>,
) -> Self {
Self {
_execution_params: Some(execution_params),
error: std::mem::transmute(error),
_marker: PhantomPinned,
}
}
/// Returns the contained GraphQLError.
pub fn graphql_error<'a>(&'a self) -> &GraphQLError<'a> {
&self.error
}
}
impl fmt::Debug for ErrorPayload {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.error.fmt(f)
}
}
impl PartialEq for ErrorPayload {
fn eq(&self, other: &Self) -> bool {
self.error.eq(&other.error)
}
}
impl Serialize for ErrorPayload {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.error.serialize(serializer)
}
}
impl From<GraphQLError<'static>> for ErrorPayload {
fn from(error: GraphQLError<'static>) -> Self {
Self {
_execution_params: None,
error,
_marker: PhantomPinned,
}
}
}
/// ServerMessage defines the message types that servers can send.
#[derive(Debug, Serialize, PartialEq)]
#[serde(bound(serialize = "S: ScalarValue"))]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum ServerMessage<S: ScalarValue> {
/// ConnectionError is used for errors that are not associated with a GraphQL operation. For
/// example, this will be used when:
///
/// * The server is unable to parse a client's message.
/// * The client's initialization parameters are rejected.
ConnectionError {
/// The error that occurred.
payload: ConnectionErrorPayload,
},
/// ConnectionAck is sent in response to a client's ConnectionInit message if the server accepted a
/// connection.
ConnectionAck,
/// Data contains the result of a query, mutation, or subscription event.
Data {
/// The id of the operation that the data is for.
id: String,
/// The data and errors that occurred during execution.
payload: DataPayload<S>,
},
/// Error contains an error that occurs before execution, such as validation errors.
Error {
/// The id of the operation that triggered this error.
id: String,
/// The error(s).
payload: ErrorPayload,
},
/// Complete indicates that no more data will be sent for the given operation.
Complete {
/// The id of the operation that has completed.
id: String,
},
/// ConnectionKeepAlive is sent periodically after accepting a connection.
#[serde(rename = "ka")]
ConnectionKeepAlive,
}
#[cfg(test)]
mod test {
use super::*;
use juniper::DefaultScalarValue;
#[test]
fn test_serialization() {
type ServerMessage = super::ServerMessage<DefaultScalarValue>;
assert_eq!(
serde_json::to_string(&ServerMessage::ConnectionError {
payload: ConnectionErrorPayload {
message: "foo".to_string(),
},
})
.unwrap(),
r##"{"type":"connection_error","payload":{"message":"foo"}}"##,
);
assert_eq!(
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
r##"{"type":"connection_ack"}"##,
);
assert_eq!(
serde_json::to_string(&ServerMessage::Data {
id: "foo".to_string(),
payload: DataPayload {
data: Value::null(),
errors: vec![],
},
})
.unwrap(),
r##"{"type":"data","id":"foo","payload":{"data":null,"errors":[]}}"##,
);
assert_eq!(
serde_json::to_string(&ServerMessage::Error {
id: "foo".to_string(),
payload: GraphQLError::UnknownOperationName.into(),
})
.unwrap(),
r##"{"type":"error","id":"foo","payload":[{"message":"Unknown operation"}]}"##,
);
assert_eq!(
serde_json::to_string(&ServerMessage::Complete {
id: "foo".to_string(),
})
.unwrap(),
r##"{"type":"complete","id":"foo"}"##,
);
assert_eq!(
serde_json::to_string(&ServerMessage::ConnectionKeepAlive).unwrap(),
r##"{"type":"ka"}"##,
);
}
}

View file

@ -222,19 +222,25 @@ where
}
if filled_count == obj_len {
let mut errors = vec![];
filled_count = 0;
let new_vec = (0..obj_len).map(|_| None).collect::<Vec<_>>();
let ready_vec = std::mem::replace(&mut ready_vec, new_vec);
let ready_vec_iterator = ready_vec.into_iter().map(|el| {
let (name, val) = el.unwrap();
if let Ok(value) = val {
(name, value)
} else {
(name, Value::Null)
match val {
Ok(value) => (name, value),
Err(e) => {
errors.push(e);
(name, Value::Null)
}
}
});
let obj = Object::from_iter(ready_vec_iterator);
Poll::Ready(Some(ExecutionOutput::from_data(Value::Object(obj))))
Poll::Ready(Some(ExecutionOutput {
data: Value::Object(obj),
errors,
}))
} else {
Poll::Pending
}

View file

@ -9,7 +9,7 @@ repository = "https://github.com/graphql-rust/juniper"
edition = "2018"
[features]
subscriptions = ["juniper_subscriptions"]
subscriptions = ["juniper_graphql_ws"]
[dependencies]
bytes = "0.5"
@ -17,7 +17,7 @@ anyhow = "1.0"
thiserror = "1.0"
futures = "0.3.1"
juniper = { version = "0.14.2", path = "../juniper", default-features = false }
juniper_subscriptions = { path = "../juniper_subscriptions", optional = true }
juniper_graphql_ws = { path = "../juniper_graphql_ws", optional = true }
serde = { version = "1.0.75", features = ["derive"] }
serde_json = "1.0.24"
tokio = { version = "0.2", features = ["blocking", "rt-core"] }

View file

@ -393,224 +393,103 @@ fn playground_response(
/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
#[cfg(feature = "subscriptions")]
pub mod subscriptions {
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
use juniper::{
futures::{
future::{self, Either},
sink::SinkExt,
stream::StreamExt,
},
GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue,
};
use juniper_graphql_ws::{ArcSchema, ClientMessage, Connection, Init};
use std::{convert::Infallible, fmt, sync::Arc};
use anyhow::anyhow;
use futures::{channel::mpsc, Future, StreamExt as _, TryFutureExt as _, TryStreamExt as _};
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
use juniper_subscriptions::Coordinator;
use serde::{Deserialize, Serialize};
use warp::ws::Message;
struct Message(warp::ws::Message);
/// Listen to incoming messages and do one of the following:
/// - execute subscription and return values from stream
/// - stop stream and close ws connection
#[allow(dead_code)]
pub fn graphql_subscriptions<Query, Mutation, Subscription, CtxT, S>(
websocket: warp::ws::WebSocket,
coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, CtxT, S>>,
context: CtxT,
) -> impl Future<Output = Result<(), anyhow::Error>> + Send
where
Query: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
{
let (sink_tx, sink_rx) = websocket.split();
let (ws_tx, ws_rx) = mpsc::unbounded();
tokio::task::spawn(
ws_rx
.take_while(|v: &Option<_>| futures::future::ready(v.is_some()))
.map(|x| x.unwrap())
.forward(sink_tx),
);
impl<S: ScalarValue> std::convert::TryFrom<Message> for ClientMessage<S> {
type Error = serde_json::Error;
let context = Arc::new(context);
let got_close_signal = Arc::new(AtomicBool::new(false));
let got_close_signal2 = got_close_signal.clone();
struct SubscriptionState {
should_stop: AtomicBool,
fn try_from(msg: Message) -> serde_json::Result<Self> {
serde_json::from_slice(msg.0.as_bytes())
}
let subscription_states = HashMap::<String, Arc<SubscriptionState>>::new();
sink_rx
.map_err(move |e| {
got_close_signal2.store(true, Ordering::Relaxed);
anyhow!("Websocket error: {}", e)
})
.try_fold(subscription_states, move |mut subscription_states, msg| {
let coordinator = coordinator.clone();
let context = context.clone();
let got_close_signal = got_close_signal.clone();
let ws_tx = ws_tx.clone();
async move {
if msg.is_close() {
return Ok(subscription_states);
}
let msg = msg
.to_str()
.map_err(|_| anyhow!("Non-text messages are not accepted"))?;
let request: WsPayload<S> = serde_json::from_str(msg)
.map_err(|e| anyhow!("Invalid WsPayload: {}", e))?;
match request.type_name.as_str() {
"connection_init" => {}
"start" => {
if got_close_signal.load(Ordering::Relaxed) {
return Ok(subscription_states);
}
let request_id = request.id.clone().unwrap_or("1".to_owned());
if let Some(existing) = subscription_states.get(&request_id) {
existing.should_stop.store(true, Ordering::Relaxed);
}
let state = Arc::new(SubscriptionState {
should_stop: AtomicBool::new(false),
});
subscription_states.insert(request_id.clone(), state.clone());
let ws_tx = ws_tx.clone();
if let Some(ref payload) = request.payload {
if payload.query.is_none() {
return Err(anyhow!("Query not found"));
}
} else {
return Err(anyhow!("Payload not found"));
}
tokio::task::spawn(async move {
let payload = request.payload.unwrap();
let graphql_request = GraphQLRequest::<S>::new(
payload.query.unwrap(),
None,
payload.variables,
);
let values_stream = match coordinator
.subscribe(&graphql_request, &context)
.await
{
Ok(s) => s,
Err(err) => {
let _ =
ws_tx.unbounded_send(Some(Ok(Message::text(format!(
r#"{{"type":"error","id":"{}","payload":{}}}"#,
request_id,
serde_json::ser::to_string(&err).unwrap_or(
"Error deserializing GraphQLError".to_owned()
)
)))));
let close_message = format!(
r#"{{"type":"complete","id":"{}","payload":null}}"#,
request_id
);
let _ = ws_tx
.unbounded_send(Some(Ok(Message::text(close_message))));
// close channel
let _ = ws_tx.unbounded_send(None);
return;
}
};
values_stream
.take_while(move |response| {
let request_id = request_id.clone();
let should_stop = state.should_stop.load(Ordering::Relaxed)
|| got_close_signal.load(Ordering::Relaxed);
if !should_stop {
let mut response_text = serde_json::to_string(
&response,
)
.unwrap_or("Error deserializing response".to_owned());
response_text = format!(
r#"{{"type":"data","id":"{}","payload":{} }}"#,
request_id, response_text
);
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(
response_text,
))));
}
async move { !should_stop }
})
.for_each(|_| async {})
.await;
});
}
"stop" => {
let request_id = request.id.unwrap_or("1".to_owned());
if let Some(existing) = subscription_states.get(&request_id) {
existing.should_stop.store(true, Ordering::Relaxed);
subscription_states.remove(&request_id);
}
let close_message = format!(
r#"{{"type":"complete","id":"{}","payload":null}}"#,
request_id
);
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message))));
// close channel
let _ = ws_tx.unbounded_send(None);
}
_ => {}
}
Ok(subscription_states)
}
})
.map_ok(|_| ())
}
#[derive(Deserialize)]
#[serde(bound = "GraphQLPayload<S>: Deserialize<'de>")]
struct WsPayload<S>
/// Errors that can happen while serving a connection.
#[derive(Debug)]
pub enum Error {
/// Errors that can happen in Warp while serving a connection.
Warp(warp::Error),
/// Errors that can happen while serializing outgoing messages. Note that errors that occur
/// while deserializing internal messages are handled internally by the protocol.
Serde(serde_json::Error),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Warp(e) => write!(f, "warp error: {}", e),
Self::Serde(e) => write!(f, "serde error: {}", e),
}
}
}
impl std::error::Error for Error {}
impl From<warp::Error> for Error {
fn from(err: warp::Error) -> Self {
Self::Warp(err)
}
}
impl From<Infallible> for Error {
fn from(_err: Infallible) -> Self {
unreachable!()
}
}
/// Serves the graphql-ws protocol over a WebSocket connection.
///
/// The `init` argument is used to provide the context and additional configuration for
/// connections. This can be a `juniper_graphql_ws::ConnectionConfig` if the context and
/// configuration are already known, or it can be a closure that gets executed asynchronously
/// when the client sends the ConnectionInit message. Using a closure allows you to perform
/// authentication based on the parameters provided by the client.
pub async fn serve_graphql_ws<Query, Mutation, Subscription, CtxT, S, I>(
websocket: warp::ws::WebSocket,
root_node: Arc<RootNode<'static, Query, Mutation, Subscription, S>>,
init: I,
) -> Result<(), Error>
where
S: ScalarValue + Send + Sync,
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: Init<S, CtxT> + Send,
{
id: Option<String>,
#[serde(rename(deserialize = "type"))]
type_name: String,
payload: Option<GraphQLPayload<S>>,
}
let (ws_tx, ws_rx) = websocket.split();
let (s_tx, s_rx) = Connection::new(ArcSchema(root_node), init).split();
#[derive(Debug, Deserialize)]
#[serde(bound = "InputValue<S>: Deserialize<'de>")]
struct GraphQLPayload<S>
where
S: ScalarValue + Send + Sync,
{
variables: Option<InputValue<S>>,
extensions: Option<HashMap<String, String>>,
#[serde(rename(deserialize = "operationName"))]
operaton_name: Option<String>,
query: Option<String>,
}
let ws_rx = ws_rx.map(|r| r.map(|msg| Message(msg)));
let s_rx = s_rx.map(|msg| {
serde_json::to_string(&msg)
.map(|t| warp::ws::Message::text(t))
.map_err(|e| Error::Serde(e))
});
#[derive(Serialize)]
struct Output {
data: String,
variables: String,
match future::select(
ws_rx.forward(s_tx.sink_err_into()),
s_rx.forward(ws_tx.sink_err_into()),
)
.await
{
Either::Left((r, _)) => r.map_err(|e| e.into()),
Either::Right((r, _)) => r,
}
}
}