diff --git a/juniper/src/types/async_await.rs b/juniper/src/types/async_await.rs index 12369bef..4297c630 100644 --- a/juniper/src/types/async_await.rs +++ b/juniper/src/types/async_await.rs @@ -5,6 +5,8 @@ use crate::value::{Object, ScalarRefValue, ScalarValue, Value}; use crate::executor::{ExecutionResult, Executor}; use crate::parser::Spanning; +use crate::BoxFuture; + use super::base::{is_excluded, merge_key_into, Arguments, GraphQLType}; pub trait GraphQLTypeAsync: GraphQLType + Send + Sync diff --git a/juniper_rocket/Cargo.toml b/juniper_rocket/Cargo.toml index 8fd1a618..737aca70 100644 --- a/juniper_rocket/Cargo.toml +++ b/juniper_rocket/Cargo.toml @@ -11,13 +11,17 @@ documentation = "https://docs.rs/juniper_rocket" repository = "https://github.com/graphql-rust/juniper" edition = "2018" +[features] +async = [ "juniper/async" ] + [dependencies] serde = { version = "1.0.2" } serde_json = { version = "1.0.2" } serde_derive = { version = "1.0.2" } juniper = { version = "0.13.1" , default-features = false, path = "../juniper"} -rocket = { version = "0.4.0" } +futures03 = { version = "0.3.0-alpha.18", package = "futures-preview", features = ["compat"] } +rocket = { git = "https://github.com/SergioBenitez/Rocket", branch = "async" } [dev-dependencies.juniper] version = "0.13.1" diff --git a/juniper_rocket/src/lib.rs b/juniper_rocket/src/lib.rs index a016f132..891a3723 100644 --- a/juniper_rocket/src/lib.rs +++ b/juniper_rocket/src/lib.rs @@ -38,17 +38,18 @@ Check the LICENSE file for details. #![doc(html_root_url = "https://docs.rs/juniper_rocket/0.2.0")] #![feature(decl_macro, proc_macro_hygiene)] +#![cfg_attr(feature = "async", feature(async_await, async_closure))] use std::{ error::Error, - io::{Cursor, Read}, + io::Cursor, }; use rocket::{ - data::{FromDataSimple, Outcome as FromDataOutcome}, + data::{FromDataFuture, FromDataSimple}, http::{ContentType, RawStr, Status}, request::{FormItems, FromForm, FromFormValue}, - response::{content, Responder, Response}, + response::{content, Responder, Response, ResultFuture}, Data, Outcome::{Failure, Forward, Success}, Request, @@ -61,12 +62,18 @@ use juniper::{ ScalarValue, }; +#[cfg(feature = "async")] +use juniper::GraphQLTypeAsync; + +#[cfg(feature = "async")] +use futures03::future::{FutureExt, TryFutureExt}; + #[derive(Debug, serde_derive::Deserialize, PartialEq)] #[serde(untagged)] #[serde(bound = "InputValue: Deserialize<'de>")] enum GraphQLBatchRequest where - S: ScalarValue, + S: ScalarValue + Sync + Send, { Single(http::GraphQLRequest), Batch(Vec>), @@ -76,7 +83,7 @@ where #[serde(untagged)] enum GraphQLBatchResponse<'a, S = DefaultScalarValue> where - S: ScalarValue, + S: ScalarValue + Sync + Send, { Single(http::GraphQLResponse<'a, S>), Batch(Vec>), @@ -84,7 +91,7 @@ where impl GraphQLBatchRequest where - S: ScalarValue, + S: ScalarValue + Send + Sync, for<'b> &'b S: ScalarRefValue<'b>, { pub fn execute<'a, CtxT, QueryT, MutationT>( @@ -109,6 +116,34 @@ where } } + #[cfg(feature = "async")] + pub async fn execute_async<'a, CtxT, QueryT, MutationT>( + &'a self, + root_node: &'a RootNode<'_, QueryT, MutationT, S>, + context: &'a CtxT, + ) -> GraphQLBatchResponse<'a, S> + where + QueryT: GraphQLTypeAsync + Send + Sync, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync + Send + Sync, + MutationT::TypeInfo: Send + Sync, + CtxT: Send + Sync, + { + match self { + &GraphQLBatchRequest::Single(ref request) => { + GraphQLBatchResponse::Single(request.execute_async(root_node, context).await) + } + &GraphQLBatchRequest::Batch(ref requests) => { + let futures = requests + .iter() + .map(|request| request.execute_async(root_node, context)) + .collect::>(); + + GraphQLBatchResponse::Batch(futures03::future::join_all(futures).await) + } + } + } + pub fn operation_names(&self) -> Vec> { match self { GraphQLBatchRequest::Single(req) => vec![req.operation_name()], @@ -121,7 +156,7 @@ where impl<'a, S> GraphQLBatchResponse<'a, S> where - S: ScalarValue, + S: ScalarValue + Send + Sync, { fn is_ok(&self) -> bool { match self { @@ -141,7 +176,7 @@ where #[derive(Debug, PartialEq)] pub struct GraphQLRequest(GraphQLBatchRequest) where - S: ScalarValue; + S: ScalarValue + Send + Sync; /// Simple wrapper around the result of executing a GraphQL query pub struct GraphQLResponse(pub Status, pub String); @@ -160,7 +195,7 @@ pub fn playground_source(graphql_endpoint_url: &str) -> content::Html { impl GraphQLRequest where - S: ScalarValue, + S: ScalarValue + Sync + Send, for<'b> &'b S: ScalarRefValue<'b>, { /// Execute an incoming GraphQL query @@ -184,6 +219,31 @@ where GraphQLResponse(status, json) } + /// Asynchronously execute an incoming GraphQL query + #[cfg(feature = "async")] + pub async fn execute_async( + &self, + root_node: &RootNode<'_, QueryT, MutationT, S>, + context: &CtxT, + ) -> GraphQLResponse + where + QueryT: GraphQLTypeAsync + Send + Sync, + QueryT::TypeInfo: Send + Sync, + MutationT: GraphQLTypeAsync + Send + Sync, + MutationT::TypeInfo: Send + Sync, + CtxT: Send + Sync, + { + let response = self.0.execute_async(root_node, context).await; + let status = if response.is_ok() { + Status::Ok + } else { + Status::BadRequest + }; + let json = serde_json::to_string(&response).unwrap(); + + GraphQLResponse(status, json) + } + /// Returns the operation names associated with this request. /// /// For batch requests there will be multiple names. @@ -249,7 +309,7 @@ impl GraphQLResponse { impl<'f, S> FromForm<'f> for GraphQLRequest where - S: ScalarValue, + S: ScalarValue + Send + Sync, { type Error = String; @@ -320,7 +380,7 @@ where impl<'v, S> FromFormValue<'v> for GraphQLRequest where - S: ScalarValue, + S: ScalarValue + Send + Sync, { type Error = String; @@ -331,38 +391,47 @@ where } } +const BODY_LIMIT: u64 = 1024 * 100; + impl FromDataSimple for GraphQLRequest where - S: ScalarValue, + S: ScalarValue + Send + Sync, { type Error = String; - fn from_data(request: &Request, data: Data) -> FromDataOutcome { + fn from_data(request: &Request, data: Data) -> FromDataFuture<'static, Self, Self::Error> { + use futures03::io::AsyncReadExt; + use rocket::AsyncReadExt as _; if !request.content_type().map_or(false, |ct| ct.is_json()) { - return Forward(data); + return Box::pin(async move { Forward(data) }); } - let mut body = String::new(); - if let Err(e) = data.open().read_to_string(&mut body) { - return Failure((Status::InternalServerError, format!("{:?}", e))); - } + Box::pin(async move { + let mut body = String::new(); + let mut reader = data.open().take(BODY_LIMIT); + if let Err(e) = reader.read_to_string(&mut body).await { + return Failure((Status::InternalServerError, format!("{:?}", e))); + } - match serde_json::from_str(&body) { - Ok(value) => Success(GraphQLRequest(value)), - Err(failure) => return Failure((Status::BadRequest, format!("{}", failure))), - } + match serde_json::from_str(&body) { + Ok(value) => Success(GraphQLRequest(value)), + Err(failure) => Failure((Status::BadRequest, format!("{}", failure))), + } + }) } } impl<'r> Responder<'r> for GraphQLResponse { - fn respond_to(self, _: &Request) -> Result, Status> { + fn respond_to(self, _: &Request) -> ResultFuture<'r> { let GraphQLResponse(status, body) = self; - Ok(Response::build() - .header(ContentType::new("application", "json")) - .status(status) - .sized_body(Cursor::new(body)) - .finalize()) + Box::pin(async move { + Ok(Response::build() + .header(ContentType::new("application", "json")) + .status(status) + .sized_body(Cursor::new(body)) + .finalize()) + }) } }