From b381c696ffc5c530c75c7d8f13a4dcd60d86ec85 Mon Sep 17 00:00:00 2001
From: tyranron <tyranron@gmail.com>
Date: Fri, 27 May 2022 18:11:40 +0200
Subject: [PATCH] Improve and polish codegen for scalars

---
 juniper/src/types/str.rs                  |   5 +-
 juniper_codegen/src/graphql_scalar/mod.rs | 218 +++++++++++++++++++---
 2 files changed, 192 insertions(+), 31 deletions(-)

diff --git a/juniper/src/types/str.rs b/juniper/src/types/str.rs
index e4f3e1e3..af7c966c 100644
--- a/juniper/src/types/str.rs
+++ b/juniper/src/types/str.rs
@@ -66,7 +66,10 @@ where
     }
 }
 
-impl<S: ScalarValue> resolve::ScalarToken<S> for str {
+impl<S> resolve::ScalarToken<S> for str
+where
+    String: resolve::ScalarToken<S>,
+{
     fn parse_scalar_token(token: ScalarToken<'_>) -> Result<S, ParseError<'_>> {
         <String as resolve::ScalarToken<S>>::parse_scalar_token(token)
     }
diff --git a/juniper_codegen/src/graphql_scalar/mod.rs b/juniper_codegen/src/graphql_scalar/mod.rs
index e7654056..45a34291 100644
--- a/juniper_codegen/src/graphql_scalar/mod.rs
+++ b/juniper_codegen/src/graphql_scalar/mod.rs
@@ -332,6 +332,7 @@ impl ToTokens for Definition {
         self.impl_reflection_traits_tokens().to_tokens(into);
         ////////////////////////////////////////////////////////////////////////
         self.impl_resolve_type_name().to_tokens(into);
+        self.impl_resolve_type().to_tokens(into);
         self.impl_resolve_input_value().to_tokens(into);
         self.impl_resolve_scalar_token().to_tokens(into);
     }
@@ -409,29 +410,76 @@ impl Definition {
     }
 
     /// Returns generated code implementing [`resolve::TypeName`] trait for this
-    /// [GraphQL scalar][1].
+    /// [GraphQL scalar][0].
     ///
     /// [`resolve::TypeName`]: juniper::resolve::TypeName
-    /// [1]: https://spec.graphql.org/October2021#sec-Scalars
+    /// [0]: https://spec.graphql.org/October2021#sec-Scalars
     fn impl_resolve_type_name(&self) -> TokenStream {
-        let name = &self.name;
-
         let (ty, generics) = self.ty_and_generics();
-        let (info_ty, generics) = self.mix_info_ty(generics);
+        let (info, generics) = self.mix_info(generics);
         let (impl_gens, _, where_clause) = generics.split_for_impl();
 
+        let name = &self.name;
+
         quote! {
             #[automatically_derived]
-            impl#impl_gens ::juniper::resolve::TypeName<#info_ty> for #ty
+            impl#impl_gens ::juniper::resolve::TypeName<#info> for #ty
                 #where_clause
             {
-                fn type_name(_: &#info_ty) -> &'static str {
+                fn type_name(_: &#info) -> &'static str {
                     #name
                 }
             }
         }
     }
 
+    /// Returns generated code implementing [`resolve::Type`] trait for this
+    /// [GraphQL scalar][0].
+    ///
+    /// [`resolve::Type`]: juniper::resolve::Type
+    /// [0]: https://spec.graphql.org/October2021#sec-Scalars
+    fn impl_resolve_type(&self) -> TokenStream {
+        let (ty, generics) = self.ty_and_generics();
+        let (info, generics) = self.mix_info(generics);
+        let (scalar, mut generics) = self.mix_scalar(generics);
+        generics.make_where_clause().predicates.push(parse_quote! {
+            Self: ::juniper::resolve::TypeName<#info>
+                  + ::juniper::resolve::ScalarToken<#scalar>
+                  + ::juniper::resolve::InputValueOwned<#scalar>
+        });
+        let (impl_gens, _, where_clause) = generics.split_for_impl();
+
+        let description = self
+            .description
+            .as_ref()
+            .map(|val| quote! { .description(#val) });
+
+        let specified_by_url = self.specified_by_url.as_ref().map(|url| {
+            let url_lit = url.as_str();
+            quote! { .specified_by_url(#url_lit) }
+        });
+
+        quote! {
+            #[automatically_derived]
+            impl#impl_gens ::juniper::resolve::Type<#info, #scalar> for #ty
+                #where_clause
+            {
+                fn meta<'r>(
+                    registry: &mut ::juniper::Registry<'r, #scalar>,
+                    info: &#info,
+                ) -> ::juniper::meta::MetaType<'r, #scalar>
+                where
+                    #scalar: 'r,
+                {
+                    registry.build_scalar_type_new::<Self, _>(info)
+                        #description
+                        #specified_by_url
+                        .into_meta()
+                }
+            }
+        }
+    }
+
     /// Returns generated code implementing [`GraphQLValue`] trait for this
     /// [GraphQL scalar][1].
     ///
@@ -553,19 +601,23 @@ impl Definition {
     }
 
     /// Returns generated code implementing [`resolve::InputValue`] trait for
-    /// this [GraphQL scalar][1].
+    /// this [GraphQL scalar][0].
     ///
     /// [`resolve::InputValue`]: juniper::resolve::InputValue
-    /// [1]: https://spec.graphql.org/October2021#sec-Scalars
+    /// [0]: https://spec.graphql.org/October2021#sec-Scalars
     fn impl_resolve_input_value(&self) -> TokenStream {
-        let conversion = self.methods.expand_try_from_input_value(&self.scalar);
-
         let (ty, generics) = self.ty_and_generics();
-        let (scalar, mut generics) = self.mix_scalar_ty(generics);
+        let (scalar, mut generics) = self.mix_scalar(generics);
         let lt: syn::GenericParam = parse_quote! { '__inp };
         generics.params.push(lt.clone());
+        generics
+            .make_where_clause()
+            .predicates
+            .push(self.methods.bound_try_from_input_value(scalar, &lt));
         let (impl_gens, _, where_clause) = generics.split_for_impl();
 
+        let conversion = self.methods.expand_try_from_input_value(scalar);
+
         quote! {
             #[automatically_derived]
             impl#impl_gens ::juniper::resolve::InputValue<#lt, #scalar> for #ty
@@ -575,7 +627,7 @@ impl Definition {
 
                 fn try_from_input_value(
                     input: &#lt ::juniper::graphql::InputValue<#scalar>,
-                ) -> Result<Self, Self::Error> {
+                ) -> ::std::result::Result<Self, Self::Error> {
                     #conversion.map_err(
                         ::juniper::IntoFieldError::<#scalar>::into_field_error,
                     )
@@ -612,17 +664,21 @@ impl Definition {
     }
 
     /// Returns generated code implementing [`resolve::ScalarToken`] trait for
-    /// this [GraphQL scalar][1].
+    /// this [GraphQL scalar][0].
     ///
     /// [`resolve::ScalarToken`]: juniper::resolve::ScalarToken
-    /// [1]: https://spec.graphql.org/October2021#sec-Scalars
+    /// [0]: https://spec.graphql.org/October2021#sec-Scalars
     fn impl_resolve_scalar_token(&self) -> TokenStream {
-        let body = self.methods.expand_parse_scalar_token(&self.scalar);
-
         let (ty, generics) = self.ty_and_generics();
-        let (scalar, generics) = self.mix_scalar_ty(generics);
+        let (scalar, mut generics) = self.mix_scalar(generics);
+        generics
+            .make_where_clause()
+            .predicates
+            .extend(self.methods.bound_parse_scalar_token(scalar));
         let (impl_gens, _, where_clause) = generics.split_for_impl();
 
+        let body = self.methods.expand_parse_scalar_token(scalar);
+
         quote! {
             #[automatically_derived]
             impl#impl_gens ::juniper::resolve::ScalarToken<#scalar> for #ty
@@ -757,6 +813,8 @@ impl Definition {
         (ty, generics)
     }
 
+    /// Returns prepared self [`syn::Type`] and [`syn::Generics`] for a trait
+    /// implementation.
     #[must_use]
     fn ty_and_generics(&self) -> (syn::Type, syn::Generics) {
         let mut generics = self.generics.clone();
@@ -779,8 +837,10 @@ impl Definition {
         (ty, generics)
     }
 
+    /// Mixes a type info [`syn::GenericParam`] into the provided
+    /// [`syn::Generics`] and returns its [`syn::Ident`].
     #[must_use]
-    fn mix_info_ty(&self, mut generics: syn::Generics) -> (syn::Ident, syn::Generics) {
+    fn mix_info(&self, mut generics: syn::Generics) -> (syn::Ident, syn::Generics) {
         let ty = parse_quote! { __Info };
 
         generics.params.push(parse_quote! { #ty: ?Sized });
@@ -788,19 +848,32 @@ impl Definition {
         (ty, generics)
     }
 
+    /// Mixes a context [`syn::GenericParam`] into the provided
+    /// [`syn::Generics`] and returns its [`syn::Ident`].
     #[must_use]
-    fn mix_scalar_ty(&self, mut generics: syn::Generics) -> (&scalar::Type, syn::Generics) {
+    fn mix_context(&self, mut generics: syn::Generics) -> (syn::Ident, syn::Generics) {
+        let ty = parse_quote! { __Ctx };
+
+        generics.params.push(parse_quote! { #ty: ?Sized });
+
+        (ty, generics)
+    }
+
+    /// Mixes a [`ScalarValue`] [`syn::GenericParam`] into the provided
+    /// [`syn::Generics`] and returns its [`scalar::Type`].
+    ///
+    /// [`ScalarValue`] trait bound is not made here, because some trait
+    /// implementations may not require it depending on the generated code or
+    /// even at all.
+    ///
+    /// [`ScalarValue`]: juniper::ScalarValue
+    #[must_use]
+    fn mix_scalar(&self, mut generics: syn::Generics) -> (&scalar::Type, syn::Generics) {
         let scalar = &self.scalar;
 
         if scalar.is_implicit_generic() {
             generics.params.push(parse_quote! { #scalar });
         }
-        if scalar.is_generic() {
-            generics
-                .make_where_clause()
-                .predicates
-                .push(parse_quote! { #scalar: ::juniper::ScalarValue });
-        }
         if let Some(bound) = scalar.bounds() {
             generics.make_where_clause().predicates.push(bound);
         }
@@ -957,6 +1030,38 @@ impl Methods {
         }
     }
 
+    /// Generates additional trait bounds for [`resolve::InputValue`]
+    /// implementation allowing to execute
+    /// [`resolve::InputValue::try_from_input_value()`][0] method.
+    ///
+    /// [`resolve::InputValue`]: juniper::resolve::InputValue
+    /// [0]: juniper::resolve::InputValue::try_from_input_value
+    fn bound_try_from_input_value(
+        &self,
+        scalar: &scalar::Type,
+        lt: &syn::GenericParam,
+    ) -> syn::WherePredicate {
+        match self {
+            Self::Custom { .. }
+            | Self::Delegated {
+                from_input: Some(_),
+                ..
+            } => {
+                parse_quote! {
+                    #scalar: ::juniper::ScalarValue
+                }
+            }
+
+            Self::Delegated { field, .. } => {
+                let field_ty = field.ty();
+
+                parse_quote! {
+                    #field_ty: ::juniper::resolve::InputValue<#lt, #scalar>
+                }
+            }
+        }
+    }
+
     /// Expands [`ParseScalarValue::from_str`] method.
     ///
     /// [`ParseScalarValue::from_str`]: juniper::ParseScalarValue::from_str
@@ -982,7 +1087,7 @@ impl Methods {
     /// Expands body of [`resolve::ScalarToken::parse_scalar_token()`][0]
     /// method.
     ///
-    /// [0]: resolve::ScalarToken::parse_scalar_token
+    /// [0]: juniper::resolve::ScalarToken::parse_scalar_token
     fn expand_parse_scalar_token(&self, scalar: &scalar::Type) -> TokenStream {
         match self {
             Self::Custom { parse_token, .. }
@@ -990,8 +1095,10 @@ impl Methods {
                 parse_token: Some(parse_token),
                 ..
             } => parse_token.expand_parse_scalar_token(scalar),
+
             Self::Delegated { field, .. } => {
                 let field_ty = field.ty();
+
                 quote! {
                     <#field_ty as ::juniper::resolve::ScalarToken<#scalar>>
                         ::parse_scalar_token(token)
@@ -999,6 +1106,30 @@ impl Methods {
             }
         }
     }
+
+    /// Generates additional trait bounds for [`resolve::ScalarToken`]
+    /// implementation allowing to execute
+    /// [`resolve::ScalarToken::parse_scalar_token()`][0] method.
+    ///
+    /// [`resolve::ScalarToken`]: juniper::resolve::ScalarToken
+    /// [0]: juniper::resolve::ScalarToken::parse_scalar_token
+    fn bound_parse_scalar_token(&self, scalar: &scalar::Type) -> Vec<syn::WherePredicate> {
+        match self {
+            Self::Custom { parse_token, .. }
+            | Self::Delegated {
+                parse_token: Some(parse_token),
+                ..
+            } => parse_token.bound_parse_scalar_token(scalar),
+
+            Self::Delegated { field, .. } => {
+                let field_ty = field.ty();
+
+                vec![parse_quote! {
+                    #field_ty: ::juniper::resolve::ScalarToken<#scalar>
+                }]
+            }
+        }
+    }
 }
 
 /// Representation of [`ParseScalarValue::from_str`] method.
@@ -1046,12 +1177,13 @@ impl ParseToken {
     /// Expands body of [`resolve::ScalarToken::parse_scalar_token()`][0]
     /// method.
     ///
-    /// [0]: resolve::ScalarToken::parse_scalar_token
+    /// [0]: juniper::resolve::ScalarToken::parse_scalar_token
     fn expand_parse_scalar_token(&self, scalar: &scalar::Type) -> TokenStream {
         match self {
             Self::Custom(parse_token) => {
                 quote! { #parse_token(token) }
             }
+
             Self::Delegated(delegated) => delegated
                 .iter()
                 .fold(None, |acc, ty| {
@@ -1075,6 +1207,31 @@ impl ParseToken {
                 .unwrap_or_default(),
         }
     }
+
+    /// Generates additional trait bounds for [`resolve::ScalarToken`]
+    /// implementation allowing to execute
+    /// [`resolve::ScalarToken::parse_scalar_token()`][0] method.
+    ///
+    /// [`resolve::ScalarToken`]: juniper::resolve::ScalarToken
+    /// [0]: juniper::resolve::ScalarToken::parse_scalar_token
+    fn bound_parse_scalar_token(&self, scalar: &scalar::Type) -> Vec<syn::WherePredicate> {
+        match self {
+            Self::Custom(_) => {
+                vec![parse_quote! {
+                    #scalar: ::juniper::ScalarValue
+                }]
+            }
+
+            Self::Delegated(delegated) => delegated
+                .iter()
+                .map(|ty| {
+                    parse_quote! {
+                        #ty: ::juniper::resolve::ScalarToken<#scalar>
+                    }
+                })
+                .collect(),
+        }
+    }
 }
 
 /// Struct field to resolve not provided methods.
@@ -1103,9 +1260,10 @@ impl Field {
         }
     }
 
-    /// Closure to construct [GraphQL scalar][1] struct from [`Field`].
+    /// Generates closure to construct a [GraphQL scalar][0] struct from a
+    /// [`Field`] value.
     ///
-    /// [1]: https://spec.graphql.org/October2021#sec-Scalars
+    /// [0]: https://spec.graphql.org/October2021#sec-Scalars
     fn closure_constructor(&self) -> TokenStream {
         match self {
             Field::Named(syn::Field { ident, .. }) => {