From 4ccb129fa2965c6c82374f242e085b0dc5e2d7f0 Mon Sep 17 00:00:00 2001 From: Genna Wingert Date: Thu, 13 Feb 2020 07:48:28 +0100 Subject: [PATCH] Update juniper_hyper to hyper 0.13 and add async resolution (#505) This involves updating to futures 0.3, tokio 0.2 stable --- juniper_hyper/Cargo.toml | 15 +- juniper_hyper/examples/hyper_server.rs | 55 ++-- juniper_hyper/src/lib.rs | 379 ++++++++++++++++--------- 3 files changed, 279 insertions(+), 170 deletions(-) diff --git a/juniper_hyper/Cargo.toml b/juniper_hyper/Cargo.toml index 5255ce7d..a0f232bc 100644 --- a/juniper_hyper/Cargo.toml +++ b/juniper_hyper/Cargo.toml @@ -9,8 +9,7 @@ repository = "https://github.com/graphql-rust/juniper" edition = "2018" [features] -# Fake feature to help CI. -async = [] +async = ["juniper/async", "futures"] [dependencies] serde = "1.0" @@ -18,11 +17,9 @@ serde_json = "1.0" serde_derive = "1.0" url = "2" juniper = { version = "0.14.2", default-features = false, path = "../juniper"} - -futures = "0.1" -tokio = "0.1.8" -hyper = "0.12" -tokio-threadpool = "0.1.7" +tokio = "0.2" +hyper = "0.13" +futures = { version = "0.3", optional = true } [dev-dependencies] pretty_env_logger = "0.2" @@ -32,3 +29,7 @@ reqwest = "0.9" version = "0.14.2" features = ["expose-test-schema", "serde_json"] path = "../juniper" + +[dev-dependencies.tokio] +version = "0.2" +features = ["macros"] \ No newline at end of file diff --git a/juniper_hyper/examples/hyper_server.rs b/juniper_hyper/examples/hyper_server.rs index 802131e5..8df3e6af 100644 --- a/juniper_hyper/examples/hyper_server.rs +++ b/juniper_hyper/examples/hyper_server.rs @@ -1,13 +1,10 @@ -extern crate futures; extern crate hyper; extern crate juniper; extern crate juniper_hyper; extern crate pretty_env_logger; -use futures::future; use hyper::{ - rt::{self, Future}, - service::service_fn, + service::{make_service_fn, service_fn}, Body, Method, Response, Server, StatusCode, }; use juniper::{ @@ -16,7 +13,8 @@ use juniper::{ }; use std::sync::Arc; -fn main() { +#[tokio::main] +async fn main() { pretty_env_logger::init(); let addr = ([127, 0, 0, 1], 3000).into(); @@ -24,30 +22,35 @@ fn main() { let db = Arc::new(Database::new()); let root_node = Arc::new(RootNode::new(Query, EmptyMutation::::new())); - let new_service = move || { + let new_service = make_service_fn(move |_| { let root_node = root_node.clone(); let ctx = db.clone(); - service_fn(move |req| -> Box + Send> { - let root_node = root_node.clone(); - let ctx = ctx.clone(); - match (req.method(), req.uri().path()) { - (&Method::GET, "/") => Box::new(juniper_hyper::graphiql("/graphql")), - (&Method::GET, "/graphql") => Box::new(juniper_hyper::graphql(root_node, ctx, req)), - (&Method::POST, "/graphql") => { - Box::new(juniper_hyper::graphql(root_node, ctx, req)) + + async move { + Ok::<_, hyper::Error>(service_fn(move |req| { + let root_node = root_node.clone(); + let ctx = ctx.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/") => juniper_hyper::graphiql("/graphql").await, + (&Method::GET, "/graphql") | (&Method::POST, "/graphql") => { + juniper_hyper::graphql(root_node, ctx, req).await + } + _ => { + let mut response = Response::new(Body::empty()); + *response.status_mut() = StatusCode::NOT_FOUND; + Ok(response) + } + } } - _ => { - let mut response = Response::new(Body::empty()); - *response.status_mut() = StatusCode::NOT_FOUND; - Box::new(future::ok(response)) - } - } - }) - }; - let server = Server::bind(&addr) - .serve(new_service) - .map_err(|e| eprintln!("server error: {}", e)); + })) + } + }); + + let server = Server::bind(&addr).serve(new_service); println!("Listening on http://{}", addr); - rt::run(server); + if let Err(e) = server.await { + eprintln!("server error: {}", e) + } } diff --git a/juniper_hyper/src/lib.rs b/juniper_hyper/src/lib.rs index 6d0a95ba..6755f195 100644 --- a/juniper_hyper/src/lib.rs +++ b/juniper_hyper/src/lib.rs @@ -3,26 +3,27 @@ #[cfg(test)] extern crate reqwest; -use futures::future::Either; +#[cfg(feature = "async")] +use futures; use hyper::{ header::{self, HeaderValue}, - rt::Stream, Body, Method, Request, Response, StatusCode, }; +#[cfg(feature = "async")] +use juniper::GraphQLTypeAsync; use juniper::{ http::GraphQLRequest as JuniperGraphQLRequest, serde::Deserialize, DefaultScalarValue, GraphQLType, InputValue, RootNode, ScalarValue, }; use serde_json::error::Error as SerdeError; use std::{error::Error, fmt, string::FromUtf8Error, sync::Arc}; -use tokio::prelude::*; use url::form_urlencoded; -pub fn graphql( +pub async fn graphql( root_node: Arc>, context: Arc, request: Request, -) -> impl Future, Error = hyper::Error> +) -> Result, hyper::Error> where S: ScalarValue + Send + Sync + 'static, CtxT: Send + Sync + 'static, @@ -32,68 +33,100 @@ where MutationT::TypeInfo: Send + Sync, { match request.method() { - &Method::GET => Either::A(Either::A( - future::done( - request - .uri() - .query() - .map(|q| gql_request_from_get(q).map(GraphQLRequest::Single)) - .unwrap_or_else(|| { - Err(GraphQLRequestError::Invalid( - "'query' parameter is missing".to_string(), - )) - }), - ) - .and_then(move |gql_req| { - execute_request(root_node, context, gql_req).map_err(|_| { - unreachable!("thread pool has shut down?!"); - }) - }) - .or_else(|err| future::ok(render_error(err))), - )), - &Method::POST => Either::A(Either::B( - request - .into_body() - .concat2() - .or_else(|err| future::done(Err(GraphQLRequestError::BodyHyper(err)))) - .and_then(move |chunk| { - future::done({ - String::from_utf8(chunk.iter().cloned().collect::>()) - .map_err(GraphQLRequestError::BodyUtf8) - .and_then(|input| { - serde_json::from_str::>(&input) - .map_err(GraphQLRequestError::BodyJSONError) - }) - }) - }) - .and_then(move |gql_req| { - execute_request(root_node, context, gql_req).map_err(|_| { - unreachable!("thread pool has shut down?!"); - }) - }) - .or_else(|err| future::ok(render_error(err))), - )), - _ => return Either::B(future::ok(new_response(StatusCode::METHOD_NOT_ALLOWED))), + &Method::GET => { + let gql_req = parse_get_req(request); + + match gql_req { + Ok(gql_req) => Ok(execute_request(root_node, context, gql_req).await), + Err(err) => Ok(render_error(err)), + } + } + &Method::POST => { + let gql_req = parse_post_req(request.into_body()).await; + + match gql_req { + Ok(gql_req) => Ok(execute_request(root_node, context, gql_req).await), + Err(err) => Ok(render_error(err)), + } + } + _ => Ok(new_response(StatusCode::METHOD_NOT_ALLOWED)), } } -pub fn graphiql( - graphql_endpoint: &str, -) -> impl Future, Error = hyper::Error> { +#[cfg(feature = "async")] +pub async fn graphql_async( + root_node: Arc>, + context: Arc, + request: Request, +) -> Result, hyper::Error> +where + S: ScalarValue + Send + Sync + 'static, + CtxT: Send + Sync + 'static, + QueryT: GraphQLTypeAsync + Send + Sync + 'static, + MutationT: GraphQLTypeAsync + Send + Sync + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT::TypeInfo: Send + Sync, +{ + match request.method() { + &Method::GET => { + let gql_req = parse_get_req(request); + + match gql_req { + Ok(gql_req) => Ok(execute_request_async(root_node, context, gql_req).await), + Err(err) => Ok(render_error(err)), + } + } + &Method::POST => { + let gql_req = parse_post_req(request.into_body()).await; + + match gql_req { + Ok(gql_req) => Ok(execute_request_async(root_node, context, gql_req).await), + Err(err) => Ok(render_error(err)), + } + } + _ => Ok(new_response(StatusCode::METHOD_NOT_ALLOWED)), + } +} + +fn parse_get_req( + req: Request, +) -> Result, GraphQLRequestError> { + req.uri() + .query() + .map(|q| gql_request_from_get(q).map(GraphQLRequest::Single)) + .unwrap_or_else(|| { + Err(GraphQLRequestError::Invalid( + "'query' parameter is missing".to_string(), + )) + }) +} + +async fn parse_post_req( + body: Body, +) -> Result, GraphQLRequestError> { + let chunk = hyper::body::to_bytes(body) + .await + .map_err(|err| GraphQLRequestError::BodyHyper(err))?; + + let input = String::from_utf8(chunk.iter().cloned().collect()) + .map_err(GraphQLRequestError::BodyUtf8)?; + + serde_json::from_str::>(&input).map_err(GraphQLRequestError::BodyJSONError) +} + +pub async fn graphiql(graphql_endpoint: &str) -> Result, hyper::Error> { let mut resp = new_html_response(StatusCode::OK); // XXX: is the call to graphiql_source blocking? *resp.body_mut() = Body::from(juniper::graphiql::graphiql_source(graphql_endpoint)); - future::ok(resp) + Ok(resp) } -pub fn playground( - graphql_endpoint: &str, -) -> impl Future, Error = hyper::Error> { +pub async fn playground(graphql_endpoint: &str) -> Result, hyper::Error> { let mut resp = new_html_response(StatusCode::OK); *resp.body_mut() = Body::from(juniper::http::playground::playground_source( graphql_endpoint, )); - future::ok(resp) + Ok(resp) } fn render_error(err: GraphQLRequestError) -> Response { @@ -103,11 +136,11 @@ fn render_error(err: GraphQLRequestError) -> Response { resp } -fn execute_request( +async fn execute_request( root_node: Arc>, context: Arc, request: GraphQLRequest, -) -> impl Future, Error = tokio_threadpool::BlockingError> +) -> Response where S: ScalarValue + Send + Sync + 'static, CtxT: Send + Sync + 'static, @@ -116,20 +149,48 @@ where QueryT::TypeInfo: Send + Sync, MutationT::TypeInfo: Send + Sync, { - request.execute(root_node, context).map(|(is_ok, body)| { - let code = if is_ok { - StatusCode::OK - } else { - StatusCode::BAD_REQUEST - }; - let mut resp = new_response(code); - resp.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - *resp.body_mut() = body; - resp - }) + let (is_ok, body) = request.execute(root_node, context); + let code = if is_ok { + StatusCode::OK + } else { + StatusCode::BAD_REQUEST + }; + let mut resp = new_response(code); + resp.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + *resp.body_mut() = body; + resp +} + +#[cfg(feature = "async")] +async fn execute_request_async( + root_node: Arc>, + context: Arc, + request: GraphQLRequest, +) -> Response +where + S: ScalarValue + Send + Sync + 'static, + CtxT: Send + Sync + 'static, + QueryT: GraphQLTypeAsync + Send + Sync + 'static, + MutationT: GraphQLTypeAsync + Send + Sync + 'static, + QueryT::TypeInfo: Send + Sync, + MutationT::TypeInfo: Send + Sync, +{ + let (is_ok, body) = request.execute_async(root_node, context).await; + let code = if is_ok { + StatusCode::OK + } else { + StatusCode::BAD_REQUEST + }; + let mut resp = new_response(code); + resp.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + *resp.body_mut() = body; + resp } fn gql_request_from_get(input: &str) -> Result, GraphQLRequestError> @@ -215,45 +276,74 @@ where self, root_node: Arc>, context: Arc, - ) -> impl Future + 'a + ) -> (bool, hyper::Body) where - S: 'a, + S: 'a + Send + Sync, QueryT: GraphQLType + 'a, MutationT: GraphQLType + 'a, { match self { - GraphQLRequest::Single(request) => Either::A(future::poll_fn(move || { - let res = futures::try_ready!(tokio_threadpool::blocking( - || request.execute(&root_node, &context) - )); + GraphQLRequest::Single(request) => { + let res = request.execute(&root_node, &context); let is_ok = res.is_ok(); let body = Body::from(serde_json::to_string_pretty(&res).unwrap()); - Ok(Async::Ready((is_ok, body))) - })), + (is_ok, body) + } GraphQLRequest::Batch(requests) => { - Either::B( - future::join_all(requests.into_iter().map(move |request| { - // TODO: these clones are sad + let results: Vec<_> = requests + .into_iter() + .map(move |request| { let root_node = root_node.clone(); - let context = context.clone(); - future::poll_fn(move || { - let res = futures::try_ready!(tokio_threadpool::blocking( - || request.execute(&root_node, &context) - )); - let is_ok = res.is_ok(); - let body = serde_json::to_string_pretty(&res).unwrap(); - Ok(Async::Ready((is_ok, body))) - }) - })) - .map(|results| { - let is_ok = results.iter().all(|&(is_ok, _)| is_ok); - // concatenate json bodies as array - // TODO: maybe use Body chunks instead? - let bodies: Vec<_> = results.into_iter().map(|(_, body)| body).collect(); - let body = hyper::Body::from(format!("[{}]", bodies.join(","))); + let res = request.execute(&root_node, &context); + let is_ok = res.is_ok(); + let body = serde_json::to_string_pretty(&res).unwrap(); (is_ok, body) - }), - ) + }) + .collect(); + + let is_ok = !results.iter().any(|&(is_ok, _)| !is_ok); + let bodies: Vec<_> = results.into_iter().map(|(_, body)| body).collect(); + let body = hyper::Body::from(format!("[{}]", bodies.join(","))); + (is_ok, body) + } + } + } + + #[cfg(feature = "async")] + async fn execute_async<'a, CtxT: 'a, QueryT, MutationT>( + self, + root_node: Arc>, + context: Arc, + ) -> (bool, hyper::Body) + where + S: Send + Sync, + QueryT: GraphQLTypeAsync + Send + Sync, + MutationT: GraphQLTypeAsync + Send + Sync, + QueryT::TypeInfo: Send + Sync, + MutationT::TypeInfo: Send + Sync, + CtxT: Send + Sync, + { + match self { + GraphQLRequest::Single(request) => { + let res = request.execute_async(&root_node, &context).await; + let is_ok = res.is_ok(); + let body = Body::from(serde_json::to_string_pretty(&res).unwrap()); + (is_ok, body) + } + GraphQLRequest::Batch(requests) => { + let futures = requests + .iter() + .map(|request| request.execute_async(&root_node, &context)) + .collect::>(); + let results = futures::future::join_all(futures).await; + + let is_ok = results.iter().all(|res| res.is_ok()); + let bodies: Vec<_> = results + .into_iter() + .map(|res| serde_json::to_string_pretty(&res).unwrap()) + .collect(); + let body = hyper::Body::from(format!("[{}]", bodies.join(","))); + (is_ok, body) } } } @@ -301,22 +391,21 @@ impl Error for GraphQLRequestError { } } } - +#[cfg(feature = "async")] #[cfg(test)] mod tests { - use futures::{ - future::{self, Either}, - Future, + use futures; + use hyper::{ + service::{make_service_fn, service_fn}, + Body, Method, Response, Server, StatusCode, }; - use hyper::{service::service_fn, Body, Method, Response, Server, StatusCode}; use juniper::{ http::tests as http_tests, tests::{model::Database, schema::Query}, EmptyMutation, RootNode, }; use reqwest::{self, Response as ReqwestResponse}; - use std::{sync::Arc, thread, time}; - use tokio::runtime::Runtime; + use std::{net::SocketAddr, sync::Arc, thread, time::Duration}; struct TestHyperIntegration; @@ -355,46 +444,62 @@ mod tests { } } - #[test] - fn test_hyper_integration() { - let addr = ([127, 0, 0, 1], 3001).into(); + #[tokio::test] + async fn test_hyper_integration() { + let addr: SocketAddr = ([127, 0, 0, 1], 3001).into(); let db = Arc::new(Database::new()); let root_node = Arc::new(RootNode::new(Query, EmptyMutation::::new())); - let new_service = move || { + let new_service = make_service_fn(move |_| { let root_node = root_node.clone(); let ctx = db.clone(); - service_fn(move |req| { - let root_node = root_node.clone(); - let ctx = ctx.clone(); - let matches = { - let path = req.uri().path(); - match req.method() { - &Method::POST | &Method::GET => path == "/graphql" || path == "/graphql/", - _ => false, + + async move { + Ok::<_, hyper::Error>(service_fn(move |req| { + let root_node = root_node.clone(); + let ctx = ctx.clone(); + let matches = { + let path = req.uri().path(); + match req.method() { + &Method::POST | &Method::GET => { + path == "/graphql" || path == "/graphql/" + } + _ => false, + } + }; + async move { + if matches { + super::graphql(root_node, ctx, req).await + } else { + let mut response = Response::new(Body::empty()); + *response.status_mut() = StatusCode::NOT_FOUND; + Ok(response) + } } - }; - if matches { - Either::A(super::graphql(root_node, ctx, req)) - } else { - let mut response = Response::new(Body::empty()); - *response.status_mut() = StatusCode::NOT_FOUND; - Either::B(future::ok(response)) - } - }) - }; + })) + } + }); + + let (shutdown_fut, shutdown) = futures::future::abortable(async { + tokio::time::delay_for(Duration::from_secs(60)).await; + }); + let server = Server::bind(&addr) .serve(new_service) - .map_err(|e| eprintln!("server error: {}", e)); + .with_graceful_shutdown(async { + shutdown_fut.await.unwrap_err(); + }); - let mut runtime = Runtime::new().unwrap(); - runtime.spawn(server); - thread::sleep(time::Duration::from_millis(10)); // wait 10ms for server to bind + tokio::task::spawn_blocking(move || { + thread::sleep(Duration::from_millis(10)); // wait 10ms for server to bind + let integration = TestHyperIntegration; + http_tests::run_http_test_suite(&integration); + shutdown.abort(); + }); - let integration = TestHyperIntegration; - http_tests::run_http_test_suite(&integration); - - runtime.shutdown_now().wait().unwrap(); + if let Err(e) = server.await { + eprintln!("server error: {}", e); + } } }