diff --git a/juniper_warp/CHANGELOG.md b/juniper_warp/CHANGELOG.md index c328dfc6..74f76c45 100644 --- a/juniper_warp/CHANGELOG.md +++ b/juniper_warp/CHANGELOG.md @@ -11,6 +11,7 @@ All user visible changes to `juniper_warp` crate will be documented in this file ### BC Breaks - Switched to 0.16 version of [`juniper` crate]. +- Removed `JoinError` from public API. ([#1222], [#1177]) ### Added @@ -20,11 +21,18 @@ All user visible changes to `juniper_warp` crate will be documented in this file ### Changed - Made `schema` argument of `make_graphql_filter()` and `make_graphql_filter_sync()` polymorphic, allowing to specify external `Arc`ed `schema`. ([#1136], [#1135]) +- Relaxed requirement for `context_extractor` to be a `BoxedFilter` only. ([#1222], [#1177]) + +### Fixed + +- Excessive `context_extractor` execution in `make_graphql_filter()` and `make_graphql_filter_sync()`. ([#1222], [#1177]) [#1135]: /../../issues/1136 [#1136]: /../../pull/1136 [#1158]: /../../pull/1158 +[#1177]: /../../issues/1177 [#1191]: /../../pull/1191 +[#1222]: /../../pull/1222 diff --git a/juniper_warp/Cargo.toml b/juniper_warp/Cargo.toml index a4615419..6aa4f137 100644 --- a/juniper_warp/Cargo.toml +++ b/juniper_warp/Cargo.toml @@ -5,7 +5,10 @@ edition = "2021" rust-version = "1.73" description = "`juniper` GraphQL integration with `warp`." license = "BSD-2-Clause" -authors = ["Tom Houlé "] +authors = [ + "Tom Houlé ", + "Kai Ren ", +] documentation = "https://docs.rs/juniper_warp" homepage = "https://github.com/graphql-rust/juniper/tree/master/juniper_warp" repository = "https://github.com/graphql-rust/juniper" @@ -20,21 +23,21 @@ rustdoc-args = ["--cfg", "docsrs"] [features] subscriptions = [ + "dep:futures", "dep:juniper_graphql_ws", "dep:log", "warp/websocket", ] [dependencies] -anyhow = "1.0.47" -futures = "0.3.22" +futures = { version = "0.3.22", optional = true } juniper = { version = "0.16.0-dev", path = "../juniper", default-features = false } juniper_graphql_ws = { version = "0.4.0-dev", path = "../juniper_graphql_ws", features = ["graphql-transport-ws", "graphql-ws"], optional = true } log = { version = "0.4", optional = true } serde = { version = "1.0.122", features = ["derive"] } serde_json = "1.0.18" thiserror = "1.0" -tokio = { version = "1.0", features = ["rt-multi-thread"] } +tokio = { version = "1.0", features = ["rt"] } warp = { version = "0.3.2", default-features = false } # Fixes for `minimal-versions` check. @@ -44,6 +47,7 @@ headers = "0.3.8" [dev-dependencies] async-stream = "0.3" env_logger = "0.10" +futures = "0.3.22" juniper = { version = "0.16.0-dev", path = "../juniper", features = ["expose-test-schema"] } log = "0.4" percent-encoding = "2.1" diff --git a/juniper_warp/LICENSE b/juniper_warp/LICENSE index 9374ec2c..b949ce9c 100644 --- a/juniper_warp/LICENSE +++ b/juniper_warp/LICENSE @@ -1,6 +1,6 @@ BSD 2-Clause License -Copyright (c) 2018-2022, Tom Houlé +Copyright (c) 2018-2023, Tom Houlé, Kai Ren All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/juniper_warp/examples/subscription.rs b/juniper_warp/examples/subscription.rs index 73b2c9fe..8ceb9049 100644 --- a/juniper_warp/examples/subscription.rs +++ b/juniper_warp/examples/subscription.rs @@ -160,7 +160,7 @@ async fn main() { .and(warp::path("graphql")) .and(juniper_warp::make_graphql_filter( schema.clone(), - warp::any().map(|| Context).boxed(), + warp::any().map(|| Context), ))) .or( warp::path("subscriptions").and(juniper_warp::subscriptions::make_ws_filter( diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs index 676f2cd9..f0f5bda0 100644 --- a/juniper_warp/src/lib.rs +++ b/juniper_warp/src/lib.rs @@ -1,34 +1,45 @@ #![doc = include_str!("../README.md")] #![cfg_attr(docsrs, feature(doc_cfg))] #![deny(missing_docs)] -#![deny(warnings)] -use std::{collections::HashMap, str, sync::Arc}; +mod response; +#[cfg(feature = "subscriptions")] +pub mod subscriptions; + +use std::{collections::HashMap, fmt, str, sync::Arc}; -use anyhow::anyhow; -use futures::{FutureExt as _, TryFutureExt as _}; use juniper::{ http::{GraphQLBatchRequest, GraphQLRequest}, ScalarValue, }; use tokio::task; -use warp::{body, filters::BoxedFilter, http, hyper::body::Bytes, query, Filter}; +use warp::{ + body::{self, BodyDeserializeError}, + http::{self, StatusCode}, + hyper::body::Bytes, + query, + reject::{self, Reject, Rejection}, + reply::{self, Reply}, + Filter, +}; -/// Makes a filter for GraphQL queries/mutations. +use self::response::JuniperResponse; + +/// Makes a [`Filter`] for handling GraphQL queries/mutations. /// /// The `schema` argument is your [`juniper`] schema. /// -/// The `context_extractor` argument should be a filter that provides the GraphQL context required by the schema. -/// -/// In order to avoid blocking, this helper will use the `tokio_threadpool` threadpool created by hyper to resolve GraphQL requests. +/// The `context_extractor` argument should be a [`Filter`] that provides the GraphQL context, +/// required by the `schema`. /// /// # Example /// /// ```rust /// # use std::sync::Arc; -/// # use warp::Filter; +/// # /// # use juniper::{graphql_object, EmptyMutation, EmptySubscription, RootNode}; /// # use juniper_warp::make_graphql_filter; +/// # use warp::Filter as _; /// # /// type UserId = String; /// # #[derive(Debug)] @@ -60,19 +71,170 @@ use warp::{body, filters::BoxedFilter, http, hyper::body::Bytes, query, Filter}; /// .map(|auth_header: String, app_state: Arc| { /// let user_id = auth_header; // we believe them /// ExampleContext(app_state, user_id) -/// }) -/// .boxed(); -/// -/// let graphql_filter = make_graphql_filter(schema, context_extractor); +/// }); /// /// let graphql_endpoint = warp::path("graphql") -/// .and(warp::post()) -/// .and(graphql_filter); +/// .and(make_graphql_filter(schema, context_extractor)); /// ``` -pub fn make_graphql_filter( +/// +/// # Fallible `context_extractor` +/// +/// > __WARNING__: In case the `context_extractor` is fallible (e.g. implements +/// > [`Filter`]``), it's error should be handled via +/// > [`Filter::recover()`] to fails fast and avoid switching to other [`Filter`]s +/// > branches, because [`Rejection` doesn't mean to abort the whole request, but +/// > rather to say that a `Filter` couldn't fulfill its preconditions][1]. +/// ```rust +/// # use std::sync::Arc; +/// # +/// # use juniper::{graphql_object, EmptyMutation, EmptySubscription, RootNode}; +/// # use juniper_warp::make_graphql_filter; +/// # use warp::{http, Filter as _, Reply as _}; +/// # +/// # type UserId = String; +/// # #[derive(Debug)] +/// # struct AppState(Vec); +/// # struct ExampleContext(Arc, UserId); +/// # impl juniper::Context for ExampleContext {} +/// # +/// # struct QueryRoot; +/// # +/// # #[graphql_object(context = ExampleContext)] +/// # impl QueryRoot { +/// # fn say_hello(context: &ExampleContext) -> String { +/// # format!( +/// # "good morning {}, the app state is {:?}", +/// # context.1, +/// # context.0, +/// # ) +/// # } +/// # } +/// # +/// #[derive(Clone, Copy, Debug)] +/// struct NotAuthorized; +/// +/// impl warp::reject::Reject for NotAuthorized {} +/// +/// impl warp::Reply for NotAuthorized { +/// fn into_response(self) -> warp::reply::Response { +/// http::StatusCode::FORBIDDEN.into_response() +/// } +/// } +/// +/// let schema = RootNode::new(QueryRoot, EmptyMutation::new(), EmptySubscription::new()); +/// +/// let app_state = Arc::new(AppState(vec![3, 4, 5])); +/// let app_state = warp::any().map(move || app_state.clone()); +/// +/// let context_extractor = warp::any() +/// .and(warp::header::("authorization")) +/// .and(app_state) +/// .and_then(|auth_header: String, app_state: Arc| async move { +/// if auth_header == "correct" { +/// Ok(ExampleContext(app_state, auth_header)) +/// } else { +/// Err(warp::reject::custom(NotAuthorized)) +/// } +/// }); +/// +/// let graphql_endpoint = warp::path("graphql") +/// .and(make_graphql_filter(schema, context_extractor)) +/// .recover(|rejection: warp::reject::Rejection| async move { +/// rejection +/// .find::() +/// .map(|e| e.into_response()) +/// .ok_or(rejection) +/// }); +/// ``` +/// +/// [1]: https://github.com/seanmonstar/warp/issues/388#issuecomment-576453485 +pub fn make_graphql_filter( schema: impl Into>>, - context_extractor: BoxedFilter<(CtxT,)>, -) -> BoxedFilter<(http::Response>,)> + context_extractor: impl Filter + Send + Sync + 'static, +) -> impl Filter + Clone + Send +where + Query: juniper::GraphQLTypeAsync + Send + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLTypeAsync + Send + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLSubscriptionType + Send + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Send + Sync + 'static, + CtxErr: Into, + S: ScalarValue + Send + Sync + 'static, +{ + let schema = schema.into(); + // At the moment, `warp` doesn't allow us to make `context_extractor` filter polymorphic over + // its `Error` type to support both `Error = Infallible` and `Error = Rejection` filters at the + // same time. This is due to the `CombinedRejection` trait and the `FilterBase::map_err()` + // combinator being sealed inside `warp` as private items. The only way to have input type + // polymorphism for `Filter::Error` type is a `BoxedFilter`, which handles it internally. + // See more in the following issues: + // https://github.com/seanmonstar/warp/issues/299 + let context_extractor = context_extractor.boxed(); + + get_query_extractor::() + .or(post_json_extractor::()) + .unify() + .or(post_graphql_extractor::()) + .unify() + .and(warp::any().map(move || schema.clone())) + .and(context_extractor) + .then(graphql_handler::) + .recover(handle_rejects) + .unify() +} + +/// Same as [`make_graphql_filter()`], but for [executing synchronously][1]. +/// +/// > __NOTE__: In order to avoid blocking, this handler will use [`tokio::task::spawn_blocking()`] +/// > on the runtime [`warp`] is running on. +/// +/// [1]: GraphQLBatchRequest::execute_sync +pub fn make_graphql_filter_sync( + schema: impl Into>>, + context_extractor: impl Filter + Send + Sync + 'static, +) -> impl Filter + Clone + Send +where + Query: juniper::GraphQLType + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLType + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Send + Sync + 'static, + CtxErr: Into, + S: ScalarValue + Send + Sync + 'static, +{ + let schema = schema.into(); + // At the moment, `warp` doesn't allow us to make `context_extractor` filter polymorphic over + // its `Error` type to support both `Error = Infallible` and `Error = Rejection` filters at the + // same time. This is due to the `CombinedRejection` trait and the `FilterBase::map_err()` + // combinator being sealed inside `warp` as private items. The only way to have input type + // polymorphism for `Filter::Error` type is a `BoxedFilter`, which handles it internally. + // See more in the following issues: + // https://github.com/seanmonstar/warp/issues/299 + let context_extractor = context_extractor.boxed(); + + get_query_extractor::() + .or(post_json_extractor::()) + .unify() + .or(post_graphql_extractor::()) + .unify() + .and(warp::any().map(move || schema.clone())) + .and(context_extractor) + .then(graphql_handler_sync::) + .recover(handle_rejects) + .unify() +} + +/// Executes the provided [`GraphQLBatchRequest`] against the provided `schema` in the provided +/// `context`. +async fn graphql_handler( + req: GraphQLBatchRequest, + schema: Arc>, + context: CtxT, +) -> reply::Response where Query: juniper::GraphQLTypeAsync + Send + 'static, Query::TypeInfo: Send + Sync, @@ -83,187 +245,140 @@ where CtxT: Send + Sync + 'static, S: ScalarValue + Send + Sync + 'static, { - let schema = schema.into(); - let post_json_schema = schema.clone(); - let post_graphql_schema = schema.clone(); - - let handle_post_json_request = move |context: CtxT, req: GraphQLBatchRequest| { - let schema = post_json_schema.clone(); - async move { - let resp = req.execute(&schema, &context).await; - - Ok::<_, warp::Rejection>(build_response( - serde_json::to_vec(&resp) - .map(|json| (json, resp.is_ok())) - .map_err(Into::into), - )) - } - }; - let post_json_filter = warp::post() - .and(context_extractor.clone()) - .and(body::json()) - .and_then(handle_post_json_request); - - let handle_post_graphql_request = move |context: CtxT, body: Bytes| { - let schema = post_graphql_schema.clone(); - async move { - let query = str::from_utf8(body.as_ref()) - .map_err(|e| anyhow!("Request body query is not a valid UTF-8 string: {e}"))?; - let req = GraphQLRequest::new(query.into(), None, None); - - let resp = req.execute(&schema, &context).await; - - Ok((serde_json::to_vec(&resp)?, resp.is_ok())) - } - .then(|res| async { Ok::<_, warp::Rejection>(build_response(res)) }) - }; - let post_graphql_filter = warp::post() - .and(context_extractor.clone()) - .and(body::bytes()) - .and_then(handle_post_graphql_request); - - let handle_get_request = move |context: CtxT, mut qry: HashMap| { - let schema = schema.clone(); - async move { - let req = GraphQLRequest::new( - qry.remove("query") - .ok_or_else(|| anyhow!("Missing GraphQL query string in query parameters"))?, - qry.remove("operation_name"), - qry.remove("variables") - .map(|vs| serde_json::from_str(&vs)) - .transpose()?, - ); - - let resp = req.execute(&schema, &context).await; - - Ok((serde_json::to_vec(&resp)?, resp.is_ok())) - } - .then(|res| async move { Ok::<_, warp::Rejection>(build_response(res)) }) - }; - let get_filter = warp::get() - .and(context_extractor) - .and(query::query()) - .and_then(handle_get_request); - - get_filter - .or(post_json_filter) - .unify() - .or(post_graphql_filter) - .unify() - .boxed() + let resp = req.execute(&*schema, &context).await; + JuniperResponse(resp).into_response() } -/// Make a synchronous filter for graphql endpoint. -pub fn make_graphql_filter_sync( - schema: impl Into>>, - context_extractor: BoxedFilter<(CtxT,)>, -) -> BoxedFilter<(http::Response>,)> +/// Same as [`graphql_handler()`], but for [executing synchronously][1]. +/// +/// [1]: GraphQLBatchRequest::execute_sync +async fn graphql_handler_sync( + req: GraphQLBatchRequest, + schema: Arc>, + context: CtxT, +) -> reply::Response where - Query: juniper::GraphQLType + Send + Sync + 'static, - Mutation: juniper::GraphQLType + Send + Sync + 'static, - Subscription: juniper::GraphQLType + Send + Sync + 'static, + Query: juniper::GraphQLType + Send + Sync + 'static, + Query::TypeInfo: Send + Sync, + Mutation: juniper::GraphQLType + Send + Sync + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: juniper::GraphQLType + Send + Sync + 'static, + Subscription::TypeInfo: Send + Sync, CtxT: Send + Sync + 'static, S: ScalarValue + Send + Sync + 'static, { - let schema = schema.into(); - let post_json_schema = schema.clone(); - let post_graphql_schema = schema.clone(); - - let handle_post_json_request = move |context: CtxT, req: GraphQLBatchRequest| { - let schema = post_json_schema.clone(); - async move { - let res = task::spawn_blocking(move || { - let resp = req.execute_sync(&schema, &context); - Ok((serde_json::to_vec(&resp)?, resp.is_ok())) - }) - .await?; - - Ok(build_response(res)) - } - .map_err(|e: task::JoinError| warp::reject::custom(JoinError(e))) - }; - let post_json_filter = warp::post() - .and(context_extractor.clone()) - .and(body::json()) - .and_then(handle_post_json_request); - - let handle_post_graphql_request = move |context: CtxT, body: Bytes| { - let schema = post_graphql_schema.clone(); - async move { - let res = task::spawn_blocking(move || { - let query = str::from_utf8(body.as_ref()) - .map_err(|e| anyhow!("Request body is not a valid UTF-8 string: {e}"))?; - let req = GraphQLRequest::new(query.into(), None, None); - - let resp = req.execute_sync(&schema, &context); - Ok((serde_json::to_vec(&resp)?, resp.is_ok())) - }) - .await?; - - Ok(build_response(res)) - } - .map_err(|e: task::JoinError| warp::reject::custom(JoinError(e))) - }; - let post_graphql_filter = warp::post() - .and(context_extractor.clone()) - .and(body::bytes()) - .and_then(handle_post_graphql_request); - - let handle_get_request = move |context: CtxT, mut qry: HashMap| { - let schema = schema.clone(); - async move { - let res = task::spawn_blocking(move || { - let req = GraphQLRequest::new( - qry.remove("query").ok_or_else(|| { - anyhow!("Missing GraphQL query string in query parameters") - })?, - qry.remove("operation_name"), - qry.remove("variables") - .map(|vs| serde_json::from_str(&vs)) - .transpose()?, - ); - - let resp = req.execute_sync(&schema, &context); - Ok((serde_json::to_vec(&resp)?, resp.is_ok())) - }) - .await?; - - Ok(build_response(res)) - } - .map_err(|e: task::JoinError| warp::reject::custom(JoinError(e))) - }; - let get_filter = warp::get() - .and(context_extractor) - .and(query::query()) - .and_then(handle_get_request); - - get_filter - .or(post_json_filter) - .unify() - .or(post_graphql_filter) - .unify() - .boxed() + task::spawn_blocking(move || req.execute_sync(&*schema, &context)) + .await + .map(|resp| JuniperResponse(resp).into_response()) + .unwrap_or_else(|e| BlockingError(e).into_response()) } -/// Error raised by `tokio_threadpool` if the thread pool has been shutdown. -/// -/// Wrapper type is needed as inner type does not implement `warp::reject::Reject`. +/// Extracts a [`GraphQLBatchRequest`] from a POST `application/json` HTTP request. +fn post_json_extractor( +) -> impl Filter,), Error = Rejection> + Clone + Send +where + S: ScalarValue + Send, +{ + warp::post().and(body::json()) +} + +/// Extracts a [`GraphQLBatchRequest`] from a POST `application/graphql` HTTP request. +fn post_graphql_extractor( +) -> impl Filter,), Error = Rejection> + Clone + Send +where + S: ScalarValue + Send, +{ + warp::post() + .and(body::bytes()) + .and_then(|body: Bytes| async move { + let query = str::from_utf8(body.as_ref()) + .map_err(|e| reject::custom(FilterError::NonUtf8Body(e)))?; + let req = GraphQLRequest::new(query.into(), None, None); + Ok::, Rejection>(GraphQLBatchRequest::Single(req)) + }) +} + +/// Extracts a [`GraphQLBatchRequest`] from a GET HTTP request. +fn get_query_extractor( +) -> impl Filter,), Error = Rejection> + Clone + Send +where + S: ScalarValue + Send, +{ + warp::get() + .and(query::query()) + .and_then(|mut qry: HashMap| async move { + let req = GraphQLRequest::new( + qry.remove("query") + .ok_or_else(|| reject::custom(FilterError::MissingPathQuery))?, + qry.remove("operation_name"), + qry.remove("variables") + .map(|vs| serde_json::from_str(&vs)) + .transpose() + .map_err(|e| reject::custom(FilterError::InvalidPathVariables(e)))?, + ); + Ok::, Rejection>(GraphQLBatchRequest::Single(req)) + }) +} + +/// Handles all the [`Rejection`]s happening in [`make_graphql_filter()`] to fail fast, if required. +async fn handle_rejects(rej: Rejection) -> Result { + let (status, msg) = if let Some(e) = rej.find::() { + (StatusCode::BAD_REQUEST, e.to_string()) + } else if let Some(e) = rej.find::() { + (StatusCode::BAD_REQUEST, e.to_string()) + } else if let Some(e) = rej.find::() { + (StatusCode::BAD_REQUEST, e.to_string()) + } else { + return Err(rej); + }; + + Ok(http::Response::builder() + .status(status) + .body(msg.into()) + .unwrap()) +} + +/// Possible errors happening in [`Filter`]s during [`GraphQLBatchRequest`] extraction. #[derive(Debug)] -pub struct JoinError(task::JoinError); +enum FilterError { + /// GET HTTP request misses query parameters. + MissingPathQuery, -impl warp::reject::Reject for JoinError {} + /// GET HTTP request contains ivalid `path` query parameter. + InvalidPathVariables(serde_json::Error), -fn build_response(response: Result<(Vec, bool), anyhow::Error>) -> http::Response> { - match response { - Ok((body, is_ok)) => http::Response::builder() - .status(if is_ok { 200 } else { 400 }) - .header("content-type", "application/json") - .body(body) - .expect("response is valid"), - Err(_) => http::Response::builder() - .status(http::StatusCode::INTERNAL_SERVER_ERROR) - .body(Vec::new()) - .expect("status code is valid"), + /// POST HTTP request contains non-UTF-8 body. + NonUtf8Body(str::Utf8Error), +} +impl Reject for FilterError {} + +impl fmt::Display for FilterError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MissingPathQuery => { + write!(f, "Missing GraphQL `query` string in query parameters") + } + Self::InvalidPathVariables(e) => write!( + f, + "Failed to deserialize GraphQL `variables` from JSON: {e}", + ), + Self::NonUtf8Body(e) => write!(f, "Request body is not a valid UTF-8 string: {e}"), + } + } +} + +/// Error raised by [`tokio::task::spawn_blocking()`] if the thread pool has been shutdown. +#[derive(Debug)] +struct BlockingError(task::JoinError); + +impl Reply for BlockingError { + fn into_response(self) -> reply::Response { + http::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(format!("Failed to execute synchronous GraphQL request: {}", self.0).into()) + .unwrap_or_else(|e| { + unreachable!("cannot build `reply::Response` out of `BlockingError`: {e}") + }) } } @@ -336,658 +451,284 @@ fn playground_response( .expect("response is valid") } -#[cfg(feature = "subscriptions")] -/// `juniper_warp` subscriptions handler implementation. -pub mod subscriptions { - use std::{convert::Infallible, fmt, sync::Arc}; - - use futures::{ - future::{self, Either}, - sink::SinkExt as _, - stream::StreamExt as _, - }; - use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue}; - use juniper_graphql_ws::{graphql_transport_ws, graphql_ws}; - use warp::{filters::BoxedFilter, reply::Reply, Filter as _}; - - struct Message(warp::ws::Message); - - impl TryFrom for graphql_ws::ClientMessage { - type Error = serde_json::Error; - - fn try_from(msg: Message) -> serde_json::Result { - if msg.0.is_close() { - Ok(Self::ConnectionTerminate) - } else { - serde_json::from_slice(msg.0.as_bytes()) - } - } - } - - impl TryFrom for graphql_transport_ws::Input { - type Error = serde_json::Error; - - fn try_from(msg: Message) -> serde_json::Result { - if msg.0.is_close() { - Ok(Self::Close) - } else { - serde_json::from_slice(msg.0.as_bytes()).map(Self::Message) - } - } - } - - /// Errors that can happen while serving a connection. - #[derive(Debug)] - pub enum Error { - /// Errors that can happen in Warp while serving a connection. - Warp(warp::Error), - - /// Errors that can happen while serializing outgoing messages. Note that errors that occur - /// while deserializing incoming messages are handled internally by the protocol. - Serde(serde_json::Error), - } - - impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Warp(e) => write!(f, "`warp` error: {e}"), - Self::Serde(e) => write!(f, "`serde` error: {e}"), - } - } - } - - impl std::error::Error for Error {} - - impl From for Error { - fn from(err: warp::Error) -> Self { - Self::Warp(err) - } - } - - impl From for Error { - fn from(_err: Infallible) -> Self { - unreachable!() - } - } - - /// Makes a filter for GraphQL subscriptions. - /// - /// This filter auto-selects between the - /// [legacy `graphql-ws` GraphQL over WebSocket Protocol][old] and the - /// [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new], based on the - /// `Sec-Websocket-Protocol` HTTP header value. - /// - /// The `schema` argument is your [`juniper`] schema. - /// - /// The `init` argument is used to provide the custom [`juniper::Context`] and additional - /// configuration for connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the - /// context and configuration are already known, or it can be a closure that gets executed - /// asynchronously whenever a client sends the subscription initialization message. Using a - /// closure allows to perform an authentication based on the parameters provided by a client. - /// - /// # Example - /// - /// ```rust - /// # use std::{convert::Infallible, pin::Pin, sync::Arc, time::Duration}; - /// # - /// # use futures::Stream; - /// # use juniper::{graphql_object, graphql_subscription, EmptyMutation, RootNode}; - /// # use juniper_graphql_ws::ConnectionConfig; - /// # use juniper_warp::make_graphql_filter; - /// # use warp::Filter as _; - /// # - /// type UserId = String; - /// # #[derive(Debug)] - /// struct AppState(Vec); - /// #[derive(Clone)] - /// struct ExampleContext(Arc, UserId); - /// # impl juniper::Context for ExampleContext {} - /// - /// struct QueryRoot; - /// - /// #[graphql_object(context = ExampleContext)] - /// impl QueryRoot { - /// fn say_hello(context: &ExampleContext) -> String { - /// format!( - /// "good morning {}, the app state is {:?}", - /// context.1, - /// context.0, - /// ) - /// } - /// } - /// - /// type StringsStream = Pin + Send>>; - /// - /// struct SubscriptionRoot; - /// - /// #[graphql_subscription(context = ExampleContext)] - /// impl SubscriptionRoot { - /// async fn say_hellos(context: &ExampleContext) -> StringsStream { - /// let mut interval = tokio::time::interval(Duration::from_secs(1)); - /// let context = context.clone(); - /// Box::pin(async_stream::stream! { - /// let mut counter = 0; - /// while counter < 5 { - /// counter += 1; - /// interval.tick().await; - /// yield format!( - /// "{counter}: good morning {}, the app state is {:?}", - /// context.1, - /// context.0, - /// ) - /// } - /// }) - /// } - /// } - /// - /// let schema = Arc::new(RootNode::new(QueryRoot, EmptyMutation::new(), SubscriptionRoot)); - /// let app_state = Arc::new(AppState(vec![3, 4, 5])); - /// let app_state_for_ws = app_state.clone(); - /// - /// let context_extractor = warp::any() - /// .and(warp::header::("authorization")) - /// .and(warp::any().map(move || app_state.clone())) - /// .map(|auth_header: String, app_state: Arc| { - /// let user_id = auth_header; // we believe them - /// ExampleContext(app_state, user_id) - /// }) - /// .boxed(); - /// - /// let graphql_endpoint = (warp::path("graphql") - /// .and(warp::post()) - /// .and(make_graphql_filter(schema.clone(), context_extractor))) - /// .or(warp::path("subscriptions") - /// .and(juniper_warp::subscriptions::make_ws_filter( - /// schema, - /// move |variables: juniper::Variables| { - /// let user_id = variables - /// .get("authorization") - /// .map(ToString::to_string) - /// .unwrap_or_default(); // we believe them - /// async move { - /// Ok::<_, Infallible>(ConnectionConfig::new( - /// ExampleContext(app_state_for_ws.clone(), user_id), - /// )) - /// } - /// }, - /// ))); - /// ``` - /// - /// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md - /// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md - pub fn make_ws_filter( - schema: impl Into>>, - init: I, - ) -> BoxedFilter<(impl Reply,)> - where - Query: GraphQLTypeAsync + Send + 'static, - Query::TypeInfo: Send + Sync, - Mutation: GraphQLTypeAsync + Send + 'static, - Mutation::TypeInfo: Send + Sync, - Subscription: GraphQLSubscriptionType + Send + 'static, - Subscription::TypeInfo: Send + Sync, - CtxT: Unpin + Send + Sync + 'static, - S: ScalarValue + Send + Sync + 'static, - I: juniper_graphql_ws::Init + Clone + Send + Sync, - { - let schema = schema.into(); - - warp::ws() - .and(warp::filters::header::value("sec-websocket-protocol")) - .map(move |ws: warp::ws::Ws, subproto| { - let schema = schema.clone(); - let init = init.clone(); - - let is_legacy = subproto == "graphql-ws"; - - warp::reply::with_header( - ws.on_upgrade(move |ws| async move { - if is_legacy { - serve_graphql_ws(ws, schema, init).await - } else { - serve_graphql_transport_ws(ws, schema, init).await - } - .unwrap_or_else(|e| { - log::error!("GraphQL over WebSocket Protocol error: {e}"); - }) - }), - "sec-websocket-protocol", - if is_legacy { - "graphql-ws" - } else { - "graphql-transport-ws" - }, - ) - }) - .boxed() - } - - /// Serves the [legacy `graphql-ws` GraphQL over WebSocket Protocol][old]. - /// - /// The `init` argument is used to provide the context and additional configuration for - /// connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the context and - /// configuration are already known, or it can be a closure that gets executed asynchronously - /// when the client sends the `GQL_CONNECTION_INIT` message. Using a closure allows to perform - /// an authentication based on the parameters provided by a client. - /// - /// > __WARNING__: This protocol has been deprecated in favor of the - /// [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new], which is - /// provided by the [`serve_graphql_transport_ws()`] function. - /// - /// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md - /// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md - pub async fn serve_graphql_ws( - websocket: warp::ws::WebSocket, - root_node: Arc>, - init: I, - ) -> Result<(), Error> - where - Query: GraphQLTypeAsync + Send + 'static, - Query::TypeInfo: Send + Sync, - Mutation: GraphQLTypeAsync + Send + 'static, - Mutation::TypeInfo: Send + Sync, - Subscription: GraphQLSubscriptionType + Send + 'static, - Subscription::TypeInfo: Send + Sync, - CtxT: Unpin + Send + Sync + 'static, - S: ScalarValue + Send + Sync + 'static, - I: juniper_graphql_ws::Init + Send, - { - let (ws_tx, ws_rx) = websocket.split(); - let (s_tx, s_rx) = - graphql_ws::Connection::new(juniper_graphql_ws::ArcSchema(root_node), init).split(); - - let ws_rx = ws_rx.map(|r| r.map(Message)); - let s_rx = s_rx.map(|msg| { - serde_json::to_string(&msg) - .map(warp::ws::Message::text) - .map_err(Error::Serde) - }); - - match future::select( - ws_rx.forward(s_tx.sink_err_into()), - s_rx.forward(ws_tx.sink_err_into()), - ) - .await - { - Either::Left((r, _)) => r.map_err(|e| e.into()), - Either::Right((r, _)) => r, - } - } - - /// Serves the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new]. - /// - /// The `init` argument is used to provide the context and additional configuration for - /// connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the context and - /// configuration are already known, or it can be a closure that gets executed asynchronously - /// when the client sends the `ConnectionInit` message. Using a closure allows to perform an - /// authentication based on the parameters provided by a client. - /// - /// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md - pub async fn serve_graphql_transport_ws( - websocket: warp::ws::WebSocket, - root_node: Arc>, - init: I, - ) -> Result<(), Error> - where - Query: GraphQLTypeAsync + Send + 'static, - Query::TypeInfo: Send + Sync, - Mutation: GraphQLTypeAsync + Send + 'static, - Mutation::TypeInfo: Send + Sync, - Subscription: GraphQLSubscriptionType + Send + 'static, - Subscription::TypeInfo: Send + Sync, - CtxT: Unpin + Send + Sync + 'static, - S: ScalarValue + Send + Sync + 'static, - I: juniper_graphql_ws::Init + Send, - { - let (ws_tx, ws_rx) = websocket.split(); - let (s_tx, s_rx) = - graphql_transport_ws::Connection::new(juniper_graphql_ws::ArcSchema(root_node), init) - .split(); - - let ws_rx = ws_rx.map(|r| r.map(Message)); - let s_rx = s_rx.map(|output| match output { - graphql_transport_ws::Output::Message(msg) => serde_json::to_string(&msg) - .map(warp::ws::Message::text) - .map_err(Error::Serde), - graphql_transport_ws::Output::Close { code, message } => { - Ok(warp::ws::Message::close_with(code, message)) - } - }); - - match future::select( - ws_rx.forward(s_tx.sink_err_into()), - s_rx.forward(ws_tx.sink_err_into()), - ) - .await - { - Either::Left((r, _)) => r.map_err(|e| e.into()), - Either::Right((r, _)) => r, - } - } -} - #[cfg(test)] mod tests { - use super::*; - use warp::{http, test::request}; + mod make_graphql_filter { + use std::future; - #[test] - fn graphiql_response_does_not_panic() { - graphiql_response("/abcd", None); - } - - #[tokio::test] - async fn graphiql_endpoint_matches() { - let filter = warp::get() - .and(warp::path("graphiql")) - .and(graphiql_filter("/graphql", None)); - let result = request() - .method("GET") - .path("/graphiql") - .header("accept", "text/html") - .filter(&filter) - .await; - - assert!(result.is_ok()); - } - - #[tokio::test] - async fn graphiql_endpoint_returns_graphiql_source() { - let filter = warp::get() - .and(warp::path("dogs-api")) - .and(warp::path("graphiql")) - .and(graphiql_filter("/dogs-api/graphql", None)); - let response = request() - .method("GET") - .path("/dogs-api/graphiql") - .header("accept", "text/html") - .reply(&filter) - .await; - - assert_eq!(response.status(), http::StatusCode::OK); - assert_eq!( - response.headers().get("content-type").unwrap(), - "text/html;charset=utf-8" - ); - let body = String::from_utf8(response.body().to_vec()).unwrap(); - - assert!(body.contains("var JUNIPER_URL = '/dogs-api/graphql';")); - } - - #[tokio::test] - async fn graphiql_endpoint_with_subscription_matches() { - let filter = warp::get().and(warp::path("graphiql")).and(graphiql_filter( - "/graphql", - Some("ws:://localhost:8080/subscriptions"), - )); - let result = request() - .method("GET") - .path("/graphiql") - .header("accept", "text/html") - .filter(&filter) - .await; - - assert!(result.is_ok()); - } - - #[tokio::test] - async fn playground_endpoint_matches() { - let filter = warp::get() - .and(warp::path("playground")) - .and(playground_filter("/graphql", Some("/subscripitons"))); - - let result = request() - .method("GET") - .path("/playground") - .header("accept", "text/html") - .filter(&filter) - .await; - - assert!(result.is_ok()); - } - - #[tokio::test] - async fn playground_endpoint_returns_playground_source() { - let filter = warp::get() - .and(warp::path("dogs-api")) - .and(warp::path("playground")) - .and(playground_filter( - "/dogs-api/graphql", - Some("/dogs-api/subscriptions"), - )); - let response = request() - .method("GET") - .path("/dogs-api/playground") - .header("accept", "text/html") - .reply(&filter) - .await; - - assert_eq!(response.status(), http::StatusCode::OK); - assert_eq!( - response.headers().get("content-type").unwrap(), - "text/html;charset=utf-8" - ); - let body = String::from_utf8(response.body().to_vec()).unwrap(); - - assert!(body.contains( - "endpoint: '/dogs-api/graphql', subscriptionEndpoint: '/dogs-api/subscriptions'", - )); - } - - #[tokio::test] - async fn graphql_handler_works_json_post() { use juniper::{ + http::GraphQLBatchRequest, tests::fixtures::starwars::schema::{Database, Query}, - EmptyMutation, EmptySubscription, RootNode, + EmptyMutation, EmptySubscription, + }; + use warp::{ + http, + reject::{self, Reject}, + test::request, + Filter as _, Reply, }; - type Schema = - juniper::RootNode<'static, Query, EmptyMutation, EmptySubscription>; + use super::super::make_graphql_filter; - let schema: Schema = RootNode::new( - Query, - EmptyMutation::::new(), - EmptySubscription::::new(), - ); - - let state = warp::any().map(Database::new); - let filter = warp::path("graphql2").and(make_graphql_filter(schema, state.boxed())); - - let response = request() - .method("POST") - .path("/graphql2") - .header("accept", "application/json") - .header("content-type", "application/json") - .body(r#"{ "variables": null, "query": "{ hero(episode: NEW_HOPE) { name } }" }"#) - .reply(&filter) - .await; - - assert_eq!(response.status(), http::StatusCode::OK); - assert_eq!( - response.headers().get("content-type").unwrap(), - "application/json", - ); - assert_eq!( - String::from_utf8(response.body().to_vec()).unwrap(), - r#"{"data":{"hero":{"name":"R2-D2"}}}"# - ); - } - - #[tokio::test] - async fn batch_requests_work() { - use juniper::{ - tests::fixtures::starwars::schema::{Database, Query}, - EmptyMutation, EmptySubscription, RootNode, - }; - - type Schema = - juniper::RootNode<'static, Query, EmptyMutation, EmptySubscription>; - - let schema: Schema = RootNode::new( - Query, - EmptyMutation::::new(), - EmptySubscription::::new(), - ); - - let state = warp::any().map(Database::new); - let filter = warp::path("graphql2").and(make_graphql_filter(schema, state.boxed())); - - let response = request() - .method("POST") - .path("/graphql2") - .header("accept", "application/json") - .header("content-type", "application/json") - .body( - r#"[ - { "variables": null, "query": "{ hero(episode: NEW_HOPE) { name } }" }, - { "variables": null, "query": "{ hero(episode: EMPIRE) { id name } }" } - ]"#, - ) - .reply(&filter) - .await; - - assert_eq!(response.status(), http::StatusCode::OK); - assert_eq!( - String::from_utf8(response.body().to_vec()).unwrap(), - r#"[{"data":{"hero":{"name":"R2-D2"}}},{"data":{"hero":{"id":"1000","name":"Luke Skywalker"}}}]"# - ); - assert_eq!( - response.headers().get("content-type").unwrap(), - "application/json", - ); - } - - #[test] - fn batch_request_deserialization_can_fail() { - let json = r#"blah"#; - let result: Result = serde_json::from_str(json); - - assert!(result.is_err()); - } -} - -#[cfg(test)] -mod tests_http_harness { - use juniper::{ - http::tests::{run_http_test_suite, HttpIntegration, TestResponse}, - tests::fixtures::starwars::schema::{Database, Query}, - EmptyMutation, EmptySubscription, RootNode, - }; - use warp::{ - filters::{path, BoxedFilter}, - Filter, - }; - - use super::*; - - struct TestWarpIntegration { - filter: BoxedFilter<(http::Response>,)>, - } - - impl TestWarpIntegration { - fn new(is_sync: bool) -> Self { - let schema = RootNode::new( + #[tokio::test] + async fn post_json() { + type Schema = juniper::RootNode< + 'static, Query, - EmptyMutation::::new(), - EmptySubscription::::new(), + EmptyMutation, + EmptySubscription, + >; + + let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new()); + + let db = warp::any().map(Database::new); + let filter = warp::path("graphql2").and(make_graphql_filter(schema, db)); + + let response = request() + .method("POST") + .path("/graphql2") + .header("accept", "application/json") + .header("content-type", "application/json") + .body(r#"{"variables": null, "query": "{ hero(episode: NEW_HOPE) { name } }"}"#) + .reply(&filter) + .await; + + assert_eq!(response.status(), http::StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "application/json", ); - let state = warp::any().map(Database::new); + assert_eq!( + String::from_utf8(response.body().to_vec()).unwrap(), + r#"{"data":{"hero":{"name":"R2-D2"}}}"#, + ); + } - let filter = path::end().and(if is_sync { - make_graphql_filter_sync(schema, state.boxed()) - } else { - make_graphql_filter(schema, state.boxed()) - }); - Self { - filter: filter.boxed(), + #[tokio::test] + async fn rejects_fast_when_context_extractor_fails() { + use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }; + + #[derive(Clone, Copy, Debug)] + struct ExtractionError; + + impl Reject for ExtractionError {} + + impl warp::Reply for ExtractionError { + fn into_response(self) -> warp::reply::Response { + http::StatusCode::IM_A_TEAPOT.into_response() + } } - } - fn make_request(&self, req: warp::test::RequestBuilder) -> TestResponse { - let rt = tokio::runtime::Runtime::new().expect("Failed to create tokio::Runtime"); - make_test_response(rt.block_on(async move { - req.filter(&self.filter).await.unwrap_or_else(|rejection| { - let code = if rejection.is_not_found() { - http::StatusCode::NOT_FOUND - } else if let Some(body::BodyDeserializeError { .. }) = rejection.find() { - http::StatusCode::BAD_REQUEST - } else { - http::StatusCode::INTERNAL_SERVER_ERROR - }; - http::Response::builder() - .status(code) - .header("content-type", "application/json") - .body(Vec::new()) - .unwrap() + type Schema = juniper::RootNode< + 'static, + Query, + EmptyMutation, + EmptySubscription, + >; + + let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new()); + + // Should error on first extraction only, to check whether it rejects fast and doesn't + // switch to other `.or()` filter branches. See #1177 for details: + // https://github.com/graphql-rust/juniper/issues/1177 + let is_called = Arc::new(AtomicBool::new(false)); + let context_extractor = warp::any().and_then(move || { + future::ready(if is_called.swap(true, Ordering::Relaxed) { + Ok(Database::new()) + } else { + Err(reject::custom(ExtractionError)) }) - })) + }); + + let filter = warp::path("graphql") + .and(make_graphql_filter(schema, context_extractor)) + .recover(|rejection: warp::reject::Rejection| async move { + rejection + .find::() + .map(|e| e.into_response()) + .ok_or(rejection) + }); + + let resp = request() + .method("POST") + .path("/graphql") + .header("accept", "application/json") + .header("content-type", "application/json") + .body(r#"{"variables": null, "query": "{ hero(episode: NEW_HOPE) { name } }"}"#) + .reply(&filter) + .await; + + assert_eq!( + resp.status(), + http::StatusCode::IM_A_TEAPOT, + "response: {resp:#?}", + ); + } + + #[tokio::test] + async fn batch_requests() { + type Schema = juniper::RootNode< + 'static, + Query, + EmptyMutation, + EmptySubscription, + >; + + let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new()); + + let db = warp::any().map(Database::new); + let filter = warp::path("graphql2").and(make_graphql_filter(schema, db)); + + let response = request() + .method("POST") + .path("/graphql2") + .header("accept", "application/json") + .header("content-type", "application/json") + .body( + r#"[ + {"variables": null, "query": "{ hero(episode: NEW_HOPE) { name } }"}, + {"variables": null, "query": "{ hero(episode: EMPIRE) { id name } }"} + ]"#, + ) + .reply(&filter) + .await; + + assert_eq!(response.status(), http::StatusCode::OK); + assert_eq!( + String::from_utf8(response.body().to_vec()).unwrap(), + r#"[{"data":{"hero":{"name":"R2-D2"}}},{"data":{"hero":{"id":"1000","name":"Luke Skywalker"}}}]"#, + ); + assert_eq!( + response.headers().get("content-type").unwrap(), + "application/json", + ); + } + + #[test] + fn batch_request_deserialization_can_fail() { + let json = r#"blah"#; + let result: Result = serde_json::from_str(json); + + assert!(result.is_err()); } } - impl HttpIntegration for TestWarpIntegration { - fn get(&self, url: &str) -> TestResponse { - use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; - use url::Url; + mod graphiql_filter { + use warp::{http, test::request, Filter as _}; - /// https://url.spec.whatwg.org/#query-state - const QUERY_ENCODE_SET: &AsciiSet = - &CONTROLS.add(b' ').add(b'"').add(b'#').add(b'<').add(b'>'); + use super::super::{graphiql_filter, graphiql_response}; - let url = Url::parse(&format!("http://localhost:3000{url}")).expect("url to parse"); - - let url: String = utf8_percent_encode(url.query().unwrap_or(""), QUERY_ENCODE_SET) - .collect::>() - .join(""); - - self.make_request( - warp::test::request() - .method("GET") - .path(&format!("/?{url}")), - ) + #[test] + fn response_does_not_panic() { + graphiql_response("/abcd", None); } - fn post_json(&self, url: &str, body: &str) -> TestResponse { - self.make_request( - warp::test::request() - .method("POST") - .header("content-type", "application/json; charset=utf-8") - .path(url) - .body(body), - ) + #[tokio::test] + async fn endpoint_matches() { + let filter = warp::get() + .and(warp::path("graphiql")) + .and(graphiql_filter("/graphql", None)); + let result = request() + .method("GET") + .path("/graphiql") + .header("accept", "text/html") + .filter(&filter) + .await; + + assert!(result.is_ok()); } - fn post_graphql(&self, url: &str, body: &str) -> TestResponse { - self.make_request( - warp::test::request() - .method("POST") - .header("content-type", "application/graphql; charset=utf-8") - .path(url) - .body(body), - ) + #[tokio::test] + async fn returns_graphiql_source() { + let filter = warp::get() + .and(warp::path("dogs-api")) + .and(warp::path("graphiql")) + .and(graphiql_filter("/dogs-api/graphql", None)); + let response = request() + .method("GET") + .path("/dogs-api/graphiql") + .header("accept", "text/html") + .reply(&filter) + .await; + + assert_eq!(response.status(), http::StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/html;charset=utf-8" + ); + let body = String::from_utf8(response.body().to_vec()).unwrap(); + + assert!(body.contains("var JUNIPER_URL = '/dogs-api/graphql';")); + } + + #[tokio::test] + async fn endpoint_with_subscription_matches() { + let filter = warp::get().and(warp::path("graphiql")).and(graphiql_filter( + "/graphql", + Some("ws:://localhost:8080/subscriptions"), + )); + let result = request() + .method("GET") + .path("/graphiql") + .header("accept", "text/html") + .filter(&filter) + .await; + + assert!(result.is_ok()); } } - fn make_test_response(resp: http::Response>) -> TestResponse { - TestResponse { - status_code: resp.status().as_u16() as i32, - body: Some(String::from_utf8(resp.body().to_owned()).unwrap()), - content_type: resp - .headers() - .get("content-type") - .expect("missing content-type header in warp response") - .to_str() - .expect("invalid content-type string") - .into(), + mod playground_filter { + use warp::{http, test::request, Filter as _}; + + use super::super::playground_filter; + + #[tokio::test] + async fn endpoint_matches() { + let filter = warp::get() + .and(warp::path("playground")) + .and(playground_filter("/graphql", Some("/subscripitons"))); + + let result = request() + .method("GET") + .path("/playground") + .header("accept", "text/html") + .filter(&filter) + .await; + + assert!(result.is_ok()); } - } - #[test] - fn test_warp_integration() { - run_http_test_suite(&TestWarpIntegration::new(false)); - } + #[tokio::test] + async fn returns_playground_source() { + let filter = warp::get() + .and(warp::path("dogs-api")) + .and(warp::path("playground")) + .and(playground_filter( + "/dogs-api/graphql", + Some("/dogs-api/subscriptions"), + )); + let response = request() + .method("GET") + .path("/dogs-api/playground") + .header("accept", "text/html") + .reply(&filter) + .await; - #[test] - fn test_sync_warp_integration() { - run_http_test_suite(&TestWarpIntegration::new(true)); + assert_eq!(response.status(), http::StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/html;charset=utf-8" + ); + + let body = String::from_utf8(response.body().to_vec()).unwrap(); + + assert!(body.contains( + "endpoint: '/dogs-api/graphql', subscriptionEndpoint: '/dogs-api/subscriptions'", + )); + } } } diff --git a/juniper_warp/src/response.rs b/juniper_warp/src/response.rs new file mode 100644 index 00000000..0438f4e1 --- /dev/null +++ b/juniper_warp/src/response.rs @@ -0,0 +1,37 @@ +//! [`JuniperResponse`] definition. + +use juniper::{http::GraphQLBatchResponse, DefaultScalarValue, ScalarValue}; +use warp::{ + http::{self, StatusCode}, + reply::{self, Reply}, +}; + +/// Wrapper around a [`GraphQLBatchResponse`], implementing [`warp::Reply`], so it can be returned +/// from [`warp`] handlers. +pub(crate) struct JuniperResponse(pub(crate) GraphQLBatchResponse) +where + S: ScalarValue; + +impl Reply for JuniperResponse +where + S: ScalarValue + Send, +{ + fn into_response(self) -> reply::Response { + match serde_json::to_vec(&self.0) { + Ok(json) => http::Response::builder() + .status(if self.0.is_ok() { + StatusCode::OK + } else { + StatusCode::BAD_REQUEST + }) + .header("content-type", "application/json") + .body(json.into()), + Err(e) => http::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(e.to_string().into()), + } + .unwrap_or_else(|e| { + unreachable!("cannot build `reply::Response` out of `JuniperResponse`: {e}") + }) + } +} diff --git a/juniper_warp/src/subscriptions.rs b/juniper_warp/src/subscriptions.rs new file mode 100644 index 00000000..6511c235 --- /dev/null +++ b/juniper_warp/src/subscriptions.rs @@ -0,0 +1,327 @@ +//! GraphQL subscriptions handler implementation. + +use std::{convert::Infallible, fmt, sync::Arc}; + +use futures::{ + future::{self, Either}, + sink::SinkExt as _, + stream::StreamExt as _, +}; +use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue}; +use juniper_graphql_ws::{graphql_transport_ws, graphql_ws}; +use warp::{filters::BoxedFilter, reply::Reply, Filter as _}; + +struct Message(warp::ws::Message); + +impl TryFrom for graphql_ws::ClientMessage { + type Error = serde_json::Error; + + fn try_from(msg: Message) -> serde_json::Result { + if msg.0.is_close() { + Ok(Self::ConnectionTerminate) + } else { + serde_json::from_slice(msg.0.as_bytes()) + } + } +} + +impl TryFrom for graphql_transport_ws::Input { + type Error = serde_json::Error; + + fn try_from(msg: Message) -> serde_json::Result { + if msg.0.is_close() { + Ok(Self::Close) + } else { + serde_json::from_slice(msg.0.as_bytes()).map(Self::Message) + } + } +} + +/// Errors that can happen while serving a connection. +#[derive(Debug)] +pub enum Error { + /// Errors that can happen in Warp while serving a connection. + Warp(warp::Error), + + /// Errors that can happen while serializing outgoing messages. Note that errors that occur + /// while deserializing incoming messages are handled internally by the protocol. + Serde(serde_json::Error), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Warp(e) => write!(f, "`warp` error: {e}"), + Self::Serde(e) => write!(f, "`serde` error: {e}"), + } + } +} + +impl std::error::Error for Error {} + +impl From for Error { + fn from(err: warp::Error) -> Self { + Self::Warp(err) + } +} + +impl From for Error { + fn from(_err: Infallible) -> Self { + unreachable!() + } +} + +/// Makes a filter for GraphQL subscriptions. +/// +/// This filter auto-selects between the +/// [legacy `graphql-ws` GraphQL over WebSocket Protocol][old] and the +/// [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new], based on the +/// `Sec-Websocket-Protocol` HTTP header value. +/// +/// The `schema` argument is your [`juniper`] schema. +/// +/// The `init` argument is used to provide the custom [`juniper::Context`] and additional +/// configuration for connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the +/// context and configuration are already known, or it can be a closure that gets executed +/// asynchronously whenever a client sends the subscription initialization message. Using a +/// closure allows to perform an authentication based on the parameters provided by a client. +/// +/// # Example +/// +/// ```rust +/// # use std::{convert::Infallible, pin::Pin, sync::Arc, time::Duration}; +/// # +/// # use futures::Stream; +/// # use juniper::{graphql_object, graphql_subscription, EmptyMutation, RootNode}; +/// # use juniper_graphql_ws::ConnectionConfig; +/// # use juniper_warp::make_graphql_filter; +/// # use warp::Filter as _; +/// # +/// type UserId = String; +/// # #[derive(Debug)] +/// struct AppState(Vec); +/// #[derive(Clone)] +/// struct ExampleContext(Arc, UserId); +/// # impl juniper::Context for ExampleContext {} +/// +/// struct QueryRoot; +/// +/// #[graphql_object(context = ExampleContext)] +/// impl QueryRoot { +/// fn say_hello(context: &ExampleContext) -> String { +/// format!( +/// "good morning {}, the app state is {:?}", +/// context.1, +/// context.0, +/// ) +/// } +/// } +/// +/// type StringsStream = Pin + Send>>; +/// +/// struct SubscriptionRoot; +/// +/// #[graphql_subscription(context = ExampleContext)] +/// impl SubscriptionRoot { +/// async fn say_hellos(context: &ExampleContext) -> StringsStream { +/// let mut interval = tokio::time::interval(Duration::from_secs(1)); +/// let context = context.clone(); +/// Box::pin(async_stream::stream! { +/// let mut counter = 0; +/// while counter < 5 { +/// counter += 1; +/// interval.tick().await; +/// yield format!( +/// "{counter}: good morning {}, the app state is {:?}", +/// context.1, +/// context.0, +/// ) +/// } +/// }) +/// } +/// } +/// +/// let schema = Arc::new(RootNode::new(QueryRoot, EmptyMutation::new(), SubscriptionRoot)); +/// let app_state = Arc::new(AppState(vec![3, 4, 5])); +/// let app_state_for_ws = app_state.clone(); +/// +/// let context_extractor = warp::any() +/// .and(warp::header::("authorization")) +/// .and(warp::any().map(move || app_state.clone())) +/// .map(|auth_header: String, app_state: Arc| { +/// let user_id = auth_header; // we believe them +/// ExampleContext(app_state, user_id) +/// }) +/// .boxed(); +/// +/// let graphql_endpoint = (warp::path("graphql") +/// .and(warp::post()) +/// .and(make_graphql_filter(schema.clone(), context_extractor))) +/// .or(warp::path("subscriptions") +/// .and(juniper_warp::subscriptions::make_ws_filter( +/// schema, +/// move |variables: juniper::Variables| { +/// let user_id = variables +/// .get("authorization") +/// .map(ToString::to_string) +/// .unwrap_or_default(); // we believe them +/// async move { +/// Ok::<_, Infallible>(ConnectionConfig::new( +/// ExampleContext(app_state_for_ws.clone(), user_id), +/// )) +/// } +/// }, +/// ))); +/// ``` +/// +/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md +/// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md +pub fn make_ws_filter( + schema: impl Into>>, + init: I, +) -> BoxedFilter<(impl Reply,)> +where + Query: GraphQLTypeAsync + Send + 'static, + Query::TypeInfo: Send + Sync, + Mutation: GraphQLTypeAsync + Send + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: GraphQLSubscriptionType + Send + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, + I: juniper_graphql_ws::Init + Clone + Send + Sync, +{ + let schema = schema.into(); + + warp::ws() + .and(warp::filters::header::value("sec-websocket-protocol")) + .map(move |ws: warp::ws::Ws, subproto| { + let schema = schema.clone(); + let init = init.clone(); + + let is_legacy = subproto == "graphql-ws"; + + warp::reply::with_header( + ws.on_upgrade(move |ws| async move { + if is_legacy { + serve_graphql_ws(ws, schema, init).await + } else { + serve_graphql_transport_ws(ws, schema, init).await + } + .unwrap_or_else(|e| { + log::error!("GraphQL over WebSocket Protocol error: {e}"); + }) + }), + "sec-websocket-protocol", + if is_legacy { + "graphql-ws" + } else { + "graphql-transport-ws" + }, + ) + }) + .boxed() +} + +/// Serves the [legacy `graphql-ws` GraphQL over WebSocket Protocol][old]. +/// +/// The `init` argument is used to provide the context and additional configuration for +/// connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the context and +/// configuration are already known, or it can be a closure that gets executed asynchronously +/// when the client sends the `GQL_CONNECTION_INIT` message. Using a closure allows to perform +/// an authentication based on the parameters provided by a client. +/// +/// > __WARNING__: This protocol has been deprecated in favor of the +/// [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new], which is +/// provided by the [`serve_graphql_transport_ws()`] function. +/// +/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md +/// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md +pub async fn serve_graphql_ws( + websocket: warp::ws::WebSocket, + root_node: Arc>, + init: I, +) -> Result<(), Error> +where + Query: GraphQLTypeAsync + Send + 'static, + Query::TypeInfo: Send + Sync, + Mutation: GraphQLTypeAsync + Send + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: GraphQLSubscriptionType + Send + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, + I: juniper_graphql_ws::Init + Send, +{ + let (ws_tx, ws_rx) = websocket.split(); + let (s_tx, s_rx) = + graphql_ws::Connection::new(juniper_graphql_ws::ArcSchema(root_node), init).split(); + + let ws_rx = ws_rx.map(|r| r.map(Message)); + let s_rx = s_rx.map(|msg| { + serde_json::to_string(&msg) + .map(warp::ws::Message::text) + .map_err(Error::Serde) + }); + + match future::select( + ws_rx.forward(s_tx.sink_err_into()), + s_rx.forward(ws_tx.sink_err_into()), + ) + .await + { + Either::Left((r, _)) => r.map_err(|e| e.into()), + Either::Right((r, _)) => r, + } +} + +/// Serves the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new]. +/// +/// The `init` argument is used to provide the context and additional configuration for +/// connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the context and +/// configuration are already known, or it can be a closure that gets executed asynchronously +/// when the client sends the `ConnectionInit` message. Using a closure allows to perform an +/// authentication based on the parameters provided by a client. +/// +/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md +pub async fn serve_graphql_transport_ws( + websocket: warp::ws::WebSocket, + root_node: Arc>, + init: I, +) -> Result<(), Error> +where + Query: GraphQLTypeAsync + Send + 'static, + Query::TypeInfo: Send + Sync, + Mutation: GraphQLTypeAsync + Send + 'static, + Mutation::TypeInfo: Send + Sync, + Subscription: GraphQLSubscriptionType + Send + 'static, + Subscription::TypeInfo: Send + Sync, + CtxT: Unpin + Send + Sync + 'static, + S: ScalarValue + Send + Sync + 'static, + I: juniper_graphql_ws::Init + Send, +{ + let (ws_tx, ws_rx) = websocket.split(); + let (s_tx, s_rx) = + graphql_transport_ws::Connection::new(juniper_graphql_ws::ArcSchema(root_node), init) + .split(); + + let ws_rx = ws_rx.map(|r| r.map(Message)); + let s_rx = s_rx.map(|output| match output { + graphql_transport_ws::Output::Message(msg) => serde_json::to_string(&msg) + .map(warp::ws::Message::text) + .map_err(Error::Serde), + graphql_transport_ws::Output::Close { code, message } => { + Ok(warp::ws::Message::close_with(code, message)) + } + }); + + match future::select( + ws_rx.forward(s_tx.sink_err_into()), + s_rx.forward(ws_tx.sink_err_into()), + ) + .await + { + Either::Left((r, _)) => r.map_err(|e| e.into()), + Either::Right((r, _)) => r, + } +} diff --git a/juniper_warp/tests/http_test_suite.rs b/juniper_warp/tests/http_test_suite.rs new file mode 100644 index 00000000..da07eead --- /dev/null +++ b/juniper_warp/tests/http_test_suite.rs @@ -0,0 +1,143 @@ +use futures::TryStreamExt as _; +use juniper::{ + http::tests::{run_http_test_suite, HttpIntegration, TestResponse}, + tests::fixtures::starwars::schema::{Database, Query}, + EmptyMutation, EmptySubscription, RootNode, +}; +use juniper_warp::{make_graphql_filter, make_graphql_filter_sync}; +use warp::{ + body, + filters::{path, BoxedFilter}, + http, reply, Filter as _, +}; + +struct TestWarpIntegration { + filter: BoxedFilter<(reply::Response,)>, +} + +impl TestWarpIntegration { + fn new(is_sync: bool) -> Self { + let schema = RootNode::new( + Query, + EmptyMutation::::new(), + EmptySubscription::::new(), + ); + let db = warp::any().map(Database::new); + + Self { + filter: path::end() + .and(if is_sync { + make_graphql_filter_sync(schema, db).boxed() + } else { + make_graphql_filter(schema, db).boxed() + }) + .boxed(), + } + } + + fn make_request(&self, req: warp::test::RequestBuilder) -> TestResponse { + let rt = tokio::runtime::Runtime::new() + .unwrap_or_else(|e| panic!("failed to create `tokio::Runtime`: {e}")); + rt.block_on(async move { + make_test_response(req.filter(&self.filter).await.unwrap_or_else(|rejection| { + let code = if rejection.is_not_found() { + http::StatusCode::NOT_FOUND + } else if let Some(body::BodyDeserializeError { .. }) = rejection.find() { + http::StatusCode::BAD_REQUEST + } else { + http::StatusCode::INTERNAL_SERVER_ERROR + }; + http::Response::builder() + .status(code) + .header("content-type", "application/json") + .body(Vec::new().into()) + .unwrap() + })) + .await + }) + } +} + +impl HttpIntegration for TestWarpIntegration { + fn get(&self, url: &str) -> TestResponse { + use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; + use url::Url; + + /// https://url.spec.whatwg.org/#query-state + const QUERY_ENCODE_SET: &AsciiSet = + &CONTROLS.add(b' ').add(b'"').add(b'#').add(b'<').add(b'>'); + + let url = Url::parse(&format!("http://localhost:3000{url}")).expect("url to parse"); + + let url: String = utf8_percent_encode(url.query().unwrap_or(""), QUERY_ENCODE_SET) + .collect::>() + .join(""); + + self.make_request( + warp::test::request() + .method("GET") + .path(&format!("/?{url}")), + ) + } + + fn post_json(&self, url: &str, body: &str) -> TestResponse { + self.make_request( + warp::test::request() + .method("POST") + .header("content-type", "application/json; charset=utf-8") + .path(url) + .body(body), + ) + } + + fn post_graphql(&self, url: &str, body: &str) -> TestResponse { + self.make_request( + warp::test::request() + .method("POST") + .header("content-type", "application/graphql; charset=utf-8") + .path(url) + .body(body), + ) + } +} + +async fn make_test_response(resp: reply::Response) -> TestResponse { + let (parts, body) = resp.into_parts(); + + let status_code = parts.status.as_u16().into(); + + let content_type = parts + .headers + .get("content-type") + .map(|header| { + header + .to_str() + .unwrap_or_else(|e| panic!("not UTF-8 header: {e}")) + .to_owned() + }) + .unwrap_or_default(); + + let body = String::from_utf8( + body.map_ok(|bytes| bytes.to_vec()) + .try_concat() + .await + .unwrap(), + ) + .unwrap_or_else(|e| panic!("not UTF-8 body: {e}")); + + TestResponse { + status_code, + content_type, + body: Some(body), + } +} + +#[test] +fn test_warp_integration() { + run_http_test_suite(&TestWarpIntegration::new(false)); +} + +#[test] +fn test_sync_warp_integration() { + run_http_test_suite(&TestWarpIntegration::new(true)); +}