//! Types and traits for extracting data from [`Request`]s.

use std::fmt;

use axum::{
    async_trait,
    body::Body,
    extract::{FromRequest, FromRequestParts, Query},
    http::{HeaderValue, Method, Request, StatusCode},
    response::{IntoResponse as _, Response},
    Json, RequestExt as _,
};
use juniper::{
    http::{GraphQLBatchRequest, GraphQLRequest},
    DefaultScalarValue, ScalarValue,
};
use serde::Deserialize;

/// Extractor for [`axum`] to extract a [`JuniperRequest`].
///
/// # Example
///
/// ```rust
/// use std::sync::Arc;
///
/// use axum::{routing::post, Extension, Json, Router};
/// use juniper::{
///     RootNode, EmptySubscription, EmptyMutation, graphql_object,
/// };
/// use juniper_axum::{extract::JuniperRequest, response::JuniperResponse};
///
/// #[derive(Clone, Copy, Debug)]
/// pub struct Context;
///
/// impl juniper::Context for Context {}
///
/// #[derive(Clone, Copy, Debug)]
/// pub struct Query;
///
/// #[graphql_object(context = Context)]
/// impl Query {
///     fn add(a: i32, b: i32) -> i32 {
///         a + b
///     }
/// }
///
/// type Schema = RootNode<'static, Query, EmptyMutation<Context>, EmptySubscription<Context>>;
///
/// let schema = Schema::new(
///    Query,
///    EmptyMutation::<Context>::new(),
///    EmptySubscription::<Context>::new()
/// );
///
/// let app: Router = Router::new()
///     .route("/graphql", post(graphql))
///     .layer(Extension(Arc::new(schema)))
///     .layer(Extension(Context));
///
/// # #[axum::debug_handler]
/// async fn graphql(
///     Extension(schema): Extension<Arc<Schema>>,
///     Extension(context): Extension<Context>,
///     JuniperRequest(req): JuniperRequest, // should be the last argument as consumes `Request`
/// ) -> JuniperResponse {
///     JuniperResponse(req.execute(&*schema, &context).await)
/// }
#[derive(Debug, PartialEq)]
pub struct JuniperRequest<S = DefaultScalarValue>(pub GraphQLBatchRequest<S>)
where
    S: ScalarValue;

#[async_trait]
impl<S, State> FromRequest<State> for JuniperRequest<S>
where
    S: ScalarValue,
    State: Sync,
    Query<GetRequest>: FromRequestParts<State>,
    Json<GraphQLBatchRequest<S>>: FromRequest<State>,
    <Json<GraphQLBatchRequest<S>> as FromRequest<State>>::Rejection: fmt::Display,
    String: FromRequest<State>,
{
    type Rejection = Response;

    async fn from_request(mut req: Request<Body>, state: &State) -> Result<Self, Self::Rejection> {
        let content_type = req
            .headers()
            .get("content-type")
            .map(HeaderValue::to_str)
            .transpose()
            .map_err(|_| {
                (
                    StatusCode::BAD_REQUEST,
                    "`Content-Type` header is not a valid HTTP header string",
                )
                    .into_response()
            })?;

        // TODO: Move into `match` expression directly once MSRV is bumped higher than 1.74.
        let method = req.method();
        match (method, content_type) {
            (&Method::GET, _) => req
                .extract_parts::<Query<GetRequest>>()
                .await
                .map_err(|e| {
                    (
                        StatusCode::BAD_REQUEST,
                        format!("Invalid request query string: {e}"),
                    )
                        .into_response()
                })
                .and_then(|query| {
                    query
                        .0
                        .try_into()
                        .map(|q| Self(GraphQLBatchRequest::Single(q)))
                        .map_err(|e| {
                            (
                                StatusCode::BAD_REQUEST,
                                format!("Invalid request query `variables`: {e}"),
                            )
                                .into_response()
                        })
                }),
            (&Method::POST, Some("application/json")) => {
                Json::<GraphQLBatchRequest<S>>::from_request(req, state)
                    .await
                    .map(|req| Self(req.0))
                    .map_err(|e| {
                        (StatusCode::BAD_REQUEST, format!("Invalid JSON body: {e}")).into_response()
                    })
            }
            (&Method::POST, Some("application/graphql")) => String::from_request(req, state)
                .await
                .map(|body| {
                    Self(GraphQLBatchRequest::Single(GraphQLRequest::new(
                        body, None, None,
                    )))
                })
                .map_err(|_| (StatusCode::BAD_REQUEST, "Not valid UTF-8 body").into_response()),
            (&Method::POST, _) => Err((
                StatusCode::UNSUPPORTED_MEDIA_TYPE,
                "`Content-Type` header is expected to be either `application/json` or \
                 `application/graphql`",
            )
                .into_response()),
            _ => Err((
                StatusCode::METHOD_NOT_ALLOWED,
                "HTTP method is expected to be either GET or POST",
            )
                .into_response()),
        }
    }
}

/// Workaround for a [`GraphQLRequest`] not being [`Deserialize`]d properly from a GET query string,
/// containing `variables` in JSON format.
#[derive(Deserialize, Debug)]
#[serde(deny_unknown_fields)]
struct GetRequest {
    query: String,
    #[serde(rename = "operationName")]
    operation_name: Option<String>,
    variables: Option<String>,
}

impl<S: ScalarValue> TryFrom<GetRequest> for GraphQLRequest<S> {
    type Error = serde_json::Error;
    fn try_from(req: GetRequest) -> Result<Self, Self::Error> {
        let GetRequest {
            query,
            operation_name,
            variables,
        } = req;
        Ok(Self::new(
            query,
            operation_name,
            variables.map(|v| serde_json::from_str(&v)).transpose()?,
        ))
    }
}

#[cfg(test)]
mod juniper_request_tests {
    use axum::{body::Body, extract::FromRequest as _, http::Request};
    use futures::TryStreamExt as _;
    use juniper::{
        graphql_input_value,
        http::{GraphQLBatchRequest, GraphQLRequest},
    };

    use super::JuniperRequest;

    #[tokio::test]
    async fn from_get_request() {
        let req = Request::get(&format!(
            "/?query={}",
            urlencoding::encode("{ add(a: 2, b: 3) }")
        ))
        .body(Body::empty())
        .unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));

        let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
            "{ add(a: 2, b: 3) }".into(),
            None,
            None,
        )));

        assert_eq!(do_from_request(req).await, expected);
    }

    #[tokio::test]
    async fn from_get_request_with_variables() {
        let req = Request::get(&format!(
            "/?query={}&variables={}",
            urlencoding::encode(
                "query($id: String!) { human(id: $id) { id, name, appearsIn, homePlanet } }",
            ),
            urlencoding::encode(r#"{"id": "1000"}"#),
        ))
        .body(Body::empty())
        .unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));

        let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
            "query($id: String!) { human(id: $id) { id, name, appearsIn, homePlanet } }".into(),
            None,
            Some(graphql_input_value!({"id": "1000"})),
        )));

        assert_eq!(do_from_request(req).await, expected);
    }

    #[tokio::test]
    async fn from_json_post_request() {
        let req = Request::post("/")
            .header("content-type", "application/json")
            .body(Body::from(r#"{"query": "{ add(a: 2, b: 3) }"}"#))
            .unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));

        let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
            "{ add(a: 2, b: 3) }".to_string(),
            None,
            None,
        )));

        assert_eq!(do_from_request(req).await, expected);
    }

    #[tokio::test]
    async fn from_graphql_post_request() {
        let req = Request::post("/")
            .header("content-type", "application/graphql")
            .body(Body::from(r#"{ add(a: 2, b: 3) }"#))
            .unwrap_or_else(|e| panic!("cannot build `Request`: {e}"));

        let expected = JuniperRequest(GraphQLBatchRequest::Single(GraphQLRequest::new(
            "{ add(a: 2, b: 3) }".to_string(),
            None,
            None,
        )));

        assert_eq!(do_from_request(req).await, expected);
    }

    /// Performs [`JuniperRequest::from_request()`].
    async fn do_from_request(req: Request<Body>) -> JuniperRequest {
        match JuniperRequest::from_request(req, &()).await {
            Ok(resp) => resp,
            Err(resp) => {
                panic!(
                    "`JuniperRequest::from_request()` failed with `{}` status and body:\n{}",
                    resp.status(),
                    display_body(resp.into_body()).await,
                )
            }
        }
    }

    /// Converts the provided [`Body`] into a [`String`].
    async fn display_body(body: Body) -> String {
        String::from_utf8(
            body.into_data_stream()
                .map_ok(|bytes| bytes.to_vec())
                .try_concat()
                .await
                .unwrap(),
        )
        .unwrap_or_else(|e| panic!("not UTF-8 body: {e}"))
    }
}