From 761710205a5a928ce1d7be69894650059f91e877 Mon Sep 17 00:00:00 2001 From: Benno Tielen Date: Thu, 9 Nov 2023 11:57:00 +0100 Subject: [PATCH] Provide `axum` integration (#1088, #986, #1184) - create `juniper_axum` crate in Cargo workspace - implement `graphql` default `axum` handler for processing GraphQL requests - implement `extract::JuniperRequest` and `response::JuniperResponse` for custom processing GraphQL requests - implement `subscriptions::graphql_transport_ws()` default `axum` handler for processing the new `graphql-transport-ws` GraphQL over WebSocket Protocol - implement `subscriptions::graphql_ws()` default `axum` handler for processing the legacy `graphql-ws` GraphQL over WebSocket Protocol - implement `subscriptions::serve_graphql_transport_ws()` function for custom processing the new `graphql-transport-ws` GraphQL over WebSocket Protocol - implement `subscriptions::serve_graphql_ws()` function for custom processing the legacy `graphql-ws` GraphQL over WebSocket Protocol - provide `examples/simple.rs` of default `juniper_axum` integration - provide `examples/custom.rs` of custom `juniper_axum` integration Additionally: - fix `junper_actix` crate MSRV to 1.73 - add `test_post_with_variables()` case to integration `juniper::http::tests` Co-authored-by: ilslv Co-authored-by: Christian Legnitto Co-authored-by: Kai Ren --- .github/workflows/ci.yml | 5 + Cargo.toml | 1 + README.md | 2 + juniper/README.md | 3 + juniper/release.toml | 6 + juniper/src/http/mod.rs | 62 +- juniper/src/tests/fixtures/starwars/schema.rs | 2 + juniper_actix/Cargo.toml | 4 +- juniper_actix/README.md | 6 +- juniper_actix/src/lib.rs | 3 +- juniper_axum/CHANGELOG.md | 43 ++ juniper_axum/Cargo.toml | 61 ++ juniper_axum/LICENSE | 25 + juniper_axum/README.md | 47 ++ juniper_axum/examples/custom.rs | 86 +++ juniper_axum/examples/simple.rs | 87 +++ juniper_axum/release.toml | 12 + juniper_axum/src/extract.rs | 293 ++++++++ juniper_axum/src/lib.rs | 139 ++++ juniper_axum/src/response.rs | 24 + juniper_axum/src/subscriptions.rs | 694 ++++++++++++++++++ juniper_axum/tests/http_test_suite.rs | 112 +++ juniper_axum/tests/ws_test_suite.rs | 142 ++++ juniper_graphql_ws/release.toml | 6 + juniper_warp/src/lib.rs | 14 +- 25 files changed, 1854 insertions(+), 25 deletions(-) create mode 100644 juniper_axum/CHANGELOG.md create mode 100644 juniper_axum/Cargo.toml create mode 100644 juniper_axum/LICENSE create mode 100644 juniper_axum/README.md create mode 100644 juniper_axum/examples/custom.rs create mode 100644 juniper_axum/examples/simple.rs create mode 100644 juniper_axum/release.toml create mode 100644 juniper_axum/src/extract.rs create mode 100644 juniper_axum/src/lib.rs create mode 100644 juniper_axum/src/response.rs create mode 100644 juniper_axum/src/subscriptions.rs create mode 100644 juniper_axum/tests/http_test_suite.rs create mode 100644 juniper_axum/tests/ws_test_suite.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ce5075b1..fda35726 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -115,6 +115,8 @@ jobs: - { feature: graphql-ws, crate: juniper_graphql_ws } - { feature: , crate: juniper_actix } - { feature: subscriptions, crate: juniper_actix } + - { feature: , crate: juniper_axum } + - { feature: subscriptions, crate: juniper_axum } - { feature: , crate: juniper_warp } - { feature: subscriptions, crate: juniper_warp } runs-on: ubuntu-latest @@ -148,6 +150,7 @@ jobs: - juniper_subscriptions - juniper_graphql_ws - juniper_actix + - juniper_axum - juniper_hyper #- juniper_iron - juniper_rocket @@ -200,6 +203,7 @@ jobs: - juniper_integration_tests - juniper_codegen_tests - juniper_actix + - juniper_axum - juniper_hyper - juniper_iron - juniper_rocket @@ -326,6 +330,7 @@ jobs: - juniper_subscriptions - juniper_graphql_ws - juniper_actix + - juniper_axum - juniper_hyper - juniper_iron - juniper_rocket diff --git a/Cargo.toml b/Cargo.toml index 212c896d..ea4d8e2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "juniper_graphql_ws", "juniper_warp", "juniper_actix", + "juniper_axum", "tests/codegen", "tests/integration", ] diff --git a/README.md b/README.md index e5c1f96d..3aa641e5 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,7 @@ your Schemas automatically. ### Web Frameworks - [actix][actix] +- [axum][axum] - [hyper][hyper] - [rocket][rocket] - [iron][iron] @@ -93,6 +94,7 @@ your Schemas automatically. Juniper has not reached 1.0 yet, thus some API instability should be expected. [actix]: https://actix.rs/ +[axum]: https://docs.rs/axum [graphql]: http://graphql.org [graphiql]: https://github.com/graphql/graphiql [playground]: https://github.com/prisma/graphql-playground diff --git a/juniper/README.md b/juniper/README.md index bd3a77bf..b72bbd6c 100644 --- a/juniper/README.md +++ b/juniper/README.md @@ -58,6 +58,7 @@ As an exception to other [GraphQL] libraries for other languages, [Juniper] buil ### Web servers - [`actix-web`] ([`juniper_actix`] crate) +- [`axum`] ([`juniper_axum`] crate) - [`hyper`] ([`juniper_hyper`] crate) - [`iron`] ([`juniper_iron`] crate) - [`rocket`] ([`juniper_rocket`] crate) @@ -81,11 +82,13 @@ This project is licensed under [BSD 2-Clause License](https://github.com/graphql [`actix-web`]: https://docs.rs/actix-web +[`axum`]: https://docs.rs/axum [`bigdecimal`]: https://docs.rs/bigdecimal [`bson`]: https://docs.rs/bson [`chrono`]: https://docs.rs/chrono [`chrono-tz`]: https://docs.rs/chrono-tz [`juniper_actix`]: https://docs.rs/juniper_actix +[`juniper_axum`]: https://docs.rs/juniper_axum [`juniper_hyper`]: https://docs.rs/juniper_hyper [`juniper_iron`]: https://docs.rs/juniper_iron [`juniper_rocket`]: https://docs.rs/juniper_rocket diff --git a/juniper/release.toml b/juniper/release.toml index f490007e..fe1d0201 100644 --- a/juniper/release.toml +++ b/juniper/release.toml @@ -40,6 +40,12 @@ exactly = 2 search = "juniper = \\{ version = \"[^\"]+\"" replace = "juniper = { version = \"{{version}}\"" +[[pre-release-replacements]] +file = "../juniper_axum/Cargo.toml" +exactly = 2 +search = "juniper = \\{ version = \"[^\"]+\"" +replace = "juniper = { version = \"{{version}}\"" + [[pre-release-replacements]] file = "../juniper_graphql_ws/Cargo.toml" exactly = 1 diff --git a/juniper/src/http/mod.rs b/juniper/src/http/mod.rs index 06244a6f..b8b4cb00 100644 --- a/juniper/src/http/mod.rs +++ b/juniper/src/http/mod.rs @@ -37,6 +37,7 @@ where pub operation_name: Option, /// Optional variables to execute the GraphQL operation with. + // TODO: Use `Variables` instead of `InputValue`? #[serde(bound( deserialize = "InputValue: Deserialize<'de>", serialize = "InputValue: Serialize", @@ -238,11 +239,11 @@ where /// A batch operation request. /// /// Empty batch is considered as invalid value, so cannot be deserialized. - #[serde(deserialize_with = "deserialize_non_empty_vec")] + #[serde(deserialize_with = "deserialize_non_empty_batch")] Batch(Vec>), } -fn deserialize_non_empty_vec<'de, D, T>(deserializer: D) -> Result, D::Error> +fn deserialize_non_empty_batch<'de, D, T>(deserializer: D) -> Result, D::Error> where D: de::Deserializer<'de>, T: Deserialize<'de>, @@ -251,7 +252,10 @@ where let v = Vec::::deserialize(deserializer)?; if v.is_empty() { - Err(D::Error::invalid_length(0, &"a positive integer")) + Err(D::Error::invalid_length( + 0, + &"non-empty batch of GraphQL requests", + )) } else { Ok(v) } @@ -403,6 +407,9 @@ pub mod tests { println!(" - test_get_with_variables"); test_get_with_variables(integration); + println!(" - test_post_with_variables"); + test_post_with_variables(integration); + println!(" - test_simple_post"); test_simple_post(integration); @@ -501,13 +508,48 @@ pub mod tests { "NEW_HOPE", "EMPIRE", "JEDI" - ], - "homePlanet": "Tatooine", - "name": "Luke Skywalker", - "id": "1000" - } + ], + "homePlanet": "Tatooine", + "name": "Luke Skywalker", + "id": "1000" } - }"# + } + }"# + ) + .expect("Invalid JSON constant in test") + ); + } + + fn test_post_with_variables(integration: &T) { + let response = integration.post_json( + "/", + r#"{ + "query": + "query($id: String!) { human(id: $id) { id, name, appearsIn, homePlanet } }", + "variables": {"id": "1000"} + }"#, + ); + + assert_eq!(response.status_code, 200); + assert_eq!(response.content_type, "application/json"); + + assert_eq!( + unwrap_json_response(&response), + serde_json::from_str::( + r#"{ + "data": { + "human": { + "appearsIn": [ + "NEW_HOPE", + "EMPIRE", + "JEDI" + ], + "homePlanet": "Tatooine", + "name": "Luke Skywalker", + "id": "1000" + } + } + }"# ) .expect("Invalid JSON constant in test") ); @@ -752,7 +794,7 @@ pub mod tests { #[allow(missing_docs)] pub async fn run_test_suite(integration: &T) { - println!("Running `graphql-ws` test suite for integration"); + println!("Running `graphql-transport-ws` test suite for integration"); println!(" - graphql_ws::test_simple_subscription"); test_simple_subscription(integration).await; diff --git a/juniper/src/tests/fixtures/starwars/schema.rs b/juniper/src/tests/fixtures/starwars/schema.rs index d89af6bc..c2b094e6 100644 --- a/juniper/src/tests/fixtures/starwars/schema.rs +++ b/juniper/src/tests/fixtures/starwars/schema.rs @@ -4,6 +4,7 @@ use std::{collections::HashMap, pin::Pin}; use crate::{graphql_interface, graphql_object, graphql_subscription, Context, GraphQLEnum}; +#[derive(Clone, Copy, Debug)] pub struct Query; #[graphql_object(context = Database)] @@ -33,6 +34,7 @@ impl Query { } } +#[derive(Clone, Copy, Debug)] pub struct Subscription; type HumanStream = Pin + Send>>; diff --git a/juniper_actix/Cargo.toml b/juniper_actix/Cargo.toml index 84c38f23..722024b8 100644 --- a/juniper_actix/Cargo.toml +++ b/juniper_actix/Cargo.toml @@ -2,7 +2,7 @@ name = "juniper_actix" version = "0.5.0-dev" edition = "2021" -rust-version = "1.68" +rust-version = "1.73" description = "`juniper` GraphQL integration with `actix-web`." license = "BSD-2-Clause" authors = ["Jordao Rosario "] @@ -12,7 +12,7 @@ repository = "https://github.com/graphql-rust/juniper" readme = "README.md" categories = ["asynchronous", "web-programming", "web-programming::http-server"] keywords = ["actix-web", "apollo", "graphql", "juniper", "websocket"] -exclude = ["/examples/", "/release.toml"] +exclude = ["/release.toml"] [package.metadata.docs.rs] all-features = true diff --git a/juniper_actix/README.md b/juniper_actix/README.md index b94de8af..0df1dad6 100644 --- a/juniper_actix/README.md +++ b/juniper_actix/README.md @@ -4,7 +4,7 @@ [![Crates.io](https://img.shields.io/crates/v/juniper_actix.svg?maxAge=2592000)](https://crates.io/crates/juniper_actix) [![Documentation](https://docs.rs/juniper_actix/badge.svg)](https://docs.rs/juniper_actix) [![CI](https://github.com/graphql-rust/juniper/workflows/CI/badge.svg?branch=master "CI")](https://github.com/graphql-rust/juniper/actions?query=workflow%3ACI+branch%3Amaster) -[![Rust 1.68+](https://img.shields.io/badge/rustc-1.68+-lightgray.svg "Rust 1.68+")](https://blog.rust-lang.org/2023/03/09/Rust-1.68.0.html) +[![Rust 1.73+](https://img.shields.io/badge/rustc-1.73+-lightgray.svg "Rust 1.73+")](https://blog.rust-lang.org/2023/10/05/Rust-1.73.0.html) - [Changelog](https://github.com/graphql-rust/juniper/blob/master/juniper_actix/CHANGELOG.md) @@ -26,7 +26,7 @@ A basic usage example can also be found in the [API docs][`juniper_actix`]. ## Examples -Check [`examples/actix_server.rs`][1] for example code of a working [`actix-web`] server with [GraphQL] handlers. +Check [`examples/subscription.rs`][1] for example code of a working [`actix-web`] server with [GraphQL] handlers. @@ -46,5 +46,5 @@ This project is licensed under [BSD 2-Clause License](https://github.com/graphql [Juniper Book]: https://graphql-rust.github.io [Rust]: https://www.rust-lang.org -[1]: https://github.com/graphql-rust/juniper/blob/master/juniper_actix/examples/actix_server.rs +[1]: https://github.com/graphql-rust/juniper/blob/master/juniper_actix/examples/subscription.rs diff --git a/juniper_actix/src/lib.rs b/juniper_actix/src/lib.rs index b24a5c29..4e8e6331 100644 --- a/juniper_actix/src/lib.rs +++ b/juniper_actix/src/lib.rs @@ -143,7 +143,6 @@ where /// let app = App::new() /// .route("/", web::get().to(|| graphiql_handler("/graphql", Some("/graphql/subscriptions")))); /// ``` -#[allow(dead_code)] pub async fn graphiql_handler( graphql_endpoint_url: &str, subscriptions_endpoint_url: Option<&'static str>, @@ -419,7 +418,7 @@ pub mod subscriptions { /// Possible errors of serving an [`actix_ws`] connection. #[derive(Debug)] enum Error { - /// Deserializing of a client or server message failed. + /// Deserializing of a client [`actix_ws::Message`] failed. Serde(serde_json::Error), /// Unexpected client [`actix_ws::Message`]. diff --git a/juniper_axum/CHANGELOG.md b/juniper_axum/CHANGELOG.md new file mode 100644 index 00000000..94ae9800 --- /dev/null +++ b/juniper_axum/CHANGELOG.md @@ -0,0 +1,43 @@ +`juniper_axum` changelog +======================== + +All user visible changes to `juniper_axum` crate will be documented in this file. This project uses [Semantic Versioning 2.0.0]. + + + + +## master + +### Initialized + +- Dependent on 0.6 version of [`axum` crate]. ([#1088]) +- Dependent on 0.16 version of [`juniper` crate]. ([#1088]) +- Dependent on 0.4 version of [`juniper_graphql_ws` crate]. ([#1088]) + +### Added + +- `extract::JuniperRequest` and `response::JuniperResponse` for using in custom [`axum` crate] handlers. ([#1088]) +- `graphql` handler processing [GraphQL] requests for the specified schema. ([#1088], [#1184]) +- `subscriptions::graphql_transport_ws()` handler and `subscriptions::serve_graphql_transport_ws()` function allowing to process the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][graphql-transport-ws]. ([#1088], [#986]) +- `subscriptions::graphql_ws()` handler and `subscriptions::serve_graphql_ws()` function allowing to process the [legacy `graphql-ws` GraphQL over WebSocket Protocol][graphql-ws]. ([#1088], [#986]) +- `subscriptions::ws()` handler and `subscriptions::serve_ws()` function allowing to auto-select between the [legacy `graphql-ws` GraphQL over WebSocket Protocol][graphql-ws] and the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][graphql-transport-ws], based on the `Sec-Websocket-Protocol` HTTP header value. ([#1088], [#986]) +- `graphiql` handler serving [GraphiQL]. ([#1088]) +- `playground` handler serving [GraphQL Playground]. ([#1088]) +- `simple.rs` and `custom.rs` integration examples. ([#1088], [#986], [#1184]) + +[#986]: /../../issues/986 +[#1088]: /../../pull/1088 +[#1184]: /../../issues/1184 + + + + +[`axum` crate]: https://docs.rs/axum +[`juniper` crate]: https://docs.rs/juniper +[`juniper_graphql_ws` crate]: https://docs.rs/juniper_graphql_ws +[GraphiQL]: https://github.com/graphql/graphiql +[GraphQL]: http://graphql.org +[GraphQL Playground]: https://github.com/prisma/graphql-playground +[Semantic Versioning 2.0.0]: https://semver.org +[graphql-transport-ws]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md +[graphql-ws]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md \ No newline at end of file diff --git a/juniper_axum/Cargo.toml b/juniper_axum/Cargo.toml new file mode 100644 index 00000000..b7cd896e --- /dev/null +++ b/juniper_axum/Cargo.toml @@ -0,0 +1,61 @@ +[package] +name = "juniper_axum" +version = "0.1.0" +edition = "2021" +rust-version = "1.73" +description = "`juniper` GraphQL integration with `axum`." +license = "BSD-2-Clause" +authors = [ + "Benno Tielen ", + "Kai Ren ", +] +documentation = "https://docs.rs/juniper_axum" +homepage = "https://github.com/graphql-rust/juniper/tree/master/juniper_axum" +repository = "https://github.com/graphql-rust/juniper" +readme = "README.md" +categories = ["asynchronous", "web-programming", "web-programming::http-server"] +keywords = ["apollo", "axum", "graphql", "juniper", "websocket"] +exclude = ["/release.toml"] + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[features] +subscriptions = ["axum/ws", "juniper_graphql_ws/graphql-ws", "dep:futures"] + +[dependencies] +axum = "0.6.20" +futures = { version = "0.3.22", optional = true } +juniper = { version = "0.16.0-dev", path = "../juniper", default-features = false } +juniper_graphql_ws = { version = "0.4.0-dev", path = "../juniper_graphql_ws", features = ["graphql-transport-ws"] } +serde = { version = "1.0.122", features = ["derive"] } +serde_json = "1.0.18" + +# Fixes for `minimal-versions` check. +# TODO: Try remove on upgrade of `axum` crate. +bytes = "1.2" + +[dev-dependencies] +anyhow = "1.0" +axum = { version = "0.6", features = ["macros"] } +hyper = "0.14" +juniper = { version = "0.16.0-dev", path = "../juniper", features = ["expose-test-schema"] } +tokio = { version = "1.20", features = ["macros", "rt-multi-thread", "time"] } +tokio-stream = "0.1" +tokio-tungstenite = "0.20" +tracing = "0.1" +tracing-subscriber = "0.3" +urlencoding = "2.1" + +[[example]] +name = "custom" +required-features = ["subscriptions"] + +[[example]] +name = "simple" +required-features = ["subscriptions"] + +[[test]] +name = "ws_test_suite" +required-features = ["subscriptions"] diff --git a/juniper_axum/LICENSE b/juniper_axum/LICENSE new file mode 100644 index 00000000..7967e75f --- /dev/null +++ b/juniper_axum/LICENSE @@ -0,0 +1,25 @@ +BSD 2-Clause License + +Copyright (c) 2022-2023, Benno Tielen, Kai Ren +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/juniper_axum/README.md b/juniper_axum/README.md new file mode 100644 index 00000000..7d59c60c --- /dev/null +++ b/juniper_axum/README.md @@ -0,0 +1,47 @@ +`juniper_axum` crate +==================== + +[![Crates.io](https://img.shields.io/crates/v/juniper_axum.svg?maxAge=2592000)](https://crates.io/crates/juniper_axum) +[![Documentation](https://docs.rs/juniper_axum/badge.svg)](https://docs.rs/juniper_axum) +[![CI](https://github.com/graphql-rust/juniper/workflows/CI/badge.svg?branch=master "CI")](https://github.com/graphql-rust/juniper/actions?query=workflow%3ACI+branch%3Amaster) +[![Rust 1.73+](https://img.shields.io/badge/rustc-1.73+-lightgray.svg "Rust 1.73+")](https://blog.rust-lang.org/2023/10/05/Rust-1.73.0.html) + +- [Changelog](https://github.com/graphql-rust/juniper/blob/master/juniper_axum/CHANGELOG.md) + +[`axum`] web server integration for [`juniper`] ([GraphQL] implementation for [Rust]). + + + + +## Documentation + +For documentation, including guides and examples, check out [Juniper Book]. + +A basic usage example can also be found in the [API docs][`juniper_axum`]. + + + + +## Examples + +Check [`examples/simple.rs`][1] and [`examples/custom.rs`][1] for example code of a working [`axum`] server with [GraphQL] handlers. + + + + +## License + +This project is licensed under [BSD 2-Clause License](https://github.com/graphql-rust/juniper/blob/master/juniper_axum/LICENSE). + + + + +[`axum`]: https://docs.rs/axum +[`juniper`]: https://docs.rs/juniper +[`juniper_axum`]: https://docs.rs/juniper_axum +[GraphQL]: http://graphql.org +[Juniper Book]: https://graphql-rust.github.io +[Rust]: https://www.rust-lang.org + +[1]: https://github.com/graphql-rust/juniper/blob/master/juniper_axum/examples/simple.rs +[2]: https://github.com/graphql-rust/juniper/blob/master/juniper_axum/examples/custom.rs diff --git a/juniper_axum/examples/custom.rs b/juniper_axum/examples/custom.rs new file mode 100644 index 00000000..a2d88b00 --- /dev/null +++ b/juniper_axum/examples/custom.rs @@ -0,0 +1,86 @@ +//! This example demonstrates custom [`Handler`]s with [`axum`], using the [`starwars::schema`]. +//! +//! [`Handler`]: axum::handler::Handler +//! [`starwars::schema`]: juniper::tests::fixtures::starwars::schema + +use std::{net::SocketAddr, sync::Arc}; + +use axum::{ + extract::WebSocketUpgrade, + response::{Html, Response}, + routing::{get, on, MethodFilter}, + Extension, Router, +}; +use juniper::{ + tests::fixtures::starwars::schema::{Database, Query, Subscription}, + EmptyMutation, RootNode, +}; +use juniper_axum::{ + extract::JuniperRequest, graphiql, playground, response::JuniperResponse, subscriptions, +}; +use juniper_graphql_ws::ConnectionConfig; + +type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; + +async fn homepage() -> Html<&'static str> { + "

juniper_axum/custom example

\ +
visit GraphiQL
\ + \ + " + .into() +} + +pub async fn custom_subscriptions( + Extension(schema): Extension>, + Extension(database): Extension, + ws: WebSocketUpgrade, +) -> Response { + ws.protocols(["graphql-transport-ws", "graphql-ws"]) + .max_frame_size(1024) + .max_message_size(1024) + .max_write_buffer_size(100) + .on_upgrade(move |socket| { + subscriptions::serve_ws( + socket, + schema, + ConnectionConfig::new(database).with_max_in_flight_operations(10), + ) + }) +} + +async fn custom_graphql( + Extension(schema): Extension>, + Extension(database): Extension, + JuniperRequest(request): JuniperRequest, +) -> JuniperResponse { + JuniperResponse(request.execute(&*schema, &database).await) +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + let schema = Schema::new(Query, EmptyMutation::new(), Subscription); + let database = Database::new(); + + let app = Router::new() + .route( + "/graphql", + on(MethodFilter::GET | MethodFilter::POST, custom_graphql), + ) + .route("/subscriptions", get(custom_subscriptions)) + .route("/graphiql", get(graphiql("/graphql", "/subscriptions"))) + .route("/playground", get(playground("/graphql", "/subscriptions"))) + .route("/", get(homepage)) + .layer(Extension(Arc::new(schema))) + .layer(Extension(database)); + + let addr = SocketAddr::from(([127, 0, 0, 1], 8080)); + tracing::info!("listening on {addr}"); + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap_or_else(|e| panic!("failed to run `axum::Server`: {e}")); +} diff --git a/juniper_axum/examples/simple.rs b/juniper_axum/examples/simple.rs new file mode 100644 index 00000000..9ccace30 --- /dev/null +++ b/juniper_axum/examples/simple.rs @@ -0,0 +1,87 @@ +//! This example demonstrates simple default integration with [`axum`]. + +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +use axum::{ + response::Html, + routing::{get, on, MethodFilter}, + Extension, Router, +}; +use futures::stream::{BoxStream, StreamExt as _}; +use juniper::{graphql_object, graphql_subscription, EmptyMutation, FieldError, RootNode}; +use juniper_axum::{graphiql, graphql, playground, ws}; +use juniper_graphql_ws::ConnectionConfig; +use tokio::time::interval; +use tokio_stream::wrappers::IntervalStream; + +#[derive(Clone, Copy, Debug)] +pub struct Query; + +#[graphql_object] +impl Query { + /// Adds two `a` and `b` numbers. + fn add(a: i32, b: i32) -> i32 { + a + b + } +} + +#[derive(Clone, Copy, Debug)] +pub struct Subscription; + +type NumberStream = BoxStream<'static, Result>; + +#[graphql_subscription] +impl Subscription { + /// Counts seconds. + async fn count() -> NumberStream { + let mut value = 0; + let stream = IntervalStream::new(interval(Duration::from_secs(1))).map(move |_| { + value += 1; + Ok(value) + }); + Box::pin(stream) + } +} + +type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; + +async fn homepage() -> Html<&'static str> { + "

juniper_axum/simple example

\ +
visit GraphiQL
\ + \ + " + .into() +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + let schema = Schema::new(Query, EmptyMutation::new(), Subscription); + + let app = Router::new() + .route( + "/graphql", + on( + MethodFilter::GET | MethodFilter::POST, + graphql::>, + ), + ) + .route( + "/subscriptions", + get(ws::>(ConnectionConfig::new(()))), + ) + .route("/graphiql", get(graphiql("/graphql", "/subscriptions"))) + .route("/playground", get(playground("/graphql", "/subscriptions"))) + .route("/", get(homepage)) + .layer(Extension(Arc::new(schema))); + + let addr = SocketAddr::from(([127, 0, 0, 1], 8080)); + tracing::info!("listening on {addr}"); + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap_or_else(|e| panic!("failed to run `axum::Server`: {e}")); +} diff --git a/juniper_axum/release.toml b/juniper_axum/release.toml new file mode 100644 index 00000000..559c6e3b --- /dev/null +++ b/juniper_axum/release.toml @@ -0,0 +1,12 @@ +[[pre-release-replacements]] +file = "CHANGELOG.md" +max = 1 +min = 0 +search = "## master" +replace = "## [{{version}}] ยท {{date}}\n[{{version}}]: /../../tree/{{crate_name}}-v{{version}}/{{crate_name}}" + +[[pre-release-replacements]] +file = "README.md" +exactly = 4 +search = "graphql-rust/juniper/blob/[^/]+/" +replace = "graphql-rust/juniper/blob/{{crate_name}}-v{{version}}/" diff --git a/juniper_axum/src/extract.rs b/juniper_axum/src/extract.rs new file mode 100644 index 00000000..ef42f384 --- /dev/null +++ b/juniper_axum/src/extract.rs @@ -0,0 +1,293 @@ +//! Types and traits for extracting data from [`Request`]s. + +use std::fmt; + +use axum::{ + async_trait, + body::Body, + extract::{FromRequest, FromRequestParts, Query}, + http::{HeaderValue, Method, Request, StatusCode}, + response::{IntoResponse as _, Response}, + Json, RequestExt as _, +}; +use juniper::{ + http::{GraphQLBatchRequest, GraphQLRequest}, + DefaultScalarValue, ScalarValue, +}; +use serde::Deserialize; + +/// Extractor for [`axum`] to extract a [`JuniperRequest`]. +/// +/// # Example +/// +/// ```rust +/// use std::sync::Arc; +/// +/// use axum::{routing::post, Extension, Json, Router}; +/// use juniper::{ +/// RootNode, EmptySubscription, EmptyMutation, graphql_object, +/// }; +/// use juniper_axum::{extract::JuniperRequest, response::JuniperResponse}; +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Context; +/// +/// impl juniper::Context for Context {} +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Query; +/// +/// #[graphql_object(context = Context)] +/// impl Query { +/// fn add(a: i32, b: i32) -> i32 { +/// a + b +/// } +/// } +/// +/// type Schema = RootNode<'static, Query, EmptyMutation, EmptySubscription>; +/// +/// let schema = Schema::new( +/// Query, +/// EmptyMutation::::new(), +/// EmptySubscription::::new() +/// ); +/// +/// let app: Router = Router::new() +/// .route("/graphql", post(graphql)) +/// .layer(Extension(Arc::new(schema))) +/// .layer(Extension(Context)); +/// +/// # #[axum::debug_handler] +/// async fn graphql( +/// Extension(schema): Extension>, +/// Extension(context): Extension, +/// JuniperRequest(req): JuniperRequest, // should be the last argument as consumes `Request` +/// ) -> JuniperResponse { +/// JuniperResponse(req.execute(&*schema, &context).await) +/// } +#[derive(Debug, PartialEq)] +pub struct JuniperRequest(pub GraphQLBatchRequest) +where + S: ScalarValue; + +#[async_trait] +impl FromRequest for JuniperRequest +where + S: ScalarValue, + State: Sync, + Query: FromRequestParts, + Json>: FromRequest, + > as FromRequest>::Rejection: fmt::Display, + String: FromRequest, +{ + type Rejection = Response; + + async fn from_request(mut req: Request, state: &State) -> Result { + let content_type = req + .headers() + .get("content-type") + .map(HeaderValue::to_str) + .transpose() + .map_err(|_| { + ( + StatusCode::BAD_REQUEST, + "`Content-Type` header is not a valid HTTP header string", + ) + .into_response() + })?; + + match (req.method(), content_type) { + (&Method::GET, _) => req + .extract_parts::>() + .await + .map_err(|e| { + ( + StatusCode::BAD_REQUEST, + format!("Invalid request query string: {e}"), + ) + .into_response() + }) + .and_then(|query| { + query + .0 + .try_into() + .map(|q| Self(GraphQLBatchRequest::Single(q))) + .map_err(|e| { + ( + StatusCode::BAD_REQUEST, + format!("Invalid request query `variables`: {e}"), + ) + .into_response() + }) + }), + (&Method::POST, Some("application/json")) => { + Json::>::from_request(req, state) + .await + .map(|req| Self(req.0)) + .map_err(|e| { + (StatusCode::BAD_REQUEST, format!("Invalid JSON body: {e}")).into_response() + }) + } + (&Method::POST, Some("application/graphql")) => String::from_request(req, state) + .await + .map(|body| { + Self(GraphQLBatchRequest::Single(GraphQLRequest::new( + body, None, None, + ))) + }) + .map_err(|_| (StatusCode::BAD_REQUEST, "Not valid UTF-8 body").into_response()), + (&Method::POST, _) => Err(( + StatusCode::UNSUPPORTED_MEDIA_TYPE, + "`Content-Type` header is expected to be either `application/json` or \ + `application/graphql`", + ) + .into_response()), + _ => Err(( + StatusCode::METHOD_NOT_ALLOWED, + "HTTP method is expected to be either GET or POST", + ) + .into_response()), + } + } +} + +/// Workaround for a [`GraphQLRequest`] not being [`Deserialize`]d properly from a GET query string, +/// containing `variables` in JSON format. +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +struct GetRequest { + query: String, + #[serde(rename = "operationName")] + operation_name: Option, + variables: Option, +} + +impl TryFrom for GraphQLRequest { + type Error = serde_json::Error; + fn try_from(req: GetRequest) -> Result { + let GetRequest { + query, + operation_name, + variables, + } = req; + Ok(Self::new( + query, + operation_name, + variables.map(|v| serde_json::from_str(&v)).transpose()?, + )) + } +} + +#[cfg(test)] +mod juniper_request_tests { + use std::fmt; + + use axum::{ + body::{Body, Bytes, HttpBody}, + extract::FromRequest as _, + http::Request, + }; + use juniper::{ + graphql_input_value, + http::{GraphQLBatchRequest, GraphQLRequest}, + }; + + use super::JuniperRequest; + + #[tokio::test] + async fn from_get_request() { + let req = Request::get(&format!( + "/?query={}", + urlencoding::encode("{ add(a: 2, b: 3) }") + )) + .body(Body::empty()) + .unwrap_or_else(|e| panic!("cannot build `Request`: {e}")); + + let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new( + "{ add(a: 2, b: 3) }".into(), + None, + None, + ))); + + assert_eq!(do_from_request(req).await, expected); + } + + #[tokio::test] + async fn from_get_request_with_variables() { + let req = Request::get(&format!( + "/?query={}&variables={}", + urlencoding::encode( + "query($id: String!) { human(id: $id) { id, name, appearsIn, homePlanet } }", + ), + urlencoding::encode(r#"{"id": "1000"}"#), + )) + .body(Body::empty()) + .unwrap_or_else(|e| panic!("cannot build `Request`: {e}")); + + let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new( + "query($id: String!) { human(id: $id) { id, name, appearsIn, homePlanet } }".into(), + None, + Some(graphql_input_value!({"id": "1000"})), + ))); + + assert_eq!(do_from_request(req).await, expected); + } + + #[tokio::test] + async fn from_json_post_request() { + let req = Request::post("/") + .header("content-type", "application/json") + .body(Body::from(r#"{"query": "{ add(a: 2, b: 3) }"}"#)) + .unwrap_or_else(|e| panic!("cannot build `Request`: {e}")); + + let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new( + "{ add(a: 2, b: 3) }".to_string(), + None, + None, + ))); + + assert_eq!(do_from_request(req).await, expected); + } + + #[tokio::test] + async fn from_graphql_post_request() { + let req = Request::post("/") + .header("content-type", "application/graphql") + .body(Body::from(r#"{ add(a: 2, b: 3) }"#)) + .unwrap_or_else(|e| panic!("cannot build `Request`: {e}")); + + let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new( + "{ add(a: 2, b: 3) }".to_string(), + None, + None, + ))); + + assert_eq!(do_from_request(req).await, expected); + } + + /// Performs [`JuniperRequest::from_request()`]. + async fn do_from_request(req: Request) -> JuniperRequest { + match JuniperRequest::from_request(req, &()).await { + Ok(resp) => resp, + Err(resp) => { + panic!( + "`JuniperRequest::from_request()` failed with `{}` status and body:\n{}", + resp.status(), + display_body(resp.into_body()).await, + ) + } + } + } + + /// Converts the provided [`HttpBody`] into a [`String`]. + async fn display_body(body: B) -> String + where + B: HttpBody, + B::Error: fmt::Display, + { + let bytes = hyper::body::to_bytes(body) + .await + .unwrap_or_else(|e| panic!("failed to represent `Body` as `Bytes`: {e}")); + String::from_utf8(bytes.into()).unwrap_or_else(|e| panic!("not UTF-8 body: {e}")) + } +} diff --git a/juniper_axum/src/lib.rs b/juniper_axum/src/lib.rs new file mode 100644 index 00000000..205669a6 --- /dev/null +++ b/juniper_axum/src/lib.rs @@ -0,0 +1,139 @@ +#![doc = include_str!("../README.md")] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![deny(missing_docs)] + +pub mod extract; +pub mod response; +#[cfg(feature = "subscriptions")] +pub mod subscriptions; + +use std::future; + +use axum::{extract::Extension, response::Html}; +use juniper_graphql_ws::Schema; + +use self::{extract::JuniperRequest, response::JuniperResponse}; + +#[cfg(feature = "subscriptions")] +#[doc(inline)] +pub use self::subscriptions::{graphql_transport_ws, graphql_ws, ws}; + +/// [`Handler`], which handles a [`JuniperRequest`] with the specified [`Schema`], by [`extract`]ing +/// it from [`Extension`]s and initializing its fresh [`Schema::Context`] as a [`Default`] one. +/// +/// > __NOTE__: This is a ready-to-go default [`Handler`] for serving GraphQL requests. If you need +/// > to customize it (for example, extract [`Schema::Context`] from [`Extension`]s +/// > instead initializing a [`Default`] one), create your own [`Handler`] accepting a +/// > [`JuniperRequest`] (see its documentation for examples). +/// +/// # Example +/// +/// ```rust +/// use std::sync::Arc; +/// +/// use axum::{routing::post, Extension, Json, Router}; +/// use juniper::{ +/// RootNode, EmptySubscription, EmptyMutation, graphql_object, +/// }; +/// use juniper_axum::graphql; +/// +/// #[derive(Clone, Copy, Debug, Default)] +/// pub struct Context; +/// +/// impl juniper::Context for Context {} +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Query; +/// +/// #[graphql_object(context = Context)] +/// impl Query { +/// fn add(a: i32, b: i32) -> i32 { +/// a + b +/// } +/// } +/// +/// type Schema = RootNode<'static, Query, EmptyMutation, EmptySubscription>; +/// +/// let schema = Schema::new( +/// Query, +/// EmptyMutation::::new(), +/// EmptySubscription::::new() +/// ); +/// +/// let app: Router = Router::new() +/// .route("/graphql", post(graphql::>)) +/// .layer(Extension(Arc::new(schema))); +/// ``` +/// +/// [`extract`]: axum::extract +/// [`Handler`]: axum::handler::Handler +#[cfg_attr(text, axum::debug_handler)] +pub async fn graphql( + Extension(schema): Extension, + JuniperRequest(req): JuniperRequest, +) -> JuniperResponse +where + S: Schema, // TODO: Refactor in the way we don't depend on `juniper_graphql_ws::Schema` here. + S::Context: Default, +{ + JuniperResponse( + req.execute(schema.root_node(), &S::Context::default()) + .await, + ) +} + +/// Creates a [`Handler`] that replies with an HTML page containing [GraphiQL]. +/// +/// This does not handle routing, so you can mount it on any endpoint. +/// +/// # Example +/// +/// ```rust +/// use axum::{routing::get, Router}; +/// use juniper_axum::graphiql; +/// +/// let app: Router = Router::new() +/// .route("/", get(graphiql("/graphql", "/subscriptions"))); +/// ``` +/// +/// [`Handler`]: axum::handler::Handler +/// [GraphiQL]: https://github.com/graphql/graphiql +pub fn graphiql<'a>( + graphql_endpoint_url: &str, + subscriptions_endpoint_url: impl Into>, +) -> impl FnOnce() -> future::Ready> + Clone + Send { + let html = Html(juniper::http::graphiql::graphiql_source( + graphql_endpoint_url, + subscriptions_endpoint_url.into(), + )); + + || future::ready(html) +} + +/// Creates a [`Handler`] that replies with an HTML page containing [GraphQL Playground]. +/// +/// This does not handle routing, so you can mount it on any endpoint. +/// +/// # Example +/// +/// ```rust +/// use axum::{routing::get, Router}; +/// use juniper_axum::playground; +/// +/// let app: Router = Router::new() +/// .route("/", get(playground("/graphql", "/subscriptions"))); +/// ``` +/// +/// [`Handler`]: axum::handler::Handler +/// [GraphQL Playground]: https://github.com/prisma/graphql-playground +pub fn playground<'a>( + graphql_endpoint_url: &str, + subscriptions_endpoint_url: impl Into>, +) -> impl FnOnce() -> future::Ready> + Clone + Send { + let html = Html(juniper::http::playground::playground_source( + graphql_endpoint_url, + subscriptions_endpoint_url.into(), + )); + + || future::ready(html) +} diff --git a/juniper_axum/src/response.rs b/juniper_axum/src/response.rs new file mode 100644 index 00000000..bd975975 --- /dev/null +++ b/juniper_axum/src/response.rs @@ -0,0 +1,24 @@ +//! [`JuniperResponse`] definition. + +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use juniper::{http::GraphQLBatchResponse, DefaultScalarValue, ScalarValue}; + +/// Wrapper around a [`GraphQLBatchResponse`], implementing [`IntoResponse`], so it can be returned +/// from [`axum`] handlers. +pub struct JuniperResponse(pub GraphQLBatchResponse) +where + S: ScalarValue; + +impl IntoResponse for JuniperResponse { + fn into_response(self) -> Response { + if self.0.is_ok() { + Json(self.0).into_response() + } else { + (StatusCode::BAD_REQUEST, Json(self.0)).into_response() + } + } +} diff --git a/juniper_axum/src/subscriptions.rs b/juniper_axum/src/subscriptions.rs new file mode 100644 index 00000000..cc1ec7d7 --- /dev/null +++ b/juniper_axum/src/subscriptions.rs @@ -0,0 +1,694 @@ +//! Definitions for handling GraphQL subscriptions. + +use std::fmt; + +use axum::{ + extract::{ + ws::{self, WebSocket, WebSocketUpgrade}, + Extension, + }, + response::Response, +}; +use futures::{future, SinkExt as _, StreamExt as _}; +use juniper::ScalarValue; +use juniper_graphql_ws::{graphql_transport_ws, graphql_ws, Init, Schema}; + +/// Creates a [`Handler`] with the specified [`Schema`], which will serve either the +/// [legacy `graphql-ws` GraphQL over WebSocket Protocol][old] or the +/// [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new], by auto-selecting between +/// them, based on the `Sec-Websocket-Protocol` HTTP header value. +/// +/// > __NOTE__: This is a ready-to-go default [`Handler`] for serving GraphQL over WebSocket +/// > Protocol. If you need to customize it (for example, configure [`WebSocketUpgrade`] +/// > parameters), create your own [`Handler`] invoking the [`serve_ws()`] function (see +/// > its documentation for examples). +/// +/// [`Schema`] is [`extract`]ed from [`Extension`]s. +/// +/// The `init` argument is used to provide the custom [`juniper::Context`] and additional +/// configuration for connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the +/// context and configuration are already known, or it can be a closure that gets executed +/// asynchronously whenever a client sends the subscription initialization message. Using a +/// closure allows to perform an authentication based on the parameters provided by a client. +/// +/// # Example +/// +/// ```rust +/// use std::{sync::Arc, time::Duration}; +/// +/// use axum::{routing::get, Extension, Router}; +/// use futures::stream::{BoxStream, StreamExt as _}; +/// use juniper::{ +/// graphql_object, graphql_subscription, EmptyMutation, FieldError, +/// RootNode, +/// }; +/// use juniper_axum::{playground, subscriptions}; +/// use juniper_graphql_ws::ConnectionConfig; +/// use tokio::time::interval; +/// use tokio_stream::wrappers::IntervalStream; +/// +/// type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Query; +/// +/// #[graphql_object] +/// impl Query { +/// /// Adds two `a` and `b` numbers. +/// fn add(a: i32, b: i32) -> i32 { +/// a + b +/// } +/// } +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Subscription; +/// +/// type NumberStream = BoxStream<'static, Result>; +/// +/// #[graphql_subscription] +/// impl Subscription { +/// /// Counts seconds. +/// async fn count() -> NumberStream { +/// let mut value = 0; +/// let stream = IntervalStream::new(interval(Duration::from_secs(1))).map(move |_| { +/// value += 1; +/// Ok(value) +/// }); +/// Box::pin(stream) +/// } +/// } +/// +/// let schema = Schema::new(Query, EmptyMutation::new(), Subscription); +/// +/// let app: Router = Router::new() +/// .route("/subscriptions", get(subscriptions::ws::>(ConnectionConfig::new(())))) +/// .layer(Extension(Arc::new(schema))); +/// ``` +/// +/// [`extract`]: axum::extract +/// [`Handler`]: axum::handler::Handler +/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md +/// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md +pub fn ws( + init: impl Init + Clone + Send, +) -> impl FnOnce(Extension, WebSocketUpgrade) -> future::Ready + Clone + Send { + move |Extension(schema), ws| { + future::ready( + ws.protocols(["graphql-transport-ws", "graphql-ws"]) + .on_upgrade(move |socket| serve_ws(socket, schema, init)), + ) + } +} + +/// Creates a [`Handler`] with the specified [`Schema`], which will serve the +/// [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new]. +/// +/// > __NOTE__: This is a ready-to-go default [`Handler`] for serving the +/// > [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new]. If you need to +/// > customize it (for example, configure [`WebSocketUpgrade`] parameters), create your +/// > own [`Handler`] invoking the [`serve_graphql_transport_ws()`] function (see its +/// > documentation for examples). +/// +/// [`Schema`] is [`extract`]ed from [`Extension`]s. +/// +/// The `init` argument is used to provide the context and additional configuration for +/// connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the context and +/// configuration are already known, or it can be a closure that gets executed asynchronously +/// when the client sends the `ConnectionInit` message. Using a closure allows to perform an +/// authentication based on the parameters provided by a client. +/// +/// # Example +/// +/// ```rust +/// use std::{sync::Arc, time::Duration}; +/// +/// use axum::{routing::get, Extension, Router}; +/// use futures::stream::{BoxStream, StreamExt as _}; +/// use juniper::{ +/// graphql_object, graphql_subscription, EmptyMutation, FieldError, +/// RootNode, +/// }; +/// use juniper_axum::{playground, subscriptions}; +/// use juniper_graphql_ws::ConnectionConfig; +/// use tokio::time::interval; +/// use tokio_stream::wrappers::IntervalStream; +/// +/// type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Query; +/// +/// #[graphql_object] +/// impl Query { +/// /// Adds two `a` and `b` numbers. +/// fn add(a: i32, b: i32) -> i32 { +/// a + b +/// } +/// } +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Subscription; +/// +/// type NumberStream = BoxStream<'static, Result>; +/// +/// #[graphql_subscription] +/// impl Subscription { +/// /// Counts seconds. +/// async fn count() -> NumberStream { +/// let mut value = 0; +/// let stream = IntervalStream::new(interval(Duration::from_secs(1))).map(move |_| { +/// value += 1; +/// Ok(value) +/// }); +/// Box::pin(stream) +/// } +/// } +/// +/// let schema = Schema::new(Query, EmptyMutation::new(), Subscription); +/// +/// let app: Router = Router::new() +/// .route( +/// "/subscriptions", +/// get(subscriptions::graphql_transport_ws::>(ConnectionConfig::new(()))), +/// ) +/// .layer(Extension(Arc::new(schema))); +/// ``` +/// +/// [`extract`]: axum::extract +/// [`Handler`]: axum::handler::Handler +/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md +pub fn graphql_transport_ws( + init: impl Init + Clone + Send, +) -> impl FnOnce(Extension, WebSocketUpgrade) -> future::Ready + Clone + Send { + move |Extension(schema), ws| { + future::ready( + ws.protocols(["graphql-transport-ws"]) + .on_upgrade(move |socket| serve_graphql_transport_ws(socket, schema, init)), + ) + } +} + +/// Creates a [`Handler`] with the specified [`Schema`], which will serve the +/// [legacy `graphql-ws` GraphQL over WebSocket Protocol][old]. +/// +/// > __NOTE__: This is a ready-to-go default [`Handler`] for serving the +/// > [legacy `graphql-ws` GraphQL over WebSocket Protocol][old]. If you need to customize +/// > it (for example, configure [`WebSocketUpgrade`] parameters), create your own +/// > [`Handler`] invoking the [`serve_graphql_ws()`] function (see its documentation for +/// > examples). +/// +/// [`Schema`] is [`extract`]ed from [`Extension`]s. +/// +/// The `init` argument is used to provide the context and additional configuration for +/// connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the context and +/// configuration are already known, or it can be a closure that gets executed asynchronously +/// when the client sends the `GQL_CONNECTION_INIT` message. Using a closure allows to perform +/// an authentication based on the parameters provided by a client. +/// +/// > __WARNING__: This protocol has been deprecated in favor of the +/// > [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new], which is +/// > provided by the [`graphql_transport_ws()`] function. +/// +/// # Example +/// +/// ```rust +/// use std::{sync::Arc, time::Duration}; +/// +/// use axum::{routing::get, Extension, Router}; +/// use futures::stream::{BoxStream, StreamExt as _}; +/// use juniper::{ +/// graphql_object, graphql_subscription, EmptyMutation, FieldError, +/// RootNode, +/// }; +/// use juniper_axum::{playground, subscriptions}; +/// use juniper_graphql_ws::ConnectionConfig; +/// use tokio::time::interval; +/// use tokio_stream::wrappers::IntervalStream; +/// +/// type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Query; +/// +/// #[graphql_object] +/// impl Query { +/// /// Adds two `a` and `b` numbers. +/// fn add(a: i32, b: i32) -> i32 { +/// a + b +/// } +/// } +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Subscription; +/// +/// type NumberStream = BoxStream<'static, Result>; +/// +/// #[graphql_subscription] +/// impl Subscription { +/// /// Counts seconds. +/// async fn count() -> NumberStream { +/// let mut value = 0; +/// let stream = IntervalStream::new(interval(Duration::from_secs(1))).map(move |_| { +/// value += 1; +/// Ok(value) +/// }); +/// Box::pin(stream) +/// } +/// } +/// +/// let schema = Schema::new(Query, EmptyMutation::new(), Subscription); +/// +/// let app: Router = Router::new() +/// .route( +/// "/subscriptions", +/// get(subscriptions::graphql_ws::>(ConnectionConfig::new(()))), +/// ) +/// .layer(Extension(Arc::new(schema))); +/// ``` +/// +/// [`extract`]: axum::extract +/// [`Handler`]: axum::handler::Handler +/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md +/// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md +pub fn graphql_ws( + init: impl Init + Clone + Send, +) -> impl FnOnce(Extension, WebSocketUpgrade) -> future::Ready + Clone + Send { + move |Extension(schema), ws| { + future::ready( + ws.protocols(["graphql-ws"]) + .on_upgrade(move |socket| serve_graphql_ws(socket, schema, init)), + ) + } +} + +/// Serves on the provided [`WebSocket`] by auto-selecting between the +/// [legacy `graphql-ws` GraphQL over WebSocket Protocol][old] and the +/// [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new], based on the +/// `Sec-Websocket-Protocol` HTTP header value. +/// +/// > __WARNING__: This function doesn't set (only checks) the `Sec-Websocket-Protocol` HTTP header +/// > value, so this should be done manually outside (see the example below). +/// > To have fully baked [`axum`] handler, use [`ws()`] handler instead. +/// +/// The `init` argument is used to provide the custom [`juniper::Context`] and additional +/// configuration for connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the +/// context and configuration are already known, or it can be a closure that gets executed +/// asynchronously whenever a client sends the subscription initialization message. Using a +/// closure allows to perform an authentication based on the parameters provided by a client. +/// +/// # Example +/// +/// ```rust +/// use std::{sync::Arc, time::Duration}; +/// +/// use axum::{ +/// extract::WebSocketUpgrade, +/// response::Response, +/// routing::get, +/// Extension, Router, +/// }; +/// use futures::stream::{BoxStream, StreamExt as _}; +/// use juniper::{ +/// graphql_object, graphql_subscription, EmptyMutation, FieldError, +/// RootNode, +/// }; +/// use juniper_axum::{playground, subscriptions}; +/// use juniper_graphql_ws::ConnectionConfig; +/// use tokio::time::interval; +/// use tokio_stream::wrappers::IntervalStream; +/// +/// type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Query; +/// +/// #[graphql_object] +/// impl Query { +/// /// Adds two `a` and `b` numbers. +/// fn add(a: i32, b: i32) -> i32 { +/// a + b +/// } +/// } +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Subscription; +/// +/// type NumberStream = BoxStream<'static, Result>; +/// +/// #[graphql_subscription] +/// impl Subscription { +/// /// Counts seconds. +/// async fn count() -> NumberStream { +/// let mut value = 0; +/// let stream = IntervalStream::new(interval(Duration::from_secs(1))).map(move |_| { +/// value += 1; +/// Ok(value) +/// }); +/// Box::pin(stream) +/// } +/// } +/// +/// async fn juniper_subscriptions( +/// Extension(schema): Extension>, +/// ws: WebSocketUpgrade, +/// ) -> Response { +/// ws.protocols(["graphql-transport-ws", "graphql-ws"]) +/// .max_frame_size(1024) +/// .max_message_size(1024) +/// .max_write_buffer_size(100) +/// .on_upgrade(move |socket| { +/// subscriptions::serve_ws(socket, schema, ConnectionConfig::new(())) +/// }) +/// } +/// +/// let schema = Schema::new(Query, EmptyMutation::new(), Subscription); +/// +/// let app: Router = Router::new() +/// .route("/subscriptions", get(juniper_subscriptions)) +/// .layer(Extension(Arc::new(schema))); +/// ``` +/// +/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md +/// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md +pub async fn serve_ws(socket: WebSocket, schema: S, init: I) +where + S: Schema, + I: Init + Send, +{ + if socket.protocol().map(AsRef::as_ref) == Some("graphql-ws".as_bytes()) { + serve_graphql_ws(socket, schema, init).await; + } else { + serve_graphql_transport_ws(socket, schema, init).await; + } +} + +/// Serves the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new] on the provided +/// [`WebSocket`]. +/// +/// > __WARNING__: This function doesn't check or set the `Sec-Websocket-Protocol` HTTP header value +/// > as `graphql-transport-ws`, so this should be done manually outside (see the +/// > example below). +/// > To have fully baked [`axum`] handler for +/// > [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new], use +/// > [`graphql_transport_ws()`] handler instead. +/// +/// The `init` argument is used to provide the context and additional configuration for +/// connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the context and +/// configuration are already known, or it can be a closure that gets executed asynchronously +/// when the client sends the `ConnectionInit` message. Using a closure allows to perform an +/// authentication based on the parameters provided by a client. +/// +/// # Example +/// +/// ```rust +/// use std::{sync::Arc, time::Duration}; +/// +/// use axum::{ +/// extract::WebSocketUpgrade, +/// response::Response, +/// routing::get, +/// Extension, Router, +/// }; +/// use futures::stream::{BoxStream, StreamExt as _}; +/// use juniper::{ +/// graphql_object, graphql_subscription, EmptyMutation, FieldError, +/// RootNode, +/// }; +/// use juniper_axum::{playground, subscriptions}; +/// use juniper_graphql_ws::ConnectionConfig; +/// use tokio::time::interval; +/// use tokio_stream::wrappers::IntervalStream; +/// +/// type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Query; +/// +/// #[graphql_object] +/// impl Query { +/// /// Adds two `a` and `b` numbers. +/// fn add(a: i32, b: i32) -> i32 { +/// a + b +/// } +/// } +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Subscription; +/// +/// type NumberStream = BoxStream<'static, Result>; +/// +/// #[graphql_subscription] +/// impl Subscription { +/// /// Counts seconds. +/// async fn count() -> NumberStream { +/// let mut value = 0; +/// let stream = IntervalStream::new(interval(Duration::from_secs(1))).map(move |_| { +/// value += 1; +/// Ok(value) +/// }); +/// Box::pin(stream) +/// } +/// } +/// +/// async fn juniper_subscriptions( +/// Extension(schema): Extension>, +/// ws: WebSocketUpgrade, +/// ) -> Response { +/// ws.protocols(["graphql-transport-ws"]) +/// .max_frame_size(1024) +/// .max_message_size(1024) +/// .max_write_buffer_size(100) +/// .on_upgrade(move |socket| { +/// subscriptions::serve_graphql_transport_ws(socket, schema, ConnectionConfig::new(())) +/// }) +/// } +/// +/// let schema = Schema::new(Query, EmptyMutation::new(), Subscription); +/// +/// let app: Router = Router::new() +/// .route("/subscriptions", get(juniper_subscriptions)) +/// .layer(Extension(Arc::new(schema))); +/// ``` +/// +/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md +pub async fn serve_graphql_transport_ws(socket: WebSocket, schema: S, init: I) +where + S: Schema, + I: Init + Send, +{ + let (ws_tx, ws_rx) = socket.split(); + let (s_tx, s_rx) = graphql_transport_ws::Connection::new(schema, init).split(); + + let input = ws_rx + .map(|r| r.map(Message)) + .forward(s_tx.sink_map_err(|e| match e {})); + + let output = s_rx + .map(|output| { + Ok(match output { + graphql_transport_ws::Output::Message(msg) => { + serde_json::to_string(&msg) + .map(ws::Message::Text) + .unwrap_or_else(|e| { + ws::Message::Close(Some(ws::CloseFrame { + code: 1011, // CloseCode::Error + reason: format!("error serializing response: {e}").into(), + })) + }) + } + graphql_transport_ws::Output::Close { code, message } => { + ws::Message::Close(Some(ws::CloseFrame { + code, + reason: message.into(), + })) + } + }) + }) + .forward(ws_tx); + + // No errors can be returned here, so ignoring is OK. + _ = future::select(input, output).await; +} + +/// Serves the [legacy `graphql-ws` GraphQL over WebSocket Protocol][old] on the provided +/// [`WebSocket`]. +/// +/// > __WARNING__: This function doesn't check or set the `Sec-Websocket-Protocol` HTTP header value +/// > as `graphql-ws`, so this should be done manually outside (see the example below). +/// > To have fully baked [`axum`] handler for +/// > [legacy `graphql-ws` GraphQL over WebSocket Protocol][old], use [`graphql_ws()`] +/// > handler instead. +/// +/// The `init` argument is used to provide the context and additional configuration for +/// connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the context and +/// configuration are already known, or it can be a closure that gets executed asynchronously +/// when the client sends the `GQL_CONNECTION_INIT` message. Using a closure allows to perform +/// an authentication based on the parameters provided by a client. +/// +/// > __WARNING__: This protocol has been deprecated in favor of the +/// > [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new], which is +/// > provided by the [`serve_graphql_transport_ws()`] function. +/// +/// # Example +/// +/// ```rust +/// use std::{sync::Arc, time::Duration}; +/// +/// use axum::{ +/// extract::WebSocketUpgrade, +/// response::Response, +/// routing::get, +/// Extension, Router, +/// }; +/// use futures::stream::{BoxStream, StreamExt as _}; +/// use juniper::{ +/// graphql_object, graphql_subscription, EmptyMutation, FieldError, +/// RootNode, +/// }; +/// use juniper_axum::{playground, subscriptions}; +/// use juniper_graphql_ws::ConnectionConfig; +/// use tokio::time::interval; +/// use tokio_stream::wrappers::IntervalStream; +/// +/// type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Query; +/// +/// #[graphql_object] +/// impl Query { +/// /// Adds two `a` and `b` numbers. +/// fn add(a: i32, b: i32) -> i32 { +/// a + b +/// } +/// } +/// +/// #[derive(Clone, Copy, Debug)] +/// pub struct Subscription; +/// +/// type NumberStream = BoxStream<'static, Result>; +/// +/// #[graphql_subscription] +/// impl Subscription { +/// /// Counts seconds. +/// async fn count() -> NumberStream { +/// let mut value = 0; +/// let stream = IntervalStream::new(interval(Duration::from_secs(1))).map(move |_| { +/// value += 1; +/// Ok(value) +/// }); +/// Box::pin(stream) +/// } +/// } +/// +/// async fn juniper_subscriptions( +/// Extension(schema): Extension>, +/// ws: WebSocketUpgrade, +/// ) -> Response { +/// ws.protocols(["graphql-ws"]) +/// .max_frame_size(1024) +/// .max_message_size(1024) +/// .max_write_buffer_size(100) +/// .on_upgrade(move |socket| { +/// subscriptions::serve_graphql_ws(socket, schema, ConnectionConfig::new(())) +/// }) +/// } +/// +/// let schema = Schema::new(Query, EmptyMutation::new(), Subscription); +/// +/// let app: Router = Router::new() +/// .route("/subscriptions", get(juniper_subscriptions)) +/// .layer(Extension(Arc::new(schema))); +/// ``` +/// +/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md +/// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md +pub async fn serve_graphql_ws(socket: WebSocket, schema: S, init: I) +where + S: Schema, + I: Init + Send, +{ + let (ws_tx, ws_rx) = socket.split(); + let (s_tx, s_rx) = graphql_ws::Connection::new(schema, init).split(); + + let input = ws_rx + .map(|r| r.map(Message)) + .forward(s_tx.sink_map_err(|e| match e {})); + + let output = s_rx + .map(|msg| { + Ok(serde_json::to_string(&msg) + .map(ws::Message::Text) + .unwrap_or_else(|e| { + ws::Message::Close(Some(ws::CloseFrame { + code: 1011, // CloseCode::Error + reason: format!("error serializing response: {e}").into(), + })) + })) + }) + .forward(ws_tx); + + // No errors can be returned here, so ignoring is OK. + _ = future::select(input, output).await; +} + +/// Wrapper around [`ws::Message`] allowing to define custom conversions. +#[derive(Debug)] +struct Message(ws::Message); + +impl TryFrom for graphql_transport_ws::Input { + type Error = Error; + + fn try_from(msg: Message) -> Result { + match msg.0 { + ws::Message::Text(text) => serde_json::from_slice(text.as_bytes()) + .map(Self::Message) + .map_err(Error::Serde), + ws::Message::Binary(bytes) => serde_json::from_slice(bytes.as_ref()) + .map(Self::Message) + .map_err(Error::Serde), + ws::Message::Close(_) => Ok(Self::Close), + other => Err(Error::UnexpectedClientMessage(other)), + } + } +} + +impl TryFrom for graphql_ws::ClientMessage { + type Error = Error; + + fn try_from(msg: Message) -> Result { + match msg.0 { + ws::Message::Text(text) => { + serde_json::from_slice(text.as_bytes()).map_err(Error::Serde) + } + ws::Message::Binary(bytes) => { + serde_json::from_slice(bytes.as_ref()).map_err(Error::Serde) + } + ws::Message::Close(_) => Ok(Self::ConnectionTerminate), + other => Err(Error::UnexpectedClientMessage(other)), + } + } +} + +/// Possible errors of serving a [`WebSocket`] connection. +#[derive(Debug)] +enum Error { + /// Deserializing of a client [`ws::Message`] failed. + Serde(serde_json::Error), + + /// Unexpected client [`ws::Message`]. + UnexpectedClientMessage(ws::Message), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Serde(e) => write!(f, "`serde` error: {e}"), + Self::UnexpectedClientMessage(m) => { + write!(f, "unexpected message received from client: {m:?}") + } + } + } +} + +impl std::error::Error for Error {} diff --git a/juniper_axum/tests/http_test_suite.rs b/juniper_axum/tests/http_test_suite.rs new file mode 100644 index 00000000..2a93405c --- /dev/null +++ b/juniper_axum/tests/http_test_suite.rs @@ -0,0 +1,112 @@ +use std::sync::Arc; + +use axum::{ + http::Request, + response::Response, + routing::{get, post}, + Extension, Router, +}; +use hyper::{service::Service, Body}; +use juniper::{ + http::tests::{run_http_test_suite, HttpIntegration, TestResponse}, + tests::fixtures::starwars::schema::{Database, Query}, + EmptyMutation, EmptySubscription, RootNode, +}; +use juniper_axum::{extract::JuniperRequest, response::JuniperResponse}; + +type Schema = RootNode<'static, Query, EmptyMutation, EmptySubscription>; + +struct TestApp(Router); + +impl TestApp { + fn new() -> Self { + #[axum::debug_handler] + async fn graphql( + Extension(schema): Extension>, + Extension(database): Extension, + JuniperRequest(request): JuniperRequest, + ) -> JuniperResponse { + JuniperResponse(request.execute(&*schema, &database).await) + } + + let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new()); + let database = Database::new(); + + Self( + Router::new() + .route("/", get(graphql)) + .route("/", post(graphql)) + .layer(Extension(Arc::new(schema))) + .layer(Extension(database)), + ) + } + + fn make_request(&self, req: Request) -> TestResponse { + let mut app = self.0.clone(); + + let task = app.call(req); + + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async move { + // PANIC: Unwrapping is OK here, because `task` is `Infallible`. + let resp = task.await.unwrap(); + into_test_response(resp).await + }) + } +} + +impl HttpIntegration for TestApp { + fn get(&self, url: &str) -> TestResponse { + let req = Request::get(url).body(Body::empty()).unwrap(); + self.make_request(req) + } + + fn post_json(&self, url: &str, body: &str) -> TestResponse { + let req = Request::post(url) + .header("content-type", "application/json") + .body(Body::from(body.to_string())) + .unwrap(); + self.make_request(req) + } + + fn post_graphql(&self, url: &str, body: &str) -> TestResponse { + let req = Request::post(url) + .header("content-type", "application/graphql") + .body(Body::from(body.to_string())) + .unwrap(); + self.make_request(req) + } +} + +/// Converts the provided [`Response`] into to a [`TestResponse`]. +async fn into_test_response(resp: Response) -> TestResponse { + let status_code = resp.status().as_u16().into(); + + let content_type: String = resp + .headers() + .get("content-type") + .map(|header| { + String::from_utf8(header.as_bytes().into()) + .unwrap_or_else(|e| panic!("not UTF-8 header: {e}")) + }) + .unwrap_or_default(); + + let body = hyper::body::to_bytes(resp.into_body()) + .await + .unwrap_or_else(|e| panic!("failed to represent `Body` as `Bytes`: {e}")); + let body = String::from_utf8(body.into()).unwrap_or_else(|e| panic!("not UTF-8 body: {e}")); + + TestResponse { + status_code, + content_type, + body: Some(body), + } +} + +#[test] +fn test_axum_integration() { + run_http_test_suite(&TestApp::new()) +} diff --git a/juniper_axum/tests/ws_test_suite.rs b/juniper_axum/tests/ws_test_suite.rs new file mode 100644 index 00000000..61447906 --- /dev/null +++ b/juniper_axum/tests/ws_test_suite.rs @@ -0,0 +1,142 @@ +#![cfg(not(windows))] + +use std::{ + net::{SocketAddr, TcpListener}, + sync::Arc, +}; + +use anyhow::anyhow; +use axum::{routing::get, Extension, Router}; +use futures::{SinkExt, StreamExt}; +use juniper::{ + http::tests::{graphql_transport_ws, graphql_ws, WsIntegration, WsIntegrationMessage}, + tests::fixtures::starwars::schema::{Database, Query, Subscription}, + EmptyMutation, LocalBoxFuture, RootNode, +}; +use juniper_axum::subscriptions; +use juniper_graphql_ws::ConnectionConfig; +use tokio::{net::TcpStream, time::timeout}; +use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream}; + +type Schema = RootNode<'static, Query, EmptyMutation, Subscription>; + +#[derive(Clone)] +struct TestApp(Router); + +impl TestApp { + fn new(protocol: &'static str) -> Self { + let schema = Schema::new(Query, EmptyMutation::new(), Subscription); + + let mut router = Router::new(); + router = if protocol == "graphql-ws" { + router.route( + "/subscriptions", + get(subscriptions::graphql_ws::>( + ConnectionConfig::new(Database::new()), + )), + ) + } else { + router.route( + "/subscriptions", + get(subscriptions::graphql_transport_ws::>( + ConnectionConfig::new(Database::new()), + )), + ) + }; + router = router.layer(Extension(Arc::new(schema))); + + Self(router) + } + + async fn run(self, messages: Vec) -> Result<(), anyhow::Error> { + let listener = TcpListener::bind("0.0.0.0:0".parse::().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + axum::Server::from_tcp(listener) + .unwrap() + .serve(self.0.into_make_service()) + .await + .unwrap(); + }); + + let (mut websocket, _) = connect_async(format!("ws://{}/subscriptions", addr)) + .await + .unwrap(); + + for msg in messages { + Self::process_message(&mut websocket, msg).await?; + } + + Ok(()) + } + + async fn process_message( + websocket: &mut WebSocketStream>, + message: WsIntegrationMessage, + ) -> Result<(), anyhow::Error> { + match message { + WsIntegrationMessage::Send(msg) => websocket + .send(Message::Text(msg.to_string())) + .await + .map_err(|e| anyhow!("Could not send message: {e}")) + .map(drop), + + WsIntegrationMessage::Expect(expected, duration) => { + let message = timeout(duration, websocket.next()) + .await + .map_err(|e| anyhow!("Timed out receiving message. Elapsed: {e}"))?; + match message { + None => Err(anyhow!("No message received")), + Some(Err(e)) => Err(anyhow!("WebSocket error: {e}")), + Some(Ok(Message::Text(json))) => { + let actual: serde_json::Value = serde_json::from_str(&json) + .map_err(|e| anyhow!("Cannot deserialize received message: {e}"))?; + if actual != expected { + return Err(anyhow!( + "Expected message: {expected}. \ + Received message: {actual}", + )); + } + Ok(()) + } + Some(Ok(Message::Close(Some(frame)))) => { + let actual = serde_json::json!({ + "code": u16::from(frame.code), + "description": frame.reason, + }); + if actual != expected { + return Err(anyhow!( + "Expected message: {expected}. \ + Received message: {actual}", + )); + } + Ok(()) + } + Some(Ok(msg)) => Err(anyhow!("Received non-text message: {msg:?}")), + } + } + } + } +} + +impl WsIntegration for TestApp { + fn run( + &self, + messages: Vec, + ) -> LocalBoxFuture> { + Box::pin(self.clone().run(messages)) + } +} + +#[tokio::test] +async fn test_graphql_ws_integration() { + let app = TestApp::new("graphql-ws"); + graphql_ws::run_test_suite(&app).await; +} + +#[tokio::test] +async fn test_graphql_transport_integration() { + let app = TestApp::new("graphql-transport-ws"); + graphql_transport_ws::run_test_suite(&app).await; +} diff --git a/juniper_graphql_ws/release.toml b/juniper_graphql_ws/release.toml index 5fa8d586..048205a6 100644 --- a/juniper_graphql_ws/release.toml +++ b/juniper_graphql_ws/release.toml @@ -4,6 +4,12 @@ exactly = 1 search = "juniper_graphql_ws = \\{ version = \"[^\"]+\"" replace = "juniper_graphql_ws = { version = \"{{version}}\"" +[[pre-release-replacements]] +file = "../juniper_axum/Cargo.toml" +exactly = 1 +search = "juniper_graphql_ws = \\{ version = \"[^\"]+\"" +replace = "juniper_graphql_ws = { version = \"{{version}}\"" + [[pre-release-replacements]] file = "../juniper_warp/Cargo.toml" exactly = 1 diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs index f1599e3b..676f2cd9 100644 --- a/juniper_warp/src/lib.rs +++ b/juniper_warp/src/lib.rs @@ -6,7 +6,7 @@ use std::{collections::HashMap, str, sync::Arc}; use anyhow::anyhow; -use futures::{FutureExt as _, TryFutureExt}; +use futures::{FutureExt as _, TryFutureExt as _}; use juniper::{ http::{GraphQLBatchRequest, GraphQLRequest}, ScalarValue, @@ -341,14 +341,12 @@ fn playground_response( pub mod subscriptions { use std::{convert::Infallible, fmt, sync::Arc}; - use juniper::{ - futures::{ - future::{self, Either}, - sink::SinkExt, - stream::StreamExt, - }, - GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue, + use futures::{ + future::{self, Either}, + sink::SinkExt as _, + stream::StreamExt as _, }; + use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue}; use juniper_graphql_ws::{graphql_transport_ws, graphql_ws}; use warp::{filters::BoxedFilter, reply::Reply, Filter as _};