diff --git a/juniper_rocket/Cargo.toml b/juniper_rocket/Cargo.toml index bc1e8686..e1bc0564 100644 --- a/juniper_rocket/Cargo.toml +++ b/juniper_rocket/Cargo.toml @@ -11,13 +11,16 @@ documentation = "https://docs.rs/juniper_rocket" repository = "https://github.com/graphql-rust/juniper" edition = "2018" +[features] +async = [ "juniper/async", "futures03" ] + [dependencies] serde = { version = "1.0.2" } serde_json = { version = "1.0.2" } serde_derive = { version = "1.0.2" } juniper = { version = "0.14.0", default-features = false, path = "../juniper"} -futures-preview = { version = "0.3.0-alpha.18", features = ["compat"] } +futures03 = { version = "0.3.0-alpha.18", optional = true, package = "futures-preview", features = ["compat"] } rocket = { git = "https://github.com/SergioBenitez/Rocket", branch = "async" } [dev-dependencies.juniper] diff --git a/juniper_rocket/src/lib.rs b/juniper_rocket/src/lib.rs index a3760aa3..c50f747d 100644 --- a/juniper_rocket/src/lib.rs +++ b/juniper_rocket/src/lib.rs @@ -37,7 +37,9 @@ Check the LICENSE file for details. */ #![doc(html_root_url = "https://docs.rs/juniper_rocket/0.2.0")] -#![feature(decl_macro, proc_macro_hygiene, async_await)] +#![feature(decl_macro, proc_macro_hygiene)] + +#![cfg_attr(feature = "async", feature(async_await, async_closure))] use std::{ error::Error, @@ -61,6 +63,9 @@ use juniper::{ ScalarValue, }; +#[cfg(feature = "async")] +use futures03::future::{FutureExt, TryFutureExt}; + #[derive(Debug, serde_derive::Deserialize, PartialEq)] #[serde(untagged)] #[serde(bound = "InputValue: Deserialize<'de>")] @@ -109,6 +114,31 @@ where } } + #[cfg(feature = "async")] + pub async fn execute_async<'a, CtxT, QueryT, MutationT>( + &'a self, + root_node: &'a RootNode, + context: &CtxT, + ) -> GraphQLBatchResponse<'a, S> + where + QueryT: GraphQLType, + MutationT: GraphQLType, + { + match self { + &GraphQLBatchRequest::Single(ref request) => { + GraphQLBatchResponse::Single(request.execute_async(root_node, context).await) + } + &GraphQLBatchRequest::Batch(ref requests) => GraphQLBatchResponse::Batch( + let futures = requests + .iter() + .map(|request| request.execute(root_node, context)) + .collect::>(), + + let responses = futures03::future::join_all(futures).await; + ), + } + } + pub fn operation_names(&self) -> Vec> { match self { GraphQLBatchRequest::Single(req) => vec![req.operation_name()], @@ -184,6 +214,28 @@ where GraphQLResponse(status, json) } + /// Asynchronously execute an incoming GraphQL query + #[cfg(feature = "async")] + pub async fn execute_async( + &self, + root_node: &RootNode, + context: &CtxT, + ) -> GraphQLResponse + where + QueryT: GraphQLType, + MutationT: GraphQLType, + { + let response = self.0.execute_async(root_node, context).await; + let status = if response.is_ok() { + Status::Ok + } else { + Status::BadRequest + }; + let json = serde_json::to_string(&response).unwrap(); + + GraphQLResponse(status, json) + } + /// Returns the operation names associated with this request. /// /// For batch requests there will be multiple names. @@ -340,7 +392,7 @@ where type Error = String; fn from_data(request: &Request, data: Data) -> FromDataFuture<'static, Self, Self::Error> { - use futures::io::AsyncReadExt; + use futures03::io::AsyncReadExt; use rocket::AsyncReadExt as _; if !request.content_type().map_or(false, |ct| ct.is_json()) { return Box::pin(async move { Forward(data) });