From c09be69b7d9bb0eb386a117e529dc27f2d6bd391 Mon Sep 17 00:00:00 2001
From: Christian Legnitto <LegNeato@users.noreply.github.com>
Date: Thu, 9 Apr 2020 22:10:24 -1000
Subject: [PATCH] Update rocket_async to central `GraphQLBatch*` enums (#612)

---
 Cargo.toml                                    |   3 +-
 juniper/release.toml                          |   3 +
 juniper/src/http/mod.rs                       |  10 +
 juniper_rocket_async/Cargo.toml               |   7 +-
 .../examples/rocket_server.rs                 |  23 +-
 juniper_rocket_async/src/lib.rs               | 227 ++++++------------
 .../tests/custom_response_tests.rs            |   6 +-
 7 files changed, 109 insertions(+), 170 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 6f10451a..4ed0d26f 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -9,6 +9,7 @@ members = [
   "juniper_hyper",
   "juniper_iron",
   "juniper_rocket",
+  "juniper_rocket_async",
   "juniper_subscriptions",
   "juniper_warp",
 ]
@@ -16,6 +17,4 @@ exclude = [
   "docs/book/tests",
   "examples/warp_async",
   "examples/warp_subscriptions",
-  # TODO enable async tests
-  "juniper_rocket_async",
 ]
diff --git a/juniper/release.toml b/juniper/release.toml
index 2ae9555f..8ef4cd80 100644
--- a/juniper/release.toml
+++ b/juniper/release.toml
@@ -22,6 +22,9 @@ pre-release-replacements = [
   # Rocket
   {file="../juniper_rocket/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
   {file="../juniper_rocket/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""},
+  # Rocket Async
+  {file="../juniper_rocket_async/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
+  {file="../juniper_rocket_async/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""},
   # Warp
   {file="../juniper_warp/Cargo.toml", search="juniper = \\{ version = \"[^\"]+\"", replace="juniper = { version = \"{{version}}\""},
   {file="../juniper_warp/Cargo.toml", search="\\[dev-dependencies\\.juniper\\]\nversion = \"[^\"]+\"", replace="[dev-dependencies.juniper]\nversion = \"{{version}}\""},
diff --git a/juniper/src/http/mod.rs b/juniper/src/http/mod.rs
index d49f286c..15fc695b 100644
--- a/juniper/src/http/mod.rs
+++ b/juniper/src/http/mod.rs
@@ -298,6 +298,16 @@ where
             }
         }
     }
+
+    /// The operation names of the request.
+    pub fn operation_names(&self) -> Vec<Option<&str>> {
+        match self {
+            GraphQLBatchRequest::Single(req) => vec![req.operation_name()],
+            GraphQLBatchRequest::Batch(reqs) => {
+                reqs.iter().map(|req| req.operation_name()).collect()
+            }
+        }
+    }
 }
 
 /// Simple wrapper around the result (GraphQLResponse) from executing a GraphQLBatchRequest
diff --git a/juniper_rocket_async/Cargo.toml b/juniper_rocket_async/Cargo.toml
index 0e896fdd..8cbe055c 100644
--- a/juniper_rocket_async/Cargo.toml
+++ b/juniper_rocket_async/Cargo.toml
@@ -1,5 +1,5 @@
 [package]
-name = "juniper_rocket"
+name = "juniper_rocket_async"
 version = "0.5.1"
 authors = [
     "Magnus Hallin <mhallin@fastmail.com>",
@@ -15,11 +15,10 @@ edition = "2018"
 serde = { version = "1.0.2" }
 serde_json = { version = "1.0.2" }
 serde_derive = { version = "1.0.2" }
-juniper = { version = "0.14.1", default-features = false, path = "../juniper"}
-
+juniper = { version = "0.14.2", default-features = false, path = "../juniper" }
 futures = { version = "0.3.1", features = ["compat"] }
 rocket = { git = "https://github.com/SergioBenitez/Rocket", branch = "async", default-features = false }
-tokio = "0.2"
+tokio = { version = "0.2", features = ["rt-core", "macros"] }
 
 [dev-dependencies.juniper]
 version = "0.14.1"
diff --git a/juniper_rocket_async/examples/rocket_server.rs b/juniper_rocket_async/examples/rocket_server.rs
index 096b98cd..d268a076 100644
--- a/juniper_rocket_async/examples/rocket_server.rs
+++ b/juniper_rocket_async/examples/rocket_server.rs
@@ -4,41 +4,46 @@ use rocket::{response::content, State};
 
 use juniper::{
     tests::{model::Database, schema::Query},
-    EmptyMutation, RootNode,
+    EmptyMutation, EmptySubscription, RootNode,
 };
 
-type Schema = RootNode<'static, Query, EmptyMutation<Database>>;
+type Schema = RootNode<'static, Query, EmptyMutation<Database>, EmptySubscription<Database>>;
 
 #[rocket::get("/")]
 fn graphiql() -> content::Html<String> {
-    juniper_rocket::graphiql_source("/graphql")
+    juniper_rocket_async::graphiql_source("/graphql")
 }
 
 #[rocket::get("/graphql?<request>")]
 fn get_graphql_handler(
     context: State<Database>,
-    request: juniper_rocket::GraphQLRequest,
+    request: juniper_rocket_async::GraphQLRequest,
     schema: State<Schema>,
-) -> juniper_rocket::GraphQLResponse {
+) -> juniper_rocket_async::GraphQLResponse {
     request.execute_sync(&schema, &context)
 }
 
 #[rocket::post("/graphql", data = "<request>")]
 fn post_graphql_handler(
     context: State<Database>,
-    request: juniper_rocket::GraphQLRequest,
+    request: juniper_rocket_async::GraphQLRequest,
     schema: State<Schema>,
-) -> juniper_rocket::GraphQLResponse {
+) -> juniper_rocket_async::GraphQLResponse {
     request.execute_sync(&schema, &context)
 }
 
 fn main() {
     rocket::ignite()
         .manage(Database::new())
-        .manage(Schema::new(Query, EmptyMutation::<Database>::new()))
+        .manage(Schema::new(
+            Query,
+            EmptyMutation::<Database>::new(),
+            EmptySubscription::<Database>::new(),
+        ))
         .mount(
             "/",
             rocket::routes![graphiql, get_graphql_handler, post_graphql_handler],
         )
-        .launch();
+        .launch()
+        .expect("server to launch");
 }
diff --git a/juniper_rocket_async/src/lib.rs b/juniper_rocket_async/src/lib.rs
index 8c911148..28341f15 100644
--- a/juniper_rocket_async/src/lib.rs
+++ b/juniper_rocket_async/src/lib.rs
@@ -1,6 +1,6 @@
 /*!
 
-# juniper_rocket
+# juniper_rocket_async
 
 This repository contains the [Rocket][Rocket] web server integration for
 [Juniper][Juniper], a [GraphQL][GraphQL] implementation for Rust.
@@ -31,130 +31,31 @@ Check the LICENSE file for details.
 [Rocket]: https://rocket.rs
 [Juniper]: https://github.com/graphql-rust/juniper
 [GraphQL]: http://graphql.org
-[documentation]: https://docs.rs/juniper_rocket
-[example]: https://github.com/graphql-rust/juniper_rocket/blob/master/examples/rocket_server.rs
+[documentation]: https://docs.rs/juniper_rocket_async
+[example]: https://github.com/graphql-rust/juniper_rocket_async/blob/master/examples/rocket_server.rs
 
 */
 
-#![doc(html_root_url = "https://docs.rs/juniper_rocket/0.2.0")]
+#![doc(html_root_url = "https://docs.rs/juniper_rocket_async/0.2.0")]
 #![feature(decl_macro, proc_macro_hygiene)]
 
-use std::{error::Error, io::Cursor};
+use std::io::Cursor;
 
 use rocket::{
     data::{FromDataFuture, FromDataSimple},
     http::{ContentType, RawStr, Status},
     request::{FormItems, FromForm, FromFormValue},
-    response::{content, Responder, Response, ResultFuture},
+    response::{self, content, Responder, Response},
     Data,
     Outcome::{Failure, Forward, Success},
     Request,
 };
 
-use juniper::{http, InputValue};
 use juniper::{
-    serde::Deserialize, DefaultScalarValue, FieldError, GraphQLType, RootNode, ScalarValue,
+    http::{self, GraphQLBatchRequest},
+    DefaultScalarValue, FieldError, GraphQLSubscriptionType, GraphQLType, GraphQLTypeAsync,
+    InputValue, RootNode, ScalarValue,
 };
-use juniper::GraphQLTypeAsync;
-use futures::future::{FutureExt, TryFutureExt};
-
-#[derive(Debug, serde_derive::Deserialize, PartialEq)]
-#[serde(untagged)]
-#[serde(bound = "InputValue<S>: Deserialize<'de>")]
-enum GraphQLBatchRequest<S = DefaultScalarValue>
-where
-    S: ScalarValue + Sync + Send,
-{
-    Single(http::GraphQLRequest<S>),
-    Batch(Vec<http::GraphQLRequest<S>>),
-}
-
-#[derive(serde_derive::Serialize)]
-#[serde(untagged)]
-enum GraphQLBatchResponse<'a, S = DefaultScalarValue>
-where
-    S: ScalarValue + Sync + Send,
-{
-    Single(http::GraphQLResponse<'a, S>),
-    Batch(Vec<http::GraphQLResponse<'a, S>>),
-}
-
-impl<S> GraphQLBatchRequest<S>
-where
-    S: ScalarValue + Send + Sync,
-{
-    pub fn execute<'a, CtxT, QueryT, MutationT>(
-        &'a self,
-        root_node: &'a RootNode<QueryT, MutationT, S>,
-        context: &CtxT,
-    ) -> GraphQLBatchResponse<'a, S>
-    where
-        QueryT: GraphQLType<S, Context = CtxT>,
-        MutationT: GraphQLType<S, Context = CtxT>,
-    {
-        match self {
-            &GraphQLBatchRequest::Single(ref request) => {
-                GraphQLBatchResponse::Single(request.execute_sync(root_node, context))
-            }
-            &GraphQLBatchRequest::Batch(ref requests) => GraphQLBatchResponse::Batch(
-                requests
-                    .iter()
-                    .map(|request| request.execute_sync(root_node, context))
-                    .collect(),
-            ),
-        }
-    }
-
-    pub async fn execute<'a, CtxT, QueryT, MutationT>(
-        &'a self,
-        root_node: &'a RootNode<'_, QueryT, MutationT, S>,
-        context: &'a CtxT,
-    ) -> GraphQLBatchResponse<'a, S>
-    where
-        QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
-        QueryT::TypeInfo: Send + Sync,
-        MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
-        MutationT::TypeInfo: Send + Sync,
-        CtxT: Send + Sync,
-    {
-        match self {
-            &GraphQLBatchRequest::Single(ref request) => {
-                GraphQLBatchResponse::Single(request.execute(root_node, context).await)
-            }
-            &GraphQLBatchRequest::Batch(ref requests) => {
-                let futures = requests
-                    .iter()
-                    .map(|request| request.execute(root_node, context))
-                    .collect::<Vec<_>>();
-
-                GraphQLBatchResponse::Batch(futures::future::join_all(futures).await)
-            }
-        }
-    }
-
-    pub fn operation_names(&self) -> Vec<Option<&str>> {
-        match self {
-            GraphQLBatchRequest::Single(req) => vec![req.operation_name()],
-            GraphQLBatchRequest::Batch(reqs) => {
-                reqs.iter().map(|req| req.operation_name()).collect()
-            }
-        }
-    }
-}
-
-impl<'a, S> GraphQLBatchResponse<'a, S>
-where
-    S: ScalarValue + Send + Sync,
-{
-    fn is_ok(&self) -> bool {
-        match self {
-            &GraphQLBatchResponse::Single(ref response) => response.is_ok(),
-            &GraphQLBatchResponse::Batch(ref responses) => responses
-                .iter()
-                .fold(true, |ok, response| ok && response.is_ok()),
-        }
-    }
-}
 
 /// Simple wrapper around an incoming GraphQL request
 ///
@@ -178,6 +79,7 @@ pub fn graphiql_source(graphql_endpoint_url: &str) -> content::Html<String> {
 pub fn playground_source(graphql_endpoint_url: &str) -> content::Html<String> {
     content::Html(juniper::http::playground::playground_source(
         graphql_endpoint_url,
+        None,
     ))
 }
 
@@ -185,15 +87,18 @@ impl<S> GraphQLRequest<S>
 where
     S: ScalarValue + Sync + Send,
 {
-    /// Execute an incoming GraphQL query
-    pub fn execute<CtxT, QueryT, MutationT>(
+    /// Execute an incoming GraphQL query synchronously.
+    pub fn execute_sync<CtxT, QueryT, MutationT, SubscriptionT>(
         &self,
-        root_node: &RootNode<QueryT, MutationT, S>,
+        root_node: &RootNode<QueryT, MutationT, SubscriptionT, S>,
         context: &CtxT,
     ) -> GraphQLResponse
     where
         QueryT: GraphQLType<S, Context = CtxT>,
         MutationT: GraphQLType<S, Context = CtxT>,
+        SubscriptionT: GraphQLType<S, Context = CtxT>,
+        SubscriptionT::TypeInfo: Send + Sync,
+        CtxT: Send + Sync,
     {
         let response = self.0.execute_sync(root_node, context);
         let status = if response.is_ok() {
@@ -207,9 +112,9 @@ where
     }
 
     /// Asynchronously execute an incoming GraphQL query
-    pub async fn execute<CtxT, QueryT, MutationT>(
+    pub async fn execute<CtxT, QueryT, MutationT, SubscriptionT>(
         &self,
-        root_node: &RootNode<'_, QueryT, MutationT, S>,
+        root_node: &RootNode<'_, QueryT, MutationT, SubscriptionT, S>,
         context: &CtxT,
     ) -> GraphQLResponse
     where
@@ -217,7 +122,10 @@ where
         QueryT::TypeInfo: Send + Sync,
         MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
         MutationT::TypeInfo: Send + Sync,
+        SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + Sync,
+        SubscriptionT::TypeInfo: Send + Sync,
         CtxT: Send + Sync,
+        S: Send + Sync,
     {
         let response = self.0.execute(root_node, context).await;
         let status = if response.is_ok() {
@@ -247,7 +155,7 @@ impl GraphQLResponse {
     /// # #![feature(decl_macro, proc_macro_hygiene)]
     /// #
     /// # extern crate juniper;
-    /// # extern crate juniper_rocket;
+    /// # extern crate juniper_rocket_async;
     /// # extern crate rocket;
     /// #
     /// # use rocket::http::Cookies;
@@ -257,20 +165,20 @@ impl GraphQLResponse {
     /// #
     /// # use juniper::tests::schema::Query;
     /// # use juniper::tests::model::Database;
-    /// # use juniper::{EmptyMutation, FieldError, RootNode, Value};
+    /// # use juniper::{EmptyMutation, EmptySubscription, FieldError, RootNode, Value};
     /// #
-    /// # type Schema = RootNode<'static, Query, EmptyMutation<Database>>;
+    /// # type Schema = RootNode<'static, Query, EmptyMutation<Database>, EmptySubscription<Database>>;
     /// #
     /// #[rocket::get("/graphql?<request..>")]
     /// fn get_graphql_handler(
     ///     mut cookies: Cookies,
     ///     context: State<Database>,
-    ///     request: Form<juniper_rocket::GraphQLRequest>,
+    ///     request: Form<juniper_rocket_async::GraphQLRequest>,
     ///     schema: State<Schema>,
-    /// ) -> juniper_rocket::GraphQLResponse {
-    ///     if cookies.get_private("user_id").is_none() {
+    /// ) -> juniper_rocket_async::GraphQLResponse {
+    ///     if cookies.get("user_id").is_none() {
     ///         let err = FieldError::new("User is not logged in", Value::null());
-    ///         return juniper_rocket::GraphQLResponse::error(err);
+    ///         return juniper_rocket_async::GraphQLResponse::error(err);
     ///     }
     ///
     ///     request.execute_sync(&schema, &context)
@@ -315,7 +223,7 @@ where
                     } else {
                         match value.url_decode() {
                             Ok(v) => query = Some(v),
-                            Err(e) => return Err(e.description().to_string()),
+                            Err(e) => return Err(e.to_string()),
                         }
                     }
                 }
@@ -327,7 +235,7 @@ where
                     } else {
                         match value.url_decode() {
                             Ok(v) => operation_name = Some(v),
-                            Err(e) => return Err(e.description().to_string()),
+                            Err(e) => return Err(e.to_string()),
                         }
                     }
                 }
@@ -338,11 +246,11 @@ where
                         let decoded;
                         match value.url_decode() {
                             Ok(v) => decoded = v,
-                            Err(e) => return Err(e.description().to_string()),
+                            Err(e) => return Err(e.to_string()),
                         }
                         variables = Some(
                             serde_json::from_str::<InputValue<_>>(&decoded)
-                                .map_err(|err| err.description().to_owned())?,
+                                .map_err(|err| err.to_string())?,
                         );
                     }
                 }
@@ -407,8 +315,15 @@ where
     }
 }
 
+use rocket::futures::future::BoxFuture;
+
 impl<'r> Responder<'r> for GraphQLResponse {
-    fn respond_to(self, _: &Request) -> ResultFuture<'r> {
+    fn respond_to<'a, 'x>(self, _req: &'r Request<'a>) -> BoxFuture<'x, response::Result<'r>>
+    where
+        'a: 'x,
+        'r: 'x,
+        Self: 'x,
+    {
         let GraphQLResponse(status, body) = self;
 
         Box::pin(async move {
@@ -416,6 +331,7 @@ impl<'r> Responder<'r> for GraphQLResponse {
                 .header(ContentType::new("application", "json"))
                 .status(status)
                 .sized_body(Cursor::new(body))
+                .await
                 .finalize())
         })
     }
@@ -483,7 +399,11 @@ mod fromform_tests {
 
     #[test]
     fn test_variables_invalid_json() {
-        check_error("query=test&variables=NOT_JSON", "JSON error", false);
+        check_error(
+            "query=test&variables=NOT_JSON",
+            "expected value at line 1 column 1",
+            false,
+        );
     }
 
     #[test]
@@ -534,22 +454,23 @@ mod fromform_tests {
 #[cfg(test)]
 mod tests {
 
+    use futures;
+
+    use juniper::{
+        http::tests as http_tests,
+        tests::{model::Database, schema::Query},
+        EmptyMutation, EmptySubscription, RootNode,
+    };
     use rocket::{
         self, get,
         http::ContentType,
-        local::{Client, LocalRequest},
+        local::{Client, LocalResponse},
         post,
         request::Form,
         routes, Rocket, State,
     };
 
-    use juniper::{
-        http::tests as http_tests,
-        tests::{model::Database, schema::Query},
-        EmptyMutation, RootNode,
-    };
-
-    type Schema = RootNode<'static, Query, EmptyMutation<Database>>;
+    type Schema = RootNode<'static, Query, EmptyMutation<Database>, EmptySubscription<Database>>;
 
     #[get("/?<request..>")]
     fn get_graphql_handler(
@@ -575,13 +496,15 @@ mod tests {
 
     impl http_tests::HTTPIntegration for TestRocketIntegration {
         fn get(&self, url: &str) -> http_tests::TestResponse {
-            let req = &self.client.get(url);
-            make_test_response(req)
+            let req = self.client.get(url);
+            let req = futures::executor::block_on(req.dispatch());
+            futures::executor::block_on(make_test_response(req))
         }
 
         fn post(&self, url: &str, body: &str) -> http_tests::TestResponse {
-            let req = &self.client.post(url).header(ContentType::JSON).body(body);
-            make_test_response(req)
+            let req = self.client.post(url).header(ContentType::JSON).body(body);
+            let req = futures::executor::block_on(req.dispatch());
+            futures::executor::block_on(make_test_response(req))
         }
     }
 
@@ -594,8 +517,8 @@ mod tests {
         http_tests::run_http_test_suite(&integration);
     }
 
-    #[test]
-    fn test_operation_names() {
+    #[tokio::test]
+    async fn test_operation_names() {
         #[post("/", data = "<request>")]
         fn post_graphql_assert_operation_name_handler(
             context: State<Database>,
@@ -610,13 +533,15 @@ mod tests {
             .mount("/", routes![post_graphql_assert_operation_name_handler]);
         let client = Client::new(rocket).expect("valid rocket");
 
-        let req = client
+        let resp = client
             .post("/")
             .header(ContentType::JSON)
-            .body(r#"{"query": "query TestQuery {hero{name}}", "operationName": "TestQuery"}"#);
-        let resp = make_test_response(&req);
+            .body(r#"{"query": "query TestQuery {hero{name}}", "operationName": "TestQuery"}"#)
+            .dispatch()
+            .await;
+        let resp = make_test_response(resp);
 
-        assert_eq!(resp.status_code, 200);
+        assert_eq!(resp.await.status_code, 200);
     }
 
     fn make_rocket() -> Rocket {
@@ -624,13 +549,14 @@ mod tests {
     }
 
     fn make_rocket_without_routes() -> Rocket {
-        rocket::ignite()
-            .manage(Database::new())
-            .manage(Schema::new(Query, EmptyMutation::<Database>::new()))
+        rocket::ignite().manage(Database::new()).manage(Schema::new(
+            Query,
+            EmptyMutation::<Database>::new(),
+            EmptySubscription::<Database>::new(),
+        ))
     }
 
-    fn make_test_response(request: &LocalRequest) -> http_tests::TestResponse {
-        let mut response = request.clone().dispatch();
+    async fn make_test_response(mut response: LocalResponse<'_>) -> http_tests::TestResponse {
         let status_code = response.status().code as i32;
         let content_type = response
             .content_type()
@@ -639,7 +565,8 @@ mod tests {
         let body = response
             .body()
             .expect("No body returned from GraphQL handler")
-            .into_string();
+            .into_string()
+            .await;
 
         http_tests::TestResponse {
             status_code,
diff --git a/juniper_rocket_async/tests/custom_response_tests.rs b/juniper_rocket_async/tests/custom_response_tests.rs
index aeb29d4b..3cee40c9 100644
--- a/juniper_rocket_async/tests/custom_response_tests.rs
+++ b/juniper_rocket_async/tests/custom_response_tests.rs
@@ -1,10 +1,6 @@
-extern crate juniper_rocket;
-extern crate rocket;
-
+use juniper_rocket_async::GraphQLResponse;
 use rocket::http::Status;
 
-use juniper_rocket::GraphQLResponse;
-
 #[test]
 fn test_graphql_response_is_public() {
     let _ = GraphQLResponse(Status::Unauthorized, "Unauthorized".to_string());