#![doc(html_root_url = "https://docs.rs/juniper_hyper/0.2.0")] #[cfg(test)] extern crate reqwest; use hyper::{ header::{self, HeaderValue}, Body, Method, Request, Response, StatusCode, }; use juniper::{ http::{GraphQLBatchRequest, GraphQLRequest as JuniperGraphQLRequest, GraphQLRequest}, GraphQLSubscriptionType, GraphQLType, GraphQLTypeAsync, InputValue, RootNode, ScalarValue, }; use serde_json::error::Error as SerdeError; use std::{error::Error, fmt, string::FromUtf8Error, sync::Arc}; use url::form_urlencoded; pub async fn graphql_sync( root_node: Arc>, context: Arc, req: Request, ) -> Result, hyper::Error> where S: ScalarValue + Send + Sync + 'static, CtxT: Send + Sync + 'static, QueryT: GraphQLType + Send + Sync + 'static, MutationT: GraphQLType + Send + Sync + 'static, SubscriptionT: GraphQLType + Send + Sync + 'static, QueryT::TypeInfo: Send + Sync, MutationT::TypeInfo: Send + Sync, SubscriptionT::TypeInfo: Send + Sync, { Ok(match parse_req(req).await { Ok(req) => execute_request_sync(root_node, context, req).await, Err(resp) => resp, }) } pub async fn graphql( root_node: Arc>, context: Arc, req: Request, ) -> Result, hyper::Error> where S: ScalarValue + Send + Sync + 'static, CtxT: Send + Sync + 'static, QueryT: GraphQLTypeAsync + Send + Sync + 'static, MutationT: GraphQLTypeAsync + Send + Sync + 'static, SubscriptionT: GraphQLSubscriptionType + Send + Sync, QueryT::TypeInfo: Send + Sync, MutationT::TypeInfo: Send + Sync, SubscriptionT::TypeInfo: Send + Sync, { Ok(match parse_req(req).await { Ok(req) => execute_request(root_node, context, req).await, Err(resp) => resp, }) } async fn parse_req( req: Request, ) -> Result, Response> { match *req.method() { Method::GET => parse_get_req(req), Method::POST => { let content_type = req .headers() .get(header::CONTENT_TYPE) .map(HeaderValue::to_str); match content_type { Some(Ok("application/json")) => parse_post_json_req(req.into_body()).await, Some(Ok("application/graphql")) => parse_post_graphql_req(req.into_body()).await, _ => return Err(new_response(StatusCode::BAD_REQUEST)), } } _ => return Err(new_response(StatusCode::METHOD_NOT_ALLOWED)), } .map_err(|e| render_error(e)) } fn parse_get_req( req: Request, ) -> Result, GraphQLRequestError> { req.uri() .query() .map(|q| gql_request_from_get(q).map(GraphQLBatchRequest::Single)) .unwrap_or_else(|| { Err(GraphQLRequestError::Invalid( "'query' parameter is missing".to_string(), )) }) } async fn parse_post_json_req( body: Body, ) -> Result, GraphQLRequestError> { let chunk = hyper::body::to_bytes(body) .await .map_err(GraphQLRequestError::BodyHyper)?; let input = String::from_utf8(chunk.iter().cloned().collect()) .map_err(GraphQLRequestError::BodyUtf8)?; serde_json::from_str::>(&input) .map_err(GraphQLRequestError::BodyJSONError) } async fn parse_post_graphql_req( body: Body, ) -> Result, GraphQLRequestError> { let chunk = hyper::body::to_bytes(body) .await .map_err(GraphQLRequestError::BodyHyper)?; let query = String::from_utf8(chunk.iter().cloned().collect()) .map_err(GraphQLRequestError::BodyUtf8)?; Ok(GraphQLBatchRequest::Single(GraphQLRequest::new( query, None, None, ))) } pub async fn graphiql( graphql_endpoint: &str, subscriptions_endpoint: Option<&str>, ) -> Result, hyper::Error> { let mut resp = new_html_response(StatusCode::OK); // XXX: is the call to graphiql_source blocking? *resp.body_mut() = Body::from(juniper::http::graphiql::graphiql_source( graphql_endpoint, subscriptions_endpoint, )); Ok(resp) } pub async fn playground( graphql_endpoint: &str, subscriptions_endpoint: Option<&str>, ) -> Result, hyper::Error> { let mut resp = new_html_response(StatusCode::OK); *resp.body_mut() = Body::from(juniper::http::playground::playground_source( graphql_endpoint, subscriptions_endpoint, )); Ok(resp) } fn render_error(err: GraphQLRequestError) -> Response { let message = format!("{}", err); let mut resp = new_response(StatusCode::BAD_REQUEST); *resp.body_mut() = Body::from(message); resp } async fn execute_request_sync( root_node: Arc>, context: Arc, request: GraphQLBatchRequest, ) -> Response where S: ScalarValue + Send + Sync + 'static, CtxT: Send + Sync + 'static, QueryT: GraphQLType + Send + Sync + 'static, MutationT: GraphQLType + Send + Sync + 'static, SubscriptionT: GraphQLType + Send + Sync + 'static, QueryT::TypeInfo: Send + Sync, MutationT::TypeInfo: Send + Sync, SubscriptionT::TypeInfo: Send + Sync, { let res = request.execute_sync(&*root_node, &context); let body = Body::from(serde_json::to_string_pretty(&res).unwrap()); let code = if res.is_ok() { StatusCode::OK } else { StatusCode::BAD_REQUEST }; let mut resp = new_response(code); resp.headers_mut().insert( header::CONTENT_TYPE, HeaderValue::from_static("application/json"), ); *resp.body_mut() = body; resp } async fn execute_request( root_node: Arc>, context: Arc, request: GraphQLBatchRequest, ) -> Response where S: ScalarValue + Send + Sync + 'static, CtxT: Send + Sync + 'static, QueryT: GraphQLTypeAsync + Send + Sync + 'static, MutationT: GraphQLTypeAsync + Send + Sync + 'static, SubscriptionT: GraphQLSubscriptionType + Send + Sync, QueryT::TypeInfo: Send + Sync, MutationT::TypeInfo: Send + Sync, SubscriptionT::TypeInfo: Send + Sync, { let res = request.execute(&*root_node, &context).await; let body = Body::from(serde_json::to_string_pretty(&res).unwrap()); let code = if res.is_ok() { StatusCode::OK } else { StatusCode::BAD_REQUEST }; let mut resp = new_response(code); resp.headers_mut().insert( header::CONTENT_TYPE, HeaderValue::from_static("application/json"), ); *resp.body_mut() = body; resp } fn gql_request_from_get(input: &str) -> Result, GraphQLRequestError> where S: ScalarValue, { let mut query = None; let operation_name = None; let mut variables = None; for (key, value) in form_urlencoded::parse(input.as_bytes()).into_owned() { match key.as_ref() { "query" => { if query.is_some() { return Err(invalid_err("query")); } query = Some(value) } "operationName" => { if operation_name.is_some() { return Err(invalid_err("operationName")); } } "variables" => { if variables.is_some() { return Err(invalid_err("variables")); } match serde_json::from_str::>(&value) .map_err(GraphQLRequestError::Variables) { Ok(parsed_variables) => variables = Some(parsed_variables), Err(e) => return Err(e), } } _ => continue, } } match query { Some(query) => Ok(JuniperGraphQLRequest::new(query, operation_name, variables)), None => Err(GraphQLRequestError::Invalid( "'query' parameter is missing".to_string(), )), } } fn invalid_err(parameter_name: &str) -> GraphQLRequestError { GraphQLRequestError::Invalid(format!( "'{}' parameter is specified multiple times", parameter_name )) } fn new_response(code: StatusCode) -> Response { let mut r = Response::new(Body::empty()); *r.status_mut() = code; r } fn new_html_response(code: StatusCode) -> Response { let mut resp = new_response(code); resp.headers_mut().insert( header::CONTENT_TYPE, HeaderValue::from_static("text/html; charset=utf-8"), ); resp } #[derive(Debug)] enum GraphQLRequestError { BodyHyper(hyper::Error), BodyUtf8(FromUtf8Error), BodyJSONError(SerdeError), Variables(SerdeError), Invalid(String), } impl fmt::Display for GraphQLRequestError { fn fmt(&self, mut f: &mut fmt::Formatter) -> fmt::Result { match *self { GraphQLRequestError::BodyHyper(ref err) => fmt::Display::fmt(err, &mut f), GraphQLRequestError::BodyUtf8(ref err) => fmt::Display::fmt(err, &mut f), GraphQLRequestError::BodyJSONError(ref err) => fmt::Display::fmt(err, &mut f), GraphQLRequestError::Variables(ref err) => fmt::Display::fmt(err, &mut f), GraphQLRequestError::Invalid(ref err) => fmt::Display::fmt(err, &mut f), } } } impl Error for GraphQLRequestError { fn source(&self) -> Option<&(dyn Error + 'static)> { match *self { GraphQLRequestError::BodyHyper(ref err) => Some(err), GraphQLRequestError::BodyUtf8(ref err) => Some(err), GraphQLRequestError::BodyJSONError(ref err) => Some(err), GraphQLRequestError::Variables(ref err) => Some(err), GraphQLRequestError::Invalid(_) => None, } } } #[cfg(test)] mod tests { use hyper::{ service::{make_service_fn, service_fn}, Body, Method, Response, Server, StatusCode, }; use juniper::{ http::tests as http_tests, tests::{model::Database, schema::Query}, EmptyMutation, EmptySubscription, RootNode, }; use reqwest::{self, Response as ReqwestResponse}; use std::{net::SocketAddr, sync::Arc, thread, time::Duration}; struct TestHyperIntegration { port: u16, } impl http_tests::HttpIntegration for TestHyperIntegration { fn get(&self, url: &str) -> http_tests::TestResponse { let url = format!("http://127.0.0.1:{}/graphql{}", self.port, url); make_test_response(reqwest::get(&url).expect(&format!("failed GET {}", url))) } fn post_json(&self, url: &str, body: &str) -> http_tests::TestResponse { let url = format!("http://127.0.0.1:{}/graphql{}", self.port, url); let client = reqwest::Client::new(); let res = client .post(&url) .header(reqwest::header::CONTENT_TYPE, "application/json") .body(body.to_string()) .send() .expect(&format!("failed POST {}", url)); make_test_response(res) } fn post_graphql(&self, url: &str, body: &str) -> http_tests::TestResponse { let url = format!("http://127.0.0.1:{}/graphql{}", self.port, url); let client = reqwest::Client::new(); let res = client .post(&url) .header(reqwest::header::CONTENT_TYPE, "application/graphql") .body(body.to_string()) .send() .expect(&format!("failed POST {}", url)); make_test_response(res) } } fn make_test_response(mut response: ReqwestResponse) -> http_tests::TestResponse { let status_code = response.status().as_u16() as i32; let body = response.text().unwrap(); let content_type_header = response.headers().get(reqwest::header::CONTENT_TYPE); let content_type = if let Some(ct) = content_type_header { format!("{}", ct.to_str().unwrap()) } else { String::default() }; http_tests::TestResponse { status_code, body: Some(body), content_type, } } async fn run_hyper_integration(is_sync: bool) { let port = if is_sync { 3002 } else { 3001 }; let addr: SocketAddr = ([127, 0, 0, 1], port).into(); let db = Arc::new(Database::new()); let root_node = Arc::new(RootNode::new( Query, EmptyMutation::::new(), EmptySubscription::::new(), )); let new_service = make_service_fn(move |_| { let root_node = root_node.clone(); let ctx = db.clone(); async move { Ok::<_, hyper::Error>(service_fn(move |req| { let root_node = root_node.clone(); let ctx = ctx.clone(); let matches = { let path = req.uri().path(); match req.method() { &Method::POST | &Method::GET => { path == "/graphql" || path == "/graphql/" } _ => false, } }; async move { if matches { if is_sync { super::graphql_sync(root_node, ctx, req).await } else { super::graphql(root_node, ctx, req).await } } else { let mut resp = Response::new(Body::empty()); *resp.status_mut() = StatusCode::NOT_FOUND; Ok(resp) } } })) } }); let (shutdown_fut, shutdown) = futures::future::abortable(async { tokio::time::delay_for(Duration::from_secs(60)).await; }); let server = Server::bind(&addr) .serve(new_service) .with_graceful_shutdown(async { shutdown_fut.await.unwrap_err(); }); tokio::task::spawn_blocking(move || { thread::sleep(Duration::from_millis(10)); // wait 10ms for server to bind let integration = TestHyperIntegration { port }; http_tests::run_http_test_suite(&integration); shutdown.abort(); }); if let Err(e) = server.await { eprintln!("server error: {}", e); } } #[tokio::test] async fn test_hyper_integration() { run_hyper_integration(false).await } #[tokio::test] async fn test_sync_hyper_integration() { run_hyper_integration(true).await } }