Merge pull request from GHSA-4rx6-g5vg-5f3j

* Replace recursions with heap allocations

* Some corrections [skip ci]

* Add recursive nested fragments test case

* Docs and small corrections

* Corrections

Co-authored-by: Kai Ren <tyranron@gmail.com>
This commit is contained in:
ilslv 2022-07-28 14:33:16 +03:00 committed by GitHub
parent 6d6c71fc3b
commit 2b609ee057
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 292 additions and 101 deletions

View file

@ -7,19 +7,6 @@ use crate::{
value::ScalarValue, value::ScalarValue,
}; };
pub struct NoFragmentCycles<'a> {
current_fragment: Option<&'a str>,
spreads: HashMap<&'a str, Vec<Spanning<&'a str>>>,
fragment_order: Vec<&'a str>,
}
struct CycleDetector<'a> {
visited: HashSet<&'a str>,
spreads: &'a HashMap<&'a str, Vec<Spanning<&'a str>>>,
path_indices: HashMap<&'a str, usize>,
errors: Vec<RuleError>,
}
pub fn factory<'a>() -> NoFragmentCycles<'a> { pub fn factory<'a>() -> NoFragmentCycles<'a> {
NoFragmentCycles { NoFragmentCycles {
current_fragment: None, current_fragment: None,
@ -28,6 +15,12 @@ pub fn factory<'a>() -> NoFragmentCycles<'a> {
} }
} }
pub struct NoFragmentCycles<'a> {
current_fragment: Option<&'a str>,
spreads: HashMap<&'a str, Vec<Spanning<&'a str>>>,
fragment_order: Vec<&'a str>,
}
impl<'a, S> Visitor<'a, S> for NoFragmentCycles<'a> impl<'a, S> Visitor<'a, S> for NoFragmentCycles<'a>
where where
S: ScalarValue, S: ScalarValue,
@ -38,14 +31,12 @@ where
let mut detector = CycleDetector { let mut detector = CycleDetector {
visited: HashSet::new(), visited: HashSet::new(),
spreads: &self.spreads, spreads: &self.spreads,
path_indices: HashMap::new(),
errors: Vec::new(), errors: Vec::new(),
}; };
for frag in &self.fragment_order { for frag in &self.fragment_order {
if !detector.visited.contains(frag) { if !detector.visited.contains(frag) {
let mut path = Vec::new(); detector.detect_from(frag);
detector.detect_from(frag, &mut path);
} }
} }
@ -91,19 +82,46 @@ where
} }
} }
type CycleDetectorState<'a> = (&'a str, Vec<&'a Spanning<&'a str>>, HashMap<&'a str, usize>);
struct CycleDetector<'a> {
visited: HashSet<&'a str>,
spreads: &'a HashMap<&'a str, Vec<Spanning<&'a str>>>,
errors: Vec<RuleError>,
}
impl<'a> CycleDetector<'a> { impl<'a> CycleDetector<'a> {
fn detect_from(&mut self, from: &'a str, path: &mut Vec<&'a Spanning<&'a str>>) { fn detect_from(&mut self, from: &'a str) {
let mut to_visit = Vec::new();
to_visit.push((from, Vec::new(), HashMap::new()));
while let Some((from, path, path_indices)) = to_visit.pop() {
to_visit.extend(self.detect_from_inner(from, path, path_indices));
}
}
/// This function should be called only inside [`Self::detect_from()`], as
/// it's a recursive function using heap instead of a stack. So, instead of
/// the recursive call, we return a [`Vec`] that is visited inside
/// [`Self::detect_from()`].
fn detect_from_inner(
&mut self,
from: &'a str,
path: Vec<&'a Spanning<&'a str>>,
mut path_indices: HashMap<&'a str, usize>,
) -> Vec<CycleDetectorState<'a>> {
self.visited.insert(from); self.visited.insert(from);
if !self.spreads.contains_key(from) { if !self.spreads.contains_key(from) {
return; return Vec::new();
} }
self.path_indices.insert(from, path.len()); path_indices.insert(from, path.len());
let mut to_visit = Vec::new();
for node in &self.spreads[from] { for node in &self.spreads[from] {
let name = &node.item; let name = node.item;
let index = self.path_indices.get(name).cloned(); let index = path_indices.get(name).cloned();
if let Some(index) = index { if let Some(index) = index {
let err_pos = if index < path.len() { let err_pos = if index < path.len() {
@ -114,14 +132,14 @@ impl<'a> CycleDetector<'a> {
self.errors self.errors
.push(RuleError::new(&error_message(name), &[err_pos.start])); .push(RuleError::new(&error_message(name), &[err_pos.start]));
} else if !self.visited.contains(name) { } else {
let mut path = path.clone();
path.push(node); path.push(node);
self.detect_from(name, path); to_visit.push((name, path, path_indices.clone()));
path.pop();
} }
} }
self.path_indices.remove(from); to_visit
} }
} }

View file

@ -12,13 +12,6 @@ pub enum Scope<'a> {
Fragment(&'a str), Fragment(&'a str),
} }
pub struct NoUndefinedVariables<'a> {
defined_variables: HashMap<Option<&'a str>, (SourcePosition, HashSet<&'a str>)>,
used_variables: HashMap<Scope<'a>, Vec<Spanning<&'a str>>>,
current_scope: Option<Scope<'a>>,
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
}
pub fn factory<'a>() -> NoUndefinedVariables<'a> { pub fn factory<'a>() -> NoUndefinedVariables<'a> {
NoUndefinedVariables { NoUndefinedVariables {
defined_variables: HashMap::new(), defined_variables: HashMap::new(),
@ -28,6 +21,13 @@ pub fn factory<'a>() -> NoUndefinedVariables<'a> {
} }
} }
pub struct NoUndefinedVariables<'a> {
defined_variables: HashMap<Option<&'a str>, (SourcePosition, HashSet<&'a str>)>,
used_variables: HashMap<Scope<'a>, Vec<Spanning<&'a str>>>,
current_scope: Option<Scope<'a>>,
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
}
impl<'a> NoUndefinedVariables<'a> { impl<'a> NoUndefinedVariables<'a> {
fn find_undef_vars( fn find_undef_vars(
&'a self, &'a self,
@ -36,8 +36,34 @@ impl<'a> NoUndefinedVariables<'a> {
unused: &mut Vec<&'a Spanning<&'a str>>, unused: &mut Vec<&'a Spanning<&'a str>>,
visited: &mut HashSet<Scope<'a>>, visited: &mut HashSet<Scope<'a>>,
) { ) {
let mut to_visit = Vec::new();
if let Some(spreads) = self.find_undef_vars_inner(scope, defined, unused, visited) {
to_visit.push(spreads);
}
while let Some(spreads) = to_visit.pop() {
for spread in spreads {
if let Some(spreads) =
self.find_undef_vars_inner(&Scope::Fragment(spread), defined, unused, visited)
{
to_visit.push(spreads);
}
}
}
}
/// This function should be called only inside [`Self::find_undef_vars()`],
/// as it's a recursive function using heap instead of a stack. So, instead
/// of the recursive call, we return a [`Vec`] that is visited inside
/// [`Self::find_undef_vars()`].
fn find_undef_vars_inner(
&'a self,
scope: &Scope<'a>,
defined: &HashSet<&'a str>,
unused: &mut Vec<&'a Spanning<&'a str>>,
visited: &mut HashSet<Scope<'a>>,
) -> Option<&'a Vec<&'a str>> {
if visited.contains(scope) { if visited.contains(scope) {
return; return None;
} }
visited.insert(scope.clone()); visited.insert(scope.clone());
@ -50,11 +76,7 @@ impl<'a> NoUndefinedVariables<'a> {
} }
} }
if let Some(spreads) = self.spreads.get(scope) { self.spreads.get(scope)
for spread in spreads {
self.find_undef_vars(&Scope::Fragment(spread), defined, unused, visited);
}
}
} }
} }

View file

@ -13,12 +13,6 @@ pub enum Scope<'a> {
Fragment(&'a str), Fragment(&'a str),
} }
pub struct NoUnusedFragments<'a> {
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
defined_fragments: HashSet<Spanning<&'a str>>,
current_scope: Option<Scope<'a>>,
}
pub fn factory<'a>() -> NoUnusedFragments<'a> { pub fn factory<'a>() -> NoUnusedFragments<'a> {
NoUnusedFragments { NoUnusedFragments {
spreads: HashMap::new(), spreads: HashMap::new(),
@ -27,22 +21,43 @@ pub fn factory<'a>() -> NoUnusedFragments<'a> {
} }
} }
pub struct NoUnusedFragments<'a> {
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
defined_fragments: HashSet<Spanning<&'a str>>,
current_scope: Option<Scope<'a>>,
}
impl<'a> NoUnusedFragments<'a> { impl<'a> NoUnusedFragments<'a> {
fn find_reachable_fragments(&self, from: &Scope<'a>, result: &mut HashSet<&'a str>) { fn find_reachable_fragments(&'a self, from: &Scope<'a>, result: &mut HashSet<&'a str>) {
let mut to_visit = Vec::new();
if let Scope::Fragment(name) = *from { if let Scope::Fragment(name) = *from {
if result.contains(name) { to_visit.push(name);
return;
} else {
result.insert(name);
}
} }
if let Some(spreads) = self.spreads.get(from) { while let Some(from) = to_visit.pop() {
for spread in spreads { if let Some(next) = self.find_reachable_fragments_inner(from, result) {
self.find_reachable_fragments(&Scope::Fragment(spread), result) to_visit.extend(next);
} }
} }
} }
/// This function should be called only inside
/// [`Self::find_reachable_fragments()`], as it's a recursive function using
/// heap instead of a stack. So, instead of the recursive call, we return a
/// [`Vec`] that is visited inside [`Self::find_reachable_fragments()`].
fn find_reachable_fragments_inner(
&'a self,
from: &'a str,
result: &mut HashSet<&'a str>,
) -> Option<&'a Vec<&'a str>> {
if result.contains(from) {
return None;
} else {
result.insert(from);
}
self.spreads.get(&Scope::Fragment(from))
}
} }
impl<'a, S> Visitor<'a, S> for NoUnusedFragments<'a> impl<'a, S> Visitor<'a, S> for NoUnusedFragments<'a>

View file

@ -12,13 +12,6 @@ pub enum Scope<'a> {
Fragment(&'a str), Fragment(&'a str),
} }
pub struct NoUnusedVariables<'a> {
defined_variables: HashMap<Option<&'a str>, HashSet<&'a Spanning<&'a str>>>,
used_variables: HashMap<Scope<'a>, Vec<&'a str>>,
current_scope: Option<Scope<'a>>,
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
}
pub fn factory<'a>() -> NoUnusedVariables<'a> { pub fn factory<'a>() -> NoUnusedVariables<'a> {
NoUnusedVariables { NoUnusedVariables {
defined_variables: HashMap::new(), defined_variables: HashMap::new(),
@ -28,16 +21,49 @@ pub fn factory<'a>() -> NoUnusedVariables<'a> {
} }
} }
pub struct NoUnusedVariables<'a> {
defined_variables: HashMap<Option<&'a str>, HashSet<&'a Spanning<&'a str>>>,
used_variables: HashMap<Scope<'a>, Vec<&'a str>>,
current_scope: Option<Scope<'a>>,
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
}
impl<'a> NoUnusedVariables<'a> { impl<'a> NoUnusedVariables<'a> {
fn find_used_vars( fn find_used_vars(
&self, &'a self,
from: &Scope<'a>, from: &Scope<'a>,
defined: &HashSet<&'a str>, defined: &HashSet<&'a str>,
used: &mut HashSet<&'a str>, used: &mut HashSet<&'a str>,
visited: &mut HashSet<Scope<'a>>, visited: &mut HashSet<Scope<'a>>,
) { ) {
let mut to_visit = Vec::new();
if let Some(spreads) = self.find_used_vars_inner(from, defined, used, visited) {
to_visit.push(spreads);
}
while let Some(spreads) = to_visit.pop() {
for spread in spreads {
if let Some(spreads) =
self.find_used_vars_inner(&Scope::Fragment(spread), defined, used, visited)
{
to_visit.push(spreads);
}
}
}
}
/// This function should be called only inside [`Self::find_used_vars()`],
/// as it's a recursive function using heap instead of a stack. So, instead
/// of the recursive call, we return a [`Vec`] that is visited inside
/// [`Self::find_used_vars()`].
fn find_used_vars_inner(
&'a self,
from: &Scope<'a>,
defined: &HashSet<&'a str>,
used: &mut HashSet<&'a str>,
visited: &mut HashSet<Scope<'a>>,
) -> Option<&'a Vec<&'a str>> {
if visited.contains(from) { if visited.contains(from) {
return; return None;
} }
visited.insert(from.clone()); visited.insert(from.clone());
@ -50,11 +76,7 @@ impl<'a> NoUnusedVariables<'a> {
} }
} }
if let Some(spreads) = self.spreads.get(from) { self.spreads.get(from)
for spread in spreads {
self.find_used_vars(&Scope::Fragment(spread), defined, used, visited);
}
}
} }
} }

View file

@ -274,30 +274,61 @@ impl<'a, S: Debug> OverlappingFieldsCanBeMerged<'a, S> {
) where ) where
S: ScalarValue, S: ScalarValue,
{ {
let fragment = match self.named_fragments.get(fragment_name) { let mut to_check = Vec::new();
Some(f) => f, if let Some(fragments) = self.collect_conflicts_between_fields_and_fragment_inner(
None => return, conflicts,
}; field_map,
fragment_name,
mutually_exclusive,
ctx,
) {
to_check.push((fragment_name, fragments))
}
while let Some((fragment_name, fragment_names2)) = to_check.pop() {
for fragment_name2 in fragment_names2 {
// Early return on fragment recursion, as it makes no sense.
// Fragment recursions are prevented by `no_fragment_cycles` validator.
if fragment_name == fragment_name2 {
return;
}
if let Some(fragments) = self.collect_conflicts_between_fields_and_fragment_inner(
conflicts,
field_map,
fragment_name2,
mutually_exclusive,
ctx,
) {
to_check.push((fragment_name2, fragments));
};
}
}
}
/// This function should be called only inside
/// [`Self::collect_conflicts_between_fields_and_fragment()`], as it's a
/// recursive function using heap instead of a stack. So, instead of the
/// recursive call, we return a [`Vec`] that is visited inside
/// [`Self::collect_conflicts_between_fields_and_fragment()`].
fn collect_conflicts_between_fields_and_fragment_inner(
&self,
conflicts: &mut Vec<Conflict>,
field_map: &AstAndDefCollection<'a, S>,
fragment_name: &str,
mutually_exclusive: bool,
ctx: &ValidatorContext<'a, S>,
) -> Option<Vec<&'a str>>
where
S: ScalarValue,
{
let fragment = self.named_fragments.get(fragment_name)?;
let (field_map2, fragment_names2) = let (field_map2, fragment_names2) =
self.get_referenced_fields_and_fragment_names(fragment, ctx); self.get_referenced_fields_and_fragment_names(fragment, ctx);
self.collect_conflicts_between(conflicts, mutually_exclusive, field_map, &field_map2, ctx); self.collect_conflicts_between(conflicts, mutually_exclusive, field_map, &field_map2, ctx);
for fragment_name2 in fragment_names2 { Some(fragment_names2)
// Early return on fragment recursion, as it makes no sense.
// Fragment recursions are prevented by `no_fragment_cycles` validator.
if fragment_name == fragment_name2 {
return;
}
self.collect_conflicts_between_fields_and_fragment(
conflicts,
field_map,
fragment_name2,
mutually_exclusive,
ctx,
);
}
} }
fn collect_conflicts_between( fn collect_conflicts_between(

View file

@ -17,14 +17,6 @@ pub enum Scope<'a> {
Fragment(&'a str), Fragment(&'a str),
} }
pub struct VariableInAllowedPosition<'a, S: fmt::Debug + 'a> {
spreads: HashMap<Scope<'a>, HashSet<&'a str>>,
variable_usages: HashMap<Scope<'a>, Vec<(Spanning<&'a String>, Type<'a>)>>,
#[allow(clippy::type_complexity)]
variable_defs: HashMap<Scope<'a>, Vec<&'a (Spanning<&'a str>, VariableDefinition<'a, S>)>>,
current_scope: Option<Scope<'a>>,
}
pub fn factory<'a, S: fmt::Debug>() -> VariableInAllowedPosition<'a, S> { pub fn factory<'a, S: fmt::Debug>() -> VariableInAllowedPosition<'a, S> {
VariableInAllowedPosition { VariableInAllowedPosition {
spreads: HashMap::new(), spreads: HashMap::new(),
@ -34,16 +26,54 @@ pub fn factory<'a, S: fmt::Debug>() -> VariableInAllowedPosition<'a, S> {
} }
} }
pub struct VariableInAllowedPosition<'a, S: fmt::Debug + 'a> {
spreads: HashMap<Scope<'a>, HashSet<&'a str>>,
variable_usages: HashMap<Scope<'a>, Vec<(Spanning<&'a String>, Type<'a>)>>,
#[allow(clippy::type_complexity)]
variable_defs: HashMap<Scope<'a>, Vec<&'a (Spanning<&'a str>, VariableDefinition<'a, S>)>>,
current_scope: Option<Scope<'a>>,
}
impl<'a, S: fmt::Debug> VariableInAllowedPosition<'a, S> { impl<'a, S: fmt::Debug> VariableInAllowedPosition<'a, S> {
fn collect_incorrect_usages( fn collect_incorrect_usages<'me>(
&self, &'me self,
from: &Scope<'a>, from: &Scope<'a>,
var_defs: &[&'a (Spanning<&'a str>, VariableDefinition<S>)], var_defs: &[&'a (Spanning<&'a str>, VariableDefinition<S>)],
ctx: &mut ValidatorContext<'a, S>, ctx: &mut ValidatorContext<'a, S>,
visited: &mut HashSet<Scope<'a>>, visited: &mut HashSet<Scope<'a>>,
) { ) {
let mut to_visit = Vec::new();
if let Some(spreads) = self.collect_incorrect_usages_inner(from, var_defs, ctx, visited) {
to_visit.push(spreads);
}
while let Some(spreads) = to_visit.pop() {
for spread in spreads {
if let Some(spreads) = self.collect_incorrect_usages_inner(
&Scope::Fragment(spread),
var_defs,
ctx,
visited,
) {
to_visit.push(spreads);
}
}
}
}
/// This function should be called only inside
/// [`Self::collect_incorrect_usages()`], as it's a recursive function using
/// heap instead of a stack. So, instead of the recursive call, we return a
/// [`Vec`] that is visited inside [`Self::collect_incorrect_usages()`].
fn collect_incorrect_usages_inner<'me>(
&'me self,
from: &Scope<'a>,
var_defs: &[&'a (Spanning<&'a str>, VariableDefinition<S>)],
ctx: &mut ValidatorContext<'a, S>,
visited: &mut HashSet<Scope<'a>>,
) -> Option<&'me HashSet<&'a str>> {
if visited.contains(from) { if visited.contains(from) {
return; return None;
} }
visited.insert(from.clone()); visited.insert(from.clone());
@ -74,11 +104,7 @@ impl<'a, S: fmt::Debug> VariableInAllowedPosition<'a, S> {
} }
} }
if let Some(spreads) = self.spreads.get(from) { self.spreads.get(from)
for spread in spreads {
self.collect_incorrect_usages(&Scope::Fragment(spread), var_defs, ctx, visited);
}
}
} }
} }

View file

@ -10,6 +10,7 @@ chrono = "0.4"
derive_more = "0.99" derive_more = "0.99"
fnv = "1.0" fnv = "1.0"
futures = "0.3" futures = "0.3"
itertools = "0.10"
juniper = { path = "../../juniper" } juniper = { path = "../../juniper" }
juniper_subscriptions = { path = "../../juniper_subscriptions" } juniper_subscriptions = { path = "../../juniper_subscriptions" }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }

View file

@ -0,0 +1,56 @@
//! Checks that long looping chain of fragments doesn't cause a stack overflow.
//!
//! ```graphql
//! # Fragment loop example
//! query {
//! ...a
//! }
//!
//! fragment a on Query {
//! ...b
//! }
//!
//! fragment b on Query {
//! ...a
//! }
//! ```
use std::iter;
use itertools::Itertools as _;
use juniper::{graphql_object, graphql_vars, EmptyMutation, EmptySubscription};
struct Query;
#[graphql_object]
impl Query {
fn dummy() -> bool {
false
}
}
type Schema = juniper::RootNode<'static, Query, EmptyMutation, EmptySubscription>;
#[tokio::test]
async fn test() {
const PERM: &str = "abcefghijk";
const CIRCLE_SIZE: usize = 7500;
let query = iter::once(format!("query {{ ...{PERM} }} "))
.chain(
PERM.chars()
.permutations(PERM.len())
.map(|vec| vec.into_iter().collect::<String>())
.take(CIRCLE_SIZE)
.collect::<Vec<_>>()
.into_iter()
.circular_tuple_windows::<(_, _)>()
.map(|(cur, next)| format!("fragment {cur} on Query {{ ...{next} }} ")),
)
.collect::<String>();
let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new());
let _ = juniper::execute(&query, None, &schema, &graphql_vars! {}, &())
.await
.unwrap_err();
}