juniper/juniper_axum/src/extract.rs

290 lines
9.3 KiB
Rust

//! Types and traits for extracting data from [`Request`]s.
use std::fmt;
use axum::{
async_trait,
body::Body,
extract::{FromRequest, FromRequestParts, Query},
http::{HeaderValue, Method, Request, StatusCode},
response::{IntoResponse as _, Response},
Json, RequestExt as _,
};
use juniper::{
http::{GraphQLBatchRequest, GraphQLRequest},
DefaultScalarValue, ScalarValue,
};
use serde::Deserialize;
/// Extractor for [`axum`] to extract a [`JuniperRequest`].
///
/// # Example
///
/// ```rust
/// use std::sync::Arc;
///
/// use axum::{routing::post, Extension, Json, Router};
/// use juniper::{
/// RootNode, EmptySubscription, EmptyMutation, graphql_object,
/// };
/// use juniper_axum::{extract::JuniperRequest, response::JuniperResponse};
///
/// #[derive(Clone, Copy, Debug)]
/// pub struct Context;
///
/// impl juniper::Context for Context {}
///
/// #[derive(Clone, Copy, Debug)]
/// pub struct Query;
///
/// #[graphql_object(context = Context)]
/// impl Query {
/// fn add(a: i32, b: i32) -> i32 {
/// a + b
/// }
/// }
///
/// type Schema = RootNode<'static, Query, EmptyMutation<Context>, EmptySubscription<Context>>;
///
/// let schema = Schema::new(
/// Query,
/// EmptyMutation::<Context>::new(),
/// EmptySubscription::<Context>::new()
/// );
///
/// let app: Router = Router::new()
/// .route("/graphql", post(graphql))
/// .layer(Extension(Arc::new(schema)))
/// .layer(Extension(Context));
///
/// # #[axum::debug_handler]
/// async fn graphql(
/// Extension(schema): Extension<Arc<Schema>>,
/// Extension(context): Extension<Context>,
/// JuniperRequest(req): JuniperRequest, // should be the last argument as consumes `Request`
/// ) -> JuniperResponse {
/// JuniperResponse(req.execute(&*schema, &context).await)
/// }
#[derive(Debug, PartialEq)]
pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>)
where
S: ScalarValue;
#[async_trait]
impl<S, State> FromRequest<State> for JuniperRequest<S>
where
S: ScalarValue,
State: Sync,
Query<GetRequest>: FromRequestParts<State>,
Json<GraphQLBatchRequest<S>>: FromRequest<State>,
<Json<GraphQLBatchRequest<S>> as FromRequest<State>>::Rejection: fmt::Display,
String: FromRequest<State>,
{
type Rejection = Response;
async fn from_request(mut req: Request<Body>, state: &State) -> Result<Self, Self::Rejection> {
let content_type = req
.headers()
.get("content-type")
.map(HeaderValue::to_str)
.transpose()
.map_err(|_| {
(
StatusCode::BAD_REQUEST,
"`Content-Type` header is not a valid HTTP header string",
)
.into_response()
})?;
// TODO: Move into `match` expression directly once MSRV is bumped higher than 1.74.
let method = req.method();
match (method, content_type) {
(&Method::GET, _) => req
.extract_parts::<Query<GetRequest>>()
.await
.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Invalid request query string: {e}"),
)
.into_response()
})
.and_then(|query| {
query
.0
.try_into()
.map(|q| Self(GraphQLBatchRequest::Single(q)))
.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Invalid request query `variables`: {e}"),
)
.into_response()
})
}),
(&Method::POST, Some("application/json")) => {
Json::<GraphQLBatchRequest<S>>::from_request(req, state)
.await
.map(|req| Self(req.0))
.map_err(|e| {
(StatusCode::BAD_REQUEST, format!("Invalid JSON body: {e}")).into_response()
})
}
(&Method::POST, Some("application/graphql")) => String::from_request(req, state)
.await
.map(|body| {
Self(GraphQLBatchRequest::Single(GraphQLRequest::new(
body, None, None,
)))
})
.map_err(|_| (StatusCode::BAD_REQUEST, "Not valid UTF-8 body").into_response()),
(&Method::POST, _) => Err((
StatusCode::UNSUPPORTED_MEDIA_TYPE,
"`Content-Type` header is expected to be either `application/json` or \
`application/graphql`",
)
.into_response()),
_ => Err((
StatusCode::METHOD_NOT_ALLOWED,
"HTTP method is expected to be either GET or POST",
)
.into_response()),
}
}
}
/// Workaround for a [`GraphQLRequest`] not being [`Deserialize`]d properly from a GET query string,
/// containing `variables` in JSON format.
#[derive(Deserialize, Debug)]
#[serde(deny_unknown_fields)]
struct GetRequest {
query: String,
#[serde(rename = "operationName")]
operation_name: Option<String>,
variables: Option<String>,
}
impl<S: ScalarValue> TryFrom<GetRequest> for GraphQLRequest<S> {
type Error = serde_json::Error;
fn try_from(req: GetRequest) -> Result<Self, Self::Error> {
let GetRequest {
query,
operation_name,
variables,
} = req;
Ok(Self::new(
query,
operation_name,
variables.map(|v| serde_json::from_str(&v)).transpose()?,
))
}
}
#[cfg(test)]
mod juniper_request_tests {
use axum::{body::Body, extract::FromRequest as _, http::Request};
use futures::TryStreamExt as _;
use juniper::{
graphql_input_value,
http::{GraphQLBatchRequest, GraphQLRequest},
};
use super::JuniperRequest;
#[tokio::test]
async fn from_get_request() {
let req = Request::get(&format!(
"/?query={}",
urlencoding::encode("{ add(a: 2, b: 3) }")
))
.body(Body::empty())
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
"{ add(a: 2, b: 3) }".into(),
None,
None,
)));
assert_eq!(do_from_request(req).await, expected);
}
#[tokio::test]
async fn from_get_request_with_variables() {
let req = Request::get(&format!(
"/?query={}&variables={}",
urlencoding::encode(
"query($id: String!) { human(id: $id) { id, name, appearsIn, homePlanet } }",
),
urlencoding::encode(r#"{"id": "1000"}"#),
))
.body(Body::empty())
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
"query($id: String!) { human(id: $id) { id, name, appearsIn, homePlanet } }".into(),
None,
Some(graphql_input_value!({"id": "1000"})),
)));
assert_eq!(do_from_request(req).await, expected);
}
#[tokio::test]
async fn from_json_post_request() {
let req = Request::post("/")
.header("content-type", "application/json")
.body(Body::from(r#"{"query": "{ add(a: 2, b: 3) }"}"#))
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
"{ add(a: 2, b: 3) }".to_string(),
None,
None,
)));
assert_eq!(do_from_request(req).await, expected);
}
#[tokio::test]
async fn from_graphql_post_request() {
let req = Request::post("/")
.header("content-type", "application/graphql")
.body(Body::from(r#"{ add(a: 2, b: 3) }"#))
.unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));
let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
"{ add(a: 2, b: 3) }".to_string(),
None,
None,
)));
assert_eq!(do_from_request(req).await, expected);
}
/// Performs [`JuniperRequest::from_request()`].
async fn do_from_request(req: Request<Body>) -> JuniperRequest {
match JuniperRequest::from_request(req, &()).await {
Ok(resp) => resp,
Err(resp) => {
panic!(
"`JuniperRequest::from_request()` failed with `{}` status and body:\n{}",
resp.status(),
display_body(resp.into_body()).await,
)
}
}
}
/// Converts the provided [`Body`] into a [`String`].
async fn display_body(body: Body) -> String {
String::from_utf8(
body.into_data_stream()
.map_ok(|bytes| bytes.to_vec())
.try_concat()
.await
.unwrap(),
)
.unwrap_or_else(|e| panic!("not UTF-8 body: {e}"))
}
}