diff --git a/Cargo.lock b/Cargo.lock index 52d02cc4e9..7b3e299c71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -162,6 +162,22 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "ctor" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e98e2ad1a782e33928b96fc3948e7c355e5af34ba4de7670fe8bac2a3b2006d" +dependencies = [ + "quote", + "syn", +] + +[[package]] +name = "diff" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e25ea47919b1560c4e3b7fe0aaab9becf5b84a10325ddf7db0f0ba5e1026499" + [[package]] name = "difference" version = "2.0.0" @@ -360,6 +376,15 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10acf907b94fc1b1a152d08ef97e7759650268cf986bf127f387e602b02c7e5a" +[[package]] +name = "output_vt100" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53cdc5b785b7a58c5aad8216b3dfa114df64b0b06ae6e1501cef91df2fbdf8f9" +dependencies = [ + "winapi", +] + [[package]] name = "percent-encoding" version = "2.1.0" @@ -372,6 +397,18 @@ version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" +[[package]] +name = "pretty_assertions" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cab0e7c02cf376875e9335e0ba1da535775beb5450d21e1dffca068818ed98b" +dependencies = [ + "ansi_term 0.12.1", + "ctor", + "diff", + "output_vt100", +] + [[package]] name = "proc-macro2" version = "1.0.24" @@ -568,9 +605,9 @@ checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" [[package]] name = "syn" -version = "1.0.60" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c700597eca8a5a762beb35753ef6b94df201c81cca676604f547495a0d7f0081" +checksum = "6498a9efc342871f91cc2d0d694c674368b4ceb40f62b65a7a08c3792935e702" dependencies = [ "proc-macro2", "quote", @@ -701,6 +738,7 @@ dependencies = [ "indexmap", "lazy_static", "log", + "pretty_assertions", "rand", "regex", "regex-syntax", diff --git a/cli/Cargo.toml b/cli/Cargo.toml index e559842f4a..2a0d49275f 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -76,6 +76,7 @@ features = ["std"] [dev-dependencies] rand = "0.8" tempfile = "3" +pretty_assertions = "0.7.2" [build-dependencies] toml = "0.5" diff --git a/cli/src/tests/corpus_test.rs b/cli/src/tests/corpus_test.rs index 5699f063d6..d2e586deac 100644 --- a/cli/src/tests/corpus_test.rs +++ b/cli/src/tests/corpus_test.rs @@ -1,40 +1,22 @@ -use super::helpers::edits::{get_random_edit, invert_edit}; -use super::helpers::fixtures::{fixtures_dir, get_language, get_test_language}; -use super::helpers::random::Rand; -use super::helpers::scope_sequence::ScopeSequence; -use crate::generate; -use crate::parse::perform_edit; -use crate::test::{parse_tests, print_diff, print_diff_key, strip_sexp_fields, TestEntry}; -use crate::util; -use lazy_static::lazy_static; -use std::{env, fs, time, usize}; +use super::helpers::{ + edits::{get_random_edit, invert_edit}, + fixtures::{fixtures_dir, get_language, get_test_language}, + random::Rand, + scope_sequence::ScopeSequence, + EXAMPLE_FILTER, LANGUAGE_FILTER, LOG_ENABLED, LOG_GRAPH_ENABLED, SEED, TRIAL_FILTER, +}; +use crate::{ + generate, + parse::perform_edit, + test::{parse_tests, print_diff, print_diff_key, strip_sexp_fields, TestEntry}, + util, +}; +use std::{fs, usize}; use tree_sitter::{allocations, LogType, Node, Parser, Tree}; const EDIT_COUNT: usize = 3; const TRIAL_COUNT: usize = 10; -lazy_static! { - static ref LOG_ENABLED: bool = env::var("TREE_SITTER_TEST_ENABLE_LOG").is_ok(); - static ref LOG_GRAPH_ENABLED: bool = env::var("TREE_SITTER_TEST_ENABLE_LOG_GRAPHS").is_ok(); - static ref LANGUAGE_FILTER: Option = env::var("TREE_SITTER_TEST_LANGUAGE_FILTER").ok(); - static ref EXAMPLE_FILTER: Option = env::var("TREE_SITTER_TEST_EXAMPLE_FILTER").ok(); - static ref TRIAL_FILTER: Option = env::var("TREE_SITTER_TEST_TRIAL_FILTER") - .map(|s| usize::from_str_radix(&s, 10).unwrap()) - .ok(); - pub static ref SEED: usize = { - let seed = env::var("TREE_SITTER_TEST_SEED") - .map(|s| usize::from_str_radix(&s, 10).unwrap()) - .unwrap_or( - time::SystemTime::now() - .duration_since(time::UNIX_EPOCH) - .unwrap() - .as_secs() as usize, - ); - eprintln!("\n\nRandom seed: {}\n", seed); - seed - }; -} - #[test] fn test_bash_corpus() { test_language_corpus("bash"); diff --git a/cli/src/tests/helpers/mod.rs b/cli/src/tests/helpers/mod.rs index 3a75dad3a5..e492a42e39 100644 --- a/cli/src/tests/helpers/mod.rs +++ b/cli/src/tests/helpers/mod.rs @@ -1,4 +1,32 @@ pub(super) mod edits; pub(super) mod fixtures; +pub(super) mod query_helpers; pub(super) mod random; pub(super) mod scope_sequence; + +use lazy_static::lazy_static; +use std::{env, time, usize}; + +lazy_static! { + pub static ref SEED: usize = { + let seed = env::var("TREE_SITTER_TEST_SEED") + .map(|s| usize::from_str_radix(&s, 10).unwrap()) + .unwrap_or( + time::SystemTime::now() + .duration_since(time::UNIX_EPOCH) + .unwrap() + .as_secs() as usize, + ); + eprintln!("\n\nRandom seed: {}\n", seed); + seed + }; + pub static ref LOG_ENABLED: bool = env::var("TREE_SITTER_TEST_ENABLE_LOG").is_ok(); + pub static ref LOG_GRAPH_ENABLED: bool = env::var("TREE_SITTER_TEST_ENABLE_LOG_GRAPHS").is_ok(); + pub static ref LANGUAGE_FILTER: Option = + env::var("TREE_SITTER_TEST_LANGUAGE_FILTER").ok(); + pub static ref EXAMPLE_FILTER: Option = + env::var("TREE_SITTER_TEST_EXAMPLE_FILTER").ok(); + pub static ref TRIAL_FILTER: Option = env::var("TREE_SITTER_TEST_TRIAL_FILTER") + .map(|s| usize::from_str_radix(&s, 10).unwrap()) + .ok(); +} diff --git a/cli/src/tests/helpers/query_helpers.rs b/cli/src/tests/helpers/query_helpers.rs new file mode 100644 index 0000000000..78ae559ccb --- /dev/null +++ b/cli/src/tests/helpers/query_helpers.rs @@ -0,0 +1,306 @@ +use rand::prelude::Rng; +use std::{cmp::Ordering, fmt::Write, ops::Range}; +use tree_sitter::{Node, Point, Tree, TreeCursor}; + +#[derive(Debug)] +pub struct Pattern { + kind: Option<&'static str>, + named: bool, + field: Option<&'static str>, + capture: Option, + children: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Match<'a, 'tree> { + pub captures: Vec<(&'a str, Node<'tree>)>, + pub last_node: Option>, +} + +const CAPTURE_NAMES: &'static [&'static str] = &[ + "one", "two", "three", "four", "five", "six", "seven", "eight", +]; + +impl Pattern { + pub fn random_pattern_in_tree(tree: &Tree, rng: &mut impl Rng) -> (Self, Range) { + let mut cursor = tree.walk(); + + // Descend to the node at a random byte offset and depth. + let mut max_depth = 0; + let byte_offset = rng.gen_range(0..cursor.node().end_byte()); + while cursor.goto_first_child_for_byte(byte_offset).is_some() { + max_depth += 1; + } + let depth = rng.gen_range(0..=max_depth); + for _ in 0..depth { + cursor.goto_parent(); + } + + // Build a pattern that matches that node. + // Sometimes include subsequent siblings of the node. + let pattern_start = cursor.node().start_position(); + let mut roots = vec![Self::random_pattern_for_node(&mut cursor, rng)]; + while roots.len() < 5 && cursor.goto_next_sibling() { + if rng.gen_bool(0.2) { + roots.push(Self::random_pattern_for_node(&mut cursor, rng)); + } + } + let pattern_end = cursor.node().end_position(); + + let mut pattern = Self { + kind: None, + named: true, + field: None, + capture: None, + children: roots, + }; + + if pattern.children.len() == 1 { + pattern = pattern.children.pop().unwrap(); + } + // In a parenthesized list of sibling patterns, the first + // sibling can't be an anonymous `_` wildcard. + else if pattern.children[0].kind == Some("_") && !pattern.children[0].named { + pattern = pattern.children.pop().unwrap(); + } + // In a parenthesized list of sibling patterns, the first + // sibling can't have a field name. + else { + pattern.children[0].field = None; + } + + (pattern, pattern_start..pattern_end) + } + + fn random_pattern_for_node(cursor: &mut TreeCursor, rng: &mut impl Rng) -> Self { + let node = cursor.node(); + + // Sometimes specify the node's type, sometimes use a wildcard. + let (kind, named) = if rng.gen_bool(0.9) { + (Some(node.kind()), node.is_named()) + } else { + (Some("_"), node.is_named() && rng.gen_bool(0.8)) + }; + + // Sometimes specify the node's field. + let field = if rng.gen_bool(0.75) { + cursor.field_name() + } else { + None + }; + + // Sometimes capture the node. + let capture = if rng.gen_bool(0.7) { + Some(CAPTURE_NAMES[rng.gen_range(0..CAPTURE_NAMES.len())].to_string()) + } else { + None + }; + + // Walk the children and include child patterns for some of them. + let mut children = Vec::new(); + if named && cursor.goto_first_child() { + let max_children = rng.gen_range(0..4); + while cursor.goto_next_sibling() { + if rng.gen_bool(0.6) { + let child_ast = Self::random_pattern_for_node(cursor, rng); + children.push(child_ast); + if children.len() >= max_children { + break; + } + } + } + cursor.goto_parent(); + } + + Self { + kind, + named, + field, + capture, + children, + } + } + + pub fn to_string(&self) -> String { + let mut result = String::new(); + self.write_to_string(&mut result, 0); + result + } + + fn write_to_string(&self, string: &mut String, indent: usize) { + if let Some(field) = self.field { + write!(string, "{}: ", field).unwrap(); + } + + if self.named { + string.push('('); + let mut has_contents = false; + if let Some(kind) = &self.kind { + write!(string, "{}", kind).unwrap(); + has_contents = true; + } + for child in &self.children { + let indent = indent + 2; + if has_contents { + string.push('\n'); + string.push_str(&" ".repeat(indent)); + } + child.write_to_string(string, indent); + has_contents = true; + } + string.push(')'); + } else if self.kind == Some("_") { + string.push('_'); + } else { + write!(string, "\"{}\"", self.kind.unwrap().replace("\"", "\\\"")).unwrap(); + } + + if let Some(capture) = &self.capture { + write!(string, " @{}", capture).unwrap(); + } + } + + pub fn matches_in_tree<'tree>(&self, tree: &'tree Tree) -> Vec> { + let mut matches = Vec::new(); + + // Compute the matches naively: walk the tree and + // retry the entire pattern for each node. + let mut cursor = tree.walk(); + let mut ascending = false; + loop { + if ascending { + if cursor.goto_next_sibling() { + ascending = false; + } else if !cursor.goto_parent() { + break; + } + } else { + let matches_here = self.match_node(&mut cursor); + matches.extend_from_slice(&matches_here); + if !cursor.goto_first_child() { + ascending = true; + } + } + } + + matches.sort_unstable(); + matches.iter_mut().for_each(|m| m.last_node = None); + matches.dedup(); + matches + } + + pub fn match_node<'tree>(&self, cursor: &mut TreeCursor<'tree>) -> Vec> { + let node = cursor.node(); + + // If a kind is specified, check that it matches the node. + if let Some(kind) = self.kind { + if kind == "_" { + if self.named && !node.is_named() { + return Vec::new(); + } + } else if kind != node.kind() || self.named != node.is_named() { + return Vec::new(); + } + } + + // If a field is specified, check that it matches the node. + if let Some(field) = self.field { + if cursor.field_name() != Some(field) { + return Vec::new(); + } + } + + // Create a match for the current node. + let mat = Match { + captures: if let Some(name) = &self.capture { + vec![(name.as_str(), node)] + } else { + Vec::new() + }, + last_node: Some(node), + }; + + // If there are no child patterns to match, then return this single match. + if self.children.is_empty() { + return vec![mat]; + } + + // Find every matching combination of child patterns and child nodes. + let mut finished_matches = Vec::::new(); + if cursor.goto_first_child() { + let mut match_states = vec![(0, mat)]; + loop { + let mut new_match_states = Vec::new(); + for (pattern_index, mat) in &match_states { + let child_pattern = &self.children[*pattern_index]; + let child_matches = child_pattern.match_node(cursor); + for child_match in child_matches { + let mut combined_match = mat.clone(); + combined_match.last_node = child_match.last_node; + combined_match + .captures + .extend_from_slice(&child_match.captures); + if pattern_index + 1 < self.children.len() { + new_match_states.push((*pattern_index + 1, combined_match)); + } else { + let mut existing = false; + for existing_match in finished_matches.iter_mut() { + if existing_match.captures == combined_match.captures { + if child_pattern.capture.is_some() { + existing_match.last_node = combined_match.last_node; + } + existing = true; + } + } + if !existing { + finished_matches.push(combined_match); + } + } + } + } + match_states.extend_from_slice(&new_match_states); + if !cursor.goto_next_sibling() { + break; + } + } + cursor.goto_parent(); + } + finished_matches + } +} + +impl<'a, 'tree> PartialOrd for Match<'a, 'tree> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl<'a, 'tree> Ord for Match<'a, 'tree> { + // Tree-sitter returns matches in the order that they terminate + // during a depth-first walk of the tree. If multiple matches + // terminate on the same node, those matches are produced in the + // order that their captures were discovered. + fn cmp(&self, other: &Self) -> Ordering { + if let Some((last_node_a, last_node_b)) = self.last_node.zip(other.last_node) { + let cmp = compare_depth_first(last_node_a, last_node_b); + if cmp.is_ne() { + return cmp; + } + } + + for (a, b) in self.captures.iter().zip(other.captures.iter()) { + let cmp = compare_depth_first(a.1, b.1); + if !cmp.is_eq() { + return cmp; + } + } + + self.captures.len().cmp(&other.captures.len()) + } +} + +fn compare_depth_first(a: Node, b: Node) -> Ordering { + let a = a.byte_range(); + let b = b.byte_range(); + a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end)) +} diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 6b28cdd5f3..ef827eb2d2 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1,7 +1,10 @@ -use super::helpers::fixtures::get_language; +use super::helpers::{ + fixtures::get_language, + query_helpers::{Match, Pattern}, +}; use lazy_static::lazy_static; -use std::env; -use std::fmt::Write; +use rand::{prelude::StdRng, SeedableRng}; +use std::{env, fmt::Write}; use tree_sitter::{ allocations, Language, Node, Parser, Point, Query, QueryCapture, QueryCursor, QueryError, QueryErrorKind, QueryMatch, QueryPredicate, QueryPredicateArg, QueryProperty, @@ -3444,7 +3447,74 @@ fn test_query_alternative_predicate_prefix() { } #[test] -fn test_query_step_is_definite() { +fn test_query_random() { + use pretty_assertions::assert_eq; + + allocations::record(|| { + let language = get_language("rust"); + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let mut cursor = QueryCursor::new(); + cursor.set_match_limit(64); + + let pattern_tree = parser + .parse(include_str!("helpers/query_helpers.rs"), None) + .unwrap(); + let test_tree = parser + .parse(include_str!("helpers/query_helpers.rs"), None) + .unwrap(); + + // let start_seed = *SEED; + let start_seed = 0; + + for i in 0..100 { + let seed = (start_seed + i) as u64; + let mut rand = StdRng::seed_from_u64(seed); + let (pattern_ast, range) = Pattern::random_pattern_in_tree(&pattern_tree, &mut rand); + let pattern = pattern_ast.to_string(); + let expected_matches = pattern_ast.matches_in_tree(&test_tree); + + eprintln!( + "seed: {}\nsource_range: {:?}\npattern:\n{}\nexpected match count: {}\n", + seed, + range, + pattern, + expected_matches.len(), + ); + + let query = Query::new(language, &pattern).unwrap(); + let mut actual_matches = cursor + .matches( + &query, + test_tree.root_node(), + (include_str!("parser_test.rs")).as_bytes(), + ) + .map(|mat| Match { + last_node: None, + captures: mat + .captures + .iter() + .map(|c| (query.capture_names()[c.index as usize].as_str(), c.node)) + .collect::>(), + }) + .collect::>(); + + // actual_matches.sort_unstable(); + actual_matches.dedup(); + + if !cursor.did_exceed_match_limit() { + assert_eq!( + actual_matches, expected_matches, + "seed: {}, pattern:\n{}", + seed, pattern + ); + } + } + }); +} + +#[test] +fn test_query_is_pattern_guaranteed_at_step() { struct Row { language: Language, description: &'static str, @@ -3454,19 +3524,19 @@ fn test_query_step_is_definite() { let rows = &[ Row { - description: "no definite steps", + description: "no guaranteed steps", language: get_language("python"), pattern: r#"(expression_statement (string))"#, results_by_substring: &[("expression_statement", false), ("string", false)], }, Row { - description: "all definite steps", + description: "all guaranteed steps", language: get_language("javascript"), pattern: r#"(object "{" "}")"#, results_by_substring: &[("object", false), ("{", true), ("}", true)], }, Row { - description: "an indefinite step that is optional", + description: "a fallible step that is optional", language: get_language("javascript"), pattern: r#"(object "{" (identifier)? @foo "}")"#, results_by_substring: &[ @@ -3477,7 +3547,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "multiple indefinite steps that are optional", + description: "multiple fallible steps that are optional", language: get_language("javascript"), pattern: r#"(object "{" (identifier)? @id1 ("," (identifier) @id2)? "}")"#, results_by_substring: &[ @@ -3489,13 +3559,13 @@ fn test_query_step_is_definite() { ], }, Row { - description: "definite step after indefinite step", + description: "guaranteed step after fallibe step", language: get_language("javascript"), pattern: r#"(pair (property_identifier) ":")"#, results_by_substring: &[("pair", false), ("property_identifier", false), (":", true)], }, Row { - description: "indefinite step in between two definite steps", + description: "fallible step in between two guaranteed steps", language: get_language("javascript"), pattern: r#"(ternary_expression condition: (_) @@ -3512,13 +3582,13 @@ fn test_query_step_is_definite() { ], }, Row { - description: "one definite step after a repetition", + description: "one guaranteed step after a repetition", language: get_language("javascript"), pattern: r#"(object "{" (_) "}")"#, results_by_substring: &[("object", false), ("{", false), ("(_)", false), ("}", true)], }, Row { - description: "definite steps after multiple repetitions", + description: "guaranteed steps after multiple repetitions", language: get_language("json"), pattern: r#"(object "{" (pair) "," (pair) "," (_) "}")"#, results_by_substring: &[ @@ -3532,7 +3602,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "a definite with a field", + description: "a guaranteed step with a field", language: get_language("javascript"), pattern: r#"(binary_expression left: (identifier) right: (_))"#, results_by_substring: &[ @@ -3542,7 +3612,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "multiple definite steps with fields", + description: "multiple guaranteed steps with fields", language: get_language("javascript"), pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#, results_by_substring: &[ @@ -3552,7 +3622,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "nesting, one definite step", + description: "nesting, one guaranteed step", language: get_language("javascript"), pattern: r#" (function_declaration @@ -3568,7 +3638,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "definite step after some deeply nested hidden nodes", + description: "a guaranteed step after some deeply nested hidden nodes", language: get_language("ruby"), pattern: r#" (singleton_class @@ -3582,7 +3652,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "nesting, no definite steps", + description: "nesting, no guaranteed steps", language: get_language("javascript"), pattern: r#" (call_expression @@ -3593,7 +3663,7 @@ fn test_query_step_is_definite() { results_by_substring: &[("property_identifier", false), ("template_string", false)], }, Row { - description: "a definite step after a nested node", + description: "a guaranteed step after a nested node", language: get_language("javascript"), pattern: r#" (subscript_expression @@ -3609,7 +3679,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "a step that is indefinite due to a predicate", + description: "a step that is fallible due to a predicate", language: get_language("javascript"), pattern: r#" (subscript_expression @@ -3626,7 +3696,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "alternation where one branch has definite steps", + description: "alternation where one branch has guaranteed steps", language: get_language("javascript"), pattern: r#" [ @@ -3645,7 +3715,7 @@ fn test_query_step_is_definite() { ], }, Row { - description: "aliased parent node", + description: "guaranteed step at the end of an aliased parent node", language: get_language("ruby"), pattern: r#" (method_parameters "(" (identifier) @id")") @@ -3700,6 +3770,21 @@ fn test_query_step_is_definite() { ("(heredoc_end)", true), ], }, + Row { + description: "multiple extra nodes", + language: get_language("rust"), + pattern: r#" + (call_expression + (line_comment) @a + (line_comment) @b + (arguments)) + "#, + results_by_substring: &[ + ("(line_comment) @a", false), + ("(line_comment) @b", false), + ("(arguments)", true), + ], + }, ]; allocations::record(|| { @@ -3716,7 +3801,7 @@ fn test_query_step_is_definite() { for (substring, is_definite) in row.results_by_substring { let offset = row.pattern.find(substring).unwrap(); assert_eq!( - query.step_is_definite(offset), + query.is_pattern_guaranteed_at_step(offset), *is_definite, "Description: {}, Pattern: {:?}, substring: {:?}, expected is_definite to be {}", row.description, diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 91b5570840..171b9a529a 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -1,4 +1,4 @@ -/* automatically generated by rust-bindgen 0.58.1 */ +/* automatically generated by rust-bindgen 0.59.1 */ pub type __darwin_size_t = ::std::os::raw::c_ulong; pub type FILE = [u64; 19usize]; @@ -659,7 +659,7 @@ extern "C" { ) -> *const TSQueryPredicateStep; } extern "C" { - pub fn ts_query_step_is_definite(self_: *const TSQuery, byte_offset: u32) -> bool; + pub fn ts_query_is_pattern_guaranteed_at_step(self_: *const TSQuery, byte_offset: u32) -> bool; } extern "C" { #[doc = " Get the name and length of one of the query's captures, or one of the"] diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 92ba4d1bd0..366959c827 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1586,8 +1586,10 @@ impl Query { /// /// A query step is 'definite' if its parent pattern will be guaranteed to match /// successfully once it reaches the step. - pub fn step_is_definite(&self, byte_offset: usize) -> bool { - unsafe { ffi::ts_query_step_is_definite(self.ptr.as_ptr(), byte_offset as u32) } + pub fn is_pattern_guaranteed_at_step(&self, byte_offset: usize) -> bool { + unsafe { + ffi::ts_query_is_pattern_guaranteed_at_step(self.ptr.as_ptr(), byte_offset as u32) + } } fn parse_property( diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 7168faec36..66ceaea1c0 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -725,7 +725,7 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( uint32_t *length ); -bool ts_query_step_is_definite( +bool ts_query_is_pattern_guaranteed_at_step( const TSQuery *self, uint32_t byte_offset ); diff --git a/lib/src/query.c b/lib/src/query.c index 66d377dea3..5feb62452c 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -8,8 +8,7 @@ #include // #define DEBUG_ANALYZE_QUERY -// #define LOG(...) fprintf(stderr, __VA_ARGS__) -#define LOG(...) +// #define DEBUG_EXECUTE_QUERY #define MAX_STEP_CAPTURE_COUNT 3 #define MAX_STATE_PREDECESSOR_COUNT 100 @@ -41,6 +40,15 @@ typedef struct { * associated with this node in the pattern, terminated by a `NONE` value. * - `depth` - The depth where this node occurs in the pattern. The root node * of the pattern has depth zero. + * - `negated_field_list_id` - An id representing a set of fields that must + * that must not be present on a node matching this step. + * + * Steps have some additional fields in order to handle the `.` (or "anchor") operator, + * which forbids additional child nodes: + * - `is_immediate` - Indicates that the node matching this step cannot be preceded + * by other sibling nodes that weren't specified in the pattern. + * - `is_last_child` - Indicates that the node matching this step cannot have any + * subsequent named siblings. * * For simple patterns, steps are matched in sequential order. But in order to * handle alternative/repeated/optional sub-patterns, query steps are not always @@ -52,18 +60,26 @@ typedef struct { * is duplicated, with one copy remaining at the original step, and one copy * moving to the alternative step. The alternative may have its own alternative * step, so this splitting is an iterative process. - * - `is_dead_end` - Indication that this state cannot be passed directly, and + * - `is_dead_end` - Indicates that this state cannot be passed directly, and * exists only in order to redirect to an alternative index, with no splitting. - * - `is_pass_through` - Indication that state has no matching logic of its own, + * - `is_pass_through` - Indicates that state has no matching logic of its own, * and exists only to split a state. One copy of the state advances immediately * to the next step, and one moves to the alternative step. + * - `alternative_is_immediate` - Indicates that this step's alternative step + * should be treated as if `is_immediate` is true. * - * Steps have some additional fields in order to handle the `.` (or "anchor") operator, - * which forbids additional child nodes: - * - `is_immediate` - Indication that the node matching this step cannot be preceded - * by other sibling nodes that weren't specified in the pattern. - * - `is_last_child` - Indicates that the node matching this step cannot have any - * subsequent named siblings. + * Steps also store some derived state that summarizes how they relate to other + * steps within the same pattern. This is used to optimize the matching process: + * - `contains_captures` - Indicates that this step or one of its child steps + * has a non-empty `capture_ids` list. + * - `parent_pattern_guaranteed` - Indicates that if this step is reached, then + * it and all of its subsequent sibling steps within the same parent pattern + * are guaranteed to match. + * - `root_pattern_guaranteed` - Similar to `parent_pattern_guaranteed`, but + * for the entire top-level pattern. When iterating through a query's + * captures using `ts_query_cursor_next_capture`, this field is used to + * detect that a capture can safely be returned from a match that has not + * even completed yet. */ typedef struct { TSSymbol symbol; @@ -73,13 +89,15 @@ typedef struct { uint16_t depth; uint16_t alternative_index; uint16_t negated_field_list_id; - bool contains_captures: 1; + bool is_named: 1; bool is_immediate: 1; bool is_last_child: 1; bool is_pass_through: 1; bool is_dead_end: 1; bool alternative_is_immediate: 1; - bool is_definite: 1; + bool contains_captures: 1; + bool root_pattern_guaranteed: 1; + bool parent_pattern_guaranteed: 1; } QueryStep; /* @@ -279,7 +297,6 @@ static const TSQueryError PARENT_DONE = -1; static const uint16_t PATTERN_DONE_MARKER = UINT16_MAX; static const uint16_t NONE = UINT16_MAX; static const TSSymbol WILDCARD_SYMBOL = 0; -static const TSSymbol NAMED_WILDCARD_SYMBOL = UINT16_MAX - 1; /********** * Stream @@ -512,9 +529,10 @@ static QueryStep query_step__new( .negated_field_list_id = 0, .contains_captures = false, .is_last_child = false, + .is_named = false, .is_pass_through = false, .is_dead_end = false, - .is_definite = false, + .root_pattern_guaranteed = false, .is_immediate = is_immediate, .alternative_is_immediate = false, }; @@ -548,9 +566,14 @@ static void query_step__remove_capture(QueryStep *self, uint16_t capture_id) { * StatePredecessorMap **********************/ -static inline StatePredecessorMap state_predecessor_map_new(const TSLanguage *language) { +static inline StatePredecessorMap state_predecessor_map_new( + const TSLanguage *language +) { return (StatePredecessorMap) { - .contents = ts_calloc(language->state_count * (MAX_STATE_PREDECESSOR_COUNT + 1), sizeof(TSStateId)), + .contents = ts_calloc( + language->state_count * (MAX_STATE_PREDECESSOR_COUNT + 1), + sizeof(TSStateId) + ), }; } @@ -565,7 +588,10 @@ static inline void state_predecessor_map_add( ) { unsigned index = state * (MAX_STATE_PREDECESSOR_COUNT + 1); TSStateId *count = &self->contents[index]; - if (*count == 0 || (*count < MAX_STATE_PREDECESSOR_COUNT && self->contents[index + *count] != predecessor)) { + if ( + *count == 0 || + (*count < MAX_STATE_PREDECESSOR_COUNT && self->contents[index + *count] != predecessor) + ) { (*count)++; self->contents[index + *count] = predecessor; } @@ -744,25 +770,39 @@ static inline void ts_query__pattern_map_insert( } static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { - // Identify all of the patterns in the query that have child patterns, both at the - // top level and nested within other larger patterns. Record the step index where - // each pattern starts. + // Walk forward through all of the steps in the query, computing some + // basic information about each step. Mark all of the steps that contain + // captures, and record the indices of all of the steps that have child steps. Array(uint32_t) parent_step_indices = array_new(); for (unsigned i = 0; i < self->steps.size; i++) { QueryStep *step = &self->steps.contents[i]; - if (i + 1 < self->steps.size) { - QueryStep *next_step = &self->steps.contents[i + 1]; + if (step->depth == PATTERN_DONE_MARKER) { + step->parent_pattern_guaranteed = true; + step->root_pattern_guaranteed = true; + continue; + } + + bool has_children = false; + bool is_wildcard = step->symbol == WILDCARD_SYMBOL; + step->contains_captures = step->capture_ids[0] != NONE; + for (unsigned j = i + 1; j < self->steps.size; j++) { + QueryStep *next_step = &self->steps.contents[j]; if ( - step->symbol != WILDCARD_SYMBOL && - step->symbol != NAMED_WILDCARD_SYMBOL && - next_step->depth > step->depth && - next_step->depth != PATTERN_DONE_MARKER - ) { - array_push(&parent_step_indices, i); + next_step->depth == PATTERN_DONE_MARKER || + next_step->depth <= step->depth + ) break; + if (next_step->capture_ids[0] != NONE) { + step->contains_captures = true; } + if (!is_wildcard) { + next_step->root_pattern_guaranteed = true; + next_step->parent_pattern_guaranteed = true; + } + has_children = true; } - if (step->depth > 0) { - step->is_definite = true; + + if (has_children && !is_wildcard) { + array_push(&parent_step_indices, i); } } @@ -945,7 +985,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { if (parent_symbol == ts_builtin_sym_error) continue; // Find the subgraph that corresponds to this pattern's root symbol. If the pattern's - // root symbols is not a non-terminal, then return an error. + // root symbol is a terminal, then return an error. unsigned subgraph_index, exists; array_search_sorted_by(&subgraphs, .symbol, parent_symbol, &subgraph_index, &exists); if (!exists) { @@ -1073,24 +1113,27 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { while (ts_lookahead_iterator_next(&lookahead_iterator)) { TSSymbol sym = lookahead_iterator.symbol; - TSStateId next_parse_state; + AnalysisSubgraphNode successor = { + .state = parse_state, + .child_index = child_index, + }; if (lookahead_iterator.action_count) { const TSParseAction *action = &lookahead_iterator.actions[lookahead_iterator.action_count - 1]; if (action->type == TSParseActionTypeShift) { - next_parse_state = action->shift.extra ? parse_state : action->shift.state; + if (!action->shift.extra) { + successor.state = action->shift.state; + successor.child_index++; + } } else { continue; } } else if (lookahead_iterator.next_state != 0) { - next_parse_state = lookahead_iterator.next_state; + successor.state = lookahead_iterator.next_state; + successor.child_index++; } else { continue; } - AnalysisSubgraphNode successor = { - .state = next_parse_state, - .child_index = child_index + 1, - }; unsigned node_index; array_search_sorted_with( &subgraph->nodes, @@ -1124,7 +1167,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { // Create a new state that has advanced past this hypothetical subtree. AnalysisState next_state = *state; AnalysisStateEntry *next_state_top = analysis_state__top(&next_state); - next_state_top->child_index++; + next_state_top->child_index = successor.child_index; next_state_top->parse_state = successor.state; if (node->done) next_state_top->done = true; @@ -1133,10 +1176,13 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { bool does_match = false; if (visible_symbol) { does_match = true; - if (step->symbol == NAMED_WILDCARD_SYMBOL) { - if (!self->language->symbol_metadata[visible_symbol].named) does_match = false; - } else if (step->symbol != WILDCARD_SYMBOL) { - if (step->symbol != visible_symbol) does_match = false; + if (step->symbol == WILDCARD_SYMBOL) { + if ( + step->is_named && + !self->language->symbol_metadata[visible_symbol].named + ) does_match = false; + } else if (step->symbol != visible_symbol) { + does_match = false; } if (step->field && step->field != field_id) { does_match = false; @@ -1198,7 +1244,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { next_step->depth <= parent_depth + 1 ) break; } - } else if (next_parse_state == parse_state) { + } else if (successor.state == parse_state) { continue; } @@ -1259,7 +1305,8 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { step->depth > parent_depth && !step->is_dead_end ) { - step->is_definite = false; + step->parent_pattern_guaranteed = false; + step->root_pattern_guaranteed = false; } } @@ -1271,7 +1318,8 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { step->depth == PATTERN_DONE_MARKER ) break; if (!step->is_dead_end) { - step->is_definite = false; + step->parent_pattern_guaranteed = false; + step->root_pattern_guaranteed = false; } } } @@ -1321,25 +1369,27 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { unsigned index, exists; array_search_sorted_by(&predicate_capture_ids, , capture_id, &index, &exists); if (exists) { - step->is_definite = false; + step->root_pattern_guaranteed = false; break; } } } } - // Propagate indefiniteness backwards. + // Propagate fallibility. If a pattern is fallible at a given step, then it is + // fallible at all of its preceding steps. bool done = self->steps.size == 0; while (!done) { done = true; for (unsigned i = self->steps.size - 1; i > 0; i--) { QueryStep *step = &self->steps.contents[i]; + if (step->depth == PATTERN_DONE_MARKER) continue; // Determine if this step is definite or has definite alternatives. - bool is_definite = false; + bool parent_pattern_guaranteed = false; for (;;) { - if (step->is_definite) { - is_definite = true; + if (step->root_pattern_guaranteed) { + parent_pattern_guaranteed = true; break; } if (step->alternative_index == NONE || step->alternative_index < i) { @@ -1349,14 +1399,14 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { } // If not, mark its predecessor as indefinite. - if (!is_definite) { + if (!parent_pattern_guaranteed) { QueryStep *prev_step = &self->steps.contents[i - 1]; if ( !prev_step->is_dead_end && prev_step->depth != PATTERN_DONE_MARKER && - prev_step->is_definite + prev_step->root_pattern_guaranteed ) { - prev_step->is_definite = false; + prev_step->root_pattern_guaranteed = false; done = false; } } @@ -1371,13 +1421,15 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { printf(" %u: DONE\n", i); } else { printf( - " %u: {symbol: %s, field: %s, is_definite: %d}\n", + " %u: {symbol: %s, field: %s, depth: %u, parent_pattern_guaranteed: %d, root_pattern_guaranteed: %d}\n", i, - (step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL) + (step->symbol == WILDCARD_SYMBOL) ? "ANY" : ts_language_symbol_name(self->language, step->symbol), (step->field ? ts_language_field_name_for_id(self->language, step->field) : "-"), - step->is_definite + step->depth, + step->parent_pattern_guaranteed, + step->root_pattern_guaranteed ); } } @@ -1401,23 +1453,6 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { return all_patterns_are_valid; } -static void ts_query__finalize_steps(TSQuery *self) { - for (unsigned i = 0; i < self->steps.size; i++) { - QueryStep *step = &self->steps.contents[i]; - uint32_t depth = step->depth; - if (step->capture_ids[0] != NONE) { - step->contains_captures = true; - } else { - step->contains_captures = false; - for (unsigned j = i + 1; j < self->steps.size; j++) { - QueryStep *s = &self->steps.contents[j]; - if (s->depth == PATTERN_DONE_MARKER || s->depth <= depth) break; - if (s->capture_ids[0] != NONE) step->contains_captures = true; - } - } - } -} - static void ts_query__add_negated_fields( TSQuery *self, uint16_t step_index, @@ -1746,15 +1781,8 @@ static TSQueryError ts_query__parse_pattern( else { TSSymbol symbol; - // TODO - remove. - // For temporary backward compatibility, handle '*' as a wildcard. - if (stream->next == '*') { - symbol = depth > 0 ? NAMED_WILDCARD_SYMBOL : WILDCARD_SYMBOL; - stream_advance(stream); - } - // Parse a normal node name - else if (stream_is_ident_start(stream)) { + if (stream_is_ident_start(stream)) { const char *node_name = stream->input; stream_scan_identifier(stream); uint32_t length = stream->input - node_name; @@ -1768,7 +1796,7 @@ static TSQueryError ts_query__parse_pattern( // Parse the wildcard symbol else if (length == 1 && node_name[0] == '_') { - symbol = depth > 0 ? NAMED_WILDCARD_SYMBOL : WILDCARD_SYMBOL; + symbol = WILDCARD_SYMBOL; } else { @@ -1789,10 +1817,13 @@ static TSQueryError ts_query__parse_pattern( // Add a step for the node. array_push(&self->steps, query_step__new(symbol, depth, is_immediate)); + QueryStep *step = array_back(&self->steps); if (ts_language_symbol_metadata(self->language, symbol).supertype) { - QueryStep *step = array_back(&self->steps); step->supertype_symbol = step->symbol; - step->symbol = NAMED_WILDCARD_SYMBOL; + step->symbol = WILDCARD_SYMBOL; + } + if (symbol == WILDCARD_SYMBOL) { + step->is_named = true; } stream_skip_whitespace(stream); @@ -1807,7 +1838,6 @@ static TSQueryError ts_query__parse_pattern( stream_scan_identifier(stream); uint32_t length = stream->input - node_name; - QueryStep *step = array_back(&self->steps); step->symbol = ts_language_symbol_for_name( self->language, node_name, @@ -1901,13 +1931,7 @@ static TSQueryError ts_query__parse_pattern( } // Parse a wildcard pattern - else if ( - stream->next == '_' || - - // TODO remove. - // For temporary backward compatibility, handle '*' as a wildcard. - stream->next == '*' - ) { + else if (stream->next == '_') { stream_advance(stream); stream_skip_whitespace(stream); @@ -2150,7 +2174,7 @@ TSQuery *ts_query_new( // then optimize the matching process by skipping matching the wildcard. // Later, during the matching process, the query cursor will check that // there is a parent node, and capture it if necessary. - if (step->symbol == WILDCARD_SYMBOL && step->depth == 0) { + if (step->symbol == WILDCARD_SYMBOL && step->depth == 0 && !step->field) { QueryStep *second_step = &self->steps.contents[start_step_index + 1]; if (second_step->symbol != WILDCARD_SYMBOL && second_step->depth == 1) { wildcard_root_alternative_index = step->alternative_index; @@ -2201,7 +2225,6 @@ TSQuery *ts_query_new( return NULL; } - ts_query__finalize_steps(self); array_delete(&self->string_buffer); return self; } @@ -2269,7 +2292,7 @@ uint32_t ts_query_start_byte_for_pattern( return self->patterns.contents[pattern_index].start_byte; } -bool ts_query_step_is_definite( +bool ts_query_is_pattern_guaranteed_at_step( const TSQuery *self, uint32_t byte_offset ) { @@ -2280,12 +2303,26 @@ bool ts_query_step_is_definite( step_index = step_offset->step_index; } if (step_index < self->steps.size) { - return self->steps.contents[step_index].is_definite; + return self->steps.contents[step_index].root_pattern_guaranteed; } else { return false; } } +bool ts_query__step_is_fallible( + const TSQuery *self, + uint16_t step_index +) { + assert((uint32_t)step_index + 1 < self->steps.size); + QueryStep *step = &self->steps.contents[step_index]; + QueryStep *next_step = &self->steps.contents[step_index + 1]; + return ( + next_step->depth != PATTERN_DONE_MARKER && + next_step->depth > step->depth && + !next_step->parent_pattern_guaranteed + ); +} + void ts_query_disable_capture( TSQuery *self, const char *name, @@ -2299,7 +2336,6 @@ void ts_query_disable_capture( QueryStep *step = &self->steps.contents[i]; query_step__remove_capture(step, id); } - ts_query__finalize_steps(self); } } @@ -2409,7 +2445,7 @@ static bool ts_query_cursor__first_in_progress_capture( uint32_t *state_index, uint32_t *byte_offset, uint32_t *pattern_index, - bool *is_definite + bool *root_pattern_guaranteed ) { bool result = false; *state_index = UINT32_MAX; @@ -2444,9 +2480,9 @@ static bool ts_query_cursor__first_in_progress_capture( (node_start_byte == *byte_offset && state->pattern_index < *pattern_index) ) { QueryStep *step = &self->query->steps.contents[state->step_index]; - if (is_definite) { - *is_definite = step->is_definite; - } else if (step->is_definite) { + if (root_pattern_guaranteed) { + *root_pattern_guaranteed = step->root_pattern_guaranteed; + } else if (step->root_pattern_guaranteed) { continue; } @@ -2532,6 +2568,12 @@ void ts_query_cursor__compare_captures( } } +#ifdef DEBUG_EXECUTE_QUERY +#define LOG(...) fprintf(stderr, __VA_ARGS__) +#else +#define LOG(...) +#endif + static void ts_query_cursor__add_state( TSQueryCursor *self, const PatternEntry *pattern @@ -2563,12 +2605,13 @@ static void ts_query_cursor__add_state( QueryState *prev_state = &self->states.contents[index - 1]; if (prev_state->start_depth < start_depth) break; if (prev_state->start_depth == start_depth) { - if (prev_state->pattern_index < pattern->pattern_index) break; - if (prev_state->pattern_index == pattern->pattern_index) { - // Avoid inserting an unnecessary duplicate state, which would be - // immediately pruned by the longest-match criteria. - if (prev_state->step_index == pattern->step_index) return; - } + // Avoid inserting an unnecessary duplicate state, which would be + // immediately pruned by the longest-match criteria. + if ( + prev_state->pattern_index == pattern->pattern_index && + prev_state->step_index == pattern->step_index + ) return; + if (prev_state->pattern_index <= pattern->pattern_index) break; } index--; } @@ -2721,7 +2764,11 @@ static inline bool ts_query_cursor__advance( // Exit the current node. if (self->ascending) { - LOG("leave node. type:%s\n", ts_node_type(ts_tree_cursor_current_node(&self->cursor))); + LOG( + "leave node. depth:%u, type:%s\n", + self->depth, + ts_node_type(ts_tree_cursor_current_node(&self->cursor)) + ); // Leave this node by stepping to its next sibling or to its parent. if (ts_tree_cursor_goto_next_sibling(&self->cursor)) { @@ -2797,7 +2844,8 @@ static inline bool ts_query_cursor__advance( &supertype_count ); LOG( - "enter node. type:%s, field:%s, row:%u state_count:%u, finished_state_count:%u\n", + "enter node. depth:%u, type:%s, field:%s, row:%u state_count:%u, finished_state_count:%u\n", + self->depth, ts_node_type(node), ts_language_field_name_for_id(self->query->language, field_id), ts_node_start_point(node).row, @@ -2873,10 +2921,12 @@ static inline bool ts_query_cursor__advance( // Determine if this node matches this step of the pattern, and also // if this node can have later siblings that match this step of the // pattern. - bool node_does_match = - step->symbol == symbol || - step->symbol == WILDCARD_SYMBOL || - (step->symbol == NAMED_WILDCARD_SYMBOL && is_named); + bool node_does_match = false; + if (step->symbol == WILDCARD_SYMBOL) { + node_does_match = is_named || !step->is_named; + } else { + node_does_match = symbol == step->symbol; + } bool later_sibling_can_match = has_later_siblings; if ((step->is_immediate && is_named) || state->seeking_immediate_match) { later_sibling_can_match = false; @@ -2943,7 +2993,10 @@ static inline bool ts_query_cursor__advance( // parent, then this query state cannot simply be updated in place. It must be // split into two states: one that matches this node, and one which skips over // this node, to preserve the possibility of matching later siblings. - if (later_sibling_can_match && (step->contains_captures || !step->is_definite)) { + if (later_sibling_can_match && ( + step->contains_captures || + ts_query__step_is_fallible(self->query, state->step_index) + )) { if (ts_query_cursor__copy_state(self, &state)) { LOG( " split state for capture. pattern:%u, step:%u\n", @@ -3005,7 +3058,7 @@ static inline bool ts_query_cursor__advance( ); QueryStep *next_step = &self->query->steps.contents[state->step_index]; - if (stop_on_definite_step && next_step->is_definite) did_match = true; + if (stop_on_definite_step && next_step->root_pattern_guaranteed) did_match = true; // If this state's next step has an alternative step, then copy the state in order // to pursue both alternatives. The alternative step itself may have an alternative, @@ -3116,7 +3169,7 @@ static inline bool ts_query_cursor__advance( } } - // If there the state is at the end of its pattern, remove it from the list + // If the state is at the end of its pattern, remove it from the list // of in-progress states and add it to the list of finished states. if (!did_remove) { LOG(