Use only a single thread pool for juniper_hyper (#256)

The previous implementation used a futures_cpupool for executing
blocking juniper operations. This pool comes in addition to the
thread pool started by hyper through tokio for executing hyper's futures.
This patch uses tokio::blocking to perform the blocking juniper
operations while re-using the same thread pool as hyper, which
simplifies the API.
This commit is contained in:
Jon Gjengset 2018-09-30 14:07:44 -04:00 committed by Christian Legnitto
parent 50a9fa31b6
commit ec963a6e71
3 changed files with 91 additions and 96 deletions

View file

@ -15,13 +15,13 @@ url = "1.7"
juniper = { version = ">=0.9, 0.10.0" , default-features = false, path = "../juniper"} juniper = { version = ">=0.9, 0.10.0" , default-features = false, path = "../juniper"}
futures = "0.1" futures = "0.1"
futures-cpupool = "0.1" tokio = "0.1.8"
hyper = "0.12" hyper = "0.12"
tokio-threadpool = "0.1.7"
[dev-dependencies] [dev-dependencies]
pretty_env_logger = "0.2" pretty_env_logger = "0.2"
tokio = "0.1.8" reqwest = "0.9"
reqwest = "0.8"
[dev-dependencies.juniper] [dev-dependencies.juniper]
version = "0.10.0" version = "0.10.0"

View file

@ -1,12 +1,10 @@
extern crate futures; extern crate futures;
extern crate futures_cpupool;
extern crate hyper; extern crate hyper;
extern crate juniper; extern crate juniper;
extern crate juniper_hyper; extern crate juniper_hyper;
extern crate pretty_env_logger; extern crate pretty_env_logger;
use futures::future; use futures::future;
use futures_cpupool::Builder as CpuPoolBuilder;
use hyper::rt::{self, Future}; use hyper::rt::{self, Future};
use hyper::service::service_fn; use hyper::service::service_fn;
use hyper::Method; use hyper::Method;
@ -21,22 +19,21 @@ fn main() {
let addr = ([127, 0, 0, 1], 3000).into(); let addr = ([127, 0, 0, 1], 3000).into();
let pool = CpuPoolBuilder::new().create();
let db = Arc::new(Database::new()); let db = Arc::new(Database::new());
let root_node = Arc::new(RootNode::new(db.clone(), EmptyMutation::<Database>::new())); let root_node = Arc::new(RootNode::new(db.clone(), EmptyMutation::<Database>::new()));
let new_service = move || { let new_service = move || {
let pool = pool.clone();
let root_node = root_node.clone(); let root_node = root_node.clone();
let ctx = db.clone(); let ctx = db.clone();
service_fn(move |req| { service_fn(move |req| -> Box<Future<Item = _, Error = _> + Send> {
let pool = pool.clone();
let root_node = root_node.clone(); let root_node = root_node.clone();
let ctx = ctx.clone(); let ctx = ctx.clone();
match (req.method(), req.uri().path()) { match (req.method(), req.uri().path()) {
(&Method::GET, "/") => juniper_hyper::graphiql("/graphql"), (&Method::GET, "/") => Box::new(juniper_hyper::graphiql("/graphql")),
(&Method::GET, "/graphql") => juniper_hyper::graphql(pool, root_node, ctx, req), (&Method::GET, "/graphql") => Box::new(juniper_hyper::graphql(root_node, ctx, req)),
(&Method::POST, "/graphql") => juniper_hyper::graphql(pool, root_node, ctx, req), (&Method::POST, "/graphql") => {
Box::new(juniper_hyper::graphql(root_node, ctx, req))
}
_ => { _ => {
let mut response = Response::new(Body::empty()); let mut response = Response::new(Body::empty());
*response.status_mut() = StatusCode::NOT_FOUND; *response.status_mut() = StatusCode::NOT_FOUND;

View file

@ -1,5 +1,5 @@
#[macro_use]
extern crate futures; extern crate futures;
extern crate futures_cpupool;
extern crate hyper; extern crate hyper;
extern crate juniper; extern crate juniper;
#[macro_use] #[macro_use]
@ -7,32 +7,29 @@ extern crate serde_derive;
#[cfg(test)] #[cfg(test)]
extern crate reqwest; extern crate reqwest;
extern crate serde_json; extern crate serde_json;
#[cfg(test)]
extern crate tokio; extern crate tokio;
extern crate tokio_threadpool;
extern crate url; extern crate url;
use futures::{future, Future}; use futures::future::Either;
use futures_cpupool::CpuPool;
use hyper::header::HeaderValue; use hyper::header::HeaderValue;
use hyper::rt::Stream; use hyper::rt::Stream;
use hyper::{header, Body, Method, Request, Response, StatusCode}; use hyper::{header, Body, Method, Request, Response, StatusCode};
use juniper::http::{ use juniper::http::GraphQLRequest as JuniperGraphQLRequest;
GraphQLRequest as JuniperGraphQLRequest, GraphQLResponse as JuniperGraphQLResponse,
};
use juniper::{GraphQLType, InputValue, RootNode}; use juniper::{GraphQLType, InputValue, RootNode};
use serde_json::error::Error as SerdeError; use serde_json::error::Error as SerdeError;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
use std::string::FromUtf8Error; use std::string::FromUtf8Error;
use std::sync::Arc; use std::sync::Arc;
use tokio::prelude::*;
use url::form_urlencoded; use url::form_urlencoded;
pub fn graphql<CtxT, QueryT, MutationT>( pub fn graphql<CtxT, QueryT, MutationT>(
pool: CpuPool,
root_node: Arc<RootNode<'static, QueryT, MutationT>>, root_node: Arc<RootNode<'static, QueryT, MutationT>>,
context: Arc<CtxT>, context: Arc<CtxT>,
request: Request<Body>, request: Request<Body>,
) -> Box<Future<Item = Response<Body>, Error = hyper::Error> + Send> ) -> impl Future<Item = Response<Body>, Error = hyper::Error>
where where
CtxT: Send + Sync + 'static, CtxT: Send + Sync + 'static,
QueryT: GraphQLType<Context = CtxT> + Send + Sync + 'static, QueryT: GraphQLType<Context = CtxT> + Send + Sync + 'static,
@ -41,7 +38,7 @@ where
MutationT::TypeInfo: Send + Sync, MutationT::TypeInfo: Send + Sync,
{ {
match request.method() { match request.method() {
&Method::GET => Box::new( &Method::GET => Either::A(Either::A(
future::done( future::done(
request request
.uri() .uri()
@ -50,10 +47,13 @@ where
.unwrap_or(Err(GraphQLRequestError::Invalid( .unwrap_or(Err(GraphQLRequestError::Invalid(
"'query' parameter is missing".to_string(), "'query' parameter is missing".to_string(),
))), ))),
).and_then(move |gql_req| execute_request(pool, root_node, context, gql_req)) ).and_then(move |gql_req| {
.or_else(|err| future::ok(render_error(err))), execute_request(root_node, context, gql_req).map_err(|_| {
), unreachable!("thread pool has shut down?!");
&Method::POST => Box::new( })
}).or_else(|err| future::ok(render_error(err))),
)),
&Method::POST => Either::A(Either::B(
request request
.into_body() .into_body()
.concat2() .concat2()
@ -67,19 +67,23 @@ where
.map_err(GraphQLRequestError::BodyJSONError) .map_err(GraphQLRequestError::BodyJSONError)
}) })
}) })
}).and_then(move |gql_req| execute_request(pool, root_node, context, gql_req)) }).and_then(move |gql_req| {
.or_else(|err| future::ok(render_error(err))), execute_request(root_node, context, gql_req).map_err(|_| {
), unreachable!("thread pool has shut down?!");
_ => return Box::new(future::ok(new_response(StatusCode::METHOD_NOT_ALLOWED))), })
}).or_else(|err| future::ok(render_error(err))),
)),
_ => return Either::B(future::ok(new_response(StatusCode::METHOD_NOT_ALLOWED))),
} }
} }
pub fn graphiql( pub fn graphiql(
graphql_endpoint: &str, graphql_endpoint: &str,
) -> Box<Future<Item = Response<Body>, Error = hyper::Error> + Send> { ) -> impl Future<Item = Response<Body>, Error = hyper::Error> {
let mut resp = new_html_response(StatusCode::OK); let mut resp = new_html_response(StatusCode::OK);
// XXX: is the call to graphiql_source blocking?
*resp.body_mut() = Body::from(juniper::graphiql::graphiql_source(graphql_endpoint)); *resp.body_mut() = Body::from(juniper::graphiql::graphiql_source(graphql_endpoint));
Box::new(future::ok(resp)) future::ok(resp)
} }
fn render_error(err: GraphQLRequestError) -> Response<Body> { fn render_error(err: GraphQLRequestError) -> Response<Body> {
@ -89,36 +93,31 @@ fn render_error(err: GraphQLRequestError) -> Response<Body> {
resp resp
} }
fn execute_request<CtxT, QueryT, MutationT, Err>( fn execute_request<CtxT, QueryT, MutationT>(
pool: CpuPool,
root_node: Arc<RootNode<'static, QueryT, MutationT>>, root_node: Arc<RootNode<'static, QueryT, MutationT>>,
context: Arc<CtxT>, context: Arc<CtxT>,
request: GraphQLRequest, request: GraphQLRequest,
) -> impl Future<Item = Response<Body>, Error = Err> ) -> impl Future<Item = Response<Body>, Error = tokio_threadpool::BlockingError>
where where
CtxT: Send + Sync + 'static, CtxT: Send + Sync + 'static,
QueryT: GraphQLType<Context = CtxT> + Send + Sync + 'static, QueryT: GraphQLType<Context = CtxT> + Send + Sync + 'static,
MutationT: GraphQLType<Context = CtxT> + Send + Sync + 'static, MutationT: GraphQLType<Context = CtxT> + Send + Sync + 'static,
QueryT::TypeInfo: Send + Sync, QueryT::TypeInfo: Send + Sync,
MutationT::TypeInfo: Send + Sync, MutationT::TypeInfo: Send + Sync,
Err: Send + Sync + 'static,
{ {
pool.spawn_fn(move || { request.execute(root_node, context).map(|(is_ok, body)| {
future::lazy(move || { let code = if is_ok {
let res = request.execute(&root_node, &context); StatusCode::OK
let code = if res.is_ok() { } else {
StatusCode::OK StatusCode::BAD_REQUEST
} else { };
StatusCode::BAD_REQUEST let mut resp = new_response(code);
}; resp.headers_mut().insert(
let mut resp = new_response(code); header::CONTENT_TYPE,
resp.headers_mut().insert( HeaderValue::from_static("application/json"),
header::CONTENT_TYPE, );
HeaderValue::from_static("application/json"), *resp.body_mut() = body;
); resp
*resp.body_mut() = Body::from(serde_json::to_string_pretty(&res).unwrap());
future::ok(resp)
})
}) })
} }
@ -191,43 +190,48 @@ enum GraphQLRequest {
} }
impl GraphQLRequest { impl GraphQLRequest {
pub fn execute<'a, CtxT, QueryT, MutationT>( fn execute<'a, CtxT: 'a, QueryT, MutationT>(
&'a self, self,
root_node: &RootNode<QueryT, MutationT>, root_node: Arc<RootNode<'a, QueryT, MutationT>>,
context: &CtxT, context: Arc<CtxT>,
) -> GraphQLResponse<'a> ) -> impl Future<Item = (bool, hyper::Body), Error = tokio_threadpool::BlockingError> + 'a
where where
QueryT: GraphQLType<Context = CtxT>, QueryT: GraphQLType<Context = CtxT> + 'a,
MutationT: GraphQLType<Context = CtxT>, MutationT: GraphQLType<Context = CtxT> + 'a,
{ {
match self { match self {
&GraphQLRequest::Single(ref request) => { GraphQLRequest::Single(request) => Either::A(future::poll_fn(move || {
GraphQLResponse::Single(request.execute(root_node, context)) let res = try_ready!(tokio_threadpool::blocking(
|| request.execute(&root_node, &context)
));
let is_ok = res.is_ok();
let body = Body::from(serde_json::to_string_pretty(&res).unwrap());
Ok(Async::Ready((is_ok, body)))
})),
GraphQLRequest::Batch(requests) => {
Either::B(
future::join_all(requests.into_iter().map(move |request| {
// TODO: these clones are sad
let root_node = root_node.clone();
let context = context.clone();
future::poll_fn(move || {
let res = try_ready!(tokio_threadpool::blocking(
|| request.execute(&root_node, &context)
));
let is_ok = res.is_ok();
let body = serde_json::to_string_pretty(&res).unwrap();
Ok(Async::Ready((is_ok, body)))
})
})).map(|results| {
let is_ok = results.iter().all(|&(is_ok, _)| is_ok);
// concatenate json bodies as array
// TODO: maybe use Body chunks instead?
let bodies: Vec<_> = results.into_iter().map(|(_, body)| body).collect();
let body = hyper::Body::from(format!("[{}]", bodies.join(",")));
(is_ok, body)
}),
)
} }
&GraphQLRequest::Batch(ref requests) => GraphQLResponse::Batch(
requests
.iter()
.map(|request| request.execute(root_node, context))
.collect(),
),
}
}
}
#[derive(Serialize)]
#[serde(untagged)]
enum GraphQLResponse<'a> {
Single(JuniperGraphQLResponse<'a>),
Batch(Vec<JuniperGraphQLResponse<'a>>),
}
impl<'a> GraphQLResponse<'a> {
fn is_ok(&self) -> bool {
match self {
&GraphQLResponse::Single(ref response) => response.is_ok(),
&GraphQLResponse::Batch(ref responses) => responses
.iter()
.fold(true, |ok, response| ok && response.is_ok()),
} }
} }
} }
@ -277,8 +281,7 @@ impl Error for GraphQLRequestError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use futures::{future, Future}; use futures::{future, future::Either, Future};
use futures_cpupool::Builder;
use hyper::service::service_fn; use hyper::service::service_fn;
use hyper::Method; use hyper::Method;
use hyper::{Body, Response, Server, StatusCode}; use hyper::{Body, Response, Server, StatusCode};
@ -316,11 +319,9 @@ mod tests {
fn make_test_response(mut response: ReqwestResponse) -> http_tests::TestResponse { fn make_test_response(mut response: ReqwestResponse) -> http_tests::TestResponse {
let status_code = response.status().as_u16() as i32; let status_code = response.status().as_u16() as i32;
let body = response.text().unwrap(); let body = response.text().unwrap();
let content_type_header = response let content_type_header = response.headers().get(reqwest::header::CONTENT_TYPE);
.headers()
.get::<reqwest::header::ContentType>();
let content_type = if let Some(ct) = content_type_header { let content_type = if let Some(ct) = content_type_header {
format!("{}", ct) format!("{}", ct.to_str().unwrap())
} else { } else {
String::default() String::default()
}; };
@ -336,16 +337,13 @@ mod tests {
fn test_hyper_integration() { fn test_hyper_integration() {
let addr = ([127, 0, 0, 1], 3001).into(); let addr = ([127, 0, 0, 1], 3001).into();
let pool = Builder::new().create();
let db = Arc::new(Database::new()); let db = Arc::new(Database::new());
let root_node = Arc::new(RootNode::new(db.clone(), EmptyMutation::<Database>::new())); let root_node = Arc::new(RootNode::new(db.clone(), EmptyMutation::<Database>::new()));
let new_service = move || { let new_service = move || {
let pool = pool.clone();
let root_node = root_node.clone(); let root_node = root_node.clone();
let ctx = db.clone(); let ctx = db.clone();
service_fn(move |req| { service_fn(move |req| {
let pool = pool.clone();
let root_node = root_node.clone(); let root_node = root_node.clone();
let ctx = ctx.clone(); let ctx = ctx.clone();
let matches = { let matches = {
@ -356,11 +354,11 @@ mod tests {
} }
}; };
if matches { if matches {
super::graphql(pool, root_node, ctx, req) Either::A(super::graphql(root_node, ctx, req))
} else { } else {
let mut response = Response::new(Body::empty()); let mut response = Response::new(Body::empty());
*response.status_mut() = StatusCode::NOT_FOUND; *response.status_mut() = StatusCode::NOT_FOUND;
Box::new(future::ok(response)) Either::B(future::ok(response))
} }
}) })
}; };