290 lines
9.3 KiB
Rust
290 lines
9.3 KiB
Rust
//! Types and traits for extracting data from [`Request`]s.
|
|
|
|
use std::fmt;
|
|
|
|
use axum::{
|
|
async_trait,
|
|
body::Body,
|
|
extract::{FromRequest, FromRequestParts, Query},
|
|
http::{HeaderValue, Method, Request, StatusCode},
|
|
response::{IntoResponse as _, Response},
|
|
Json, RequestExt as _,
|
|
};
|
|
use juniper::{
|
|
http::{GraphQLBatchRequest, GraphQLRequest},
|
|
DefaultScalarValue, ScalarValue,
|
|
};
|
|
use serde::Deserialize;
|
|
|
|
/// Extractor for [`axum`] to extract a [`JuniperRequest`].
|
|
///
|
|
/// # Example
|
|
///
|
|
/// ```rust
|
|
/// use std::sync::Arc;
|
|
///
|
|
/// use axum::{routing::post, Extension, Json, Router};
|
|
/// use juniper::{
|
|
/// RootNode, EmptySubscription, EmptyMutation, graphql_object,
|
|
/// };
|
|
/// use juniper_axum::{extract::JuniperRequest, response::JuniperResponse};
|
|
///
|
|
/// #[derive(Clone, Copy, Debug)]
|
|
/// pub struct Context;
|
|
///
|
|
/// impl juniper::Context for Context {}
|
|
///
|
|
/// #[derive(Clone, Copy, Debug)]
|
|
/// pub struct Query;
|
|
///
|
|
/// #[graphql_object(context = Context)]
|
|
/// impl Query {
|
|
/// fn add(a: i32, b: i32) -> i32 {
|
|
/// a + b
|
|
/// }
|
|
/// }
|
|
///
|
|
/// type Schema = RootNode<'static, Query, EmptyMutation<Context>, EmptySubscription<Context>>;
|
|
///
|
|
/// let schema = Schema::new(
|
|
/// Query,
|
|
/// EmptyMutation::<Context>::new(),
|
|
/// EmptySubscription::<Context>::new()
|
|
/// );
|
|
///
|
|
/// let app: Router = Router::new()
|
|
/// .route("/graphql", post(graphql))
|
|
/// .layer(Extension(Arc::new(schema)))
|
|
/// .layer(Extension(Context));
|
|
///
|
|
/// # #[axum::debug_handler]
|
|
/// async fn graphql(
|
|
/// Extension(schema): Extension<Arc<Schema>>,
|
|
/// Extension(context): Extension<Context>,
|
|
/// JuniperRequest(req): JuniperRequest, // should be the last argument as consumes `Request`
|
|
/// ) -> JuniperResponse {
|
|
/// JuniperResponse(req.execute(&*schema, &context).await)
|
|
/// }
|
|
#[derive(Debug, PartialEq)]
|
|
pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>)
|
|
where
|
|
S: ScalarValue;
|
|
|
|
#[async_trait]
|
|
impl<S, State> FromRequest<State> for JuniperRequest<S>
|
|
where
|
|
S: ScalarValue,
|
|
State: Sync,
|
|
Query<GetRequest>: FromRequestParts<State>,
|
|
Json<GraphQLBatchRequest<S>>: FromRequest<State>,
|
|
<Json<GraphQLBatchRequest<S>> as FromRequest<State>>::Rejection: fmt::Display,
|
|
String: FromRequest<State>,
|
|
{
|
|
type Rejection = Response;
|
|
|
|
async fn from_request(mut req: Request<Body>, state: &State) -> Result<Self, Self::Rejection> {
|
|
let content_type = req
|
|
.headers()
|
|
.get("content-type")
|
|
.map(HeaderValue::to_str)
|
|
.transpose()
|
|
.map_err(|_| {
|
|
(
|
|
StatusCode::BAD_REQUEST,
|
|
"`Content-Type` header is not a valid HTTP header string",
|
|
)
|
|
.into_response()
|
|
})?;
|
|
|
|
// TODO: Move into `match` expression directly once MSRV is bumped higher than 1.74.
|
|
let method = req.method();
|
|
match (method, content_type) {
|
|
(&Method::GET, _) => req
|
|
.extract_parts::<Query<GetRequest>>()
|
|
.await
|
|
.map_err(|e| {
|
|
(
|
|
StatusCode::BAD_REQUEST,
|
|
format!("Invalid request query string: {e}"),
|
|
)
|
|
.into_response()
|
|
})
|
|
.and_then(|query| {
|
|
query
|
|
.0
|
|
.try_into()
|
|
.map(|q| Self(GraphQLBatchRequest::Single(q)))
|
|
.map_err(|e| {
|
|
(
|
|
StatusCode::BAD_REQUEST,
|
|
format!("Invalid request query `variables`: {e}"),
|
|
)
|
|
.into_response()
|
|
})
|
|
}),
|
|
(&Method::POST, Some("application/json")) => {
|
|
Json::<GraphQLBatchRequest<S>>::from_request(req, state)
|
|
.await
|
|
.map(|req| Self(req.0))
|
|
.map_err(|e| {
|
|
(StatusCode::BAD_REQUEST, format!("Invalid JSON body: {e}")).into_response()
|
|
})
|
|
}
|
|
(&Method::POST, Some("application/graphql")) => String::from_request(req, state)
|
|
.await
|
|
.map(|body| {
|
|
Self(GraphQLBatchRequest::Single(GraphQLRequest::new(
|
|
body, None, None,
|
|
)))
|
|
})
|
|
.map_err(|_| (StatusCode::BAD_REQUEST, "Not valid UTF-8 body").into_response()),
|
|
(&Method::POST, _) => Err((
|
|
StatusCode::UNSUPPORTED_MEDIA_TYPE,
|
|
"`Content-Type` header is expected to be either `application/json` or \
|
|
`application/graphql`",
|
|
)
|
|
.into_response()),
|
|
_ => Err((
|
|
StatusCode::METHOD_NOT_ALLOWED,
|
|
"HTTP method is expected to be either GET or POST",
|
|
)
|
|
.into_response()),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Workaround for a [`GraphQLRequest`] not being [`Deserialize`]d properly from a GET query string,
|
|
/// containing `variables` in JSON format.
|
|
#[derive(Deserialize, Debug)]
|
|
#[serde(deny_unknown_fields)]
|
|
struct GetRequest {
|
|
query: String,
|
|
#[serde(rename = "operationName")]
|
|
operation_name: Option<String>,
|
|
variables: Option<String>,
|
|
}
|
|
|
|
impl<S: ScalarValue> TryFrom<GetRequest> for GraphQLRequest<S> {
|
|
type Error = serde_json::Error;
|
|
fn try_from(req: GetRequest) -> Result<Self, Self::Error> {
|
|
let GetRequest {
|
|
query,
|
|
operation_name,
|
|
variables,
|
|
} = req;
|
|
Ok(Self::new(
|
|
query,
|
|
operation_name,
|
|
variables.map(|v| serde_json::from_str(&v)).transpose()?,
|
|
))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod juniper_request_tests {
|
|
use axum::{body::Body, extract::FromRequest as _, http::Request};
|
|
use futures::TryStreamExt as _;
|
|
use juniper::{
|
|
graphql_input_value,
|
|
http::{GraphQLBatchRequest, GraphQLRequest},
|
|
};
|
|
|
|
use super::JuniperRequest;
|
|
|
|
#[tokio::test]
|
|
async fn from_get_request() {
|
|
let req = Request::get(&format!(
|
|
"/?query={}",
|
|
urlencoding::encode("{ add(a: 2, b: 3) }")
|
|
))
|
|
.body(Body::empty())
|
|
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
|
|
|
|
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
|
|
"{ add(a: 2, b: 3) }".into(),
|
|
None,
|
|
None,
|
|
)));
|
|
|
|
assert_eq!(do_from_request(req).await, expected);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn from_get_request_with_variables() {
|
|
let req = Request::get(&format!(
|
|
"/?query={}&variables={}",
|
|
urlencoding::encode(
|
|
"query($id: String!) { human(id: $id) { id, name, appearsIn, homePlanet } }",
|
|
),
|
|
urlencoding::encode(r#"{"id": "1000"}"#),
|
|
))
|
|
.body(Body::empty())
|
|
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
|
|
|
|
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
|
|
"query($id: String!) { human(id: $id) { id, name, appearsIn, homePlanet } }".into(),
|
|
None,
|
|
Some(graphql_input_value!({"id": "1000"})),
|
|
)));
|
|
|
|
assert_eq!(do_from_request(req).await, expected);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn from_json_post_request() {
|
|
let req = Request::post("/")
|
|
.header("content-type", "application/json")
|
|
.body(Body::from(r#"{"query": "{ add(a: 2, b: 3) }"}"#))
|
|
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
|
|
|
|
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
|
|
"{ add(a: 2, b: 3) }".to_string(),
|
|
None,
|
|
None,
|
|
)));
|
|
|
|
assert_eq!(do_from_request(req).await, expected);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn from_graphql_post_request() {
|
|
let req = Request::post("/")
|
|
.header("content-type", "application/graphql")
|
|
.body(Body::from(r#"{ add(a: 2, b: 3) }"#))
|
|
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
|
|
|
|
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
|
|
"{ add(a: 2, b: 3) }".to_string(),
|
|
None,
|
|
None,
|
|
)));
|
|
|
|
assert_eq!(do_from_request(req).await, expected);
|
|
}
|
|
|
|
/// Performs [`JuniperRequest::from_request()`].
|
|
async fn do_from_request(req: Request<Body>) -> JuniperRequest {
|
|
match JuniperRequest::from_request(req, &()).await {
|
|
Ok(resp) => resp,
|
|
Err(resp) => {
|
|
panic!(
|
|
"`JuniperRequest::from_request()` failed with `{}` status and body:\n{}",
|
|
resp.status(),
|
|
display_body(resp.into_body()).await,
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Converts the provided [`Body`] into a [`String`].
|
|
async fn display_body(body: Body) -> String {
|
|
String::from_utf8(
|
|
body.into_data_stream()
|
|
.map_ok(|bytes| bytes.to_vec())
|
|
.try_concat()
|
|
.await
|
|
.unwrap(),
|
|
)
|
|
.unwrap_or_else(|e| panic!("not UTF-8 body: {e}"))
|
|
}
|
|
}
|