Use actix-ws for juniper_actix subscriptions (#1197)

- fix panicking issues in `graphql-transport-ws` protocol implementation
- rework `graphql-ws` integration tests in `juniper::http`
- add `graphql-transport-ws` integration tests in `juniper::http`
This commit is contained in:
Kai Ren 2023-10-24 19:59:36 +02:00 committed by GitHub
parent 828d059b1b
commit d11e351a49
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 470 additions and 389 deletions

View file

@ -360,8 +360,11 @@ impl<S: ScalarValue> GraphQLBatchResponse<S> {
#[cfg(feature = "expose-test-schema")] #[cfg(feature = "expose-test-schema")]
#[allow(missing_docs)] #[allow(missing_docs)]
pub mod tests { pub mod tests {
use std::time::Duration;
use serde_json::Value as Json;
use crate::LocalBoxFuture; use crate::LocalBoxFuture;
use serde_json::{self, Value as Json};
/// Normalized response content we expect to get back from /// Normalized response content we expect to get back from
/// the http framework integration we are testing. /// the http framework integration we are testing.
@ -598,162 +601,281 @@ pub mod tests {
) -> LocalBoxFuture<Result<(), anyhow::Error>>; ) -> LocalBoxFuture<Result<(), anyhow::Error>>;
} }
/// WebSocket framework integration message /// WebSocket framework integration message.
pub enum WsIntegrationMessage { pub enum WsIntegrationMessage {
/// Send message through the WebSocket /// Send a message through a WebSocket.
/// Takes a message as a String Send(Json),
Send(String),
/// Expect message to come through the WebSocket /// Expects a message to come through a WebSocket, with the specified timeout.
/// Takes expected message as a String and a timeout in milliseconds Expect(Json, Duration),
Expect(String, u64),
} }
/// Default value in milliseconds for how long to wait for an incoming message /// Default value in milliseconds for how long to wait for an incoming WebSocket message.
pub const WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT: u64 = 100; pub const WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT: Duration = Duration::from_millis(100);
#[allow(missing_docs)] /// Integration tests for the [legacy `graphql-ws` GraphQL over WebSocket Protocol][old].
pub async fn run_ws_test_suite<T: WsIntegration>(integration: &T) { ///
println!("Running WebSocket Test suite for integration"); /// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md
pub mod graphql_ws {
use serde_json::json;
println!(" - test_ws_simple_subscription"); use super::{WsIntegration, WsIntegrationMessage, WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT};
test_ws_simple_subscription(integration).await;
println!(" - test_ws_invalid_json"); #[allow(missing_docs)]
test_ws_invalid_json(integration).await; pub async fn run_test_suite<T: WsIntegration>(integration: &T) {
println!("Running `graphql-ws` test suite for integration");
println!(" - test_ws_invalid_query"); println!(" - graphql_ws::test_simple_subscription");
test_ws_invalid_query(integration).await; test_simple_subscription(integration).await;
println!(" - graphql_ws::test_invalid_json");
test_invalid_json(integration).await;
println!(" - graphql_ws::test_invalid_query");
test_invalid_query(integration).await;
}
async fn test_simple_subscription<T: WsIntegration>(integration: &T) {
let messages = vec![
WsIntegrationMessage::Send(json!({
"type": "connection_init",
"payload": {},
})),
WsIntegrationMessage::Expect(
json!({"type": "connection_ack"}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
WsIntegrationMessage::Expect(
json!({"type": "ka"}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
WsIntegrationMessage::Send(json!({
"id": "1",
"type": "start",
"payload": {
"variables": {},
"extensions": {},
"operationName": null,
"query": "subscription { asyncHuman { id, name, homePlanet } }",
},
})),
WsIntegrationMessage::Expect(
json!({
"type": "data",
"id": "1",
"payload": {
"data": {
"asyncHuman": {
"id": "1000",
"name": "Luke Skywalker",
"homePlanet": "Tatooine",
},
},
},
}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
];
integration.run(messages).await.unwrap();
}
async fn test_invalid_json<T: WsIntegration>(integration: &T) {
let messages = vec![
WsIntegrationMessage::Send(json!({"whatever": "invalid value"})),
WsIntegrationMessage::Expect(
json!({
"type": "connection_error",
"payload": {
"message": "`serde` error: missing field `type` at line 1 column 28",
},
}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
];
integration.run(messages).await.unwrap();
}
async fn test_invalid_query<T: WsIntegration>(integration: &T) {
let messages = vec![
WsIntegrationMessage::Send(json!({
"type": "connection_init",
"payload": {},
})),
WsIntegrationMessage::Expect(
json!({"type": "connection_ack"}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
WsIntegrationMessage::Expect(
json!({"type": "ka"}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
WsIntegrationMessage::Send(json!({
"id": "1",
"type": "start",
"payload": {
"variables": {},
"extensions": {},
"operationName": null,
"query": "subscription { asyncHuman }",
},
})),
WsIntegrationMessage::Expect(
json!({
"type": "error",
"id": "1",
"payload": [{
"message": "Field \"asyncHuman\" of type \"Human!\" must have a selection \
of subfields. Did you mean \"asyncHuman { ... }\"?",
"locations": [{
"line": 1,
"column": 16,
}],
}],
}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
];
integration.run(messages).await.unwrap();
}
} }
async fn test_ws_simple_subscription<T: WsIntegration>(integration: &T) { /// Integration tests for the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new].
let messages = vec![ ///
WsIntegrationMessage::Send( /// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md
r#"{ pub mod graphql_transport_ws {
"type":"connection_init", use serde_json::json;
"payload":{}
}"#
.into(),
),
WsIntegrationMessage::Expect(
r#"{
"type":"connection_ack"
}"#
.into(),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
WsIntegrationMessage::Expect(
r#"{
"type":"ka"
}"#
.into(),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
WsIntegrationMessage::Send(
r#"{
"id":"1",
"type":"start",
"payload":{
"variables":{},
"extensions":{},
"operationName":null,
"query":"subscription { asyncHuman { id, name, homePlanet } }"
}
}"#
.into(),
),
WsIntegrationMessage::Expect(
r#"{
"type":"data",
"id":"1",
"payload":{
"data":{
"asyncHuman":{
"id":"1000",
"name":"Luke Skywalker",
"homePlanet":"Tatooine"
}
}
}
}"#
.into(),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
];
integration.run(messages).await.unwrap(); use super::{WsIntegration, WsIntegrationMessage, WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT};
}
async fn test_ws_invalid_json<T: WsIntegration>(integration: &T) { #[allow(missing_docs)]
let messages = vec![ pub async fn run_test_suite<T: WsIntegration>(integration: &T) {
WsIntegrationMessage::Send("invalid json".into()), println!("Running `graphql-ws` test suite for integration");
WsIntegrationMessage::Expect(
r#"{
"type":"connection_error",
"payload":{
"message":"serde error: expected value at line 1 column 1"
}
}"#
.into(),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
];
integration.run(messages).await.unwrap(); println!(" - graphql_ws::test_simple_subscription");
} test_simple_subscription(integration).await;
async fn test_ws_invalid_query<T: WsIntegration>(integration: &T) { println!(" - graphql_ws::test_invalid_json");
let messages = vec![ test_invalid_json(integration).await;
WsIntegrationMessage::Send(
r#"{
"type":"connection_init",
"payload":{}
}"#
.into(),
),
WsIntegrationMessage::Expect(
r#"{
"type":"connection_ack"
}"#
.into(),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT
),
WsIntegrationMessage::Expect(
r#"{
"type":"ka"
}"#
.into(),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT
),
WsIntegrationMessage::Send(
r#"{
"id":"1",
"type":"start",
"payload":{
"variables":{},
"extensions":{},
"operationName":null,
"query":"subscription { asyncHuman }"
}
}"#
.into(),
),
WsIntegrationMessage::Expect(
r#"{
"type":"error",
"id":"1",
"payload":[{
"message":"Field \"asyncHuman\" of type \"Human!\" must have a selection of subfields. Did you mean \"asyncHuman { ... }\"?",
"locations":[{
"line":1,
"column":16
}]
}]
}"#
.into(),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT
)
];
integration.run(messages).await.unwrap(); println!(" - graphql_ws::test_invalid_query");
test_invalid_query(integration).await;
}
async fn test_simple_subscription<T: WsIntegration>(integration: &T) {
let messages = vec![
WsIntegrationMessage::Send(json!({
"type": "connection_init",
"payload": {},
})),
WsIntegrationMessage::Expect(
json!({"type": "connection_ack"}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
WsIntegrationMessage::Expect(
json!({"type": "pong"}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
WsIntegrationMessage::Send(json!({"type": "ping"})),
WsIntegrationMessage::Expect(
json!({"type": "pong"}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
WsIntegrationMessage::Send(json!({
"id": "1",
"type": "subscribe",
"payload": {
"variables": {},
"extensions": {},
"operationName": null,
"query": "subscription { asyncHuman { id, name, homePlanet } }",
},
})),
WsIntegrationMessage::Expect(
json!({
"id": "1",
"type": "next",
"payload": {
"data": {
"asyncHuman": {
"id": "1000",
"name": "Luke Skywalker",
"homePlanet": "Tatooine",
},
},
},
}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
];
integration.run(messages).await.unwrap();
}
async fn test_invalid_json<T: WsIntegration>(integration: &T) {
let messages = vec![
WsIntegrationMessage::Send(json!({"whatever": "invalid value"})),
WsIntegrationMessage::Expect(
json!({
"code": 4400,
"description": "`serde` error: missing field `type` at line 1 column 28",
}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
];
integration.run(messages).await.unwrap();
}
async fn test_invalid_query<T: WsIntegration>(integration: &T) {
let messages = vec![
WsIntegrationMessage::Send(json!({
"type": "connection_init",
"payload": {},
})),
WsIntegrationMessage::Expect(
json!({"type": "connection_ack"}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
WsIntegrationMessage::Expect(
json!({"type": "pong"}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
WsIntegrationMessage::Send(json!({"type": "ping"})),
WsIntegrationMessage::Expect(
json!({"type": "pong"}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
WsIntegrationMessage::Send(json!({
"id": "1",
"type": "subscribe",
"payload": {
"variables": {},
"extensions": {},
"operationName": null,
"query": "subscription { asyncHuman }",
},
})),
WsIntegrationMessage::Expect(
json!({
"type": "error",
"id": "1",
"payload": [{
"message": "Field \"asyncHuman\" of type \"Human!\" must have a selection \
of subfields. Did you mean \"asyncHuman { ... }\"?",
"locations": [{
"line": 1,
"column": 16,
}],
}],
}),
WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT,
),
];
integration.run(messages).await.unwrap();
}
} }
} }

View file

@ -13,13 +13,13 @@ All user visible changes to `juniper_actix` crate will be documented in this fil
- Switched to 4.0 version of [`actix-web` crate] and its ecosystem. ([#1034]) - Switched to 4.0 version of [`actix-web` crate] and its ecosystem. ([#1034])
- Switched to 0.16 version of [`juniper` crate]. - Switched to 0.16 version of [`juniper` crate].
- Switched to 0.4 version of [`juniper_graphql_ws` crate]. - Switched to 0.4 version of [`juniper_graphql_ws` crate].
- Renamed `subscriptions::subscriptions_handler()` as `subscriptions::graphql_ws_handler()` for processing the [legacy `graphql-ws` GraphQL over WebSocket Protocol][graphql-ws]. ([#1191]) - Switched to 0.2 version of [`actix-ws` crate]. ([#1197])
- Renamed `subscriptions::subscriptions_handler()` as `subscriptions::graphql_ws_handler()` for processing the [legacy `graphql-ws` GraphQL over WebSocket Protocol][graphql-ws]. ([#1191], [#1197])
### Added ### Added
- `subscriptions::graphql_transport_ws_handler()` allowing to process the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][graphql-transport-ws]. ([#1191]) - `subscriptions::graphql_transport_ws_handler()` allowing to process the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][graphql-transport-ws]. ([#1191], [#1197])
- `subscriptions::ws_handler()` with auto-selection between the [legacy `graphql-ws` GraphQL over WebSocket Protocol][graphql-ws] and the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][graphql-transport-ws], based on the `Sec-Websocket-Protocol` HTTP header value. ([#1191]) - `subscriptions::ws_handler()` with auto-selection between the [legacy `graphql-ws` GraphQL over WebSocket Protocol][graphql-ws] and the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][graphql-transport-ws], based on the `Sec-Websocket-Protocol` HTTP header value. ([#1191], [#1197])
- Support of 0.14 version of [`actix` crate]. ([#1189])
### Fixed ### Fixed
@ -28,8 +28,8 @@ All user visible changes to `juniper_actix` crate will be documented in this fil
[#1034]: /../../pull/1034 [#1034]: /../../pull/1034
[#1169]: /../../issues/1169 [#1169]: /../../issues/1169
[#1187]: /../../pull/1187 [#1187]: /../../pull/1187
[#1189]: /../../pull/1189
[#1191]: /../../pull/1191 [#1191]: /../../pull/1191
[#1197]: /../../pull/1197
@ -43,6 +43,7 @@ See [old CHANGELOG](/../../blob/juniper_actix-v0.4.0/juniper_actix/CHANGELOG.md)
[`actix` crate]: https://docs.rs/actix [`actix` crate]: https://docs.rs/actix
[`actix-web` crate]: https://docs.rs/actix-web [`actix-web` crate]: https://docs.rs/actix-web
[`actix-ws` crate]: https://docs.rs/actix-ws
[`juniper` crate]: https://docs.rs/juniper [`juniper` crate]: https://docs.rs/juniper
[`juniper_graphql_ws` crate]: https://docs.rs/juniper_graphql_ws [`juniper_graphql_ws` crate]: https://docs.rs/juniper_graphql_ws
[Semantic Versioning 2.0.0]: https://semver.org [Semantic Versioning 2.0.0]: https://semver.org

View file

@ -19,18 +19,12 @@ all-features = true
rustdoc-args = ["--cfg", "docsrs"] rustdoc-args = ["--cfg", "docsrs"]
[features] [features]
subscriptions = [ subscriptions = ["dep:actix-ws", "dep:juniper_graphql_ws"]
"dep:actix",
"dep:actix-web-actors",
"dep:juniper_graphql_ws",
"dep:tokio",
]
[dependencies] [dependencies]
actix = { version = ">=0.12, <=0.14", optional = true }
actix-http = "3.2" actix-http = "3.2"
actix-web = "4.4" actix-web = "4.4"
actix-web-actors = { version = "4.1", optional = true } actix-ws = { version = "0.2", optional = true }
anyhow = "1.0.47" anyhow = "1.0.47"
futures = "0.3.22" futures = "0.3.22"
juniper = { version = "0.16.0-dev", path = "../juniper", default-features = false } juniper = { version = "0.16.0-dev", path = "../juniper", default-features = false }
@ -39,7 +33,6 @@ http = "0.2.4"
serde = { version = "1.0.122", features = ["derive"] } serde = { version = "1.0.122", features = ["derive"] }
serde_json = "1.0.18" serde_json = "1.0.18"
thiserror = "1.0" thiserror = "1.0"
tokio = { version = "1.0", features = ["sync"], optional = true }
[dev-dependencies] [dev-dependencies]
actix-cors = "0.6" actix-cors = "0.6"

View file

@ -168,23 +168,15 @@ pub async fn playground_handler(
#[cfg(feature = "subscriptions")] #[cfg(feature = "subscriptions")]
/// `juniper_actix` subscriptions handler implementation. /// `juniper_actix` subscriptions handler implementation.
pub mod subscriptions { pub mod subscriptions {
use std::{fmt, sync::Arc}; use std::{fmt, pin::pin, sync::Arc};
use actix::{
AsyncContext as _, ContextFutureSpawner as _, Handler, StreamHandler, WrapFuture as _,
};
use actix_web::{ use actix_web::{
http::header::{HeaderName, HeaderValue}, http::header::{HeaderName, HeaderValue},
web, HttpRequest, HttpResponse, web, HttpRequest, HttpResponse,
}; };
use actix_web_actors::ws; use futures::{future, SinkExt as _, StreamExt as _};
use futures::{
stream::{SplitSink, SplitStream},
SinkExt as _, Stream, StreamExt as _,
};
use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue}; use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue};
use juniper_graphql_ws::{graphql_transport_ws, graphql_ws, ArcSchema, Init}; use juniper_graphql_ws::{graphql_transport_ws, graphql_ws, ArcSchema, Init};
use tokio::sync::Mutex;
/// Serves by auto-selecting between the /// Serves by auto-selecting between the
/// [legacy `graphql-ws` GraphQL over WebSocket Protocol][old] and the /// [legacy `graphql-ws` GraphQL over WebSocket Protocol][old] and the
@ -261,22 +253,45 @@ pub mod subscriptions {
S: ScalarValue + Send + Sync + 'static, S: ScalarValue + Send + Sync + 'static,
I: Init<S, CtxT> + Send, I: Init<S, CtxT> + Send,
{ {
let (s_tx, s_rx) = graphql_ws::Connection::new(ArcSchema(schema), init).split::<Message>(); let (mut resp, mut ws_tx, ws_rx) = actix_ws::handle(&req, stream)?;
let (s_tx, mut s_rx) = graphql_ws::Connection::new(ArcSchema(schema), init).split();
let mut resp = ws::start( actix_web::rt::spawn(async move {
Actor { let input = ws_rx
tx: Arc::new(Mutex::new(s_tx)), .map(|r| r.map(Message))
rx: Arc::new(Mutex::new(s_rx)), .forward(s_tx.sink_map_err(|e| match e {}));
}, let output = pin!(async move {
&req, while let Some(msg) = s_rx.next().await {
stream, match serde_json::to_string(&msg) {
)?; Ok(m) => {
if ws_tx.text(m).await.is_err() {
return;
}
}
Err(e) => {
_ = ws_tx
.close(Some(actix_ws::CloseReason {
code: actix_ws::CloseCode::Error,
description: Some(format!("error serializing response: {e}")),
}))
.await;
return;
}
}
}
_ = ws_tx
.close(Some((actix_ws::CloseCode::Normal, "Normal Closure").into()))
.await;
});
// No errors can be returned here, so ignoring is OK.
_ = future::select(input, output).await;
});
resp.headers_mut().insert( resp.headers_mut().insert(
HeaderName::from_static("sec-websocket-protocol"), HeaderName::from_static("sec-websocket-protocol"),
HeaderValue::from_static("graphql-ws"), HeaderValue::from_static("graphql-ws"),
); );
Ok(resp) Ok(resp)
} }
@ -306,164 +321,80 @@ pub mod subscriptions {
S: ScalarValue + Send + Sync + 'static, S: ScalarValue + Send + Sync + 'static,
I: Init<S, CtxT> + Send, I: Init<S, CtxT> + Send,
{ {
let (s_tx, s_rx) = let (mut resp, mut ws_tx, ws_rx) = actix_ws::handle(&req, stream)?;
graphql_transport_ws::Connection::new(ArcSchema(schema), init).split::<Message>(); let (s_tx, mut s_rx) =
graphql_transport_ws::Connection::new(ArcSchema(schema), init).split();
let mut resp = ws::start( actix_web::rt::spawn(async move {
Actor { let input = ws_rx
tx: Arc::new(Mutex::new(s_tx)), .map(|r| r.map(Message))
rx: Arc::new(Mutex::new(s_rx)), .forward(s_tx.sink_map_err(|e| match e {}));
}, let output = pin!(async move {
&req, while let Some(output) = s_rx.next().await {
stream, match output {
)?; graphql_transport_ws::Output::Message(msg) => {
match serde_json::to_string(&msg) {
Ok(m) => {
if ws_tx.text(m).await.is_err() {
return;
}
}
Err(e) => {
_ = ws_tx
.close(Some(actix_ws::CloseReason {
code: actix_ws::CloseCode::Error,
description: Some(format!(
"error serializing response: {e}",
)),
}))
.await;
return;
}
}
}
graphql_transport_ws::Output::Close { code, message } => {
_ = ws_tx
.close(Some(actix_ws::CloseReason {
code: code.into(),
description: Some(message),
}))
.await;
return;
}
}
}
_ = ws_tx
.close(Some((actix_ws::CloseCode::Normal, "Normal Closure").into()))
.await;
});
// No errors can be returned here, so ignoring is OK.
_ = future::select(input, output).await;
});
resp.headers_mut().insert( resp.headers_mut().insert(
HeaderName::from_static("sec-websocket-protocol"), HeaderName::from_static("sec-websocket-protocol"),
HeaderValue::from_static("graphql-transport-ws"), HeaderValue::from_static("graphql-transport-ws"),
); );
Ok(resp) Ok(resp)
} }
type ConnectionSplitSink<Conn> = Arc<Mutex<SplitSink<Conn, Message>>>;
type ConnectionSplitStream<Conn> = Arc<Mutex<SplitStream<Conn>>>;
/// [`actix::Actor`], coordinating messages between [`actix_web`] and [`juniper_graphql_ws`]:
/// - incoming [`ws::Message`] -> [`Actor`] -> [`juniper`]
/// - [`juniper`] -> [`Actor`] -> response [`ws::Message`]
struct Actor<Conn> {
tx: ConnectionSplitSink<Conn>,
rx: ConnectionSplitStream<Conn>,
}
impl<Conn> StreamHandler<Result<ws::Message, ws::ProtocolError>> for Actor<Conn>
where
Self: actix::Actor<Context = ws::WebsocketContext<Self>>,
Conn: futures::Sink<Message>,
<Conn as futures::Sink<Message>>::Error: fmt::Debug,
{
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
#[allow(clippy::single_match)]
match msg {
Ok(msg) => {
let tx = self.tx.clone();
// TODO: Somehow this implementation always closes as `1006: Abnormal closure`
// due to excessive polling of `tx` part.
// Needs to be reworked.
async move {
tx.lock()
.await
.send(Message(msg))
.await
.expect("Infallible: this should not happen");
}
.into_actor(self)
.wait(ctx);
}
Err(_) => {
// TODO: trace
// ignore the message if there's a transport error
}
}
}
}
/// [`juniper`] -> [`Actor`].
impl<Conn> actix::Actor for Actor<Conn>
where
Conn: Stream + 'static,
<Conn as Stream>::Item: IntoWsResponse + Send,
{
type Context = ws::WebsocketContext<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
let stream = self.rx.clone();
let addr = ctx.address();
let fut = async move {
let mut stream = stream.lock().await;
while let Some(msg) = stream.next().await {
// Sending the `msg` to `self`, so that it can be forwarded back to the client.
addr.do_send(ServerMessage(msg));
}
}
.into_actor(self);
// TODO: trace
ctx.spawn(fut);
}
fn stopped(&mut self, _: &mut Self::Context) {
// TODO: trace
}
}
/// [`Actor`] -> response [`ws::Message`].
impl<Conn, M> Handler<ServerMessage<M>> for Actor<Conn>
where
Conn: Stream<Item = M> + 'static,
M: IntoWsResponse + Send,
{
type Result = ();
fn handle(&mut self, msg: ServerMessage<M>, ctx: &mut Self::Context) -> Self::Result {
match msg.0.into_ws_response() {
Ok(msg) => ctx.text(msg),
// TODO: trace
Err(reason) => ctx.close(Some(reason)),
}
}
}
#[derive(actix::Message)]
#[rtype(result = "()")]
struct ServerMessage<T>(T);
/// Conversion of a [`ServerMessage`] into a response [`ws::Message`].
pub trait IntoWsResponse {
/// Converts this [`ServerMessage`] into response [`ws::Message`].
fn into_ws_response(self) -> Result<String, ws::CloseReason>;
}
impl<S: ScalarValue> IntoWsResponse for graphql_transport_ws::Output<S> {
fn into_ws_response(self) -> Result<String, ws::CloseReason> {
match self {
Self::Message(msg) => serde_json::to_string(&msg).map_err(|e| ws::CloseReason {
code: ws::CloseCode::Error,
description: Some(format!("error serializing response: {e}")),
}),
Self::Close { code, message } => Err(ws::CloseReason {
code: code.into(),
description: Some(message),
}),
}
}
}
impl<S: ScalarValue> IntoWsResponse for graphql_ws::ServerMessage<S> {
fn into_ws_response(self) -> Result<String, ws::CloseReason> {
serde_json::to_string(&self).map_err(|e| ws::CloseReason {
code: ws::CloseCode::Error,
description: Some(format!("error serializing response: {e}")),
})
}
}
#[derive(Debug)] #[derive(Debug)]
struct Message(ws::Message); struct Message(actix_ws::Message);
impl<S: ScalarValue> TryFrom<Message> for graphql_transport_ws::Input<S> { impl<S: ScalarValue> TryFrom<Message> for graphql_transport_ws::Input<S> {
type Error = Error; type Error = Error;
fn try_from(msg: Message) -> Result<Self, Self::Error> { fn try_from(msg: Message) -> Result<Self, Self::Error> {
match msg.0 { match msg.0 {
ws::Message::Text(text) => serde_json::from_slice(text.as_bytes()) actix_ws::Message::Text(text) => serde_json::from_slice(text.as_bytes())
.map(Self::Message) .map(Self::Message)
.map_err(Error::Serde), .map_err(Error::Serde),
ws::Message::Close(_) => Ok(Self::Close), actix_ws::Message::Binary(bytes) => serde_json::from_slice(bytes.as_ref())
_ => Err(Error::UnexpectedClientMessage), .map(Self::Message)
.map_err(Error::Serde),
actix_ws::Message::Close(_) => Ok(Self::Close),
other => Err(Error::UnexpectedClientMessage(other)),
} }
} }
} }
@ -473,31 +404,34 @@ pub mod subscriptions {
fn try_from(msg: Message) -> Result<Self, Self::Error> { fn try_from(msg: Message) -> Result<Self, Self::Error> {
match msg.0 { match msg.0 {
ws::Message::Text(text) => { actix_ws::Message::Text(text) => {
serde_json::from_slice(text.as_bytes()).map_err(Error::Serde) serde_json::from_slice(text.as_bytes()).map_err(Error::Serde)
} }
ws::Message::Close(_) => Ok(Self::ConnectionTerminate), actix_ws::Message::Binary(bytes) => {
_ => Err(Error::UnexpectedClientMessage), serde_json::from_slice(bytes.as_ref()).map_err(Error::Serde)
}
actix_ws::Message::Close(_) => Ok(Self::ConnectionTerminate),
other => Err(Error::UnexpectedClientMessage(other)),
} }
} }
} }
/// Errors that can happen while handling client messages /// Possible errors of serving an [`actix_ws`] connection.
#[derive(Debug)] #[derive(Debug)]
enum Error { enum Error {
/// Errors that can happen while deserializing client messages /// Deserializing of a client or server message failed.
Serde(serde_json::Error), Serde(serde_json::Error),
/// Error for unexpected client messages /// Unexpected client [`actix_ws::Message`].
UnexpectedClientMessage, UnexpectedClientMessage(actix_ws::Message),
} }
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
Self::Serde(e) => write!(f, "serde error: {e}"), Self::Serde(e) => write!(f, "`serde` error: {e}"),
Self::UnexpectedClientMessage => { Self::UnexpectedClientMessage(m) => {
write!(f, "unexpected message received from client") write!(f, "unexpected message received from client: {m:?}")
} }
} }
} }
@ -514,7 +448,7 @@ mod tests {
use actix_web::{ use actix_web::{
dev::ServiceResponse, dev::ServiceResponse,
http, http,
http::header::CONTENT_TYPE, http::header::{ACCEPT, CONTENT_TYPE},
test::{self, TestRequest}, test::{self, TestRequest},
web::Data, web::Data,
App, App,
@ -527,7 +461,6 @@ mod tests {
}; };
use super::*; use super::*;
use actix_web::http::header::ACCEPT;
type Schema = type Schema =
juniper::RootNode<'static, Query, EmptyMutation<Database>, EmptySubscription<Database>>; juniper::RootNode<'static, Query, EmptyMutation<Database>, EmptySubscription<Database>>;
@ -832,41 +765,37 @@ mod tests {
#[cfg(feature = "subscriptions")] #[cfg(feature = "subscriptions")]
#[cfg(test)] #[cfg(test)]
mod subscription_tests { mod subscription_tests {
use std::time::Duration; use actix_http::ws;
use actix_test::start; use actix_test::start;
use actix_web::{ use actix_web::{web, App, Error, HttpRequest, HttpResponse};
web::{self, Data},
App, Error, HttpRequest, HttpResponse,
};
use actix_web_actors::ws;
use juniper::{ use juniper::{
futures::{SinkExt, StreamExt}, futures::{SinkExt, StreamExt},
http::tests::{run_ws_test_suite, WsIntegration, WsIntegrationMessage}, http::tests::{graphql_transport_ws, graphql_ws, WsIntegration, WsIntegrationMessage},
tests::fixtures::starwars::schema::{Database, Query, Subscription}, tests::fixtures::starwars::schema::{Database, Query, Subscription},
EmptyMutation, LocalBoxFuture, EmptyMutation, LocalBoxFuture,
}; };
use juniper_graphql_ws::ConnectionConfig; use juniper_graphql_ws::ConnectionConfig;
use tokio::time::timeout; use tokio::time::timeout;
use super::subscriptions::graphql_ws_handler; use super::subscriptions;
#[derive(Default)] struct TestWsIntegration(&'static str);
struct TestActixWsIntegration;
impl TestActixWsIntegration { impl TestWsIntegration {
async fn run_async( async fn run_async(
&self, &self,
messages: Vec<WsIntegrationMessage>, messages: Vec<WsIntegrationMessage>,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
let proto = self.0;
let mut server = start(|| { let mut server = start(|| {
App::new() App::new()
.app_data(Data::new(Schema::new( .app_data(web::Data::new(Schema::new(
Query, Query,
EmptyMutation::<Database>::new(), EmptyMutation::<Database>::new(),
Subscription, Subscription,
))) )))
.service(web::resource("/subscriptions").to(subscriptions)) .service(web::resource("/subscriptions").to(subscription(proto)))
}); });
let mut framed = server.ws_at("/subscriptions").await.unwrap(); let mut framed = server.ws_at("/subscriptions").await.unwrap();
@ -874,12 +803,12 @@ mod subscription_tests {
match message { match message {
WsIntegrationMessage::Send(body) => { WsIntegrationMessage::Send(body) => {
framed framed
.send(ws::Message::Text(body.to_owned().into())) .send(ws::Message::Text(body.to_string().into()))
.await .await
.map_err(|e| anyhow::anyhow!("WS error: {e:?}"))?; .map_err(|e| anyhow::anyhow!("WS error: {e:?}"))?;
} }
WsIntegrationMessage::Expect(body, message_timeout) => { WsIntegrationMessage::Expect(body, message_timeout) => {
let frame = timeout(Duration::from_millis(*message_timeout), framed.next()) let frame = timeout(*message_timeout, framed.next())
.await .await
.map_err(|_| anyhow::anyhow!("Timed-out waiting for message"))? .map_err(|_| anyhow::anyhow!("Timed-out waiting for message"))?
.ok_or_else(|| anyhow::anyhow!("Empty message received"))? .ok_or_else(|| anyhow::anyhow!("Empty message received"))?
@ -887,21 +816,29 @@ mod subscription_tests {
match frame { match frame {
ws::Frame::Text(ref bytes) => { ws::Frame::Text(ref bytes) => {
let expected_value =
serde_json::from_str::<serde_json::Value>(body)
.map_err(|e| anyhow::anyhow!("Serde error: {e:?}"))?;
let value: serde_json::Value = serde_json::from_slice(bytes) let value: serde_json::Value = serde_json::from_slice(bytes)
.map_err(|e| anyhow::anyhow!("Serde error: {e:?}"))?; .map_err(|e| anyhow::anyhow!("Serde error: {e:?}"))?;
if value != expected_value { if value != *body {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
"Expected message: {expected_value}. \ "Expected message: {body}. \
Received message: {value}", Received message: {value}",
)); ));
} }
} }
_ => return Err(anyhow::anyhow!("Received non-text frame")), ws::Frame::Close(Some(reason)) => {
let actual = serde_json::json!({
"code": u16::from(reason.code),
"description": reason.description,
});
if actual != *body {
return Err(anyhow::anyhow!(
"Expected message: {body}. \
Received message: {actual}",
));
}
}
f => return Err(anyhow::anyhow!("Received non-text frame: {f:?}")),
} }
} }
} }
@ -911,7 +848,7 @@ mod subscription_tests {
} }
} }
impl WsIntegration for TestActixWsIntegration { impl WsIntegration for TestWsIntegration {
fn run( fn run(
&self, &self,
messages: Vec<WsIntegrationMessage>, messages: Vec<WsIntegrationMessage>,
@ -922,20 +859,32 @@ mod subscription_tests {
type Schema = juniper::RootNode<'static, Query, EmptyMutation<Database>, Subscription>; type Schema = juniper::RootNode<'static, Query, EmptyMutation<Database>, Subscription>;
async fn subscriptions( fn subscription(
req: HttpRequest, proto: &'static str,
stream: web::Payload, ) -> impl actix_web::Handler<
schema: web::Data<Schema>, (HttpRequest, web::Payload, web::Data<Schema>),
) -> Result<HttpResponse, Error> { Output = Result<HttpResponse, Error>,
let context = Database::new(); > {
let schema = schema.into_inner(); move |req: HttpRequest, stream: web::Payload, schema: web::Data<Schema>| async move {
let config = ConnectionConfig::new(context); let context = Database::new();
let schema = schema.into_inner();
let config = ConnectionConfig::new(context);
graphql_ws_handler(req, stream, schema, config).await if proto == "graphql-ws" {
subscriptions::graphql_ws_handler(req, stream, schema, config).await
} else {
subscriptions::graphql_transport_ws_handler(req, stream, schema, config).await
}
}
} }
#[actix_web::rt::test] #[actix_web::rt::test]
async fn test_actix_ws_integration() { async fn test_graphql_ws_integration() {
run_ws_test_suite(&mut TestActixWsIntegration).await; graphql_ws::run_test_suite(&mut TestWsIntegration("graphql-ws")).await;
}
#[actix_web::rt::test]
async fn test_graphql_transport_ws_integration() {
graphql_transport_ws::run_test_suite(&mut TestWsIntegration("graphql-transport-ws")).await;
} }
} }

View file

@ -16,7 +16,7 @@ All user visible changes to `juniper_graphql_ws` crate will be documented in thi
### Added ### Added
- `graphql_transport_ws` module implementing [`graphql-transport-ws` GraphQL over WebSocket Protocol][proto-5.14.0] as of 5.14.0 version of [`graphql-ws` npm package] behind `graphql-transport-ws` Cargo feature. ([#1158], [#1191], [#1196], [#1022]) - `graphql_transport_ws` module implementing [`graphql-transport-ws` GraphQL over WebSocket Protocol][proto-5.14.0] as of 5.14.0 version of [`graphql-ws` npm package] behind `graphql-transport-ws` Cargo feature. ([#1158], [#1191], [#1196], [#1197], [#1022])
### Changed ### Changed
@ -26,6 +26,7 @@ All user visible changes to `juniper_graphql_ws` crate will be documented in thi
[#1158]: /../../pull/1158 [#1158]: /../../pull/1158
[#1191]: /../../pull/1191 [#1191]: /../../pull/1191
[#1196]: /../../pull/1196 [#1196]: /../../pull/1196
[#1197]: /../../pull/1197
[proto-5.14.0]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md [proto-5.14.0]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md
[proto-legacy]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md [proto-legacy]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md

View file

@ -478,6 +478,30 @@ where
}, },
} }
} }
/// Performs polling of the [`Sink`] part of this [`Connection`].
///
/// Effectively represents an implementation of [`Sink::poll_ready()`] and
/// [`Sink::poll_flush()`] methods.
fn poll_sink(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), &'static str>> {
match &mut self.sink_state {
ConnectionSinkState::Ready { .. } => Poll::Ready(Ok(())),
ConnectionSinkState::HandlingMessage { ref mut result } => {
match Pin::new(result).poll(cx) {
Poll::Ready((state, reactions)) => {
self.reactions.push(reactions);
self.sink_state = ConnectionSinkState::Ready { state };
if let Some(waker) = self.stream_waker.take() {
waker.wake();
}
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
ConnectionSinkState::Closed => Poll::Ready(Err("polled after close")),
}
}
} }
impl<S, I, T> Sink<T> for Connection<S, I> impl<S, I, T> Sink<T> for Connection<S, I>
@ -489,21 +513,9 @@ where
{ {
type Error = Infallible; type Error = Infallible;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
match &mut self.sink_state { self.poll_sink(cx)
ConnectionSinkState::Ready { .. } => Poll::Ready(Ok(())), .map_err(|e| panic!("`Connection::poll_ready()`: {e}"))
ConnectionSinkState::HandlingMessage { ref mut result } => {
match Pin::new(result).poll(cx) {
Poll::Ready((state, reactions)) => {
self.reactions.push(reactions);
self.sink_state = ConnectionSinkState::Ready { state };
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
ConnectionSinkState::Closed => panic!("poll_ready called after close"),
}
} }
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
@ -538,19 +550,19 @@ where
} }
} }
} }
_ => panic!("start_send called when not ready"), _ => panic!("`Sink::start_send()`: called when not ready"),
}; };
Ok(()) Ok(())
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
<Self as Sink<T>>::poll_ready(self, cx) self.poll_sink(cx).map(|_| Ok(()))
} }
fn poll_close(mut self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_close(mut self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Self::Error>> {
self.sink_state = ConnectionSinkState::Closed; self.sink_state = ConnectionSinkState::Closed;
if let Some(waker) = self.stream_waker.take() { if let Some(waker) = self.stream_waker.take() {
// Wake up the stream so it can close too. // Wake up the `Stream` so it can close too.
waker.wake(); waker.wake();
} }
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
@ -567,9 +579,7 @@ where
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.stream_waker = Some(cx.waker().clone()); self.stream_waker = Some(cx.waker().clone());
if let ConnectionSinkState::Closed = self.sink_state { if self.stream_terminated {
return Poll::Ready(None);
} else if self.stream_terminated {
return Poll::Ready(None); return Poll::Ready(None);
} }
@ -590,6 +600,11 @@ where
_ => (), _ => (),
} }
} }
if let ConnectionSinkState::Closed = self.sink_state {
return Poll::Ready(None);
}
Poll::Pending Poll::Pending
} }
} }

View file

@ -392,8 +392,8 @@ pub mod subscriptions {
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
Self::Warp(e) => write!(f, "warp error: {e}"), Self::Warp(e) => write!(f, "`warp` error: {e}"),
Self::Serde(e) => write!(f, "serde error: {e}"), Self::Serde(e) => write!(f, "`serde` error: {e}"),
} }
} }
} }