diff --git a/juniper/src/validation/rules/no_fragment_cycles.rs b/juniper/src/validation/rules/no_fragment_cycles.rs index 3a6782a7..0845f820 100644 --- a/juniper/src/validation/rules/no_fragment_cycles.rs +++ b/juniper/src/validation/rules/no_fragment_cycles.rs @@ -7,19 +7,6 @@ use crate::{ value::ScalarValue, }; -pub struct NoFragmentCycles<'a> { - current_fragment: Option<&'a str>, - spreads: HashMap<&'a str, Vec>>, - fragment_order: Vec<&'a str>, -} - -struct CycleDetector<'a> { - visited: HashSet<&'a str>, - spreads: &'a HashMap<&'a str, Vec>>, - path_indices: HashMap<&'a str, usize>, - errors: Vec, -} - pub fn factory<'a>() -> NoFragmentCycles<'a> { NoFragmentCycles { current_fragment: None, @@ -28,6 +15,12 @@ pub fn factory<'a>() -> NoFragmentCycles<'a> { } } +pub struct NoFragmentCycles<'a> { + current_fragment: Option<&'a str>, + spreads: HashMap<&'a str, Vec>>, + fragment_order: Vec<&'a str>, +} + impl<'a, S> Visitor<'a, S> for NoFragmentCycles<'a> where S: ScalarValue, @@ -38,14 +31,12 @@ where let mut detector = CycleDetector { visited: HashSet::new(), spreads: &self.spreads, - path_indices: HashMap::new(), errors: Vec::new(), }; for frag in &self.fragment_order { if !detector.visited.contains(frag) { - let mut path = Vec::new(); - detector.detect_from(frag, &mut path); + detector.detect_from(frag); } } @@ -91,19 +82,46 @@ where } } +type CycleDetectorState<'a> = (&'a str, Vec<&'a Spanning<&'a str>>, HashMap<&'a str, usize>); + +struct CycleDetector<'a> { + visited: HashSet<&'a str>, + spreads: &'a HashMap<&'a str, Vec>>, + errors: Vec, +} + impl<'a> CycleDetector<'a> { - fn detect_from(&mut self, from: &'a str, path: &mut Vec<&'a Spanning<&'a str>>) { + fn detect_from(&mut self, from: &'a str) { + let mut to_visit = Vec::new(); + to_visit.push((from, Vec::new(), HashMap::new())); + + while let Some((from, path, path_indices)) = to_visit.pop() { + to_visit.extend(self.detect_from_inner(from, path, path_indices)); + } + } + + /// This function should be called only inside [`Self::detect_from()`], as + /// it's a recursive function using heap instead of a stack. So, instead of + /// the recursive call, we return a [`Vec`] that is visited inside + /// [`Self::detect_from()`]. + fn detect_from_inner( + &mut self, + from: &'a str, + path: Vec<&'a Spanning<&'a str>>, + mut path_indices: HashMap<&'a str, usize>, + ) -> Vec> { self.visited.insert(from); if !self.spreads.contains_key(from) { - return; + return Vec::new(); } - self.path_indices.insert(from, path.len()); + path_indices.insert(from, path.len()); + let mut to_visit = Vec::new(); for node in &self.spreads[from] { - let name = &node.item; - let index = self.path_indices.get(name).cloned(); + let name = node.item; + let index = path_indices.get(name).cloned(); if let Some(index) = index { let err_pos = if index < path.len() { @@ -114,14 +132,14 @@ impl<'a> CycleDetector<'a> { self.errors .push(RuleError::new(&error_message(name), &[err_pos.start])); - } else if !self.visited.contains(name) { + } else { + let mut path = path.clone(); path.push(node); - self.detect_from(name, path); - path.pop(); + to_visit.push((name, path, path_indices.clone())); } } - self.path_indices.remove(from); + to_visit } } diff --git a/juniper/src/validation/rules/no_undefined_variables.rs b/juniper/src/validation/rules/no_undefined_variables.rs index 18287c49..6e382b23 100644 --- a/juniper/src/validation/rules/no_undefined_variables.rs +++ b/juniper/src/validation/rules/no_undefined_variables.rs @@ -12,13 +12,6 @@ pub enum Scope<'a> { Fragment(&'a str), } -pub struct NoUndefinedVariables<'a> { - defined_variables: HashMap, (SourcePosition, HashSet<&'a str>)>, - used_variables: HashMap, Vec>>, - current_scope: Option>, - spreads: HashMap, Vec<&'a str>>, -} - pub fn factory<'a>() -> NoUndefinedVariables<'a> { NoUndefinedVariables { defined_variables: HashMap::new(), @@ -28,6 +21,13 @@ pub fn factory<'a>() -> NoUndefinedVariables<'a> { } } +pub struct NoUndefinedVariables<'a> { + defined_variables: HashMap, (SourcePosition, HashSet<&'a str>)>, + used_variables: HashMap, Vec>>, + current_scope: Option>, + spreads: HashMap, Vec<&'a str>>, +} + impl<'a> NoUndefinedVariables<'a> { fn find_undef_vars( &'a self, @@ -36,8 +36,34 @@ impl<'a> NoUndefinedVariables<'a> { unused: &mut Vec<&'a Spanning<&'a str>>, visited: &mut HashSet>, ) { + let mut to_visit = Vec::new(); + if let Some(spreads) = self.find_undef_vars_inner(scope, defined, unused, visited) { + to_visit.push(spreads); + } + while let Some(spreads) = to_visit.pop() { + for spread in spreads { + if let Some(spreads) = + self.find_undef_vars_inner(&Scope::Fragment(spread), defined, unused, visited) + { + to_visit.push(spreads); + } + } + } + } + + /// This function should be called only inside [`Self::find_undef_vars()`], + /// as it's a recursive function using heap instead of a stack. So, instead + /// of the recursive call, we return a [`Vec`] that is visited inside + /// [`Self::find_undef_vars()`]. + fn find_undef_vars_inner( + &'a self, + scope: &Scope<'a>, + defined: &HashSet<&'a str>, + unused: &mut Vec<&'a Spanning<&'a str>>, + visited: &mut HashSet>, + ) -> Option<&'a Vec<&'a str>> { if visited.contains(scope) { - return; + return None; } visited.insert(scope.clone()); @@ -50,11 +76,7 @@ impl<'a> NoUndefinedVariables<'a> { } } - if let Some(spreads) = self.spreads.get(scope) { - for spread in spreads { - self.find_undef_vars(&Scope::Fragment(spread), defined, unused, visited); - } - } + self.spreads.get(scope) } } diff --git a/juniper/src/validation/rules/no_unused_fragments.rs b/juniper/src/validation/rules/no_unused_fragments.rs index 38a360b5..8ac7f8ae 100644 --- a/juniper/src/validation/rules/no_unused_fragments.rs +++ b/juniper/src/validation/rules/no_unused_fragments.rs @@ -13,12 +13,6 @@ pub enum Scope<'a> { Fragment(&'a str), } -pub struct NoUnusedFragments<'a> { - spreads: HashMap, Vec<&'a str>>, - defined_fragments: HashSet>, - current_scope: Option>, -} - pub fn factory<'a>() -> NoUnusedFragments<'a> { NoUnusedFragments { spreads: HashMap::new(), @@ -27,22 +21,43 @@ pub fn factory<'a>() -> NoUnusedFragments<'a> { } } +pub struct NoUnusedFragments<'a> { + spreads: HashMap, Vec<&'a str>>, + defined_fragments: HashSet>, + current_scope: Option>, +} + impl<'a> NoUnusedFragments<'a> { - fn find_reachable_fragments(&self, from: &Scope<'a>, result: &mut HashSet<&'a str>) { + fn find_reachable_fragments(&'a self, from: &Scope<'a>, result: &mut HashSet<&'a str>) { + let mut to_visit = Vec::new(); if let Scope::Fragment(name) = *from { - if result.contains(name) { - return; - } else { - result.insert(name); - } + to_visit.push(name); } - if let Some(spreads) = self.spreads.get(from) { - for spread in spreads { - self.find_reachable_fragments(&Scope::Fragment(spread), result) + while let Some(from) = to_visit.pop() { + if let Some(next) = self.find_reachable_fragments_inner(from, result) { + to_visit.extend(next); } } } + + /// This function should be called only inside + /// [`Self::find_reachable_fragments()`], as it's a recursive function using + /// heap instead of a stack. So, instead of the recursive call, we return a + /// [`Vec`] that is visited inside [`Self::find_reachable_fragments()`]. + fn find_reachable_fragments_inner( + &'a self, + from: &'a str, + result: &mut HashSet<&'a str>, + ) -> Option<&'a Vec<&'a str>> { + if result.contains(from) { + return None; + } else { + result.insert(from); + } + + self.spreads.get(&Scope::Fragment(from)) + } } impl<'a, S> Visitor<'a, S> for NoUnusedFragments<'a> diff --git a/juniper/src/validation/rules/no_unused_variables.rs b/juniper/src/validation/rules/no_unused_variables.rs index 811acde5..81b7dfb3 100644 --- a/juniper/src/validation/rules/no_unused_variables.rs +++ b/juniper/src/validation/rules/no_unused_variables.rs @@ -12,13 +12,6 @@ pub enum Scope<'a> { Fragment(&'a str), } -pub struct NoUnusedVariables<'a> { - defined_variables: HashMap, HashSet<&'a Spanning<&'a str>>>, - used_variables: HashMap, Vec<&'a str>>, - current_scope: Option>, - spreads: HashMap, Vec<&'a str>>, -} - pub fn factory<'a>() -> NoUnusedVariables<'a> { NoUnusedVariables { defined_variables: HashMap::new(), @@ -28,16 +21,49 @@ pub fn factory<'a>() -> NoUnusedVariables<'a> { } } +pub struct NoUnusedVariables<'a> { + defined_variables: HashMap, HashSet<&'a Spanning<&'a str>>>, + used_variables: HashMap, Vec<&'a str>>, + current_scope: Option>, + spreads: HashMap, Vec<&'a str>>, +} + impl<'a> NoUnusedVariables<'a> { fn find_used_vars( - &self, + &'a self, from: &Scope<'a>, defined: &HashSet<&'a str>, used: &mut HashSet<&'a str>, visited: &mut HashSet>, ) { + let mut to_visit = Vec::new(); + if let Some(spreads) = self.find_used_vars_inner(from, defined, used, visited) { + to_visit.push(spreads); + } + while let Some(spreads) = to_visit.pop() { + for spread in spreads { + if let Some(spreads) = + self.find_used_vars_inner(&Scope::Fragment(spread), defined, used, visited) + { + to_visit.push(spreads); + } + } + } + } + + /// This function should be called only inside [`Self::find_used_vars()`], + /// as it's a recursive function using heap instead of a stack. So, instead + /// of the recursive call, we return a [`Vec`] that is visited inside + /// [`Self::find_used_vars()`]. + fn find_used_vars_inner( + &'a self, + from: &Scope<'a>, + defined: &HashSet<&'a str>, + used: &mut HashSet<&'a str>, + visited: &mut HashSet>, + ) -> Option<&'a Vec<&'a str>> { if visited.contains(from) { - return; + return None; } visited.insert(from.clone()); @@ -50,11 +76,7 @@ impl<'a> NoUnusedVariables<'a> { } } - if let Some(spreads) = self.spreads.get(from) { - for spread in spreads { - self.find_used_vars(&Scope::Fragment(spread), defined, used, visited); - } - } + self.spreads.get(from) } } diff --git a/juniper/src/validation/rules/overlapping_fields_can_be_merged.rs b/juniper/src/validation/rules/overlapping_fields_can_be_merged.rs index 5a9260b1..efacd1ce 100644 --- a/juniper/src/validation/rules/overlapping_fields_can_be_merged.rs +++ b/juniper/src/validation/rules/overlapping_fields_can_be_merged.rs @@ -274,30 +274,61 @@ impl<'a, S: Debug> OverlappingFieldsCanBeMerged<'a, S> { ) where S: ScalarValue, { - let fragment = match self.named_fragments.get(fragment_name) { - Some(f) => f, - None => return, - }; + let mut to_check = Vec::new(); + if let Some(fragments) = self.collect_conflicts_between_fields_and_fragment_inner( + conflicts, + field_map, + fragment_name, + mutually_exclusive, + ctx, + ) { + to_check.push((fragment_name, fragments)) + } + + while let Some((fragment_name, fragment_names2)) = to_check.pop() { + for fragment_name2 in fragment_names2 { + // Early return on fragment recursion, as it makes no sense. + // Fragment recursions are prevented by `no_fragment_cycles` validator. + if fragment_name == fragment_name2 { + return; + } + if let Some(fragments) = self.collect_conflicts_between_fields_and_fragment_inner( + conflicts, + field_map, + fragment_name2, + mutually_exclusive, + ctx, + ) { + to_check.push((fragment_name2, fragments)); + }; + } + } + } + + /// This function should be called only inside + /// [`Self::collect_conflicts_between_fields_and_fragment()`], as it's a + /// recursive function using heap instead of a stack. So, instead of the + /// recursive call, we return a [`Vec`] that is visited inside + /// [`Self::collect_conflicts_between_fields_and_fragment()`]. + fn collect_conflicts_between_fields_and_fragment_inner( + &self, + conflicts: &mut Vec, + field_map: &AstAndDefCollection<'a, S>, + fragment_name: &str, + mutually_exclusive: bool, + ctx: &ValidatorContext<'a, S>, + ) -> Option> + where + S: ScalarValue, + { + let fragment = self.named_fragments.get(fragment_name)?; let (field_map2, fragment_names2) = self.get_referenced_fields_and_fragment_names(fragment, ctx); self.collect_conflicts_between(conflicts, mutually_exclusive, field_map, &field_map2, ctx); - for fragment_name2 in fragment_names2 { - // Early return on fragment recursion, as it makes no sense. - // Fragment recursions are prevented by `no_fragment_cycles` validator. - if fragment_name == fragment_name2 { - return; - } - self.collect_conflicts_between_fields_and_fragment( - conflicts, - field_map, - fragment_name2, - mutually_exclusive, - ctx, - ); - } + Some(fragment_names2) } fn collect_conflicts_between( diff --git a/juniper/src/validation/rules/variables_in_allowed_position.rs b/juniper/src/validation/rules/variables_in_allowed_position.rs index 550b47a8..776f12db 100644 --- a/juniper/src/validation/rules/variables_in_allowed_position.rs +++ b/juniper/src/validation/rules/variables_in_allowed_position.rs @@ -17,14 +17,6 @@ pub enum Scope<'a> { Fragment(&'a str), } -pub struct VariableInAllowedPosition<'a, S: fmt::Debug + 'a> { - spreads: HashMap, HashSet<&'a str>>, - variable_usages: HashMap, Vec<(Spanning<&'a String>, Type<'a>)>>, - #[allow(clippy::type_complexity)] - variable_defs: HashMap, Vec<&'a (Spanning<&'a str>, VariableDefinition<'a, S>)>>, - current_scope: Option>, -} - pub fn factory<'a, S: fmt::Debug>() -> VariableInAllowedPosition<'a, S> { VariableInAllowedPosition { spreads: HashMap::new(), @@ -34,16 +26,54 @@ pub fn factory<'a, S: fmt::Debug>() -> VariableInAllowedPosition<'a, S> { } } +pub struct VariableInAllowedPosition<'a, S: fmt::Debug + 'a> { + spreads: HashMap, HashSet<&'a str>>, + variable_usages: HashMap, Vec<(Spanning<&'a String>, Type<'a>)>>, + #[allow(clippy::type_complexity)] + variable_defs: HashMap, Vec<&'a (Spanning<&'a str>, VariableDefinition<'a, S>)>>, + current_scope: Option>, +} + impl<'a, S: fmt::Debug> VariableInAllowedPosition<'a, S> { - fn collect_incorrect_usages( - &self, + fn collect_incorrect_usages<'me>( + &'me self, from: &Scope<'a>, var_defs: &[&'a (Spanning<&'a str>, VariableDefinition)], ctx: &mut ValidatorContext<'a, S>, visited: &mut HashSet>, ) { + let mut to_visit = Vec::new(); + if let Some(spreads) = self.collect_incorrect_usages_inner(from, var_defs, ctx, visited) { + to_visit.push(spreads); + } + + while let Some(spreads) = to_visit.pop() { + for spread in spreads { + if let Some(spreads) = self.collect_incorrect_usages_inner( + &Scope::Fragment(spread), + var_defs, + ctx, + visited, + ) { + to_visit.push(spreads); + } + } + } + } + + /// This function should be called only inside + /// [`Self::collect_incorrect_usages()`], as it's a recursive function using + /// heap instead of a stack. So, instead of the recursive call, we return a + /// [`Vec`] that is visited inside [`Self::collect_incorrect_usages()`]. + fn collect_incorrect_usages_inner<'me>( + &'me self, + from: &Scope<'a>, + var_defs: &[&'a (Spanning<&'a str>, VariableDefinition)], + ctx: &mut ValidatorContext<'a, S>, + visited: &mut HashSet>, + ) -> Option<&'me HashSet<&'a str>> { if visited.contains(from) { - return; + return None; } visited.insert(from.clone()); @@ -74,11 +104,7 @@ impl<'a, S: fmt::Debug> VariableInAllowedPosition<'a, S> { } } - if let Some(spreads) = self.spreads.get(from) { - for spread in spreads { - self.collect_incorrect_usages(&Scope::Fragment(spread), var_defs, ctx, visited); - } - } + self.spreads.get(from) } } diff --git a/tests/integration/Cargo.toml b/tests/integration/Cargo.toml index 041edf23..f339a718 100644 --- a/tests/integration/Cargo.toml +++ b/tests/integration/Cargo.toml @@ -10,6 +10,7 @@ chrono = "0.4" derive_more = "0.99" fnv = "1.0" futures = "0.3" +itertools = "0.10" juniper = { path = "../../juniper" } juniper_subscriptions = { path = "../../juniper_subscriptions" } serde = { version = "1.0", features = ["derive"] } diff --git a/tests/integration/tests/cve_2022_31173.rs b/tests/integration/tests/cve_2022_31173.rs new file mode 100644 index 00000000..332306a3 --- /dev/null +++ b/tests/integration/tests/cve_2022_31173.rs @@ -0,0 +1,56 @@ +//! Checks that long looping chain of fragments doesn't cause a stack overflow. +//! +//! ```graphql +//! # Fragment loop example +//! query { +//! ...a +//! } +//! +//! fragment a on Query { +//! ...b +//! } +//! +//! fragment b on Query { +//! ...a +//! } +//! ``` + +use std::iter; + +use itertools::Itertools as _; +use juniper::{graphql_object, graphql_vars, EmptyMutation, EmptySubscription}; + +struct Query; + +#[graphql_object] +impl Query { + fn dummy() -> bool { + false + } +} + +type Schema = juniper::RootNode<'static, Query, EmptyMutation, EmptySubscription>; + +#[tokio::test] +async fn test() { + const PERM: &str = "abcefghijk"; + const CIRCLE_SIZE: usize = 7500; + + let query = iter::once(format!("query {{ ...{PERM} }} ")) + .chain( + PERM.chars() + .permutations(PERM.len()) + .map(|vec| vec.into_iter().collect::()) + .take(CIRCLE_SIZE) + .collect::>() + .into_iter() + .circular_tuple_windows::<(_, _)>() + .map(|(cur, next)| format!("fragment {cur} on Query {{ ...{next} }} ")), + ) + .collect::(); + + let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new()); + let _ = juniper::execute(&query, None, &schema, &graphql_vars! {}, &()) + .await + .unwrap_err(); +}