Fix Rejection handling in juniper_warp filters (#1222, #1177)

- rework `make_graphql_filter()` and `make_graphql_filter_sync()` to execute `context_extractor` only once
- handle all non-recoverable `Rejection`s in `make_graphql_filter()` and `make_graphql_filter_sync()`
- relax requirement for `context_extractor` to be a `BoxedFilter` only
- remove `JoinError` from public API
- provide example of fallible `context_extractor` in `make_graphql_filter()` API docs

Additionally:
- split  `juniper_warp` modules into separate files
- add @tyranron as `juniper_warp` co-author
This commit is contained in:
Kai Ren 2023-11-23 18:39:12 +01:00 committed by GitHub
parent 904d9cdbb6
commit 4ef8cf7de9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 1079 additions and 819 deletions

View file

@ -11,6 +11,7 @@ All user visible changes to `juniper_warp` crate will be documented in this file
### BC Breaks ### BC Breaks
- Switched to 0.16 version of [`juniper` crate]. - Switched to 0.16 version of [`juniper` crate].
- Removed `JoinError` from public API. ([#1222], [#1177])
### Added ### Added
@ -20,11 +21,18 @@ All user visible changes to `juniper_warp` crate will be documented in this file
### Changed ### Changed
- Made `schema` argument of `make_graphql_filter()` and `make_graphql_filter_sync()` polymorphic, allowing to specify external `Arc`ed `schema`. ([#1136], [#1135]) - Made `schema` argument of `make_graphql_filter()` and `make_graphql_filter_sync()` polymorphic, allowing to specify external `Arc`ed `schema`. ([#1136], [#1135])
- Relaxed requirement for `context_extractor` to be a `BoxedFilter` only. ([#1222], [#1177])
### Fixed
- Excessive `context_extractor` execution in `make_graphql_filter()` and `make_graphql_filter_sync()`. ([#1222], [#1177])
[#1135]: /../../issues/1136 [#1135]: /../../issues/1136
[#1136]: /../../pull/1136 [#1136]: /../../pull/1136
[#1158]: /../../pull/1158 [#1158]: /../../pull/1158
[#1177]: /../../issues/1177
[#1191]: /../../pull/1191 [#1191]: /../../pull/1191
[#1222]: /../../pull/1222

View file

@ -5,7 +5,10 @@ edition = "2021"
rust-version = "1.73" rust-version = "1.73"
description = "`juniper` GraphQL integration with `warp`." description = "`juniper` GraphQL integration with `warp`."
license = "BSD-2-Clause" license = "BSD-2-Clause"
authors = ["Tom Houlé <tom@tomhoule.com>"] authors = [
"Tom Houlé <tom@tomhoule.com>",
"Kai Ren <tyranron@gmail.com>",
]
documentation = "https://docs.rs/juniper_warp" documentation = "https://docs.rs/juniper_warp"
homepage = "https://github.com/graphql-rust/juniper/tree/master/juniper_warp" homepage = "https://github.com/graphql-rust/juniper/tree/master/juniper_warp"
repository = "https://github.com/graphql-rust/juniper" repository = "https://github.com/graphql-rust/juniper"
@ -20,21 +23,21 @@ rustdoc-args = ["--cfg", "docsrs"]
[features] [features]
subscriptions = [ subscriptions = [
"dep:futures",
"dep:juniper_graphql_ws", "dep:juniper_graphql_ws",
"dep:log", "dep:log",
"warp/websocket", "warp/websocket",
] ]
[dependencies] [dependencies]
anyhow = "1.0.47" futures = { version = "0.3.22", optional = true }
futures = "0.3.22"
juniper = { version = "0.16.0-dev", path = "../juniper", default-features = false } juniper = { version = "0.16.0-dev", path = "../juniper", default-features = false }
juniper_graphql_ws = { version = "0.4.0-dev", path = "../juniper_graphql_ws", features = ["graphql-transport-ws", "graphql-ws"], optional = true } juniper_graphql_ws = { version = "0.4.0-dev", path = "../juniper_graphql_ws", features = ["graphql-transport-ws", "graphql-ws"], optional = true }
log = { version = "0.4", optional = true } log = { version = "0.4", optional = true }
serde = { version = "1.0.122", features = ["derive"] } serde = { version = "1.0.122", features = ["derive"] }
serde_json = "1.0.18" serde_json = "1.0.18"
thiserror = "1.0" thiserror = "1.0"
tokio = { version = "1.0", features = ["rt-multi-thread"] } tokio = { version = "1.0", features = ["rt"] }
warp = { version = "0.3.2", default-features = false } warp = { version = "0.3.2", default-features = false }
# Fixes for `minimal-versions` check. # Fixes for `minimal-versions` check.
@ -44,6 +47,7 @@ headers = "0.3.8"
[dev-dependencies] [dev-dependencies]
async-stream = "0.3" async-stream = "0.3"
env_logger = "0.10" env_logger = "0.10"
futures = "0.3.22"
juniper = { version = "0.16.0-dev", path = "../juniper", features = ["expose-test-schema"] } juniper = { version = "0.16.0-dev", path = "../juniper", features = ["expose-test-schema"] }
log = "0.4" log = "0.4"
percent-encoding = "2.1" percent-encoding = "2.1"

View file

@ -1,6 +1,6 @@
BSD 2-Clause License BSD 2-Clause License
Copyright (c) 2018-2022, Tom Houlé Copyright (c) 2018-2023, Tom Houlé, Kai Ren
All rights reserved. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without

View file

@ -160,7 +160,7 @@ async fn main() {
.and(warp::path("graphql")) .and(warp::path("graphql"))
.and(juniper_warp::make_graphql_filter( .and(juniper_warp::make_graphql_filter(
schema.clone(), schema.clone(),
warp::any().map(|| Context).boxed(), warp::any().map(|| Context),
))) )))
.or( .or(
warp::path("subscriptions").and(juniper_warp::subscriptions::make_ws_filter( warp::path("subscriptions").and(juniper_warp::subscriptions::make_ws_filter(

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,37 @@
//! [`JuniperResponse`] definition.
use juniper::{http::GraphQLBatchResponse, DefaultScalarValue, ScalarValue};
use warp::{
http::{self, StatusCode},
reply::{self, Reply},
};
/// Wrapper around a [`GraphQLBatchResponse`], implementing [`warp::Reply`], so it can be returned
/// from [`warp`] handlers.
pub(crate) struct JuniperResponse<S = DefaultScalarValue>(pub(crate) GraphQLBatchResponse<S>)
where
S: ScalarValue;
impl<S> Reply for JuniperResponse<S>
where
S: ScalarValue + Send,
{
fn into_response(self) -> reply::Response {
match serde_json::to_vec(&self.0) {
Ok(json) => http::Response::builder()
.status(if self.0.is_ok() {
StatusCode::OK
} else {
StatusCode::BAD_REQUEST
})
.header("content-type", "application/json")
.body(json.into()),
Err(e) => http::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(e.to_string().into()),
}
.unwrap_or_else(|e| {
unreachable!("cannot build `reply::Response` out of `JuniperResponse`: {e}")
})
}
}

View file

@ -0,0 +1,327 @@
//! GraphQL subscriptions handler implementation.
use std::{convert::Infallible, fmt, sync::Arc};
use futures::{
future::{self, Either},
sink::SinkExt as _,
stream::StreamExt as _,
};
use juniper::{GraphQLSubscriptionType, GraphQLTypeAsync, RootNode, ScalarValue};
use juniper_graphql_ws::{graphql_transport_ws, graphql_ws};
use warp::{filters::BoxedFilter, reply::Reply, Filter as _};
struct Message(warp::ws::Message);
impl<S: ScalarValue> TryFrom<Message> for graphql_ws::ClientMessage<S> {
type Error = serde_json::Error;
fn try_from(msg: Message) -> serde_json::Result<Self> {
if msg.0.is_close() {
Ok(Self::ConnectionTerminate)
} else {
serde_json::from_slice(msg.0.as_bytes())
}
}
}
impl<S: ScalarValue> TryFrom<Message> for graphql_transport_ws::Input<S> {
type Error = serde_json::Error;
fn try_from(msg: Message) -> serde_json::Result<Self> {
if msg.0.is_close() {
Ok(Self::Close)
} else {
serde_json::from_slice(msg.0.as_bytes()).map(Self::Message)
}
}
}
/// Errors that can happen while serving a connection.
#[derive(Debug)]
pub enum Error {
/// Errors that can happen in Warp while serving a connection.
Warp(warp::Error),
/// Errors that can happen while serializing outgoing messages. Note that errors that occur
/// while deserializing incoming messages are handled internally by the protocol.
Serde(serde_json::Error),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Warp(e) => write!(f, "`warp` error: {e}"),
Self::Serde(e) => write!(f, "`serde` error: {e}"),
}
}
}
impl std::error::Error for Error {}
impl From<warp::Error> for Error {
fn from(err: warp::Error) -> Self {
Self::Warp(err)
}
}
impl From<Infallible> for Error {
fn from(_err: Infallible) -> Self {
unreachable!()
}
}
/// Makes a filter for GraphQL subscriptions.
///
/// This filter auto-selects between the
/// [legacy `graphql-ws` GraphQL over WebSocket Protocol][old] and the
/// [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new], based on the
/// `Sec-Websocket-Protocol` HTTP header value.
///
/// The `schema` argument is your [`juniper`] schema.
///
/// The `init` argument is used to provide the custom [`juniper::Context`] and additional
/// configuration for connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the
/// context and configuration are already known, or it can be a closure that gets executed
/// asynchronously whenever a client sends the subscription initialization message. Using a
/// closure allows to perform an authentication based on the parameters provided by a client.
///
/// # Example
///
/// ```rust
/// # use std::{convert::Infallible, pin::Pin, sync::Arc, time::Duration};
/// #
/// # use futures::Stream;
/// # use juniper::{graphql_object, graphql_subscription, EmptyMutation, RootNode};
/// # use juniper_graphql_ws::ConnectionConfig;
/// # use juniper_warp::make_graphql_filter;
/// # use warp::Filter as _;
/// #
/// type UserId = String;
/// # #[derive(Debug)]
/// struct AppState(Vec<i64>);
/// #[derive(Clone)]
/// struct ExampleContext(Arc<AppState>, UserId);
/// # impl juniper::Context for ExampleContext {}
///
/// struct QueryRoot;
///
/// #[graphql_object(context = ExampleContext)]
/// impl QueryRoot {
/// fn say_hello(context: &ExampleContext) -> String {
/// format!(
/// "good morning {}, the app state is {:?}",
/// context.1,
/// context.0,
/// )
/// }
/// }
///
/// type StringsStream = Pin<Box<dyn Stream<Item = String> + Send>>;
///
/// struct SubscriptionRoot;
///
/// #[graphql_subscription(context = ExampleContext)]
/// impl SubscriptionRoot {
/// async fn say_hellos(context: &ExampleContext) -> StringsStream {
/// let mut interval = tokio::time::interval(Duration::from_secs(1));
/// let context = context.clone();
/// Box::pin(async_stream::stream! {
/// let mut counter = 0;
/// while counter < 5 {
/// counter += 1;
/// interval.tick().await;
/// yield format!(
/// "{counter}: good morning {}, the app state is {:?}",
/// context.1,
/// context.0,
/// )
/// }
/// })
/// }
/// }
///
/// let schema = Arc::new(RootNode::new(QueryRoot, EmptyMutation::new(), SubscriptionRoot));
/// let app_state = Arc::new(AppState(vec![3, 4, 5]));
/// let app_state_for_ws = app_state.clone();
///
/// let context_extractor = warp::any()
/// .and(warp::header::<String>("authorization"))
/// .and(warp::any().map(move || app_state.clone()))
/// .map(|auth_header: String, app_state: Arc<AppState>| {
/// let user_id = auth_header; // we believe them
/// ExampleContext(app_state, user_id)
/// })
/// .boxed();
///
/// let graphql_endpoint = (warp::path("graphql")
/// .and(warp::post())
/// .and(make_graphql_filter(schema.clone(), context_extractor)))
/// .or(warp::path("subscriptions")
/// .and(juniper_warp::subscriptions::make_ws_filter(
/// schema,
/// move |variables: juniper::Variables| {
/// let user_id = variables
/// .get("authorization")
/// .map(ToString::to_string)
/// .unwrap_or_default(); // we believe them
/// async move {
/// Ok::<_, Infallible>(ConnectionConfig::new(
/// ExampleContext(app_state_for_ws.clone(), user_id),
/// ))
/// }
/// },
/// )));
/// ```
///
/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md
/// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md
pub fn make_ws_filter<Query, Mutation, Subscription, CtxT, S, I>(
schema: impl Into<Arc<RootNode<'static, Query, Mutation, Subscription, S>>>,
init: I,
) -> BoxedFilter<(impl Reply,)>
where
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: juniper_graphql_ws::Init<S, CtxT> + Clone + Send + Sync,
{
let schema = schema.into();
warp::ws()
.and(warp::filters::header::value("sec-websocket-protocol"))
.map(move |ws: warp::ws::Ws, subproto| {
let schema = schema.clone();
let init = init.clone();
let is_legacy = subproto == "graphql-ws";
warp::reply::with_header(
ws.on_upgrade(move |ws| async move {
if is_legacy {
serve_graphql_ws(ws, schema, init).await
} else {
serve_graphql_transport_ws(ws, schema, init).await
}
.unwrap_or_else(|e| {
log::error!("GraphQL over WebSocket Protocol error: {e}");
})
}),
"sec-websocket-protocol",
if is_legacy {
"graphql-ws"
} else {
"graphql-transport-ws"
},
)
})
.boxed()
}
/// Serves the [legacy `graphql-ws` GraphQL over WebSocket Protocol][old].
///
/// The `init` argument is used to provide the context and additional configuration for
/// connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the context and
/// configuration are already known, or it can be a closure that gets executed asynchronously
/// when the client sends the `GQL_CONNECTION_INIT` message. Using a closure allows to perform
/// an authentication based on the parameters provided by a client.
///
/// > __WARNING__: This protocol has been deprecated in favor of the
/// [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new], which is
/// provided by the [`serve_graphql_transport_ws()`] function.
///
/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md
/// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md
pub async fn serve_graphql_ws<Query, Mutation, Subscription, CtxT, S, I>(
websocket: warp::ws::WebSocket,
root_node: Arc<RootNode<'static, Query, Mutation, Subscription, S>>,
init: I,
) -> Result<(), Error>
where
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: juniper_graphql_ws::Init<S, CtxT> + Send,
{
let (ws_tx, ws_rx) = websocket.split();
let (s_tx, s_rx) =
graphql_ws::Connection::new(juniper_graphql_ws::ArcSchema(root_node), init).split();
let ws_rx = ws_rx.map(|r| r.map(Message));
let s_rx = s_rx.map(|msg| {
serde_json::to_string(&msg)
.map(warp::ws::Message::text)
.map_err(Error::Serde)
});
match future::select(
ws_rx.forward(s_tx.sink_err_into()),
s_rx.forward(ws_tx.sink_err_into()),
)
.await
{
Either::Left((r, _)) => r.map_err(|e| e.into()),
Either::Right((r, _)) => r,
}
}
/// Serves the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new].
///
/// The `init` argument is used to provide the context and additional configuration for
/// connections. This can be a [`juniper_graphql_ws::ConnectionConfig`] if the context and
/// configuration are already known, or it can be a closure that gets executed asynchronously
/// when the client sends the `ConnectionInit` message. Using a closure allows to perform an
/// authentication based on the parameters provided by a client.
///
/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md
pub async fn serve_graphql_transport_ws<Query, Mutation, Subscription, CtxT, S, I>(
websocket: warp::ws::WebSocket,
root_node: Arc<RootNode<'static, Query, Mutation, Subscription, S>>,
init: I,
) -> Result<(), Error>
where
Query: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Query::TypeInfo: Send + Sync,
Mutation: GraphQLTypeAsync<S, Context = CtxT> + Send + 'static,
Mutation::TypeInfo: Send + Sync,
Subscription: GraphQLSubscriptionType<S, Context = CtxT> + Send + 'static,
Subscription::TypeInfo: Send + Sync,
CtxT: Unpin + Send + Sync + 'static,
S: ScalarValue + Send + Sync + 'static,
I: juniper_graphql_ws::Init<S, CtxT> + Send,
{
let (ws_tx, ws_rx) = websocket.split();
let (s_tx, s_rx) =
graphql_transport_ws::Connection::new(juniper_graphql_ws::ArcSchema(root_node), init)
.split();
let ws_rx = ws_rx.map(|r| r.map(Message));
let s_rx = s_rx.map(|output| match output {
graphql_transport_ws::Output::Message(msg) => serde_json::to_string(&msg)
.map(warp::ws::Message::text)
.map_err(Error::Serde),
graphql_transport_ws::Output::Close { code, message } => {
Ok(warp::ws::Message::close_with(code, message))
}
});
match future::select(
ws_rx.forward(s_tx.sink_err_into()),
s_rx.forward(ws_tx.sink_err_into()),
)
.await
{
Either::Left((r, _)) => r.map_err(|e| e.into()),
Either::Right((r, _)) => r,
}
}

View file

@ -0,0 +1,143 @@
use futures::TryStreamExt as _;
use juniper::{
http::tests::{run_http_test_suite, HttpIntegration, TestResponse},
tests::fixtures::starwars::schema::{Database, Query},
EmptyMutation, EmptySubscription, RootNode,
};
use juniper_warp::{make_graphql_filter, make_graphql_filter_sync};
use warp::{
body,
filters::{path, BoxedFilter},
http, reply, Filter as _,
};
struct TestWarpIntegration {
filter: BoxedFilter<(reply::Response,)>,
}
impl TestWarpIntegration {
fn new(is_sync: bool) -> Self {
let schema = RootNode::new(
Query,
EmptyMutation::<Database>::new(),
EmptySubscription::<Database>::new(),
);
let db = warp::any().map(Database::new);
Self {
filter: path::end()
.and(if is_sync {
make_graphql_filter_sync(schema, db).boxed()
} else {
make_graphql_filter(schema, db).boxed()
})
.boxed(),
}
}
fn make_request(&self, req: warp::test::RequestBuilder) -> TestResponse {
let rt = tokio::runtime::Runtime::new()
.unwrap_or_else(|e| panic!("failed to create `tokio::Runtime`: {e}"));
rt.block_on(async move {
make_test_response(req.filter(&self.filter).await.unwrap_or_else(|rejection| {
let code = if rejection.is_not_found() {
http::StatusCode::NOT_FOUND
} else if let Some(body::BodyDeserializeError { .. }) = rejection.find() {
http::StatusCode::BAD_REQUEST
} else {
http::StatusCode::INTERNAL_SERVER_ERROR
};
http::Response::builder()
.status(code)
.header("content-type", "application/json")
.body(Vec::new().into())
.unwrap()
}))
.await
})
}
}
impl HttpIntegration for TestWarpIntegration {
fn get(&self, url: &str) -> TestResponse {
use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
use url::Url;
/// https://url.spec.whatwg.org/#query-state
const QUERY_ENCODE_SET: &AsciiSet =
&CONTROLS.add(b' ').add(b'"').add(b'#').add(b'<').add(b'>');
let url = Url::parse(&format!("http://localhost:3000{url}")).expect("url to parse");
let url: String = utf8_percent_encode(url.query().unwrap_or(""), QUERY_ENCODE_SET)
.collect::<Vec<_>>()
.join("");
self.make_request(
warp::test::request()
.method("GET")
.path(&format!("/?{url}")),
)
}
fn post_json(&self, url: &str, body: &str) -> TestResponse {
self.make_request(
warp::test::request()
.method("POST")
.header("content-type", "application/json; charset=utf-8")
.path(url)
.body(body),
)
}
fn post_graphql(&self, url: &str, body: &str) -> TestResponse {
self.make_request(
warp::test::request()
.method("POST")
.header("content-type", "application/graphql; charset=utf-8")
.path(url)
.body(body),
)
}
}
async fn make_test_response(resp: reply::Response) -> TestResponse {
let (parts, body) = resp.into_parts();
let status_code = parts.status.as_u16().into();
let content_type = parts
.headers
.get("content-type")
.map(|header| {
header
.to_str()
.unwrap_or_else(|e| panic!("not UTF-8 header: {e}"))
.to_owned()
})
.unwrap_or_default();
let body = String::from_utf8(
body.map_ok(|bytes| bytes.to_vec())
.try_concat()
.await
.unwrap(),
)
.unwrap_or_else(|e| panic!("not UTF-8 body: {e}"));
TestResponse {
status_code,
content_type,
body: Some(body),
}
}
#[test]
fn test_warp_integration() {
run_http_test_suite(&TestWarpIntegration::new(false));
}
#[test]
fn test_sync_warp_integration() {
run_http_test_suite(&TestWarpIntegration::new(true));
}