diff --git a/juniper/src/http/mod.rs b/juniper/src/http/mod.rs index 45e310be..06244a6f 100644 --- a/juniper/src/http/mod.rs +++ b/juniper/src/http/mod.rs @@ -360,8 +360,11 @@ impl GraphQLBatchResponse { #[cfg(feature = "expose-test-schema")] #[allow(missing_docs)] pub mod tests { + use std::time::Duration; + + use serde_json::Value as Json; + use crate::LocalBoxFuture; - use serde_json::{self, Value as Json}; /// Normalized response content we expect to get back from /// the http framework integration we are testing. @@ -598,162 +601,281 @@ pub mod tests { ) -> LocalBoxFuture>; } - /// WebSocket framework integration message + /// WebSocket framework integration message. pub enum WsIntegrationMessage { - /// Send message through the WebSocket - /// Takes a message as a String - Send(String), - /// Expect message to come through the WebSocket - /// Takes expected message as a String and a timeout in milliseconds - Expect(String, u64), + /// Send a message through a WebSocket. + Send(Json), + + /// Expects a message to come through a WebSocket, with the specified timeout. + Expect(Json, Duration), } - /// Default value in milliseconds for how long to wait for an incoming message - pub const WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT: u64 = 100; + /// Default value in milliseconds for how long to wait for an incoming WebSocket message. + pub const WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT: Duration = Duration::from_millis(100); - #[allow(missing_docs)] - pub async fn run_ws_test_suite(integration: &T) { - println!("Running WebSocket Test suite for integration"); + /// Integration tests for the [legacy `graphql-ws` GraphQL over WebSocket Protocol][old]. + /// + /// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md + pub mod graphql_ws { + use serde_json::json; - println!(" - test_ws_simple_subscription"); - test_ws_simple_subscription(integration).await; + use super::{WsIntegration, WsIntegrationMessage, WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT}; - println!(" - test_ws_invalid_json"); - test_ws_invalid_json(integration).await; + #[allow(missing_docs)] + pub async fn run_test_suite(integration: &T) { + println!("Running `graphql-ws` test suite for integration"); - println!(" - test_ws_invalid_query"); - test_ws_invalid_query(integration).await; + println!(" - graphql_ws::test_simple_subscription"); + test_simple_subscription(integration).await; + + println!(" - graphql_ws::test_invalid_json"); + test_invalid_json(integration).await; + + println!(" - graphql_ws::test_invalid_query"); + test_invalid_query(integration).await; + } + + async fn test_simple_subscription(integration: &T) { + let messages = vec![ + WsIntegrationMessage::Send(json!({ + "type": "connection_init", + "payload": {}, + })), + WsIntegrationMessage::Expect( + json!({"type": "connection_ack"}), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + WsIntegrationMessage::Expect( + json!({"type": "ka"}), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + WsIntegrationMessage::Send(json!({ + "id": "1", + "type": "start", + "payload": { + "variables": {}, + "extensions": {}, + "operationName": null, + "query": "subscription { asyncHuman { id, name, homePlanet } }", + }, + })), + WsIntegrationMessage::Expect( + json!({ + "type": "data", + "id": "1", + "payload": { + "data": { + "asyncHuman": { + "id": "1000", + "name": "Luke Skywalker", + "homePlanet": "Tatooine", + }, + }, + }, + }), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + ]; + + integration.run(messages).await.unwrap(); + } + + async fn test_invalid_json(integration: &T) { + let messages = vec![ + WsIntegrationMessage::Send(json!({"whatever": "invalid value"})), + WsIntegrationMessage::Expect( + json!({ + "type": "connection_error", + "payload": { + "message": "`serde` error: missing field `type` at line 1 column 28", + }, + }), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + ]; + + integration.run(messages).await.unwrap(); + } + + async fn test_invalid_query(integration: &T) { + let messages = vec![ + WsIntegrationMessage::Send(json!({ + "type": "connection_init", + "payload": {}, + })), + WsIntegrationMessage::Expect( + json!({"type": "connection_ack"}), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + WsIntegrationMessage::Expect( + json!({"type": "ka"}), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + WsIntegrationMessage::Send(json!({ + "id": "1", + "type": "start", + "payload": { + "variables": {}, + "extensions": {}, + "operationName": null, + "query": "subscription { asyncHuman }", + }, + })), + WsIntegrationMessage::Expect( + json!({ + "type": "error", + "id": "1", + "payload": [{ + "message": "Field \"asyncHuman\" of type \"Human!\" must have a selection \ + of subfields. Did you mean \"asyncHuman { ... }\"?", + "locations": [{ + "line": 1, + "column": 16, + }], + }], + }), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + ]; + + integration.run(messages).await.unwrap(); + } } - async fn test_ws_simple_subscription(integration: &T) { - let messages = vec![ - WsIntegrationMessage::Send( - r#"{ - "type":"connection_init", - "payload":{} - }"# - .into(), - ), - WsIntegrationMessage::Expect( - r#"{ - "type":"connection_ack" - }"# - .into(), - WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, - ), - WsIntegrationMessage::Expect( - r#"{ - "type":"ka" - }"# - .into(), - WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, - ), - WsIntegrationMessage::Send( - r#"{ - "id":"1", - "type":"start", - "payload":{ - "variables":{}, - "extensions":{}, - "operationName":null, - "query":"subscription { asyncHuman { id, name, homePlanet } }" - } - }"# - .into(), - ), - WsIntegrationMessage::Expect( - r#"{ - "type":"data", - "id":"1", - "payload":{ - "data":{ - "asyncHuman":{ - "id":"1000", - "name":"Luke Skywalker", - "homePlanet":"Tatooine" - } - } - } - }"# - .into(), - WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, - ), - ]; + /// Integration tests for the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new]. + /// + /// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md + pub mod graphql_transport_ws { + use serde_json::json; - integration.run(messages).await.unwrap(); - } + use super::{WsIntegration, WsIntegrationMessage, WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT}; - async fn test_ws_invalid_json(integration: &T) { - let messages = vec![ - WsIntegrationMessage::Send("invalid json".into()), - WsIntegrationMessage::Expect( - r#"{ - "type":"connection_error", - "payload":{ - "message":"serde error: expected value at line 1 column 1" - } - }"# - .into(), - WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, - ), - ]; + #[allow(missing_docs)] + pub async fn run_test_suite(integration: &T) { + println!("Running `graphql-ws` test suite for integration"); - integration.run(messages).await.unwrap(); - } + println!(" - graphql_ws::test_simple_subscription"); + test_simple_subscription(integration).await; - async fn test_ws_invalid_query(integration: &T) { - let messages = vec![ - WsIntegrationMessage::Send( - r#"{ - "type":"connection_init", - "payload":{} - }"# - .into(), - ), - WsIntegrationMessage::Expect( - r#"{ - "type":"connection_ack" - }"# - .into(), - WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT - ), - WsIntegrationMessage::Expect( - r#"{ - "type":"ka" - }"# - .into(), - WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT - ), - WsIntegrationMessage::Send( - r#"{ - "id":"1", - "type":"start", - "payload":{ - "variables":{}, - "extensions":{}, - "operationName":null, - "query":"subscription { asyncHuman }" - } - }"# - .into(), - ), - WsIntegrationMessage::Expect( - r#"{ - "type":"error", - "id":"1", - "payload":[{ - "message":"Field \"asyncHuman\" of type \"Human!\" must have a selection of subfields. Did you mean \"asyncHuman { ... }\"?", - "locations":[{ - "line":1, - "column":16 - }] - }] - }"# - .into(), - WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT - ) - ]; + println!(" - graphql_ws::test_invalid_json"); + test_invalid_json(integration).await; - integration.run(messages).await.unwrap(); + println!(" - graphql_ws::test_invalid_query"); + test_invalid_query(integration).await; + } + + async fn test_simple_subscription(integration: &T) { + let messages = vec![ + WsIntegrationMessage::Send(json!({ + "type": "connection_init", + "payload": {}, + })), + WsIntegrationMessage::Expect( + json!({"type": "connection_ack"}), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + WsIntegrationMessage::Expect( + json!({"type": "pong"}), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + WsIntegrationMessage::Send(json!({"type": "ping"})), + WsIntegrationMessage::Expect( + json!({"type": "pong"}), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + WsIntegrationMessage::Send(json!({ + "id": "1", + "type": "subscribe", + "payload": { + "variables": {}, + "extensions": {}, + "operationName": null, + "query": "subscription { asyncHuman { id, name, homePlanet } }", + }, + })), + WsIntegrationMessage::Expect( + json!({ + "id": "1", + "type": "next", + "payload": { + "data": { + "asyncHuman": { + "id": "1000", + "name": "Luke Skywalker", + "homePlanet": "Tatooine", + }, + }, + }, + }), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + ]; + + integration.run(messages).await.unwrap(); + } + + async fn test_invalid_json(integration: &T) { + let messages = vec![ + WsIntegrationMessage::Send(json!({"whatever": "invalid value"})), + WsIntegrationMessage::Expect( + json!({ + "code": 4400, + "description": "`serde` error: missing field `type` at line 1 column 28", + }), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + ]; + + integration.run(messages).await.unwrap(); + } + + async fn test_invalid_query(integration: &T) { + let messages = vec![ + WsIntegrationMessage::Send(json!({ + "type": "connection_init", + "payload": {}, + })), + WsIntegrationMessage::Expect( + json!({"type": "connection_ack"}), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + WsIntegrationMessage::Expect( + json!({"type": "pong"}), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + WsIntegrationMessage::Send(json!({"type": "ping"})), + WsIntegrationMessage::Expect( + json!({"type": "pong"}), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + WsIntegrationMessage::Send(json!({ + "id": "1", + "type": "subscribe", + "payload": { + "variables": {}, + "extensions": {}, + "operationName": null, + "query": "subscription { asyncHuman }", + }, + })), + WsIntegrationMessage::Expect( + json!({ + "type": "error", + "id": "1", + "payload": [{ + "message": "Field \"asyncHuman\" of type \"Human!\" must have a selection \ + of subfields. Did you mean \"asyncHuman { ... }\"?", + "locations": [{ + "line": 1, + "column": 16, + }], + }], + }), + WS_INTEGRATION_EXPECT_DEFAULT_TIMEOUT, + ), + ]; + + integration.run(messages).await.unwrap(); + } } } diff --git a/juniper_actix/CHANGELOG.md b/juniper_actix/CHANGELOG.md index acbbab68..2ddd1695 100644 --- a/juniper_actix/CHANGELOG.md +++ b/juniper_actix/CHANGELOG.md @@ -13,13 +13,13 @@ All user visible changes to `juniper_actix` crate will be documented in this fil - Switched to 4.0 version of [`actix-web` crate] and its ecosystem. ([#1034]) - Switched to 0.16 version of [`juniper` crate]. - Switched to 0.4 version of [`juniper_graphql_ws` crate]. -- Renamed `subscriptions::subscriptions_handler()` as `subscriptions::graphql_ws_handler()` for processing the [legacy `graphql-ws` GraphQL over WebSocket Protocol][graphql-ws]. ([#1191]) +- Switched to 0.2 version of [`actix-ws` crate]. ([#1197]) +- Renamed `subscriptions::subscriptions_handler()` as `subscriptions::graphql_ws_handler()` for processing the [legacy `graphql-ws` GraphQL over WebSocket Protocol][graphql-ws]. ([#1191], [#1197]) ### Added -- `subscriptions::graphql_transport_ws_handler()` allowing to process the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][graphql-transport-ws]. ([#1191]) -- `subscriptions::ws_handler()` with auto-selection between the [legacy `graphql-ws` GraphQL over WebSocket Protocol][graphql-ws] and the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][graphql-transport-ws], based on the `Sec-Websocket-Protocol` HTTP header value. ([#1191]) -- Support of 0.14 version of [`actix` crate]. ([#1189]) +- `subscriptions::graphql_transport_ws_handler()` allowing to process the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][graphql-transport-ws]. ([#1191], [#1197]) +- `subscriptions::ws_handler()` with auto-selection between the [legacy `graphql-ws` GraphQL over WebSocket Protocol][graphql-ws] and the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][graphql-transport-ws], based on the `Sec-Websocket-Protocol` HTTP header value. ([#1191], [#1197]) ### Fixed @@ -28,8 +28,8 @@ All user visible changes to `juniper_actix` crate will be documented in this fil [#1034]: /../../pull/1034 [#1169]: /../../issues/1169 [#1187]: /../../pull/1187 -[#1189]: /../../pull/1189 [#1191]: /../../pull/1191 +[#1197]: /../../pull/1197 @@ -43,6 +43,7 @@ See [old CHANGELOG](/../../blob/juniper_actix-v0.4.0/juniper_actix/CHANGELOG.md) [`actix` crate]: https://docs.rs/actix [`actix-web` crate]: https://docs.rs/actix-web +[`actix-ws` crate]: https://docs.rs/actix-ws [`juniper` crate]: https://docs.rs/juniper [`juniper_graphql_ws` crate]: https://docs.rs/juniper_graphql_ws [Semantic Versioning 2.0.0]: https://semver.org diff --git a/juniper_actix/Cargo.toml b/juniper_actix/Cargo.toml index 7b5d0246..84c38f23 100644 --- a/juniper_actix/Cargo.toml +++ b/juniper_actix/Cargo.toml @@ -19,18 +19,12 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [features] -subscriptions = [ - "dep:actix", - "dep:actix-web-actors", - "dep:juniper_graphql_ws", - "dep:tokio", -] +subscriptions = ["dep:actix-ws", "dep:juniper_graphql_ws"] [dependencies] -actix = { version = ">=0.12, <=0.14", optional = true } actix-http = "3.2" actix-web = "4.4" -actix-web-actors = { version = "4.1", optional = true } +actix-ws = { version = "0.2", optional = true } anyhow = "1.0.47" futures = "0.3.22" juniper = { version = "0.16.0-dev", path = "../juniper", default-features = false } @@ -39,7 +33,6 @@ http = "0.2.4" serde = { version = "1.0.122", features = ["derive"] } serde_json = "1.0.18" thiserror = "1.0" -tokio = { version = "1.0", features = ["sync"], optional = true } [dev-dependencies] actix-cors = "0.6" diff --git a/juniper_actix/src/lib.rs b/juniper_actix/src/lib.rs index 2a476d33..b24a5c29 100644 --- a/juniper_actix/src/lib.rs +++ b/juniper_actix/src/lib.rs @@ -168,23 +168,15 @@ pub async fn playground_handler( #[cfg(feature = "subscriptions")] /// `juniper_actix` subscriptions handler implementation. pub mod subscriptions { - use std::{fmt, sync::Arc}; + use std::{fmt, pin::pin, sync::Arc}; - use actix::{ - AsyncContext as _, ContextFutureSpawner as _, Handler, StreamHandler, WrapFuture as _, - }; use actix_web::{ http::header::{HeaderName, HeaderValue}, web, HttpRequest, HttpResponse, }; - use actix_web_actors::ws; - use futures::{ - stream::{SplitSink, SplitStream}, - SinkExt as _, Stream, StreamExt as _, - }; + use futures::{future, SinkExt as _, StreamExt as _}; use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue}; use juniper_graphql_ws::{graphql_transport_ws, graphql_ws, ArcSchema, Init}; - use tokio::sync::Mutex; /// Serves by auto-selecting between the /// [legacy `graphql-ws` GraphQL over WebSocket Protocol][old] and the @@ -261,22 +253,45 @@ pub mod subscriptions { S: ScalarValue + Send + Sync + 'static, I: Init + Send, { - let (s_tx, s_rx) = graphql_ws::Connection::new(ArcSchema(schema), init).split::(); + let (mut resp, mut ws_tx, ws_rx) = actix_ws::handle(&req, stream)?; + let (s_tx, mut s_rx) = graphql_ws::Connection::new(ArcSchema(schema), init).split(); - let mut resp = ws::start( - Actor { - tx: Arc::new(Mutex::new(s_tx)), - rx: Arc::new(Mutex::new(s_rx)), - }, - &req, - stream, - )?; + actix_web::rt::spawn(async move { + let input = ws_rx + .map(|r| r.map(Message)) + .forward(s_tx.sink_map_err(|e| match e {})); + let output = pin!(async move { + while let Some(msg) = s_rx.next().await { + match serde_json::to_string(&msg) { + Ok(m) => { + if ws_tx.text(m).await.is_err() { + return; + } + } + Err(e) => { + _ = ws_tx + .close(Some(actix_ws::CloseReason { + code: actix_ws::CloseCode::Error, + description: Some(format!("error serializing response: {e}")), + })) + .await; + return; + } + } + } + _ = ws_tx + .close(Some((actix_ws::CloseCode::Normal, "Normal Closure").into())) + .await; + }); + + // No errors can be returned here, so ignoring is OK. + _ = future::select(input, output).await; + }); resp.headers_mut().insert( HeaderName::from_static("sec-websocket-protocol"), HeaderValue::from_static("graphql-ws"), ); - Ok(resp) } @@ -306,164 +321,80 @@ pub mod subscriptions { S: ScalarValue + Send + Sync + 'static, I: Init + Send, { - let (s_tx, s_rx) = - graphql_transport_ws::Connection::new(ArcSchema(schema), init).split::(); + let (mut resp, mut ws_tx, ws_rx) = actix_ws::handle(&req, stream)?; + let (s_tx, mut s_rx) = + graphql_transport_ws::Connection::new(ArcSchema(schema), init).split(); - let mut resp = ws::start( - Actor { - tx: Arc::new(Mutex::new(s_tx)), - rx: Arc::new(Mutex::new(s_rx)), - }, - &req, - stream, - )?; + actix_web::rt::spawn(async move { + let input = ws_rx + .map(|r| r.map(Message)) + .forward(s_tx.sink_map_err(|e| match e {})); + let output = pin!(async move { + while let Some(output) = s_rx.next().await { + match output { + graphql_transport_ws::Output::Message(msg) => { + match serde_json::to_string(&msg) { + Ok(m) => { + if ws_tx.text(m).await.is_err() { + return; + } + } + Err(e) => { + _ = ws_tx + .close(Some(actix_ws::CloseReason { + code: actix_ws::CloseCode::Error, + description: Some(format!( + "error serializing response: {e}", + )), + })) + .await; + return; + } + } + } + graphql_transport_ws::Output::Close { code, message } => { + _ = ws_tx + .close(Some(actix_ws::CloseReason { + code: code.into(), + description: Some(message), + })) + .await; + return; + } + } + } + _ = ws_tx + .close(Some((actix_ws::CloseCode::Normal, "Normal Closure").into())) + .await; + }); + + // No errors can be returned here, so ignoring is OK. + _ = future::select(input, output).await; + }); resp.headers_mut().insert( HeaderName::from_static("sec-websocket-protocol"), HeaderValue::from_static("graphql-transport-ws"), ); - Ok(resp) } - type ConnectionSplitSink = Arc>>; - type ConnectionSplitStream = Arc>>; - - /// [`actix::Actor`], coordinating messages between [`actix_web`] and [`juniper_graphql_ws`]: - /// - incoming [`ws::Message`] -> [`Actor`] -> [`juniper`] - /// - [`juniper`] -> [`Actor`] -> response [`ws::Message`] - struct Actor { - tx: ConnectionSplitSink, - rx: ConnectionSplitStream, - } - - impl StreamHandler> for Actor - where - Self: actix::Actor>, - Conn: futures::Sink, - >::Error: fmt::Debug, - { - fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { - #[allow(clippy::single_match)] - match msg { - Ok(msg) => { - let tx = self.tx.clone(); - - // TODO: Somehow this implementation always closes as `1006: Abnormal closure` - // due to excessive polling of `tx` part. - // Needs to be reworked. - async move { - tx.lock() - .await - .send(Message(msg)) - .await - .expect("Infallible: this should not happen"); - } - .into_actor(self) - .wait(ctx); - } - Err(_) => { - // TODO: trace - // ignore the message if there's a transport error - } - } - } - } - - /// [`juniper`] -> [`Actor`]. - impl actix::Actor for Actor - where - Conn: Stream + 'static, - ::Item: IntoWsResponse + Send, - { - type Context = ws::WebsocketContext; - - fn started(&mut self, ctx: &mut Self::Context) { - let stream = self.rx.clone(); - let addr = ctx.address(); - - let fut = async move { - let mut stream = stream.lock().await; - while let Some(msg) = stream.next().await { - // Sending the `msg` to `self`, so that it can be forwarded back to the client. - addr.do_send(ServerMessage(msg)); - } - } - .into_actor(self); - - // TODO: trace - ctx.spawn(fut); - } - - fn stopped(&mut self, _: &mut Self::Context) { - // TODO: trace - } - } - - /// [`Actor`] -> response [`ws::Message`]. - impl Handler> for Actor - where - Conn: Stream + 'static, - M: IntoWsResponse + Send, - { - type Result = (); - - fn handle(&mut self, msg: ServerMessage, ctx: &mut Self::Context) -> Self::Result { - match msg.0.into_ws_response() { - Ok(msg) => ctx.text(msg), - // TODO: trace - Err(reason) => ctx.close(Some(reason)), - } - } - } - - #[derive(actix::Message)] - #[rtype(result = "()")] - struct ServerMessage(T); - - /// Conversion of a [`ServerMessage`] into a response [`ws::Message`]. - pub trait IntoWsResponse { - /// Converts this [`ServerMessage`] into response [`ws::Message`]. - fn into_ws_response(self) -> Result; - } - - impl IntoWsResponse for graphql_transport_ws::Output { - fn into_ws_response(self) -> Result { - match self { - Self::Message(msg) => serde_json::to_string(&msg).map_err(|e| ws::CloseReason { - code: ws::CloseCode::Error, - description: Some(format!("error serializing response: {e}")), - }), - Self::Close { code, message } => Err(ws::CloseReason { - code: code.into(), - description: Some(message), - }), - } - } - } - - impl IntoWsResponse for graphql_ws::ServerMessage { - fn into_ws_response(self) -> Result { - serde_json::to_string(&self).map_err(|e| ws::CloseReason { - code: ws::CloseCode::Error, - description: Some(format!("error serializing response: {e}")), - }) - } - } - #[derive(Debug)] - struct Message(ws::Message); + struct Message(actix_ws::Message); impl TryFrom for graphql_transport_ws::Input { type Error = Error; fn try_from(msg: Message) -> Result { match msg.0 { - ws::Message::Text(text) => serde_json::from_slice(text.as_bytes()) + actix_ws::Message::Text(text) => serde_json::from_slice(text.as_bytes()) .map(Self::Message) .map_err(Error::Serde), - ws::Message::Close(_) => Ok(Self::Close), - _ => Err(Error::UnexpectedClientMessage), + actix_ws::Message::Binary(bytes) => serde_json::from_slice(bytes.as_ref()) + .map(Self::Message) + .map_err(Error::Serde), + actix_ws::Message::Close(_) => Ok(Self::Close), + other => Err(Error::UnexpectedClientMessage(other)), } } } @@ -473,31 +404,34 @@ pub mod subscriptions { fn try_from(msg: Message) -> Result { match msg.0 { - ws::Message::Text(text) => { + actix_ws::Message::Text(text) => { serde_json::from_slice(text.as_bytes()).map_err(Error::Serde) } - ws::Message::Close(_) => Ok(Self::ConnectionTerminate), - _ => Err(Error::UnexpectedClientMessage), + actix_ws::Message::Binary(bytes) => { + serde_json::from_slice(bytes.as_ref()).map_err(Error::Serde) + } + actix_ws::Message::Close(_) => Ok(Self::ConnectionTerminate), + other => Err(Error::UnexpectedClientMessage(other)), } } } - /// Errors that can happen while handling client messages + /// Possible errors of serving an [`actix_ws`] connection. #[derive(Debug)] enum Error { - /// Errors that can happen while deserializing client messages + /// Deserializing of a client or server message failed. Serde(serde_json::Error), - /// Error for unexpected client messages - UnexpectedClientMessage, + /// Unexpected client [`actix_ws::Message`]. + UnexpectedClientMessage(actix_ws::Message), } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Serde(e) => write!(f, "serde error: {e}"), - Self::UnexpectedClientMessage => { - write!(f, "unexpected message received from client") + Self::Serde(e) => write!(f, "`serde` error: {e}"), + Self::UnexpectedClientMessage(m) => { + write!(f, "unexpected message received from client: {m:?}") } } } @@ -514,7 +448,7 @@ mod tests { use actix_web::{ dev::ServiceResponse, http, - http::header::CONTENT_TYPE, + http::header::{ACCEPT, CONTENT_TYPE}, test::{self, TestRequest}, web::Data, App, @@ -527,7 +461,6 @@ mod tests { }; use super::*; - use actix_web::http::header::ACCEPT; type Schema = juniper::RootNode<'static, Query, EmptyMutation, EmptySubscription>; @@ -832,41 +765,37 @@ mod tests { #[cfg(feature = "subscriptions")] #[cfg(test)] mod subscription_tests { - use std::time::Duration; - + use actix_http::ws; use actix_test::start; - use actix_web::{ - web::{self, Data}, - App, Error, HttpRequest, HttpResponse, - }; - use actix_web_actors::ws; + use actix_web::{web, App, Error, HttpRequest, HttpResponse}; use juniper::{ futures::{SinkExt, StreamExt}, - http::tests::{run_ws_test_suite, WsIntegration, WsIntegrationMessage}, + http::tests::{graphql_transport_ws, graphql_ws, WsIntegration, WsIntegrationMessage}, tests::fixtures::starwars::schema::{Database, Query, Subscription}, EmptyMutation, LocalBoxFuture, }; use juniper_graphql_ws::ConnectionConfig; use tokio::time::timeout; - use super::subscriptions::graphql_ws_handler; + use super::subscriptions; - #[derive(Default)] - struct TestActixWsIntegration; + struct TestWsIntegration(&'static str); - impl TestActixWsIntegration { + impl TestWsIntegration { async fn run_async( &self, messages: Vec, ) -> Result<(), anyhow::Error> { + let proto = self.0; + let mut server = start(|| { App::new() - .app_data(Data::new(Schema::new( + .app_data(web::Data::new(Schema::new( Query, EmptyMutation::::new(), Subscription, ))) - .service(web::resource("/subscriptions").to(subscriptions)) + .service(web::resource("/subscriptions").to(subscription(proto))) }); let mut framed = server.ws_at("/subscriptions").await.unwrap(); @@ -874,12 +803,12 @@ mod subscription_tests { match message { WsIntegrationMessage::Send(body) => { framed - .send(ws::Message::Text(body.to_owned().into())) + .send(ws::Message::Text(body.to_string().into())) .await .map_err(|e| anyhow::anyhow!("WS error: {e:?}"))?; } WsIntegrationMessage::Expect(body, message_timeout) => { - let frame = timeout(Duration::from_millis(*message_timeout), framed.next()) + let frame = timeout(*message_timeout, framed.next()) .await .map_err(|_| anyhow::anyhow!("Timed-out waiting for message"))? .ok_or_else(|| anyhow::anyhow!("Empty message received"))? @@ -887,21 +816,29 @@ mod subscription_tests { match frame { ws::Frame::Text(ref bytes) => { - let expected_value = - serde_json::from_str::(body) - .map_err(|e| anyhow::anyhow!("Serde error: {e:?}"))?; - let value: serde_json::Value = serde_json::from_slice(bytes) .map_err(|e| anyhow::anyhow!("Serde error: {e:?}"))?; - if value != expected_value { + if value != *body { return Err(anyhow::anyhow!( - "Expected message: {expected_value}. \ + "Expected message: {body}. \ Received message: {value}", )); } } - _ => return Err(anyhow::anyhow!("Received non-text frame")), + ws::Frame::Close(Some(reason)) => { + let actual = serde_json::json!({ + "code": u16::from(reason.code), + "description": reason.description, + }); + if actual != *body { + return Err(anyhow::anyhow!( + "Expected message: {body}. \ + Received message: {actual}", + )); + } + } + f => return Err(anyhow::anyhow!("Received non-text frame: {f:?}")), } } } @@ -911,7 +848,7 @@ mod subscription_tests { } } - impl WsIntegration for TestActixWsIntegration { + impl WsIntegration for TestWsIntegration { fn run( &self, messages: Vec, @@ -922,20 +859,32 @@ mod subscription_tests { type Schema = juniper::RootNode<'static, Query, EmptyMutation, Subscription>; - async fn subscriptions( - req: HttpRequest, - stream: web::Payload, - schema: web::Data, - ) -> Result { - let context = Database::new(); - let schema = schema.into_inner(); - let config = ConnectionConfig::new(context); + fn subscription( + proto: &'static str, + ) -> impl actix_web::Handler< + (HttpRequest, web::Payload, web::Data), + Output = Result, + > { + move |req: HttpRequest, stream: web::Payload, schema: web::Data| async move { + let context = Database::new(); + let schema = schema.into_inner(); + let config = ConnectionConfig::new(context); - graphql_ws_handler(req, stream, schema, config).await + if proto == "graphql-ws" { + subscriptions::graphql_ws_handler(req, stream, schema, config).await + } else { + subscriptions::graphql_transport_ws_handler(req, stream, schema, config).await + } + } } #[actix_web::rt::test] - async fn test_actix_ws_integration() { - run_ws_test_suite(&mut TestActixWsIntegration).await; + async fn test_graphql_ws_integration() { + graphql_ws::run_test_suite(&mut TestWsIntegration("graphql-ws")).await; + } + + #[actix_web::rt::test] + async fn test_graphql_transport_ws_integration() { + graphql_transport_ws::run_test_suite(&mut TestWsIntegration("graphql-transport-ws")).await; } } diff --git a/juniper_graphql_ws/CHANGELOG.md b/juniper_graphql_ws/CHANGELOG.md index 5c84f57f..489d8568 100644 --- a/juniper_graphql_ws/CHANGELOG.md +++ b/juniper_graphql_ws/CHANGELOG.md @@ -16,7 +16,7 @@ All user visible changes to `juniper_graphql_ws` crate will be documented in thi ### Added -- `graphql_transport_ws` module implementing [`graphql-transport-ws` GraphQL over WebSocket Protocol][proto-5.14.0] as of 5.14.0 version of [`graphql-ws` npm package] behind `graphql-transport-ws` Cargo feature. ([#1158], [#1191], [#1196], [#1022]) +- `graphql_transport_ws` module implementing [`graphql-transport-ws` GraphQL over WebSocket Protocol][proto-5.14.0] as of 5.14.0 version of [`graphql-ws` npm package] behind `graphql-transport-ws` Cargo feature. ([#1158], [#1191], [#1196], [#1197], [#1022]) ### Changed @@ -26,6 +26,7 @@ All user visible changes to `juniper_graphql_ws` crate will be documented in thi [#1158]: /../../pull/1158 [#1191]: /../../pull/1191 [#1196]: /../../pull/1196 +[#1197]: /../../pull/1197 [proto-5.14.0]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md [proto-legacy]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md diff --git a/juniper_graphql_ws/src/graphql_transport_ws/mod.rs b/juniper_graphql_ws/src/graphql_transport_ws/mod.rs index 4d22c6a2..c08cf021 100644 --- a/juniper_graphql_ws/src/graphql_transport_ws/mod.rs +++ b/juniper_graphql_ws/src/graphql_transport_ws/mod.rs @@ -478,6 +478,30 @@ where }, } } + + /// Performs polling of the [`Sink`] part of this [`Connection`]. + /// + /// Effectively represents an implementation of [`Sink::poll_ready()`] and + /// [`Sink::poll_flush()`] methods. + fn poll_sink(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match &mut self.sink_state { + ConnectionSinkState::Ready { .. } => Poll::Ready(Ok(())), + ConnectionSinkState::HandlingMessage { ref mut result } => { + match Pin::new(result).poll(cx) { + Poll::Ready((state, reactions)) => { + self.reactions.push(reactions); + self.sink_state = ConnectionSinkState::Ready { state }; + if let Some(waker) = self.stream_waker.take() { + waker.wake(); + } + Poll::Ready(Ok(())) + } + Poll::Pending => Poll::Pending, + } + } + ConnectionSinkState::Closed => Poll::Ready(Err("polled after close")), + } + } } impl Sink for Connection @@ -489,21 +513,9 @@ where { type Error = Infallible; - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - match &mut self.sink_state { - ConnectionSinkState::Ready { .. } => Poll::Ready(Ok(())), - ConnectionSinkState::HandlingMessage { ref mut result } => { - match Pin::new(result).poll(cx) { - Poll::Ready((state, reactions)) => { - self.reactions.push(reactions); - self.sink_state = ConnectionSinkState::Ready { state }; - Poll::Ready(Ok(())) - } - Poll::Pending => Poll::Pending, - } - } - ConnectionSinkState::Closed => panic!("poll_ready called after close"), - } + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.poll_sink(cx) + .map_err(|e| panic!("`Connection::poll_ready()`: {e}")) } fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { @@ -538,19 +550,19 @@ where } } } - _ => panic!("start_send called when not ready"), + _ => panic!("`Sink::start_send()`: called when not ready"), }; Ok(()) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - >::poll_ready(self, cx) + self.poll_sink(cx).map(|_| Ok(())) } fn poll_close(mut self: Pin<&mut Self>, _: &mut Context) -> Poll> { self.sink_state = ConnectionSinkState::Closed; if let Some(waker) = self.stream_waker.take() { - // Wake up the stream so it can close too. + // Wake up the `Stream` so it can close too. waker.wake(); } Poll::Ready(Ok(())) @@ -567,9 +579,7 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.stream_waker = Some(cx.waker().clone()); - if let ConnectionSinkState::Closed = self.sink_state { - return Poll::Ready(None); - } else if self.stream_terminated { + if self.stream_terminated { return Poll::Ready(None); } @@ -590,6 +600,11 @@ where _ => (), } } + + if let ConnectionSinkState::Closed = self.sink_state { + return Poll::Ready(None); + } + Poll::Pending } } diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs index 595fe8dc..f1599e3b 100644 --- a/juniper_warp/src/lib.rs +++ b/juniper_warp/src/lib.rs @@ -392,8 +392,8 @@ pub mod subscriptions { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Warp(e) => write!(f, "warp error: {e}"), - Self::Serde(e) => write!(f, "serde error: {e}"), + Self::Warp(e) => write!(f, "`warp` error: {e}"), + Self::Serde(e) => write!(f, "`serde` error: {e}"), } } }