From f730829c1b60a0b96c2cd0909ae961b52a006893 Mon Sep 17 00:00:00 2001
From: Christoph Herzog <chris@theduke.at>
Date: Mon, 19 Aug 2019 22:17:05 +0200
Subject: [PATCH] Update warp for async

---
 juniper/src/http/mod.rs      | 27 +++++++++++++++++
 juniper/src/types/scalars.rs | 12 ++++++++
 juniper_warp/Cargo.toml      |  2 +-
 juniper_warp/src/lib.rs      | 58 +++++++++++++++++++++++++++++-------
 4 files changed, 88 insertions(+), 11 deletions(-)

diff --git a/juniper/src/http/mod.rs b/juniper/src/http/mod.rs
index ab3c2588..e6fbf21e 100644
--- a/juniper/src/http/mod.rs
+++ b/juniper/src/http/mod.rs
@@ -93,6 +93,33 @@ where
             context,
         ))
     }
+
+    #[cfg(feature = "async")]
+    pub async fn execute_async<'a, CtxT, QueryT, MutationT>(
+        &'a self,
+        root_node: &'a RootNode<'a, QueryT, MutationT, S>,
+        context: &'a CtxT,
+    ) -> GraphQLResponse<'a, S>
+    where
+        S: ScalarValue + Send + Sync,
+        QueryT: crate::GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
+        QueryT::TypeInfo: Send + Sync,
+        MutationT: crate::GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
+        MutationT::TypeInfo: Send + Sync,
+        CtxT: Send + Sync,
+        for<'b> &'b S: ScalarRefValue<'b>,
+    {
+        let op = self.operation_name();
+        let vars = &self.variables();
+        let res = crate::execute_async(
+            &self.query,
+            op,
+            root_node,
+            vars,
+            context,
+        ).await;
+        GraphQLResponse(res)
+    }
 }
 
 /// Simple wrapper around the result from executing a GraphQL query
diff --git a/juniper/src/types/scalars.rs b/juniper/src/types/scalars.rs
index dc30c39a..a2df7eb3 100644
--- a/juniper/src/types/scalars.rs
+++ b/juniper/src/types/scalars.rs
@@ -329,6 +329,18 @@ where
     }
 }
 
+#[cfg(feature = "async")]
+impl<S, T> crate::GraphQLTypeAsync<S> for EmptyMutation<T>
+where
+    S: ScalarValue + Send + Sync,
+    Self: GraphQLType<S> + Send + Sync,
+    Self::TypeInfo: Send + Sync,
+    Self::Context: Send + Sync,
+    T: Send + Sync,
+    for<'b> &'b S: ScalarRefValue<'b>,
+{
+}
+
 #[cfg(test)]
 mod tests {
     use super::ID;
diff --git a/juniper_warp/Cargo.toml b/juniper_warp/Cargo.toml
index c6af3d17..04805c21 100644
--- a/juniper_warp/Cargo.toml
+++ b/juniper_warp/Cargo.toml
@@ -21,7 +21,7 @@ futures = "0.1.23"
 serde = "1.0.75"
 tokio-threadpool = "0.1.7"
 
-futures03 = { version = "0.3.0-alpha.18", optional = true, package = "futures-preview" }
+futures03 = { version = "0.3.0-alpha.18", optional = true, package = "futures-preview", features = ["compat"] }
 
 [dev-dependencies]
 juniper = { version = "0.13.1", path = "../juniper", features = ["expose-test-schema", "serde_json"] }
diff --git a/juniper_warp/src/lib.rs b/juniper_warp/src/lib.rs
index e0bda8c2..6f6b7711 100644
--- a/juniper_warp/src/lib.rs
+++ b/juniper_warp/src/lib.rs
@@ -40,11 +40,16 @@ Check the LICENSE file for details.
 #![deny(warnings)]
 #![doc(html_root_url = "https://docs.rs/juniper_warp/0.2.0")]
 
+#![cfg_attr(feature = "async", feature(async_await, async_closure))]
+
 use futures::{future::poll_fn, Future};
 use serde::Deserialize;
 use std::sync::Arc;
 use warp::{filters::BoxedFilter, Filter};
 
+#[cfg(feature = "async")]
+use futures03::future::{FutureExt, TryFutureExt};
+
 use juniper::{DefaultScalarValue, InputValue, ScalarRefValue, ScalarValue};
 
 #[derive(Debug, serde_derive::Deserialize, PartialEq)]
@@ -84,6 +89,37 @@ where
             ),
         }
     }
+
+    #[cfg(feature = "async")]
+    pub async fn execute_async<'a, CtxT, QueryT, MutationT>(
+        &'a self,
+        root_node: &'a juniper::RootNode<'a, QueryT, MutationT, S>,
+        context: &'a CtxT,
+    ) -> GraphQLBatchResponse<'a, S>
+    where
+        QueryT: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
+        QueryT::TypeInfo: Send + Sync,
+        MutationT: juniper::GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
+        MutationT::TypeInfo: Send + Sync,
+        CtxT: Send + Sync,
+        S: Send + Sync,
+    {
+        match self {
+            &GraphQLBatchRequest::Single(ref request) => {
+                let res = request.execute_async(root_node, context).await;
+                GraphQLBatchResponse::Single(res)
+            }
+            &GraphQLBatchRequest::Batch(ref requests) => {
+                let futures = requests
+                    .iter()
+                    .map(|request| request.execute_async(root_node, context))
+                    .collect::<Vec<_>>();
+                let responses =  futures03::future::join_all(futures).await;
+
+                GraphQLBatchResponse::Batch(responses)
+            }
+        }
+    }
 }
 
 #[derive(serde_derive::Serialize)]
@@ -242,6 +278,7 @@ where
 }
 
 /// FIXME: docs
+#[cfg(feature = "async")]
 pub fn make_graphql_filter_async<Query, Mutation, Context, S>(
     schema: juniper::RootNode<'static, Query, Mutation, S>,
     context_extractor: BoxedFilter<(Context,)>,
@@ -261,16 +298,17 @@ where
     let handle_post_request =
         move |context: Context, request: GraphQLBatchRequest<S>| -> Response {
             let schema = post_schema.clone();
-            Box::new(
-                poll_fn(move || {
-                    tokio_threadpool::blocking(|| {
-                        let response = request.execute(&schema, &context);
-                        Ok((serde_json::to_vec(&response)?, response.is_ok()))
-                    })
-                })
-                .and_then(|result| ::futures::future::done(Ok(build_response(result))))
-                .map_err(|e: tokio_threadpool::BlockingError| warp::reject::custom(e)),
-            )
+
+            let f = async move {
+                let res = request.execute_async(&schema, &context).await;
+
+                match serde_json::to_vec(&res) {
+                    Ok(json) => Ok(build_response(Ok((json, res.is_ok())))),
+                    Err(e) => Err(warp::reject::custom(e)),
+                }
+            };
+
+            Box::new(f.boxed().compat())
         };
 
     let post_filter = warp::post2()