From cedced5bb00300811e41edb225eeb2cce74fc640 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Halvor=20Fladsrud=20B=C3=B8?= Date: Mon, 7 Mar 2022 06:29:18 +0000 Subject: [PATCH] Slop support for phrase queries (#1241) Closes #1068 --- src/query/phrase_query/mod.rs | 84 ++++++++++ src/query/phrase_query/phrase_query.rs | 17 +- src/query/phrase_query/phrase_scorer.rs | 206 +++++++++++++++++------- src/query/phrase_query/phrase_weight.rs | 8 + 4 files changed, 255 insertions(+), 60 deletions(-) diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index edec22fec5..2b3f5469ed 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -181,6 +181,90 @@ pub mod tests { Ok(()) } + #[ignore] + #[test] + pub fn test_phrase_score_with_slop() -> crate::Result<()> { + let index = create_index(&["a c b", "a b c a b"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader().unwrap().searcher(); + let test_query = |texts: Vec<&str>| { + let terms: Vec = texts + .iter() + .map(|text| Term::from_field_text(text_field, text)) + .collect(); + let mut phrase_query = PhraseQuery::new(terms); + phrase_query.set_slop(1); + searcher + .search(&phrase_query, &TEST_COLLECTOR_WITH_SCORE) + .expect("search should succeed") + .scores() + .to_vec() + }; + let scores = test_query(vec!["a", "b"]); + assert_nearly_equals!(scores[0], 0.40618482); + assert_nearly_equals!(scores[1], 0.46844664); + Ok(()) + } + + #[test] + pub fn test_phrase_score_with_slop_size() -> crate::Result<()> { + let index = create_index(&["a b e c", "a e e e c", "a e e e e c"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader().unwrap().searcher(); + let test_query = |texts: Vec<&str>| { + let terms: Vec = texts + .iter() + .map(|text| Term::from_field_text(text_field, text)) + .collect(); + let mut phrase_query = PhraseQuery::new(terms); + phrase_query.set_slop(3); + searcher + .search(&phrase_query, &TEST_COLLECTOR_WITH_SCORE) + .expect("search should succeed") + .scores() + .to_vec() + }; + let scores = test_query(vec!["a", "c"]); + assert_nearly_equals!(scores[0], 0.29086056); + assert_nearly_equals!(scores[1], 0.26706287); + Ok(()) + } + + #[test] + pub fn test_phrase_score_with_slop_ordering() -> crate::Result<()> { + let index = create_index(&[ + "a e b e c", + "a e e e e e b e e e e c", + "a c b", + "a c e b e", + "a e c b", + "a e b c", + ])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader().unwrap().searcher(); + let test_query = |texts: Vec<&str>| { + let terms: Vec = texts + .iter() + .map(|text| Term::from_field_text(text_field, text)) + .collect(); + let mut phrase_query = PhraseQuery::new(terms); + phrase_query.set_slop(3); + searcher + .search(&phrase_query, &TEST_COLLECTOR_WITH_SCORE) + .expect("search should succeed") + .scores() + .to_vec() + }; + let scores = test_query(vec!["a", "b", "c"]); + // The first and last matches. + assert_nearly_equals!(scores[0], 0.23091172); + assert_nearly_equals!(scores[1], 0.25024384); + Ok(()) + } + #[test] // motivated by #234 pub fn test_phrase_query_docfreq_order() -> crate::Result<()> { let mut schema_builder = Schema::builder(); diff --git a/src/query/phrase_query/phrase_query.rs b/src/query/phrase_query/phrase_query.rs index 34af3e571a..b147158a40 100644 --- a/src/query/phrase_query/phrase_query.rs +++ b/src/query/phrase_query/phrase_query.rs @@ -23,6 +23,7 @@ use crate::schema::{Field, IndexRecordOption, Term}; pub struct PhraseQuery { field: Field, phrase_terms: Vec<(usize, Term)>, + slop: u32, } impl PhraseQuery { @@ -53,9 +54,15 @@ impl PhraseQuery { PhraseQuery { field, phrase_terms: terms, + slop: 0, } } + /// Slop allowed for the phrase. + pub fn set_slop(&mut self, value: u32) { + self.slop = value; + } + /// The `Field` this `PhraseQuery` is targeting. pub fn field(&self) -> Field { self.field @@ -94,11 +101,11 @@ impl PhraseQuery { } let terms = self.phrase_terms(); let bm25_weight = Bm25Weight::for_terms(searcher, &terms)?; - Ok(PhraseWeight::new( - self.phrase_terms.clone(), - bm25_weight, - scoring_enabled, - )) + let mut weight = PhraseWeight::new(self.phrase_terms.clone(), bm25_weight, scoring_enabled); + if self.slop > 0 { + weight.slop(self.slop); + } + Ok(weight) } } diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 4f8d45894e..37d83477af 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -52,24 +52,25 @@ pub struct PhraseScorer { fieldnorm_reader: FieldNormReader, similarity_weight: Bm25Weight, scoring_enabled: bool, + slop: u32, } /// Returns true if and only if the two sorted arrays contain a common element fn intersection_exists(left: &[u32], right: &[u32]) -> bool { - let mut left_i = 0; - let mut right_i = 0; - while left_i < left.len() && right_i < right.len() { - let left_val = left[left_i]; - let right_val = right[right_i]; + let mut left_index = 0; + let mut right_index = 0; + while left_index < left.len() && right_index < right.len() { + let left_val = left[left_index]; + let right_val = right[right_index]; match left_val.cmp(&right_val) { Ordering::Less => { - left_i += 1; + left_index += 1; } Ordering::Equal => { return true; } Ordering::Greater => { - right_i += 1; + right_index += 1; } } } @@ -77,23 +78,23 @@ fn intersection_exists(left: &[u32], right: &[u32]) -> bool { } fn intersection_count(left: &[u32], right: &[u32]) -> usize { - let mut left_i = 0; - let mut right_i = 0; + let mut left_index = 0; + let mut right_index = 0; let mut count = 0; - while left_i < left.len() && right_i < right.len() { - let left_val = left[left_i]; - let right_val = right[right_i]; + while left_index < left.len() && right_index < right.len() { + let left_val = left[left_index]; + let right_val = right[right_index]; match left_val.cmp(&right_val) { Ordering::Less => { - left_i += 1; + left_index += 1; } Ordering::Equal => { count += 1; - left_i += 1; - right_i += 1; + left_index += 1; + right_index += 1; } Ordering::Greater => { - right_i += 1; + right_index += 1; } } } @@ -105,38 +106,91 @@ fn intersection_count(left: &[u32], right: &[u32]) -> usize { /// /// Returns the length of the intersection fn intersection(left: &mut [u32], right: &[u32]) -> usize { - let mut left_i = 0; - let mut right_i = 0; + let mut left_index = 0; + let mut right_index = 0; let mut count = 0; let left_len = left.len(); let right_len = right.len(); - while left_i < left_len && right_i < right_len { - let left_val = left[left_i]; - let right_val = right[right_i]; + while left_index < left_len && right_index < right_len { + let left_val = left[left_index]; + let right_val = right[right_index]; match left_val.cmp(&right_val) { Ordering::Less => { - left_i += 1; + left_index += 1; } Ordering::Equal => { left[count] = left_val; count += 1; - left_i += 1; - right_i += 1; + left_index += 1; + right_index += 1; } Ordering::Greater => { - right_i += 1; + right_index += 1; } } } count } +/// Intersect twos sorted arrays `left` and `right` and outputs the +/// resulting array in left. +/// +/// Condition for match is that the value stored in left is less than or equal to +/// the value in right and that the distance to the previous token is lte to the slop. +/// +/// Returns the length of the intersection +fn intersection_with_slop(left: &mut [u32], right: &[u32], slop: u32) -> usize { + let mut left_index = 0; + let mut right_index = 0; + let mut count = 0; + let left_len = left.len(); + let right_len = right.len(); + while left_index < left_len && right_index < right_len { + let left_val = left[left_index]; + let right_val = right[right_index]; + + // The three conditions are: + // left_val < right_slop -> left index increment. + // right_slop <= left_val <= right -> find the best match. + // left_val > right -> right index increment. + let right_slop = if right_val >= slop { + right_val - slop + } else { + 0 + }; + + if left_val < right_slop { + left_index += 1; + } else if right_slop <= left_val && left_val <= right_val { + while left_index + 1 < left_len { + // there could be a better match + let next_left_val = left[left_index + 1]; + if next_left_val > right_val { + // the next value is outside the range, so current one is the best. + break; + } + // the next value is better. + left_index += 1; + } + // store the match in left. + left[count] = right_val; + count += 1; + left_index += 1; + right_index += 1; + } else if left_val > right_val { + right_index += 1; + } + } + count +} + impl PhraseScorer { pub fn new( term_postings: Vec<(usize, TPostings)>, similarity_weight: Bm25Weight, fieldnorm_reader: FieldNormReader, scoring_enabled: bool, + slop: u32, ) -> PhraseScorer { let max_offset = term_postings .iter() @@ -159,6 +213,7 @@ impl PhraseScorer { similarity_weight, fieldnorm_reader, scoring_enabled, + slop, }; if scorer.doc() != TERMINATED && !scorer.phrase_match() { scorer.advance(); @@ -181,51 +236,54 @@ impl PhraseScorer { } fn phrase_exists(&mut self) -> bool { - self.intersection_docset - .docset_mut_specialized(0) - .positions(&mut self.left); - let mut intersection_len = self.left.len(); - for i in 1..self.num_terms - 1 { - { - self.intersection_docset - .docset_mut_specialized(i) - .positions(&mut self.right); - } - intersection_len = intersection(&mut self.left[..intersection_len], &self.right[..]); - if intersection_len == 0 { - return false; - } - } - - self.intersection_docset - .docset_mut_specialized(self.num_terms - 1) - .positions(&mut self.right); + let intersection_len = self.compute_phrase_match(); intersection_exists(&self.left[..intersection_len], &self.right[..]) } fn compute_phrase_count(&mut self) -> u32 { + let intersection_len = self.compute_phrase_match(); + intersection_count(&self.left[..intersection_len], &self.right[..]) as u32 + } + + fn compute_phrase_match(&mut self) -> usize { { self.intersection_docset .docset_mut_specialized(0) .positions(&mut self.left); } let mut intersection_len = self.left.len(); - for i in 1..self.num_terms - 1 { + let end_term = if self.has_slop() { + self.num_terms + } else { + self.num_terms - 1 + }; + for i in 1..end_term { { self.intersection_docset .docset_mut_specialized(i) .positions(&mut self.right); } - intersection_len = intersection(&mut self.left[..intersection_len], &self.right[..]); + intersection_len = if self.has_slop() { + intersection_with_slop( + &mut self.left[..intersection_len], + &self.right[..], + self.slop, + ) + } else { + intersection(&mut self.left[..intersection_len], &self.right[..]) + }; if intersection_len == 0 { - return 0u32; + return 0; } } - self.intersection_docset .docset_mut_specialized(self.num_terms - 1) .positions(&mut self.right); - intersection_count(&self.left[..intersection_len], &self.right[..]) as u32 + intersection_len + } + + fn has_slop(&self) -> bool { + self.slop > 0 } } @@ -268,18 +326,26 @@ impl Scorer for PhraseScorer { #[cfg(test)] mod tests { - use super::{intersection, intersection_count}; + use super::{intersection, intersection_count, intersection_with_slop}; fn test_intersection_sym(left: &[u32], right: &[u32], expected: &[u32]) { - test_intersection_aux(left, right, expected); - test_intersection_aux(right, left, expected); + test_intersection_aux(left, right, expected, 0); + test_intersection_aux(right, left, expected, 0); } - fn test_intersection_aux(left: &[u32], right: &[u32], expected: &[u32]) { + fn test_intersection_aux(left: &[u32], right: &[u32], expected: &[u32], slop: u32) { let mut left_vec = Vec::from(left); let left_mut = &mut left_vec[..]; - assert_eq!(intersection_count(left_mut, right), expected.len()); - let count = intersection(left_mut, right); + if slop == 0 { + let left_mut = &mut left_vec[..]; + assert_eq!(intersection_count(left_mut, right), expected.len()); + let count = intersection(left_mut, right); + assert_eq!(&left_mut[..count], expected); + return; + } + let mut right_vec = Vec::from(right); + let right_mut = &mut right_vec[..]; + let count = intersection_with_slop(left_mut, right_mut, slop); assert_eq!(&left_mut[..count], expected); } @@ -291,6 +357,36 @@ mod tests { test_intersection_sym(&[5, 7], &[1, 5, 10, 12], &[5]); test_intersection_sym(&[1, 5, 6, 9, 10, 12], &[6, 8, 9, 12], &[6, 9, 12]); } + #[test] + fn test_slop() { + // The slop is not symetric. It does not allow for the phrase to be out of order. + test_intersection_aux(&[1], &[2], &[2], 1); + test_intersection_aux(&[1], &[3], &[], 1); + test_intersection_aux(&[1], &[3], &[3], 2); + test_intersection_aux(&[], &[2], &[], 100000); + test_intersection_aux(&[5, 7, 11], &[1, 5, 10, 12], &[5, 12], 1); + test_intersection_aux(&[1, 5, 6, 9, 10, 12], &[6, 8, 9, 12], &[6, 9, 12], 1); + test_intersection_aux(&[1, 5, 6, 9, 10, 12], &[6, 8, 9, 12], &[6, 9, 12], 10); + test_intersection_aux(&[1, 3, 5], &[2, 4, 6], &[2, 4, 6], 1); + test_intersection_aux(&[1, 3, 5], &[2, 4, 6], &[], 0); + } + + fn test_merge(left: &[u32], right: &[u32], expected_left: &[u32], slop: u32) { + let mut left_vec = Vec::from(left); + let left_mut = &mut left_vec[..]; + let mut right_vec = Vec::from(right); + let right_mut = &mut right_vec[..]; + let count = intersection_with_slop(left_mut, right_mut, slop); + assert_eq!(&left_mut[..count], expected_left); + } + + #[test] + fn test_merge_slop() { + test_merge(&[1, 2], &[1], &[1], 1); + test_merge(&[3], &[4], &[4], 2); + test_merge(&[3], &[4], &[4], 2); + test_merge(&[1, 5, 6, 9, 10, 12], &[6, 8, 9, 12], &[6, 9, 12], 10); + } } #[cfg(all(test, feature = "unstable"))] diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index 5a8e507fd6..ff39c2e8ec 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -12,6 +12,7 @@ pub struct PhraseWeight { phrase_terms: Vec<(usize, Term)>, similarity_weight: Bm25Weight, scoring_enabled: bool, + slop: u32, } impl PhraseWeight { @@ -21,10 +22,12 @@ impl PhraseWeight { similarity_weight: Bm25Weight, scoring_enabled: bool, ) -> PhraseWeight { + let slop = 0; PhraseWeight { phrase_terms, similarity_weight, scoring_enabled, + slop, } } @@ -74,8 +77,13 @@ impl PhraseWeight { similarity_weight, fieldnorm_reader, self.scoring_enabled, + self.slop, ))) } + + pub fn slop(&mut self, slop: u32) { + self.slop = slop; + } } impl Weight for PhraseWeight {