From fea722b1961e0aa2d91aa764e5ab9294a12b4be4 Mon Sep 17 00:00:00 2001
From: tyranron <tyranron@gmail.com>
Date: Thu, 11 Aug 2022 18:06:43 +0300
Subject: [PATCH] Impl input objects, vol.2 [skip ci]

---
 juniper/src/behavior.rs                       |  11 -
 juniper/src/executor/mod.rs                   |  45 +++
 juniper/src/graphql/mod.rs                    |  10 +
 juniper/src/schema/meta.rs                    |  19 ++
 juniper_codegen/src/graphql_enum/mod.rs       |   9 +-
 .../src/graphql_input_object/mod.rs           | 308 +++++++++++++++++-
 juniper_codegen/src/graphql_scalar/mod.rs     |   9 +-
 7 files changed, 386 insertions(+), 25 deletions(-)

diff --git a/juniper/src/behavior.rs b/juniper/src/behavior.rs
index d8f8b4d2..cb1976cd 100644
--- a/juniper/src/behavior.rs
+++ b/juniper/src/behavior.rs
@@ -64,17 +64,6 @@ where
     }
 }
 
-impl<T, SV, B1, B2> resolve::ToInputValue<SV, B1> for Coerce<T, B2>
-where
-    T: resolve::ToInputValue<SV, B2> + ?Sized,
-    B1: ?Sized,
-    B2: ?Sized,
-{
-    fn to_input_value(&self) -> graphql::InputValue<SV> {
-        self.1.to_input_value()
-    }
-}
-
 impl<'i, T, SV, B1, B2> resolve::InputValue<'i, SV, B1> for Coerce<T, B2>
 where
     T: resolve::InputValue<'i, SV, B2>,
diff --git a/juniper/src/executor/mod.rs b/juniper/src/executor/mod.rs
index 4cd8bffa..558d5850 100644
--- a/juniper/src/executor/mod.rs
+++ b/juniper/src/executor/mod.rs
@@ -1263,6 +1263,16 @@ impl<'r, S: 'r> Registry<'r, S> {
         Argument::new(name, self.get_type::<T>(info)).default_value(value.to_input_value())
     }
 
+    /// Creates an [`Argument`] with the provided `name`.
+    pub fn arg_reworked<'ti, T, TI>(&mut self, name: &str, type_info: &'ti TI) -> Argument<'r, S>
+    where
+        T: resolve::Type<TI, S> + resolve::InputValueOwned<S>,
+        TI: ?Sized,
+        'ti: 'r,
+    {
+        Argument::new(name, T::meta(self, type_info).as_type())
+    }
+
     fn insert_placeholder(&mut self, name: Name, of_type: Type<'r>) {
         self.types
             .entry(name)
@@ -1531,4 +1541,39 @@ impl<'r, S: 'r> Registry<'r, S> {
 
         InputObjectMeta::new::<T>(Cow::Owned(name.into()), args)
     }
+
+    /// Builds an [`InputObjectMeta`] information for the specified
+    /// [`graphql::Type`], allowing to `customize` the created [`ScalarMeta`],
+    /// and stores it in this [`Registry`].
+    ///
+    /// # Idempotent
+    ///
+    /// If this [`Registry`] contains a [`MetaType`] with such [`TypeName`]
+    /// already, then just returns it without doing anything.
+    ///
+    /// [`graphql::Type`]: resolve::Type
+    /// [`TypeName`]: resolve::TypeName
+    pub fn register_input_object_with<'ti, T, TI, F>(
+        &mut self,
+        fields: &[Argument<'r, S>],
+        type_info: &'ti TI,
+        customize: F,
+    ) -> MetaType<'r, S>
+    where
+        T: resolve::TypeName<TI> + resolve::InputValueOwned<S>,
+        TI: ?Sized,
+        'ti: 'r,
+        F: FnOnce(InputObjectMeta<'r, S>) -> InputObjectMeta<'r, S>,
+        S: Clone,
+    {
+        self.entry_type::<T, _>(type_info)
+            .or_insert_with(move || {
+                customize(InputObjectMeta::new_reworked::<T, _>(
+                    T::type_name(type_info),
+                    fields,
+                ))
+                .into_meta()
+            })
+            .clone()
+    }
 }
diff --git a/juniper/src/graphql/mod.rs b/juniper/src/graphql/mod.rs
index 097e775d..387690d1 100644
--- a/juniper/src/graphql/mod.rs
+++ b/juniper/src/graphql/mod.rs
@@ -47,6 +47,16 @@ pub trait Object<S>: OutputType<S>
     fn assert_object();
 }*/
 
+pub trait InputObject<
+    'inp,
+    TypeInfo: ?Sized,
+    ScalarValue: 'inp,
+    Behavior: ?Sized = behavior::Standard,
+>: InputType<'inp, TypeInfo, ScalarValue, Behavior>
+{
+    fn assert_input_object();
+}
+
 pub trait Scalar<
     'inp,
     TypeInfo: ?Sized,
diff --git a/juniper/src/schema/meta.rs b/juniper/src/schema/meta.rs
index 717a2383..13d526eb 100644
--- a/juniper/src/schema/meta.rs
+++ b/juniper/src/schema/meta.rs
@@ -769,6 +769,25 @@ impl<'a, S> InputObjectMeta<'a, S> {
         }
     }
 
+    /// Builds a new [`InputObjectMeta`] information with the specified `name`
+    /// and its `fields`.
+    // TODO: Use `impl Into<Cow<'a, str>>` argument once feature
+    //       `explicit_generic_args_with_impl_trait` hits stable:
+    //       https://github.com/rust-lang/rust/issues/83701
+    pub fn new_reworked<T, N>(name: N, fields: &[Argument<'a, S>]) -> Self
+    where
+        T: resolve::InputValueOwned<S>,
+        Cow<'a, str>: From<N>,
+        S: Clone,
+    {
+        Self {
+            name: name.into(),
+            description: None,
+            input_fields: fields.to_vec(),
+            try_parse_fn: try_parse_fn_reworked::<T, S>,
+        }
+    }
+
     /// Set the `description` of this [`InputObjectMeta`] type.
     ///
     /// Overwrites any previously set description.
diff --git a/juniper_codegen/src/graphql_enum/mod.rs b/juniper_codegen/src/graphql_enum/mod.rs
index 85e22be0..3f26625f 100644
--- a/juniper_codegen/src/graphql_enum/mod.rs
+++ b/juniper_codegen/src/graphql_enum/mod.rs
@@ -533,9 +533,14 @@ impl Definition {
         quote! {
             #[automatically_derived]
             impl #impl_gens ::juniper::graphql::Enum<#lt, #inf, #cx, #sv, #bh>
-                for #ty #where_clause
+             for #ty #where_clause
             {
-                fn assert_enum() {}
+                fn assert_enum() {
+                    <Self as ::juniper::graphql::InputType<#lt, #inf, #sv, #bh>>
+                        ::assert_input_type();
+                    <Self as ::juniper::graphql::OutputType<#inf, #cx, #sv, #bh>>
+                        ::assert_output_type();
+                }
             }
         }
     }
diff --git a/juniper_codegen/src/graphql_input_object/mod.rs b/juniper_codegen/src/graphql_input_object/mod.rs
index 9d37269f..ec6ea64c 100644
--- a/juniper_codegen/src/graphql_input_object/mod.rs
+++ b/juniper_codegen/src/graphql_input_object/mod.rs
@@ -386,6 +386,14 @@ struct FieldDefinition {
     ignored: bool,
 }
 
+impl FieldDefinition {
+    /// Indicates whether this [`FieldDefinition`] uses [`Default::default()`]
+    /// ans its [`FieldDefinition::default`] value.
+    fn needs_default_trait_bound(&self) -> bool {
+        matches!(self.default, Some(default::Value::Default))
+    }
+}
+
 /// Representation of [GraphQL input object][0] for code generation.
 ///
 /// [0]: https://spec.graphql.org/October2021#sec-Input-Objects
@@ -455,12 +463,12 @@ impl ToTokens for Definition {
         self.impl_to_input_value_tokens().to_tokens(into);
         self.impl_reflection_traits_tokens().to_tokens(into);
         ////////////////////////////////////////////////////////////////////////
-        //self.impl_resolve_type().to_tokens(into);
+        self.impl_resolve_type().to_tokens(into);
         self.impl_resolve_type_name().to_tokens(into);
         self.impl_resolve_to_input_value().to_tokens(into);
-        //self.impl_resolve_input_value().to_tokens(into);
-        //self.impl_graphql_input_type().to_tokens(into);
-        //self.impl_graphql_input_object().to_tokens(into);
+        self.impl_resolve_input_value().to_tokens(into);
+        self.impl_graphql_input_type().to_tokens(into);
+        self.impl_graphql_input_object().to_tokens(into);
         self.impl_reflect().to_tokens(into);
     }
 }
@@ -503,6 +511,88 @@ impl Definition {
         }
     }
 
+    /// Returns generated code implementing [`graphql::InputType`] trait for
+    /// [GraphQL input object][0].
+    ///
+    /// [`graphql::InputType`]: juniper::graphql::InputType
+    /// [0]: https://spec.graphql.org/October2021#sec-Input-Objects
+    #[must_use]
+    fn impl_graphql_input_type(&self) -> TokenStream {
+        let bh = &self.behavior;
+        let (ty, generics) = self.ty_and_generics();
+        let (inf, generics) = self.mix_type_info(generics);
+        let (sv, generics) = self.mix_scalar_value(generics);
+        let (lt, mut generics) = self.mix_input_lifetime(generics, &sv);
+        generics.make_where_clause().predicates.push(parse_quote! {
+            Self: ::juniper::resolve::Type<#inf, #sv, #bh>
+                  + ::juniper::resolve::ToInputValue<#sv, #bh>
+                  + ::juniper::resolve::InputValue<#lt, #sv, #bh>
+        });
+        for f in self.fields.iter().filter(|f| !f.ignored) {
+            let field_ty = &f.ty;
+            let field_bh = &f.behavior;
+            generics.make_where_clause().predicates.push(parse_quote! {
+                #field_ty:
+                    ::juniper::graphql::InputType<#lt, #inf, #sv, #field_bh>
+            });
+        }
+        let (impl_gens, _, where_clause) = generics.split_for_impl();
+
+        let fields_assertions = self.fields.iter().filter_map(|f| {
+            (!f.ignored).then(|| {
+                let field_ty = &f.ty;
+                let field_bh = &f.behavior;
+
+                quote! {
+                    <#field_ty as
+                     ::juniper::graphql::InputType<#lt, #inf, #sv, #field_bh>>
+                        ::assert_input_type();
+                }
+            })
+        });
+
+        quote! {
+            #[automatically_derived]
+            impl #impl_gens ::juniper::graphql::InputType<#lt, #inf, #sv, #bh>
+             for #ty #where_clause
+            {
+                fn assert_input_type() {
+                    #( #fields_assertions )*
+                }
+            }
+        }
+    }
+
+    /// Returns generated code implementing [`graphql::InputObject`] trait for
+    /// this [GraphQL input object][0].
+    ///
+    /// [`graphql::InputObject`]: juniper::graphql::InputObject
+    /// [0]: https://spec.graphql.org/October2021#sec-Input-Objects
+    #[must_use]
+    fn impl_graphql_input_object(&self) -> TokenStream {
+        let bh = &self.behavior;
+        let (ty, generics) = self.ty_and_generics();
+        let (inf, generics) = self.mix_type_info(generics);
+        let (sv, generics) = self.mix_scalar_value(generics);
+        let (lt, mut generics) = self.mix_input_lifetime(generics, &sv);
+        generics.make_where_clause().predicates.push(parse_quote! {
+            Self: ::juniper::graphql::InputType<#lt, #inf, #sv, #bh>
+        });
+        let (impl_gens, _, where_clause) = generics.split_for_impl();
+
+        quote! {
+            #[automatically_derived]
+            impl #impl_gens ::juniper::graphql::InputObject<#lt, #inf, #sv, #bh>
+             for #ty #where_clause
+            {
+                fn assert_input_object() {
+                    <Self as ::juniper::graphql::InputType<#lt, #inf, #sv, #bh>>
+                        ::assert_input_type();
+                }
+            }
+        }
+    }
+
     /// Returns generated code implementing [`GraphQLType`] trait for this
     /// [GraphQL input object][0].
     ///
@@ -563,6 +653,96 @@ impl Definition {
         }
     }
 
+    /// Returns generated code implementing [`resolve::Type`] trait for this
+    /// [GraphQL input object][0].
+    ///
+    /// [`resolve::Type`]: juniper::resolve::Type
+    /// [0]: https://spec.graphql.org/October2021#sec-Input-Objects
+    fn impl_resolve_type(&self) -> TokenStream {
+        let bh = &self.behavior;
+        let (ty, generics) = self.ty_and_generics();
+        let (inf, generics) = self.mix_type_info(generics);
+        let (sv, mut generics) = self.mix_scalar_value(generics);
+        let preds = &mut generics.make_where_clause().predicates;
+        preds.push(parse_quote! { #sv: Clone });
+        preds.push(parse_quote! {
+            ::juniper::behavior::Coerce<Self>:
+                ::juniper::resolve::TypeName<#inf>
+                + ::juniper::resolve::InputValueOwned<#sv>
+        });
+        for f in self.fields.iter().filter(|f| !f.ignored) {
+            let field_ty = &f.ty;
+            let field_bh = &f.behavior;
+            preds.push(parse_quote! {
+                ::juniper::behavior::Coerce<#field_ty>:
+                    ::juniper::resolve::Type<#inf, #sv>
+                    + ::juniper::resolve::InputValueOwned<#sv>
+            });
+            if f.default.is_some() {
+                preds.push(parse_quote! {
+                    #field_ty: ::juniper::resolve::ToInputValue<#sv, #field_bh>
+                });
+            }
+            if f.needs_default_trait_bound() {
+                preds.push(parse_quote! {
+                    #field_ty: ::std::default::Default
+                });
+            }
+        }
+        let (impl_gens, _, where_clause) = generics.split_for_impl();
+
+        let description = &self.description;
+
+        let fields_meta = self.fields.iter().filter_map(|f| {
+            (!f.ignored).then(|| {
+                let f_ty = &f.ty;
+                let f_bh = &f.behavior;
+                let f_name = &f.name;
+                let f_description = &f.description;
+                let f_default = f.default.as_ref().map(|expr| {
+                    quote! {
+                        .default_value(
+                            <#f_ty as
+                             ::juniper::resolve::ToInputValue<#sv, #f_bh>>
+                                ::to_input_value(&{ #expr }),
+                        )
+                    }
+                });
+
+                quote! {
+                    registry.arg_reworked::<
+                        ::juniper::behavior::Coerce<#f_ty>, _,
+                    >(#f_name, type_info)
+                        #f_description
+                        #f_default
+                }
+            })
+        });
+
+        quote! {
+            #[automatically_derived]
+            impl #impl_gens ::juniper::resolve::Type<#inf, #sv, #bh>
+             for #ty #where_clause
+            {
+                fn meta<'__r, '__ti: '__r>(
+                    registry: &mut ::juniper::Registry<'__r, #sv>,
+                    type_info: &'__ti #inf,
+                ) -> ::juniper::meta::MetaType<'__r, #sv>
+                where
+                    #sv: '__r,
+                {
+                    let fields = [#( #fields_meta ),*];
+
+                    registry.register_input_object_with::<
+                        ::juniper::behavior::Coerce<Self>, _, _,
+                    >(&fields, type_info, |meta| {
+                        meta #description
+                    })
+                }
+            }
+        }
+    }
+
     /// Returns generated code implementing [`resolve::TypeName`] trait for this
     /// [GraphQL input object][0].
     ///
@@ -717,6 +897,96 @@ impl Definition {
         }
     }
 
+    /// Returns generated code implementing [`resolve::InputValue`] trait for
+    /// this [GraphQL input object][0].
+    ///
+    /// [`resolve::InputValue`]: juniper::resolve::InputValue
+    /// [0]: https://spec.graphql.org/October2021#sec-Input-Objects
+    fn impl_resolve_input_value(&self) -> TokenStream {
+        let bh = &self.behavior;
+        let (ty, generics) = self.ty_and_generics();
+        let (sv, generics) = self.mix_scalar_value(generics);
+        let (lt, mut generics) = self.mix_input_lifetime(generics, &sv);
+        generics.make_where_clause().predicates.push(parse_quote! {
+            #sv: ::juniper::ScalarValue
+        });
+        for f in self.fields.iter().filter(|f| !f.ignored) {
+            let field_ty = &f.ty;
+            let field_bh = &f.behavior;
+            generics.make_where_clause().predicates.push(parse_quote! {
+                #field_ty: ::juniper::resolve::InputValue<#lt, #sv, #field_bh>
+            });
+        }
+        for f in self.fields.iter().filter(|f| f.needs_default_trait_bound()) {
+            let field_ty = &f.ty;
+            generics.make_where_clause().predicates.push(parse_quote! {
+                #field_ty: ::std::default::Default,
+            });
+        }
+        let (impl_gens, _, where_clause) = generics.split_for_impl();
+
+        let fields = self.fields.iter().map(|f| {
+            let field = &f.ident;
+            let field_ty = &f.ty;
+            let field_bh = &f.behavior;
+
+            let constructor = if f.ignored {
+                let expr = f.default.clone().unwrap_or_default();
+
+                quote! { #expr }
+            } else {
+                let name = &f.name;
+
+                let fallback = f.default.as_ref().map_or_else(
+                    || {
+                        quote! {
+                            <#field_ty as ::juniper::resolve::InputValue<#lt, #sv, #field_bh>>
+                                ::try_from_implicit_null()
+                                .map_err(::juniper::IntoFieldError::<#sv>::into_field_error)?
+                        }
+                    },
+                    |expr| quote! { #expr },
+                );
+
+                quote! {
+                    match obj.get(#name) {
+                        ::std::option::Option::Some(v) => {
+                            <#field_ty as ::juniper::resolve::InputValue<#lt, #sv, #field_bh>>
+                                ::try_from_input_value(v)
+                                .map_err(::juniper::IntoFieldError::<#sv>::into_field_error)?
+                        }
+                        ::std::option::Option::None => { #fallback }
+                    }
+                }
+            };
+
+            quote! { #field: { #constructor }, }
+        });
+
+        quote! {
+            #[automatically_derived]
+            impl #impl_gens ::juniper::resolve::InputValue<#lt, #sv, #bh>
+             for #ty #where_clause
+            {
+                type Error = ::juniper::FieldError<#sv>;
+
+                fn try_from_input_value(
+                    input: &#lt ::juniper::graphql::InputValue<#sv>,
+                ) -> ::std::result::Result<Self, Self::Error> {
+                    let obj = input
+                        .to_object_value()
+                        .ok_or_else(|| ::std::format!(
+                            "Expected input object, found: {}", input,
+                        ))?;
+
+                    ::std::result::Result::Ok(Self {
+                        #( #fields )*
+                    })
+                }
+            }
+        }
+    }
+
     /// Returns generated code implementing [`ToInputValue`] trait for this
     /// [GraphQL input object][0].
     ///
@@ -764,7 +1034,7 @@ impl Definition {
         let bh = &self.behavior;
         let (ty, generics) = self.ty_and_generics();
         let (sv, mut generics) = self.mix_scalar_value(generics);
-        for f in &self.fields {
+        for f in self.fields.iter().filter(|f| !f.ignored) {
             let field_ty = &f.ty;
             let field_bh = &f.behavior;
             generics.make_where_clause().predicates.push(parse_quote! {
@@ -774,12 +1044,12 @@ impl Definition {
         let (impl_gens, _, where_clause) = generics.split_for_impl();
 
         let fields = self.fields.iter().filter_map(|f| {
-            let field = &f.ident;
-            let field_ty = &f.ty;
-            let field_bh = &f.behavior;
-            let name = &f.name;
-
             (!f.ignored).then(|| {
+                let field = &f.ident;
+                let field_ty = &f.ty;
+                let field_bh = &f.behavior;
+                let name = &f.name;
+
                 quote! {
                     (#name, <#field_ty as
                              ::juniper::resolve::ToInputValue<#sv, #field_bh>>
@@ -973,4 +1243,22 @@ impl Definition {
         generics.params.push(parse_quote! { #sv });
         (sv, generics)
     }
+
+    /// Mixes an [`InputValue`]'s lifetime [`syn::GenericParam`] into the
+    /// provided [`syn::Generics`] and returns it.
+    ///
+    /// [`InputValue`]: juniper::resolve::InputValue
+    fn mix_input_lifetime(
+        &self,
+        mut generics: syn::Generics,
+        sv: &syn::Ident,
+    ) -> (syn::GenericParam, syn::Generics) {
+        let lt: syn::GenericParam = parse_quote! { '__inp };
+        generics.params.push(lt.clone());
+        generics
+            .make_where_clause()
+            .predicates
+            .push(parse_quote! { #sv: #lt });
+        (lt, generics)
+    }
 }
diff --git a/juniper_codegen/src/graphql_scalar/mod.rs b/juniper_codegen/src/graphql_scalar/mod.rs
index 316d79ed..2752c4eb 100644
--- a/juniper_codegen/src/graphql_scalar/mod.rs
+++ b/juniper_codegen/src/graphql_scalar/mod.rs
@@ -467,9 +467,14 @@ impl Definition {
         quote! {
             #[automatically_derived]
             impl #impl_gens ::juniper::graphql::Scalar<#lt, #inf, #cx, #sv, #bh>
-                for #ty #where_clause
+             for #ty #where_clause
             {
-                fn assert_scalar() {}
+                fn assert_scalar() {
+                    <Self as ::juniper::graphql::InputType<#lt, #inf, #sv, #bh>>
+                        ::assert_input_type();
+                    <Self as ::juniper::graphql::OutputType<#inf, #cx, #sv, #bh>>
+                        ::assert_output_type();
+                }
             }
         }
     }