Merge pull request #417 from obmarg/async-await

Update juniper_rocket for async.
This commit is contained in:
theduke 2019-08-22 11:54:33 +02:00 committed by GitHub
commit 35fc8d8e2a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 104 additions and 29 deletions

View file

@ -5,6 +5,8 @@ use crate::value::{Object, ScalarRefValue, ScalarValue, Value};
use crate::executor::{ExecutionResult, Executor}; use crate::executor::{ExecutionResult, Executor};
use crate::parser::Spanning; use crate::parser::Spanning;
use crate::BoxFuture;
use super::base::{is_excluded, merge_key_into, Arguments, GraphQLType}; use super::base::{is_excluded, merge_key_into, Arguments, GraphQLType};
pub trait GraphQLTypeAsync<S>: GraphQLType<S> + Send + Sync pub trait GraphQLTypeAsync<S>: GraphQLType<S> + Send + Sync

View file

@ -11,13 +11,17 @@ documentation = "https://docs.rs/juniper_rocket"
repository = "https://github.com/graphql-rust/juniper" repository = "https://github.com/graphql-rust/juniper"
edition = "2018" edition = "2018"
[features]
async = [ "juniper/async" ]
[dependencies] [dependencies]
serde = { version = "1.0.2" } serde = { version = "1.0.2" }
serde_json = { version = "1.0.2" } serde_json = { version = "1.0.2" }
serde_derive = { version = "1.0.2" } serde_derive = { version = "1.0.2" }
juniper = { version = "0.13.1" , default-features = false, path = "../juniper"} juniper = { version = "0.13.1" , default-features = false, path = "../juniper"}
rocket = { version = "0.4.0" } futures03 = { version = "0.3.0-alpha.18", package = "futures-preview", features = ["compat"] }
rocket = { git = "https://github.com/SergioBenitez/Rocket", branch = "async" }
[dev-dependencies.juniper] [dev-dependencies.juniper]
version = "0.13.1" version = "0.13.1"

View file

@ -38,17 +38,18 @@ Check the LICENSE file for details.
#![doc(html_root_url = "https://docs.rs/juniper_rocket/0.2.0")] #![doc(html_root_url = "https://docs.rs/juniper_rocket/0.2.0")]
#![feature(decl_macro, proc_macro_hygiene)] #![feature(decl_macro, proc_macro_hygiene)]
#![cfg_attr(feature = "async", feature(async_await, async_closure))]
use std::{ use std::{
error::Error, error::Error,
io::{Cursor, Read}, io::Cursor,
}; };
use rocket::{ use rocket::{
data::{FromDataSimple, Outcome as FromDataOutcome}, data::{FromDataFuture, FromDataSimple},
http::{ContentType, RawStr, Status}, http::{ContentType, RawStr, Status},
request::{FormItems, FromForm, FromFormValue}, request::{FormItems, FromForm, FromFormValue},
response::{content, Responder, Response}, response::{content, Responder, Response, ResultFuture},
Data, Data,
Outcome::{Failure, Forward, Success}, Outcome::{Failure, Forward, Success},
Request, Request,
@ -61,12 +62,18 @@ use juniper::{
ScalarValue, ScalarValue,
}; };
#[cfg(feature = "async")]
use juniper::GraphQLTypeAsync;
#[cfg(feature = "async")]
use futures03::future::{FutureExt, TryFutureExt};
#[derive(Debug, serde_derive::Deserialize, PartialEq)] #[derive(Debug, serde_derive::Deserialize, PartialEq)]
#[serde(untagged)] #[serde(untagged)]
#[serde(bound = "InputValue<S>: Deserialize<'de>")] #[serde(bound = "InputValue<S>: Deserialize<'de>")]
enum GraphQLBatchRequest<S = DefaultScalarValue> enum GraphQLBatchRequest<S = DefaultScalarValue>
where where
S: ScalarValue, S: ScalarValue + Sync + Send,
{ {
Single(http::GraphQLRequest<S>), Single(http::GraphQLRequest<S>),
Batch(Vec<http::GraphQLRequest<S>>), Batch(Vec<http::GraphQLRequest<S>>),
@ -76,7 +83,7 @@ where
#[serde(untagged)] #[serde(untagged)]
enum GraphQLBatchResponse<'a, S = DefaultScalarValue> enum GraphQLBatchResponse<'a, S = DefaultScalarValue>
where where
S: ScalarValue, S: ScalarValue + Sync + Send,
{ {
Single(http::GraphQLResponse<'a, S>), Single(http::GraphQLResponse<'a, S>),
Batch(Vec<http::GraphQLResponse<'a, S>>), Batch(Vec<http::GraphQLResponse<'a, S>>),
@ -84,7 +91,7 @@ where
impl<S> GraphQLBatchRequest<S> impl<S> GraphQLBatchRequest<S>
where where
S: ScalarValue, S: ScalarValue + Send + Sync,
for<'b> &'b S: ScalarRefValue<'b>, for<'b> &'b S: ScalarRefValue<'b>,
{ {
pub fn execute<'a, CtxT, QueryT, MutationT>( pub fn execute<'a, CtxT, QueryT, MutationT>(
@ -109,6 +116,34 @@ where
} }
} }
#[cfg(feature = "async")]
pub async fn execute_async<'a, CtxT, QueryT, MutationT>(
&'a self,
root_node: &'a RootNode<'_, QueryT, MutationT, S>,
context: &'a CtxT,
) -> GraphQLBatchResponse<'a, S>
where
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
QueryT::TypeInfo: Send + Sync,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
MutationT::TypeInfo: Send + Sync,
CtxT: Send + Sync,
{
match self {
&GraphQLBatchRequest::Single(ref request) => {
GraphQLBatchResponse::Single(request.execute_async(root_node, context).await)
}
&GraphQLBatchRequest::Batch(ref requests) => {
let futures = requests
.iter()
.map(|request| request.execute_async(root_node, context))
.collect::<Vec<_>>();
GraphQLBatchResponse::Batch(futures03::future::join_all(futures).await)
}
}
}
pub fn operation_names(&self) -> Vec<Option<&str>> { pub fn operation_names(&self) -> Vec<Option<&str>> {
match self { match self {
GraphQLBatchRequest::Single(req) => vec![req.operation_name()], GraphQLBatchRequest::Single(req) => vec![req.operation_name()],
@ -121,7 +156,7 @@ where
impl<'a, S> GraphQLBatchResponse<'a, S> impl<'a, S> GraphQLBatchResponse<'a, S>
where where
S: ScalarValue, S: ScalarValue + Send + Sync,
{ {
fn is_ok(&self) -> bool { fn is_ok(&self) -> bool {
match self { match self {
@ -141,7 +176,7 @@ where
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct GraphQLRequest<S = DefaultScalarValue>(GraphQLBatchRequest<S>) pub struct GraphQLRequest<S = DefaultScalarValue>(GraphQLBatchRequest<S>)
where where
S: ScalarValue; S: ScalarValue + Send + Sync;
/// Simple wrapper around the result of executing a GraphQL query /// Simple wrapper around the result of executing a GraphQL query
pub struct GraphQLResponse(pub Status, pub String); pub struct GraphQLResponse(pub Status, pub String);
@ -160,7 +195,7 @@ pub fn playground_source(graphql_endpoint_url: &str) -> content::Html<String> {
impl<S> GraphQLRequest<S> impl<S> GraphQLRequest<S>
where where
S: ScalarValue, S: ScalarValue + Sync + Send,
for<'b> &'b S: ScalarRefValue<'b>, for<'b> &'b S: ScalarRefValue<'b>,
{ {
/// Execute an incoming GraphQL query /// Execute an incoming GraphQL query
@ -184,6 +219,31 @@ where
GraphQLResponse(status, json) GraphQLResponse(status, json)
} }
/// Asynchronously execute an incoming GraphQL query
#[cfg(feature = "async")]
pub async fn execute_async<CtxT, QueryT, MutationT>(
&self,
root_node: &RootNode<'_, QueryT, MutationT, S>,
context: &CtxT,
) -> GraphQLResponse
where
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
QueryT::TypeInfo: Send + Sync,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
MutationT::TypeInfo: Send + Sync,
CtxT: Send + Sync,
{
let response = self.0.execute_async(root_node, context).await;
let status = if response.is_ok() {
Status::Ok
} else {
Status::BadRequest
};
let json = serde_json::to_string(&response).unwrap();
GraphQLResponse(status, json)
}
/// Returns the operation names associated with this request. /// Returns the operation names associated with this request.
/// ///
/// For batch requests there will be multiple names. /// For batch requests there will be multiple names.
@ -249,7 +309,7 @@ impl GraphQLResponse {
impl<'f, S> FromForm<'f> for GraphQLRequest<S> impl<'f, S> FromForm<'f> for GraphQLRequest<S>
where where
S: ScalarValue, S: ScalarValue + Send + Sync,
{ {
type Error = String; type Error = String;
@ -320,7 +380,7 @@ where
impl<'v, S> FromFormValue<'v> for GraphQLRequest<S> impl<'v, S> FromFormValue<'v> for GraphQLRequest<S>
where where
S: ScalarValue, S: ScalarValue + Send + Sync,
{ {
type Error = String; type Error = String;
@ -331,38 +391,47 @@ where
} }
} }
const BODY_LIMIT: u64 = 1024 * 100;
impl<S> FromDataSimple for GraphQLRequest<S> impl<S> FromDataSimple for GraphQLRequest<S>
where where
S: ScalarValue, S: ScalarValue + Send + Sync,
{ {
type Error = String; type Error = String;
fn from_data(request: &Request, data: Data) -> FromDataOutcome<Self, Self::Error> { fn from_data(request: &Request, data: Data) -> FromDataFuture<'static, Self, Self::Error> {
use futures03::io::AsyncReadExt;
use rocket::AsyncReadExt as _;
if !request.content_type().map_or(false, |ct| ct.is_json()) { if !request.content_type().map_or(false, |ct| ct.is_json()) {
return Forward(data); return Box::pin(async move { Forward(data) });
} }
Box::pin(async move {
let mut body = String::new(); let mut body = String::new();
if let Err(e) = data.open().read_to_string(&mut body) { let mut reader = data.open().take(BODY_LIMIT);
if let Err(e) = reader.read_to_string(&mut body).await {
return Failure((Status::InternalServerError, format!("{:?}", e))); return Failure((Status::InternalServerError, format!("{:?}", e)));
} }
match serde_json::from_str(&body) { match serde_json::from_str(&body) {
Ok(value) => Success(GraphQLRequest(value)), Ok(value) => Success(GraphQLRequest(value)),
Err(failure) => return Failure((Status::BadRequest, format!("{}", failure))), Err(failure) => Failure((Status::BadRequest, format!("{}", failure))),
} }
})
} }
} }
impl<'r> Responder<'r> for GraphQLResponse { impl<'r> Responder<'r> for GraphQLResponse {
fn respond_to(self, _: &Request) -> Result<Response<'r>, Status> { fn respond_to(self, _: &Request) -> ResultFuture<'r> {
let GraphQLResponse(status, body) = self; let GraphQLResponse(status, body) = self;
Box::pin(async move {
Ok(Response::build() Ok(Response::build()
.header(ContentType::new("application", "json")) .header(ContentType::new("application", "json"))
.status(status) .status(status)
.sized_body(Cursor::new(body)) .sized_body(Cursor::new(body))
.finalize()) .finalize())
})
} }
} }