GraphQL-WS crate and Warp subscriptions update (#721)
* update pre-existing juniper_warp::subscriptions * initial draft * finish up, update example * polish + timing test * fix pre-existing bug * rebase updates * address comments * add release.toml * makefile and initial changelog * add new Cargo.toml to juniper/release.toml
This commit is contained in:
parent
dc309b83b7
commit
84c9720b53
15 changed files with 1696 additions and 238 deletions
|
@ -15,6 +15,7 @@ members = [
|
|||
"juniper_rocket",
|
||||
"juniper_rocket_async",
|
||||
"juniper_subscriptions",
|
||||
"juniper_graphql_ws",
|
||||
"juniper_warp",
|
||||
"juniper_actix",
|
||||
]
|
||||
|
|
|
@ -13,6 +13,6 @@ serde_json = "1.0"
|
|||
tokio = { version = "0.2", features = ["rt-core", "macros"] }
|
||||
warp = "0.2.1"
|
||||
|
||||
juniper = { git = "https://github.com/graphql-rust/juniper" }
|
||||
juniper_subscriptions = { git = "https://github.com/graphql-rust/juniper" }
|
||||
juniper_warp = { git = "https://github.com/graphql-rust/juniper", features = ["subscriptions"] }
|
||||
juniper = { path = "../../juniper" }
|
||||
juniper_graphql_ws = { path = "../../juniper_graphql_ws" }
|
||||
juniper_warp = { path = "../../juniper_warp", features = ["subscriptions"] }
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
|
||||
use std::{env, pin::Pin, sync::Arc, time::Duration};
|
||||
|
||||
use futures::{Future, FutureExt as _, Stream};
|
||||
use futures::{FutureExt as _, Stream};
|
||||
use juniper::{DefaultScalarValue, EmptyMutation, FieldError, RootNode};
|
||||
use juniper_subscriptions::Coordinator;
|
||||
use juniper_warp::{playground_filter, subscriptions::graphql_subscriptions};
|
||||
use juniper_graphql_ws::ConnectionConfig;
|
||||
use juniper_warp::{playground_filter, subscriptions::serve_graphql_ws};
|
||||
use warp::{http::Response, Filter};
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -151,30 +151,24 @@ async fn main() {
|
|||
let qm_state = warp::any().map(move || Context {});
|
||||
let qm_graphql_filter = juniper_warp::make_graphql_filter(qm_schema, qm_state.boxed());
|
||||
|
||||
let sub_state = warp::any().map(move || Context {});
|
||||
let coordinator = Arc::new(juniper_subscriptions::Coordinator::new(schema()));
|
||||
let root_node = Arc::new(schema());
|
||||
|
||||
log::info!("Listening on 127.0.0.1:8080");
|
||||
|
||||
let routes = (warp::path("subscriptions")
|
||||
.and(warp::ws())
|
||||
.and(sub_state.clone())
|
||||
.and(warp::any().map(move || Arc::clone(&coordinator)))
|
||||
.map(
|
||||
|ws: warp::ws::Ws,
|
||||
ctx: Context,
|
||||
coordinator: Arc<Coordinator<'static, _, _, _, _, _>>| {
|
||||
ws.on_upgrade(|websocket| -> Pin<Box<dyn Future<Output = ()> + Send>> {
|
||||
graphql_subscriptions(websocket, coordinator, ctx)
|
||||
.map(|r| {
|
||||
if let Err(e) = r {
|
||||
println!("Websocket error: {}", e);
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
})
|
||||
},
|
||||
))
|
||||
.map(move |ws: warp::ws::Ws| {
|
||||
let root_node = root_node.clone();
|
||||
ws.on_upgrade(move |websocket| async move {
|
||||
serve_graphql_ws(websocket, root_node, ConnectionConfig::new(Context {}))
|
||||
.map(|r| {
|
||||
if let Err(e) = r {
|
||||
println!("Websocket error: {}", e);
|
||||
}
|
||||
})
|
||||
.await
|
||||
})
|
||||
}))
|
||||
.map(|reply| {
|
||||
// TODO#584: remove this workaround
|
||||
warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws")
|
||||
|
|
|
@ -30,6 +30,8 @@ pre-release-replacements = [
|
|||
{file="../juniper_warp/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""},
|
||||
# Subscriptions
|
||||
{file="../juniper_subscriptions/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
|
||||
# GraphQL-WS
|
||||
{file="../juniper_graphql_ws/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
|
||||
# Actix-Web
|
||||
{file="../juniper_actix/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
|
||||
{file="../juniper_actix/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""},
|
||||
|
|
3
juniper_graphql_ws/CHANGELOG.md
Normal file
3
juniper_graphql_ws/CHANGELOG.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# master
|
||||
|
||||
- Initial Release
|
19
juniper_graphql_ws/Cargo.toml
Normal file
19
juniper_graphql_ws/Cargo.toml
Normal file
|
@ -0,0 +1,19 @@
|
|||
[package]
|
||||
name = "juniper_graphql_ws"
|
||||
version = "0.1.0"
|
||||
authors = ["Christopher Brown <ccbrown112@gmail.com>"]
|
||||
license = "BSD-2-Clause"
|
||||
description = "Graphql-ws protocol implementation for Juniper"
|
||||
documentation = "https://docs.rs/juniper_graphql_ws"
|
||||
repository = "https://github.com/graphql-rust/juniper"
|
||||
keywords = ["graphql-ws", "juniper", "graphql", "apollo"]
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
juniper = { version = "0.14.2", path = "../juniper", default-features = false }
|
||||
juniper_subscriptions = { path = "../juniper_subscriptions" }
|
||||
serde = { version = "1.0.8", features = ["derive"] }
|
||||
tokio = { version = "0.2", features = ["macros", "rt-core", "time"] }
|
||||
|
||||
[dev-dependencies]
|
||||
serde_json = { version = "1.0.2" }
|
20
juniper_graphql_ws/Makefile.toml
Normal file
20
juniper_graphql_ws/Makefile.toml
Normal file
|
@ -0,0 +1,20 @@
|
|||
[env]
|
||||
CARGO_MAKE_CARGO_ALL_FEATURES = ""
|
||||
|
||||
[tasks.build-verbose]
|
||||
condition = { rust_version = { min = "1.29.0" } }
|
||||
|
||||
[tasks.build-verbose.windows]
|
||||
condition = { rust_version = { min = "1.29.0" }, env = { "TARGET" = "x86_64-pc-windows-msvc" } }
|
||||
|
||||
[tasks.test-verbose]
|
||||
condition = { rust_version = { min = "1.29.0" } }
|
||||
|
||||
[tasks.test-verbose.windows]
|
||||
condition = { rust_version = { min = "1.29.0" }, env = { "TARGET" = "x86_64-pc-windows-msvc" } }
|
||||
|
||||
[tasks.ci-coverage-flow]
|
||||
condition = { rust_version = { min = "1.29.0" } }
|
||||
|
||||
[tasks.ci-coverage-flow.windows]
|
||||
disabled = true
|
8
juniper_graphql_ws/release.toml
Normal file
8
juniper_graphql_ws/release.toml
Normal file
|
@ -0,0 +1,8 @@
|
|||
no-dev-version = true
|
||||
pre-release-commit-message = "Release {{crate_name}} {{version}}"
|
||||
pro-release-commit-message = "Bump {{crate_name}} version to {{next_version}}"
|
||||
tag-message = "Release {{crate_name}} {{version}}"
|
||||
upload-doc = false
|
||||
pre-release-replacements = [
|
||||
{file="src/lib.rs", search="docs.rs/juniper_graphql_ws/[a-z0-9\\.-]+", replace="docs.rs/juniper_graphql_ws/{{version}}"},
|
||||
]
|
131
juniper_graphql_ws/src/client_message.rs
Normal file
131
juniper_graphql_ws/src/client_message.rs
Normal file
|
@ -0,0 +1,131 @@
|
|||
use juniper::{ScalarValue, Variables};
|
||||
|
||||
/// The payload for a client's "start" message. This triggers execution of a query, mutation, or
|
||||
/// subscription.
|
||||
#[derive(Debug, Deserialize, PartialEq)]
|
||||
#[serde(bound(deserialize = "S: ScalarValue"))]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct StartPayload<S: ScalarValue> {
|
||||
/// The document body.
|
||||
pub query: String,
|
||||
|
||||
/// The optional variables.
|
||||
#[serde(default)]
|
||||
pub variables: Variables<S>,
|
||||
|
||||
/// The optional operation name (required if the document contains multiple operations).
|
||||
pub operation_name: Option<String>,
|
||||
}
|
||||
|
||||
/// ClientMessage defines the message types that clients can send.
|
||||
#[derive(Debug, Deserialize, PartialEq)]
|
||||
#[serde(bound(deserialize = "S: ScalarValue"))]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ClientMessage<S: ScalarValue> {
|
||||
/// ConnectionInit is sent by the client upon connecting.
|
||||
ConnectionInit {
|
||||
/// Optional parameters of any type sent from the client. These are often used for
|
||||
/// authentication.
|
||||
#[serde(default)]
|
||||
payload: Variables<S>,
|
||||
},
|
||||
/// Start messages are used to execute a GraphQL operation.
|
||||
Start {
|
||||
/// The id of the operation. This can be anything, but must be unique. If there are other
|
||||
/// in-flight operations with the same id, the message will be ignored or cause an error.
|
||||
id: String,
|
||||
|
||||
/// The query, variables, and operation name.
|
||||
payload: StartPayload<S>,
|
||||
},
|
||||
/// Stop messages are used to unsubscribe from a subscription.
|
||||
Stop {
|
||||
/// The id of the operation to stop.
|
||||
id: String,
|
||||
},
|
||||
/// ConnectionTerminate is used to terminate the connection.
|
||||
ConnectionTerminate,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use juniper::{DefaultScalarValue, InputValue};
|
||||
|
||||
#[test]
|
||||
fn test_deserialization() {
|
||||
type ClientMessage = super::ClientMessage<DefaultScalarValue>;
|
||||
|
||||
assert_eq!(
|
||||
ClientMessage::ConnectionInit {
|
||||
payload: [("foo".to_string(), InputValue::scalar("bar"))]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect(),
|
||||
},
|
||||
serde_json::from_str(r##"{"type": "connection_init", "payload": {"foo": "bar"}}"##)
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
ClientMessage::ConnectionInit {
|
||||
payload: Variables::default(),
|
||||
},
|
||||
serde_json::from_str(r##"{"type": "connection_init"}"##).unwrap(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
ClientMessage::Start {
|
||||
id: "foo".to_string(),
|
||||
payload: StartPayload {
|
||||
query: "query MyQuery { __typename }".to_string(),
|
||||
variables: [("foo".to_string(), InputValue::scalar("bar"))]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect(),
|
||||
operation_name: Some("MyQuery".to_string()),
|
||||
},
|
||||
},
|
||||
serde_json::from_str(
|
||||
r##"{"type": "start", "id": "foo", "payload": {
|
||||
"query": "query MyQuery { __typename }",
|
||||
"variables": {
|
||||
"foo": "bar"
|
||||
},
|
||||
"operationName": "MyQuery"
|
||||
}}"##
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
ClientMessage::Start {
|
||||
id: "foo".to_string(),
|
||||
payload: StartPayload {
|
||||
query: "query MyQuery { __typename }".to_string(),
|
||||
variables: Variables::default(),
|
||||
operation_name: None,
|
||||
},
|
||||
},
|
||||
serde_json::from_str(
|
||||
r##"{"type": "start", "id": "foo", "payload": {
|
||||
"query": "query MyQuery { __typename }"
|
||||
}}"##
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
ClientMessage::Stop {
|
||||
id: "foo".to_string()
|
||||
},
|
||||
serde_json::from_str(r##"{"type": "stop", "id": "foo"}"##).unwrap(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
ClientMessage::ConnectionTerminate,
|
||||
serde_json::from_str(r##"{"type": "connection_terminate"}"##).unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
1073
juniper_graphql_ws/src/lib.rs
Normal file
1073
juniper_graphql_ws/src/lib.rs
Normal file
File diff suppressed because it is too large
Load diff
131
juniper_graphql_ws/src/schema.rs
Normal file
131
juniper_graphql_ws/src/schema.rs
Normal file
|
@ -0,0 +1,131 @@
|
|||
use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Schema defines the requirements for schemas that can be used for operations. Typically this is
|
||||
/// just an `Arc<RootNode<...>>` and you should not have to implement it yourself.
|
||||
pub trait Schema: Unpin + Clone + Send + Sync + 'static {
|
||||
/// The context type.
|
||||
type Context: Unpin + Send + Sync;
|
||||
|
||||
/// The scalar value type.
|
||||
type ScalarValue: ScalarValue + Send + Sync;
|
||||
|
||||
/// The query type info.
|
||||
type QueryTypeInfo: Send + Sync;
|
||||
|
||||
/// The query type.
|
||||
type Query: GraphQLTypeAsync<Self::ScalarValue, Context = Self::Context, TypeInfo = Self::QueryTypeInfo>
|
||||
+ Send;
|
||||
|
||||
/// The mutation type info.
|
||||
type MutationTypeInfo: Send + Sync;
|
||||
|
||||
/// The mutation type.
|
||||
type Mutation: GraphQLTypeAsync<
|
||||
Self::ScalarValue,
|
||||
Context = Self::Context,
|
||||
TypeInfo = Self::MutationTypeInfo,
|
||||
> + Send;
|
||||
|
||||
/// The subscription type info.
|
||||
type SubscriptionTypeInfo: Send + Sync;
|
||||
|
||||
/// The subscription type.
|
||||
type Subscription: GraphQLSubscriptionType<
|
||||
Self::ScalarValue,
|
||||
Context = Self::Context,
|
||||
TypeInfo = Self::SubscriptionTypeInfo,
|
||||
> + Send;
|
||||
|
||||
/// Returns the root node for the schema.
|
||||
fn root_node(
|
||||
&self,
|
||||
) -> &RootNode<'static, Self::Query, Self::Mutation, Self::Subscription, Self::ScalarValue>;
|
||||
}
|
||||
|
||||
/// This exists as a work-around for this issue: https://github.com/rust-lang/rust/issues/64552
|
||||
///
|
||||
/// It can be used in generators where using Arc directly would result in an error.
|
||||
// TODO: Remove this once that issue is resolved.
|
||||
#[doc(hidden)]
|
||||
pub struct ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>(
|
||||
pub Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>,
|
||||
)
|
||||
where
|
||||
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
QueryT::TypeInfo: Send + Sync,
|
||||
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
MutationT::TypeInfo: Send + Sync,
|
||||
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
||||
SubscriptionT::TypeInfo: Send + Sync,
|
||||
CtxT: Unpin + Send + Sync,
|
||||
S: ScalarValue + Send + Sync + 'static;
|
||||
|
||||
impl<QueryT, MutationT, SubscriptionT, CtxT, S> Clone
|
||||
for ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>
|
||||
where
|
||||
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
QueryT::TypeInfo: Send + Sync,
|
||||
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
MutationT::TypeInfo: Send + Sync,
|
||||
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
||||
SubscriptionT::TypeInfo: Send + Sync,
|
||||
CtxT: Unpin + Send + Sync,
|
||||
S: ScalarValue + Send + Sync + 'static,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<QueryT, MutationT, SubscriptionT, CtxT, S> Schema
|
||||
for ArcSchema<QueryT, MutationT, SubscriptionT, CtxT, S>
|
||||
where
|
||||
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
QueryT::TypeInfo: Send + Sync,
|
||||
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
MutationT::TypeInfo: Send + Sync,
|
||||
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
||||
SubscriptionT::TypeInfo: Send + Sync,
|
||||
CtxT: Unpin + Send + Sync + 'static,
|
||||
S: ScalarValue + Send + Sync + 'static,
|
||||
{
|
||||
type Context = CtxT;
|
||||
type ScalarValue = S;
|
||||
type QueryTypeInfo = QueryT::TypeInfo;
|
||||
type Query = QueryT;
|
||||
type MutationTypeInfo = MutationT::TypeInfo;
|
||||
type Mutation = MutationT;
|
||||
type SubscriptionTypeInfo = SubscriptionT::TypeInfo;
|
||||
type Subscription = SubscriptionT;
|
||||
|
||||
fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<QueryT, MutationT, SubscriptionT, CtxT, S> Schema
|
||||
for Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>
|
||||
where
|
||||
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
QueryT::TypeInfo: Send + Sync,
|
||||
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
MutationT::TypeInfo: Send + Sync,
|
||||
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
||||
SubscriptionT::TypeInfo: Send + Sync,
|
||||
CtxT: Unpin + Send + Sync,
|
||||
S: ScalarValue + Send + Sync + 'static,
|
||||
{
|
||||
type Context = CtxT;
|
||||
type ScalarValue = S;
|
||||
type QueryTypeInfo = QueryT::TypeInfo;
|
||||
type Query = QueryT;
|
||||
type MutationTypeInfo = MutationT::TypeInfo;
|
||||
type Mutation = MutationT;
|
||||
type SubscriptionTypeInfo = SubscriptionT::TypeInfo;
|
||||
type Subscription = SubscriptionT;
|
||||
|
||||
fn root_node(&self) -> &RootNode<'static, QueryT, MutationT, SubscriptionT, S> {
|
||||
self
|
||||
}
|
||||
}
|
191
juniper_graphql_ws/src/server_message.rs
Normal file
191
juniper_graphql_ws/src/server_message.rs
Normal file
|
@ -0,0 +1,191 @@
|
|||
use juniper::{ExecutionError, GraphQLError, ScalarValue, Value};
|
||||
use serde::{Serialize, Serializer};
|
||||
use std::{any::Any, fmt, marker::PhantomPinned};
|
||||
|
||||
/// The payload for errors that are not associated with a GraphQL operation.
|
||||
#[derive(Debug, Serialize, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ConnectionErrorPayload {
|
||||
/// The error message.
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Sent after execution of an operation. For queries and mutations, this is sent to the client
|
||||
/// once. For subscriptions, this is sent for every event in the event stream.
|
||||
#[derive(Debug, Serialize, PartialEq)]
|
||||
#[serde(bound(serialize = "S: ScalarValue"))]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DataPayload<S> {
|
||||
/// The result data.
|
||||
pub data: Value<S>,
|
||||
|
||||
/// The errors that have occurred during execution. Note that parse and validation errors are
|
||||
/// not included here. They are sent via Error messages.
|
||||
pub errors: Vec<ExecutionError<S>>,
|
||||
}
|
||||
|
||||
/// A payload for errors that can happen before execution. Errors that happen during execution are
|
||||
/// instead sent to the client via `DataPayload`. `ErrorPayload` is a wrapper for an owned
|
||||
/// `GraphQLError`.
|
||||
// XXX: Think carefully before deriving traits. This is self-referential (error references
|
||||
// _execution_params).
|
||||
pub struct ErrorPayload {
|
||||
_execution_params: Option<Box<dyn Any + Send>>,
|
||||
error: GraphQLError<'static>,
|
||||
_marker: PhantomPinned,
|
||||
}
|
||||
|
||||
impl ErrorPayload {
|
||||
/// For this to be okay, the caller must guarantee that the error can only reference data from
|
||||
/// execution_params and that execution_params has not been modified or moved.
|
||||
pub(crate) unsafe fn new_unchecked<'a>(
|
||||
execution_params: Box<dyn Any + Send>,
|
||||
error: GraphQLError<'a>,
|
||||
) -> Self {
|
||||
Self {
|
||||
_execution_params: Some(execution_params),
|
||||
error: std::mem::transmute(error),
|
||||
_marker: PhantomPinned,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the contained GraphQLError.
|
||||
pub fn graphql_error<'a>(&'a self) -> &GraphQLError<'a> {
|
||||
&self.error
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for ErrorPayload {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.error.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for ErrorPayload {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.error.eq(&other.error)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for ErrorPayload {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
self.error.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GraphQLError<'static>> for ErrorPayload {
|
||||
fn from(error: GraphQLError<'static>) -> Self {
|
||||
Self {
|
||||
_execution_params: None,
|
||||
error,
|
||||
_marker: PhantomPinned,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// ServerMessage defines the message types that servers can send.
|
||||
#[derive(Debug, Serialize, PartialEq)]
|
||||
#[serde(bound(serialize = "S: ScalarValue"))]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ServerMessage<S: ScalarValue> {
|
||||
/// ConnectionError is used for errors that are not associated with a GraphQL operation. For
|
||||
/// example, this will be used when:
|
||||
///
|
||||
/// * The server is unable to parse a client's message.
|
||||
/// * The client's initialization parameters are rejected.
|
||||
ConnectionError {
|
||||
/// The error that occurred.
|
||||
payload: ConnectionErrorPayload,
|
||||
},
|
||||
/// ConnectionAck is sent in response to a client's ConnectionInit message if the server accepted a
|
||||
/// connection.
|
||||
ConnectionAck,
|
||||
/// Data contains the result of a query, mutation, or subscription event.
|
||||
Data {
|
||||
/// The id of the operation that the data is for.
|
||||
id: String,
|
||||
|
||||
/// The data and errors that occurred during execution.
|
||||
payload: DataPayload<S>,
|
||||
},
|
||||
/// Error contains an error that occurs before execution, such as validation errors.
|
||||
Error {
|
||||
/// The id of the operation that triggered this error.
|
||||
id: String,
|
||||
|
||||
/// The error(s).
|
||||
payload: ErrorPayload,
|
||||
},
|
||||
/// Complete indicates that no more data will be sent for the given operation.
|
||||
Complete {
|
||||
/// The id of the operation that has completed.
|
||||
id: String,
|
||||
},
|
||||
/// ConnectionKeepAlive is sent periodically after accepting a connection.
|
||||
#[serde(rename = "ka")]
|
||||
ConnectionKeepAlive,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use juniper::DefaultScalarValue;
|
||||
|
||||
#[test]
|
||||
fn test_serialization() {
|
||||
type ServerMessage = super::ServerMessage<DefaultScalarValue>;
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_string(&ServerMessage::ConnectionError {
|
||||
payload: ConnectionErrorPayload {
|
||||
message: "foo".to_string(),
|
||||
},
|
||||
})
|
||||
.unwrap(),
|
||||
r##"{"type":"connection_error","payload":{"message":"foo"}}"##,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
|
||||
r##"{"type":"connection_ack"}"##,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_string(&ServerMessage::Data {
|
||||
id: "foo".to_string(),
|
||||
payload: DataPayload {
|
||||
data: Value::null(),
|
||||
errors: vec![],
|
||||
},
|
||||
})
|
||||
.unwrap(),
|
||||
r##"{"type":"data","id":"foo","payload":{"data":null,"errors":[]}}"##,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_string(&ServerMessage::Error {
|
||||
id: "foo".to_string(),
|
||||
payload: GraphQLError::UnknownOperationName.into(),
|
||||
})
|
||||
.unwrap(),
|
||||
r##"{"type":"error","id":"foo","payload":[{"message":"Unknown operation"}]}"##,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_string(&ServerMessage::Complete {
|
||||
id: "foo".to_string(),
|
||||
})
|
||||
.unwrap(),
|
||||
r##"{"type":"complete","id":"foo"}"##,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_string(&ServerMessage::ConnectionKeepAlive).unwrap(),
|
||||
r##"{"type":"ka"}"##,
|
||||
);
|
||||
}
|
||||
}
|
|
@ -222,19 +222,25 @@ where
|
|||
}
|
||||
|
||||
if filled_count == obj_len {
|
||||
let mut errors = vec![];
|
||||
filled_count = 0;
|
||||
let new_vec = (0..obj_len).map(|_| None).collect::<Vec<_>>();
|
||||
let ready_vec = std::mem::replace(&mut ready_vec, new_vec);
|
||||
let ready_vec_iterator = ready_vec.into_iter().map(|el| {
|
||||
let (name, val) = el.unwrap();
|
||||
if let Ok(value) = val {
|
||||
(name, value)
|
||||
} else {
|
||||
(name, Value::Null)
|
||||
match val {
|
||||
Ok(value) => (name, value),
|
||||
Err(e) => {
|
||||
errors.push(e);
|
||||
(name, Value::Null)
|
||||
}
|
||||
}
|
||||
});
|
||||
let obj = Object::from_iter(ready_vec_iterator);
|
||||
Poll::Ready(Some(ExecutionOutput::from_data(Value::Object(obj))))
|
||||
Poll::Ready(Some(ExecutionOutput {
|
||||
data: Value::Object(obj),
|
||||
errors,
|
||||
}))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ repository = "https://github.com/graphql-rust/juniper"
|
|||
edition = "2018"
|
||||
|
||||
[features]
|
||||
subscriptions = ["juniper_subscriptions"]
|
||||
subscriptions = ["juniper_graphql_ws"]
|
||||
|
||||
[dependencies]
|
||||
bytes = "0.5"
|
||||
|
@ -17,7 +17,7 @@ anyhow = "1.0"
|
|||
thiserror = "1.0"
|
||||
futures = "0.3.1"
|
||||
juniper = { version = "0.14.2", path = "../juniper", default-features = false }
|
||||
juniper_subscriptions = { path = "../juniper_subscriptions", optional = true }
|
||||
juniper_graphql_ws = { path = "../juniper_graphql_ws", optional = true }
|
||||
serde = { version = "1.0.75", features = ["derive"] }
|
||||
serde_json = "1.0.24"
|
||||
tokio = { version = "0.2", features = ["blocking", "rt-core"] }
|
||||
|
|
|
@ -393,224 +393,103 @@ fn playground_response(
|
|||
/// [1]: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
|
||||
#[cfg(feature = "subscriptions")]
|
||||
pub mod subscriptions {
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
use juniper::{
|
||||
futures::{
|
||||
future::{self, Either},
|
||||
sink::SinkExt,
|
||||
stream::StreamExt,
|
||||
},
|
||||
GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue,
|
||||
};
|
||||
use juniper_graphql_ws::{ArcSchema, ClientMessage, Connection, Init};
|
||||
use std::{convert::Infallible, fmt, sync::Arc};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use futures::{channel::mpsc, Future, StreamExt as _, TryFutureExt as _, TryStreamExt as _};
|
||||
use juniper::{http::GraphQLRequest, InputValue, ScalarValue, SubscriptionCoordinator as _};
|
||||
use juniper_subscriptions::Coordinator;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use warp::ws::Message;
|
||||
struct Message(warp::ws::Message);
|
||||
|
||||
/// Listen to incoming messages and do one of the following:
|
||||
/// - execute subscription and return values from stream
|
||||
/// - stop stream and close ws connection
|
||||
#[allow(dead_code)]
|
||||
pub fn graphql_subscriptions<Query, Mutation, Subscription, CtxT, S>(
|
||||
websocket: warp::ws::WebSocket,
|
||||
coordinator: Arc<Coordinator<'static, Query, Mutation, Subscription, CtxT, S>>,
|
||||
context: CtxT,
|
||||
) -> impl Future<Output = Result<(), anyhow::Error>> + Send
|
||||
where
|
||||
Query: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Query::TypeInfo: Send + Sync,
|
||||
Mutation: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Mutation::TypeInfo: Send + Sync,
|
||||
Subscription: juniper::GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
||||
Subscription::TypeInfo: Send + Sync,
|
||||
CtxT: Send + Sync + 'static,
|
||||
S: ScalarValue + Send + Sync + 'static,
|
||||
{
|
||||
let (sink_tx, sink_rx) = websocket.split();
|
||||
let (ws_tx, ws_rx) = mpsc::unbounded();
|
||||
tokio::task::spawn(
|
||||
ws_rx
|
||||
.take_while(|v: &Option<_>| futures::future::ready(v.is_some()))
|
||||
.map(|x| x.unwrap())
|
||||
.forward(sink_tx),
|
||||
);
|
||||
impl<S: ScalarValue> std::convert::TryFrom<Message> for ClientMessage<S> {
|
||||
type Error = serde_json::Error;
|
||||
|
||||
let context = Arc::new(context);
|
||||
let got_close_signal = Arc::new(AtomicBool::new(false));
|
||||
let got_close_signal2 = got_close_signal.clone();
|
||||
|
||||
struct SubscriptionState {
|
||||
should_stop: AtomicBool,
|
||||
fn try_from(msg: Message) -> serde_json::Result<Self> {
|
||||
serde_json::from_slice(msg.0.as_bytes())
|
||||
}
|
||||
let subscription_states = HashMap::<String, Arc<SubscriptionState>>::new();
|
||||
|
||||
sink_rx
|
||||
.map_err(move |e| {
|
||||
got_close_signal2.store(true, Ordering::Relaxed);
|
||||
anyhow!("Websocket error: {}", e)
|
||||
})
|
||||
.try_fold(subscription_states, move |mut subscription_states, msg| {
|
||||
let coordinator = coordinator.clone();
|
||||
let context = context.clone();
|
||||
let got_close_signal = got_close_signal.clone();
|
||||
let ws_tx = ws_tx.clone();
|
||||
|
||||
async move {
|
||||
if msg.is_close() {
|
||||
return Ok(subscription_states);
|
||||
}
|
||||
|
||||
let msg = msg
|
||||
.to_str()
|
||||
.map_err(|_| anyhow!("Non-text messages are not accepted"))?;
|
||||
let request: WsPayload<S> = serde_json::from_str(msg)
|
||||
.map_err(|e| anyhow!("Invalid WsPayload: {}", e))?;
|
||||
|
||||
match request.type_name.as_str() {
|
||||
"connection_init" => {}
|
||||
"start" => {
|
||||
if got_close_signal.load(Ordering::Relaxed) {
|
||||
return Ok(subscription_states);
|
||||
}
|
||||
|
||||
let request_id = request.id.clone().unwrap_or("1".to_owned());
|
||||
|
||||
if let Some(existing) = subscription_states.get(&request_id) {
|
||||
existing.should_stop.store(true, Ordering::Relaxed);
|
||||
}
|
||||
let state = Arc::new(SubscriptionState {
|
||||
should_stop: AtomicBool::new(false),
|
||||
});
|
||||
subscription_states.insert(request_id.clone(), state.clone());
|
||||
|
||||
let ws_tx = ws_tx.clone();
|
||||
|
||||
if let Some(ref payload) = request.payload {
|
||||
if payload.query.is_none() {
|
||||
return Err(anyhow!("Query not found"));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Payload not found"));
|
||||
}
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let payload = request.payload.unwrap();
|
||||
|
||||
let graphql_request = GraphQLRequest::<S>::new(
|
||||
payload.query.unwrap(),
|
||||
None,
|
||||
payload.variables,
|
||||
);
|
||||
|
||||
let values_stream = match coordinator
|
||||
.subscribe(&graphql_request, &context)
|
||||
.await
|
||||
{
|
||||
Ok(s) => s,
|
||||
Err(err) => {
|
||||
let _ =
|
||||
ws_tx.unbounded_send(Some(Ok(Message::text(format!(
|
||||
r#"{{"type":"error","id":"{}","payload":{}}}"#,
|
||||
request_id,
|
||||
serde_json::ser::to_string(&err).unwrap_or(
|
||||
"Error deserializing GraphQLError".to_owned()
|
||||
)
|
||||
)))));
|
||||
|
||||
let close_message = format!(
|
||||
r#"{{"type":"complete","id":"{}","payload":null}}"#,
|
||||
request_id
|
||||
);
|
||||
let _ = ws_tx
|
||||
.unbounded_send(Some(Ok(Message::text(close_message))));
|
||||
// close channel
|
||||
let _ = ws_tx.unbounded_send(None);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
values_stream
|
||||
.take_while(move |response| {
|
||||
let request_id = request_id.clone();
|
||||
let should_stop = state.should_stop.load(Ordering::Relaxed)
|
||||
|| got_close_signal.load(Ordering::Relaxed);
|
||||
if !should_stop {
|
||||
let mut response_text = serde_json::to_string(
|
||||
&response,
|
||||
)
|
||||
.unwrap_or("Error deserializing response".to_owned());
|
||||
|
||||
response_text = format!(
|
||||
r#"{{"type":"data","id":"{}","payload":{} }}"#,
|
||||
request_id, response_text
|
||||
);
|
||||
|
||||
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(
|
||||
response_text,
|
||||
))));
|
||||
}
|
||||
|
||||
async move { !should_stop }
|
||||
})
|
||||
.for_each(|_| async {})
|
||||
.await;
|
||||
});
|
||||
}
|
||||
"stop" => {
|
||||
let request_id = request.id.unwrap_or("1".to_owned());
|
||||
if let Some(existing) = subscription_states.get(&request_id) {
|
||||
existing.should_stop.store(true, Ordering::Relaxed);
|
||||
subscription_states.remove(&request_id);
|
||||
}
|
||||
|
||||
let close_message = format!(
|
||||
r#"{{"type":"complete","id":"{}","payload":null}}"#,
|
||||
request_id
|
||||
);
|
||||
let _ = ws_tx.unbounded_send(Some(Ok(Message::text(close_message))));
|
||||
|
||||
// close channel
|
||||
let _ = ws_tx.unbounded_send(None);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(subscription_states)
|
||||
}
|
||||
})
|
||||
.map_ok(|_| ())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(bound = "GraphQLPayload<S>: Deserialize<'de>")]
|
||||
struct WsPayload<S>
|
||||
/// Errors that can happen while serving a connection.
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
/// Errors that can happen in Warp while serving a connection.
|
||||
Warp(warp::Error),
|
||||
|
||||
/// Errors that can happen while serializing outgoing messages. Note that errors that occur
|
||||
/// while deserializing internal messages are handled internally by the protocol.
|
||||
Serde(serde_json::Error),
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Warp(e) => write!(f, "warp error: {}", e),
|
||||
Self::Serde(e) => write!(f, "serde error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl From<warp::Error> for Error {
|
||||
fn from(err: warp::Error) -> Self {
|
||||
Self::Warp(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Infallible> for Error {
|
||||
fn from(_err: Infallible) -> Self {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
/// Serves the graphql-ws protocol over a WebSocket connection.
|
||||
///
|
||||
/// The `init` argument is used to provide the context and additional configuration for
|
||||
/// connections. This can be a `juniper_graphql_ws::ConnectionConfig` if the context and
|
||||
/// configuration are already known, or it can be a closure that gets executed asynchronously
|
||||
/// when the client sends the ConnectionInit message. Using a closure allows you to perform
|
||||
/// authentication based on the parameters provided by the client.
|
||||
pub async fn serve_graphql_ws<Query, Mutation, Subscription, CtxT, S, I>(
|
||||
websocket: warp::ws::WebSocket,
|
||||
root_node: Arc<RootNode<'static, Query, Mutation, Subscription, S>>,
|
||||
init: I,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: ScalarValue + Send + Sync,
|
||||
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Query::TypeInfo: Send + Sync,
|
||||
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
|
||||
Mutation::TypeInfo: Send + Sync,
|
||||
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
|
||||
Subscription::TypeInfo: Send + Sync,
|
||||
CtxT: Unpin + Send + Sync + 'static,
|
||||
S: ScalarValue + Send + Sync + 'static,
|
||||
I: Init<S, CtxT> + Send,
|
||||
{
|
||||
id: Option<String>,
|
||||
#[serde(rename(deserialize = "type"))]
|
||||
type_name: String,
|
||||
payload: Option<GraphQLPayload<S>>,
|
||||
}
|
||||
let (ws_tx, ws_rx) = websocket.split();
|
||||
let (s_tx, s_rx) = Connection::new(ArcSchema(root_node), init).split();
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(bound = "InputValue<S>: Deserialize<'de>")]
|
||||
struct GraphQLPayload<S>
|
||||
where
|
||||
S: ScalarValue + Send + Sync,
|
||||
{
|
||||
variables: Option<InputValue<S>>,
|
||||
extensions: Option<HashMap<String, String>>,
|
||||
#[serde(rename(deserialize = "operationName"))]
|
||||
operaton_name: Option<String>,
|
||||
query: Option<String>,
|
||||
}
|
||||
let ws_rx = ws_rx.map(|r| r.map(|msg| Message(msg)));
|
||||
let s_rx = s_rx.map(|msg| {
|
||||
serde_json::to_string(&msg)
|
||||
.map(|t| warp::ws::Message::text(t))
|
||||
.map_err(|e| Error::Serde(e))
|
||||
});
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Output {
|
||||
data: String,
|
||||
variables: String,
|
||||
match future::select(
|
||||
ws_rx.forward(s_tx.sink_err_into()),
|
||||
s_rx.forward(ws_tx.sink_err_into()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Either::Left((r, _)) => r.map_err(|e| e.into()),
|
||||
Either::Right((r, _)) => r,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue