From 3f88718f387df7969e68ab52ed7714681cfe3273 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Tue, 10 May 2022 16:29:16 +0800 Subject: [PATCH 01/10] refactor aggregations --- src/aggregation/agg_result.rs | 186 +----------------- src/aggregation/bucket/histogram/histogram.rs | 12 +- src/aggregation/bucket/range.rs | 2 +- src/aggregation/collector.rs | 2 +- src/aggregation/intermediate_agg_result.rs | 182 ++++++++++++++++- src/aggregation/mod.rs | 11 +- 6 files changed, 193 insertions(+), 202 deletions(-) diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index 3dba73d1a6..fc614990b4 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -4,21 +4,15 @@ //! intermediate average results, which is the sum and the number of values. The actual average is //! calculated on the step from intermediate to final aggregation result tree. -use std::cmp::Ordering; use std::collections::HashMap; use serde::{Deserialize, Serialize}; -use super::agg_req::{ - Aggregations, AggregationsInternal, BucketAggregationInternal, MetricAggregation, -}; -use super::bucket::{intermediate_buckets_to_final_buckets, GetDocCount}; -use super::intermediate_agg_result::{ - IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, - IntermediateMetricResult, IntermediateRangeBucketEntry, -}; +use super::agg_req::BucketAggregationInternal; +use super::bucket::GetDocCount; +use super::intermediate_agg_result::{IntermediateBucketResult, IntermediateMetricResult}; use super::metric::{SingleMetricResult, Stats}; -use super::{Key, VecWithNames}; +use super::Key; use crate::TantivyError; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] @@ -41,98 +35,6 @@ impl AggregationResults { ))) } } - - /// Convert and intermediate result and its aggregation request to the final result - pub fn from_intermediate_and_req( - results: IntermediateAggregationResults, - agg: Aggregations, - ) -> crate::Result { - AggregationResults::from_intermediate_and_req_internal(results, &(agg.into())) - } - - /// Convert and intermediate result and its aggregation request to the final result - /// - /// Internal function, CollectorAggregations is used instead Aggregations, which is optimized - /// for internal processing, by splitting metric and buckets into seperate groups. - pub(crate) fn from_intermediate_and_req_internal( - intermediate_results: IntermediateAggregationResults, - req: &AggregationsInternal, - ) -> crate::Result { - // Important assumption: - // When the tree contains buckets/metric, we expect it to have all buckets/metrics from the - // request - let mut results: HashMap = HashMap::new(); - - if let Some(buckets) = intermediate_results.buckets { - add_coverted_final_buckets_to_result(&mut results, buckets, &req.buckets)? - } else { - // When there are no buckets, we create empty buckets, so that the serialized json - // format is constant - add_empty_final_buckets_to_result(&mut results, &req.buckets)? - }; - - if let Some(metrics) = intermediate_results.metrics { - add_converted_final_metrics_to_result(&mut results, metrics); - } else { - // When there are no metrics, we create empty metric results, so that the serialized - // json format is constant - add_empty_final_metrics_to_result(&mut results, &req.metrics)?; - } - Ok(Self(results)) - } -} - -fn add_converted_final_metrics_to_result( - results: &mut HashMap, - metrics: VecWithNames, -) { - results.extend( - metrics - .into_iter() - .map(|(key, metric)| (key, AggregationResult::MetricResult(metric.into()))), - ); -} - -fn add_empty_final_metrics_to_result( - results: &mut HashMap, - req_metrics: &VecWithNames, -) -> crate::Result<()> { - results.extend(req_metrics.iter().map(|(key, req)| { - let empty_bucket = IntermediateMetricResult::empty_from_req(req); - ( - key.to_string(), - AggregationResult::MetricResult(empty_bucket.into()), - ) - })); - Ok(()) -} - -fn add_empty_final_buckets_to_result( - results: &mut HashMap, - req_buckets: &VecWithNames, -) -> crate::Result<()> { - let requested_buckets = req_buckets.iter(); - for (key, req) in requested_buckets { - let empty_bucket = AggregationResult::BucketResult(BucketResult::empty_from_req(req)?); - results.insert(key.to_string(), empty_bucket); - } - Ok(()) -} - -fn add_coverted_final_buckets_to_result( - results: &mut HashMap, - buckets: VecWithNames, - req_buckets: &VecWithNames, -) -> crate::Result<()> { - assert_eq!(buckets.len(), req_buckets.len()); - - let buckets_with_request = buckets.into_iter().zip(req_buckets.values()); - for ((key, bucket), req) in buckets_with_request { - let result = - AggregationResult::BucketResult(BucketResult::from_intermediate_and_req(bucket, req)?); - results.insert(key, result); - } - Ok(()) } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -154,7 +56,8 @@ impl AggregationResult { match self { AggregationResult::BucketResult(_bucket) => Err(TantivyError::InternalError( "Tried to retrieve value from bucket aggregation. This is not supported and \ - should not happen during collection, but should be catched during validation" + should not happen during collection phase, but should be catched during \ + validation" .to_string(), )), AggregationResult::MetricResult(metric) => metric.get_value(agg_property), @@ -230,48 +133,7 @@ pub enum BucketResult { impl BucketResult { pub(crate) fn empty_from_req(req: &BucketAggregationInternal) -> crate::Result { let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg); - BucketResult::from_intermediate_and_req(empty_bucket, req) - } - - fn from_intermediate_and_req( - bucket_result: IntermediateBucketResult, - req: &BucketAggregationInternal, - ) -> crate::Result { - match bucket_result { - IntermediateBucketResult::Range(range_res) => { - let mut buckets: Vec = range_res - .buckets - .into_iter() - .map(|(_, bucket)| { - RangeBucketEntry::from_intermediate_and_req(bucket, &req.sub_aggregation) - }) - .collect::>>()?; - - buckets.sort_by(|left, right| { - // TODO use total_cmp next stable rust release - left.from - .unwrap_or(f64::MIN) - .partial_cmp(&right.from.unwrap_or(f64::MIN)) - .unwrap_or(Ordering::Equal) - }); - Ok(BucketResult::Range { buckets }) - } - IntermediateBucketResult::Histogram { buckets } => { - let buckets = intermediate_buckets_to_final_buckets( - buckets, - req.as_histogram() - .expect("unexpected aggregation, expected histogram aggregation"), - &req.sub_aggregation, - )?; - - Ok(BucketResult::Histogram { buckets }) - } - IntermediateBucketResult::Terms(terms) => terms.into_final_result( - req.as_term() - .expect("unexpected aggregation, expected term aggregation"), - &req.sub_aggregation, - ), - } + empty_bucket.into_final_bucket_result(req) } } @@ -311,22 +173,6 @@ pub struct BucketEntry { /// Sub-aggregations in this bucket. pub sub_aggregation: AggregationResults, } - -impl BucketEntry { - pub(crate) fn from_intermediate_and_req( - entry: IntermediateHistogramBucketEntry, - req: &AggregationsInternal, - ) -> crate::Result { - Ok(BucketEntry { - key: Key::F64(entry.key), - doc_count: entry.doc_count, - sub_aggregation: AggregationResults::from_intermediate_and_req_internal( - entry.sub_aggregation, - req, - )?, - }) - } -} impl GetDocCount for &BucketEntry { fn doc_count(&self) -> u64 { self.doc_count @@ -384,21 +230,3 @@ pub struct RangeBucketEntry { #[serde(skip_serializing_if = "Option::is_none")] pub to: Option, } - -impl RangeBucketEntry { - fn from_intermediate_and_req( - entry: IntermediateRangeBucketEntry, - req: &AggregationsInternal, - ) -> crate::Result { - Ok(RangeBucketEntry { - key: entry.key, - doc_count: entry.doc_count, - sub_aggregation: AggregationResults::from_intermediate_and_req_internal( - entry.sub_aggregation, - req, - )?, - to: entry.to, - from: entry.from, - }) - } -} diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index a2a4a87e59..0d5f5574c1 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -482,14 +482,12 @@ fn intermediate_buckets_to_final_buckets_fill_gaps( sub_aggregation: empty_sub_aggregation.clone(), }, }) - .map(|intermediate_bucket| { - BucketEntry::from_intermediate_and_req(intermediate_bucket, sub_aggregation) - }) + .map(|intermediate_bucket| intermediate_bucket.into_final_bucket_entry(sub_aggregation)) .collect::>>() } // Convert to BucketEntry -pub(crate) fn intermediate_buckets_to_final_buckets( +pub(crate) fn intermediate_histogram_buckets_to_final_buckets( buckets: Vec, histogram_req: &HistogramAggregation, sub_aggregation: &AggregationsInternal, @@ -503,8 +501,8 @@ pub(crate) fn intermediate_buckets_to_final_buckets( } else { buckets .into_iter() - .filter(|bucket| bucket.doc_count >= histogram_req.min_doc_count()) - .map(|bucket| BucketEntry::from_intermediate_and_req(bucket, sub_aggregation)) + .filter(|histogram_bucket| histogram_bucket.doc_count >= histogram_req.min_doc_count()) + .map(|histogram_bucket| histogram_bucket.into_final_bucket_entry(sub_aggregation)) .collect::>>() } } @@ -546,7 +544,7 @@ pub(crate) fn generate_buckets_with_opt_minmax( let offset = req.offset.unwrap_or(0.0); let first_bucket_num = get_bucket_num_f64(min, req.interval, offset) as i64; let last_bucket_num = get_bucket_num_f64(max, req.interval, offset) as i64; - let mut buckets = vec![]; + let mut buckets = Vec::with_capacity((first_bucket_num..=last_bucket_num).count()); for bucket_pos in first_bucket_num..=last_bucket_num { let bucket_key = bucket_pos as f64 * req.interval + offset; buckets.push(bucket_key); diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index d570e96a37..69206b1100 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -317,7 +317,7 @@ fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Resu } /// Extends the provided buckets to contain the whole value range, by inserting buckets at the -/// beginning and end. +/// beginning and end and filling gaps. fn extend_validate_ranges( buckets: &[RangeAggregationRange], field_type: &Type, diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index f35a5e3e17..47cd94e6c9 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -87,7 +87,7 @@ impl Collector for AggregationCollector { segment_fruits: Vec<::Fruit>, ) -> crate::Result { let res = merge_fruits(segment_fruits)?; - AggregationResults::from_intermediate_and_req(res, self.agg.clone()) + res.into_final_bucket_result(self.agg.clone()) } } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 9bde00707e..cb2f9f416c 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -3,16 +3,20 @@ //! indices. use std::cmp::Ordering; +use std::collections::HashMap; use fnv::FnvHashMap; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use super::agg_req::{AggregationsInternal, BucketAggregationType, MetricAggregation}; -use super::agg_result::BucketResult; +use super::agg_req::{ + Aggregations, AggregationsInternal, BucketAggregationInternal, BucketAggregationType, + MetricAggregation, +}; +use super::agg_result::{AggregationResult, BucketResult, RangeBucketEntry}; use super::bucket::{ - cut_off_buckets, get_agg_name_and_property, GetDocCount, Order, OrderTarget, - SegmentHistogramBucketEntry, TermsAggregation, + cut_off_buckets, get_agg_name_and_property, intermediate_histogram_buckets_to_final_buckets, + GetDocCount, Order, OrderTarget, SegmentHistogramBucketEntry, TermsAggregation, }; use super::metric::{IntermediateAverage, IntermediateStats}; use super::segment_agg_result::SegmentMetricResultCollector; @@ -31,6 +35,46 @@ pub struct IntermediateAggregationResults { } impl IntermediateAggregationResults { + /// Convert and intermediate result and its aggregation request to the final result + pub(crate) fn into_final_bucket_result( + self, + req: Aggregations, + ) -> crate::Result { + self.into_final_bucket_result_internal(&(req.into())) + } + + /// Convert and intermediate result and its aggregation request to the final result + /// + /// Internal function, AggregationsInternal is used instead Aggregations, which is optimized + /// for internal processing, by splitting metric and buckets into seperate groups. + pub(crate) fn into_final_bucket_result_internal( + self, + req: &AggregationsInternal, + ) -> crate::Result { + // Important assumption: + // When the tree contains buckets/metric, we expect it to have all buckets/metrics from the + // request + let mut results: HashMap = HashMap::new(); + + if let Some(buckets) = self.buckets { + convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets)? + } else { + // When there are no buckets, we create empty buckets, so that the serialized json + // format is constant + add_empty_final_buckets_to_result(&mut results, &req.buckets)? + }; + + if let Some(metrics) = self.metrics { + convert_and_add_final_metrics_to_result(&mut results, metrics); + } else { + // When there are no metrics, we create empty metric results, so that the serialized + // json format is constant + add_empty_final_metrics_to_result(&mut results, &req.metrics)?; + } + + Ok(AggregationResults(results)) + } + pub(crate) fn empty_from_req(req: &AggregationsInternal) -> Self { let metrics = if req.metrics.is_empty() { None @@ -90,6 +134,58 @@ impl IntermediateAggregationResults { } } +fn convert_and_add_final_metrics_to_result( + results: &mut HashMap, + metrics: VecWithNames, +) { + results.extend( + metrics + .into_iter() + .map(|(key, metric)| (key, AggregationResult::MetricResult(metric.into()))), + ); +} + +fn add_empty_final_metrics_to_result( + results: &mut HashMap, + req_metrics: &VecWithNames, +) -> crate::Result<()> { + results.extend(req_metrics.iter().map(|(key, req)| { + let empty_bucket = IntermediateMetricResult::empty_from_req(req); + ( + key.to_string(), + AggregationResult::MetricResult(empty_bucket.into()), + ) + })); + Ok(()) +} + +fn add_empty_final_buckets_to_result( + results: &mut HashMap, + req_buckets: &VecWithNames, +) -> crate::Result<()> { + let requested_buckets = req_buckets.iter(); + for (key, req) in requested_buckets { + let empty_bucket = AggregationResult::BucketResult(BucketResult::empty_from_req(req)?); + results.insert(key.to_string(), empty_bucket); + } + Ok(()) +} + +fn convert_and_add_final_buckets_to_result( + results: &mut HashMap, + buckets: VecWithNames, + req_buckets: &VecWithNames, +) -> crate::Result<()> { + assert_eq!(buckets.len(), req_buckets.len()); + + let buckets_with_request = buckets.into_iter().zip(req_buckets.values()); + for ((key, bucket), req) in buckets_with_request { + let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req)?); + results.insert(key, result); + } + Ok(()) +} + /// An aggregation is either a bucket or a metric. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum IntermediateAggregationResult { @@ -171,6 +267,45 @@ pub enum IntermediateBucketResult { } impl IntermediateBucketResult { + pub(crate) fn into_final_bucket_result( + self, + req: &BucketAggregationInternal, + ) -> crate::Result { + match self { + IntermediateBucketResult::Range(range_res) => { + let mut buckets: Vec = range_res + .buckets + .into_iter() + .map(|(_, bucket)| bucket.into_final_bucket_entry(&req.sub_aggregation)) + .collect::>>()?; + + buckets.sort_by(|left, right| { + // TODO use total_cmp next stable rust release + left.from + .unwrap_or(f64::MIN) + .partial_cmp(&right.from.unwrap_or(f64::MIN)) + .unwrap_or(Ordering::Equal) + }); + Ok(BucketResult::Range { buckets }) + } + IntermediateBucketResult::Histogram { buckets } => { + let buckets = intermediate_histogram_buckets_to_final_buckets( + buckets, + req.as_histogram() + .expect("unexpected aggregation, expected histogram aggregation"), + &req.sub_aggregation, + )?; + + Ok(BucketResult::Histogram { buckets }) + } + IntermediateBucketResult::Terms(terms) => terms.into_final_result( + req.as_term() + .expect("unexpected aggregation, expected term aggregation"), + &req.sub_aggregation, + ), + } + } + pub(crate) fn empty_from_req(req: &BucketAggregationType) -> Self { match req { BucketAggregationType::Terms(_) => IntermediateBucketResult::Terms(Default::default()), @@ -267,10 +402,9 @@ impl IntermediateTermBucketResult { Ok(BucketEntry { key: Key::Str(key), doc_count: entry.doc_count, - sub_aggregation: AggregationResults::from_intermediate_and_req_internal( - entry.sub_aggregation, - sub_aggregation_req, - )?, + sub_aggregation: entry + .sub_aggregation + .into_final_bucket_result_internal(sub_aggregation_req)?, }) }) .collect::>()?; @@ -374,6 +508,21 @@ pub struct IntermediateHistogramBucketEntry { pub sub_aggregation: IntermediateAggregationResults, } +impl IntermediateHistogramBucketEntry { + pub(crate) fn into_final_bucket_entry( + self, + req: &AggregationsInternal, + ) -> crate::Result { + Ok(BucketEntry { + key: Key::F64(self.key), + doc_count: self.doc_count, + sub_aggregation: self + .sub_aggregation + .into_final_bucket_result_internal(req)?, + }) + } +} + impl From for IntermediateHistogramBucketEntry { fn from(entry: SegmentHistogramBucketEntry) -> Self { IntermediateHistogramBucketEntry { @@ -402,6 +551,23 @@ pub struct IntermediateRangeBucketEntry { pub to: Option, } +impl IntermediateRangeBucketEntry { + pub(crate) fn into_final_bucket_entry( + self, + req: &AggregationsInternal, + ) -> crate::Result { + Ok(RangeBucketEntry { + key: self.key, + doc_count: self.doc_count, + sub_aggregation: self + .sub_aggregation + .into_final_bucket_result_internal(req)?, + to: self.to, + from: self.from, + }) + } +} + /// This is the term entry for a bucket, which contains a count, and optionally /// sub_aggregations. #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 37fa05c0ff..dfaaf3265a 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -545,11 +545,10 @@ mod tests { let collector = DistributedAggregationCollector::from_aggs(agg_req.clone()); let searcher = reader.searcher(); - AggregationResults::from_intermediate_and_req( - searcher.search(&AllQuery, &collector).unwrap(), - agg_req, - ) - .unwrap() + let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap(); + intermediate_agg_result + .into_final_bucket_result(agg_req) + .unwrap() } else { let collector = AggregationCollector::from_aggs(agg_req); @@ -985,7 +984,7 @@ mod tests { // Test de/serialization roundtrip on intermediate_agg_result let res: IntermediateAggregationResults = serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap(); - AggregationResults::from_intermediate_and_req(res, agg_req.clone()).unwrap() + res.into_final_bucket_result(agg_req.clone()).unwrap() } else { let collector = AggregationCollector::from_aggs(agg_req.clone()); From a99e5459e3e58064ab8917b81e21e924e5ca0548 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Wed, 11 May 2022 17:05:50 +0800 Subject: [PATCH 02/10] return result from segment collector --- examples/custom_collector.rs | 3 +- src/aggregation/collector.rs | 3 +- src/collector/count_collector.rs | 11 +++--- src/collector/custom_score_top_collector.rs | 3 +- src/collector/docset_collector.rs | 3 +- src/collector/facet_collector.rs | 3 +- src/collector/filter_collector_wrapper.rs | 5 +-- src/collector/histogram_collector.rs | 3 +- src/collector/mod.rs | 38 ++++++++++++--------- src/collector/multi_collector.rs | 16 +++++---- src/collector/tests.rs | 9 +++-- src/collector/top_score_collector.rs | 3 +- src/collector/tweak_score_top_collector.rs | 3 +- 13 files changed, 61 insertions(+), 42 deletions(-) diff --git a/examples/custom_collector.rs b/examples/custom_collector.rs index 7bdc9d06b4..12f846a430 100644 --- a/examples/custom_collector.rs +++ b/examples/custom_collector.rs @@ -102,11 +102,12 @@ struct StatsSegmentCollector { impl SegmentCollector for StatsSegmentCollector { type Fruit = Option; - fn collect(&mut self, doc: u32, _score: Score) { + fn collect(&mut self, doc: u32, _score: Score) -> crate::Result<()> { let value = self.fast_field_reader.get(doc) as f64; self.stats.count += 1; self.stats.sum += value; self.stats.squared_sum += value * value; + Ok(()) } fn harvest(self) -> ::Fruit { diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 47cd94e6c9..3cbbbcdc4e 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -132,8 +132,9 @@ impl SegmentCollector for AggregationSegmentCollector { type Fruit = crate::Result; #[inline] - fn collect(&mut self, doc: crate::DocId, _score: crate::Score) { + fn collect(&mut self, doc: crate::DocId, _score: crate::Score) -> crate::Result<()> { self.result.collect(doc, &self.aggs_with_accessor); + Ok(()) } fn harvest(mut self) -> Self::Fruit { diff --git a/src/collector/count_collector.rs b/src/collector/count_collector.rs index 075a4f36b4..02f30f85c1 100644 --- a/src/collector/count_collector.rs +++ b/src/collector/count_collector.rs @@ -65,8 +65,9 @@ pub struct SegmentCountCollector { impl SegmentCollector for SegmentCountCollector { type Fruit = usize; - fn collect(&mut self, _: DocId, _: Score) { + fn collect(&mut self, _: DocId, _: Score) -> crate::Result<()> { self.count += 1; + Ok(()) } fn harvest(self) -> usize { @@ -92,18 +93,18 @@ mod tests { } { let mut count_collector = SegmentCountCollector::default(); - count_collector.collect(0u32, 1.0); + count_collector.collect(0u32, 1.0).unwrap(); assert_eq!(count_collector.harvest(), 1); } { let mut count_collector = SegmentCountCollector::default(); - count_collector.collect(0u32, 1.0); + count_collector.collect(0u32, 1.0).unwrap(); assert_eq!(count_collector.harvest(), 1); } { let mut count_collector = SegmentCountCollector::default(); - count_collector.collect(0u32, 1.0); - count_collector.collect(1u32, 1.0); + count_collector.collect(0u32, 1.0).unwrap(); + count_collector.collect(1u32, 1.0).unwrap(); assert_eq!(count_collector.harvest(), 2); } } diff --git a/src/collector/custom_score_top_collector.rs b/src/collector/custom_score_top_collector.rs index d645004ade..d597727ed1 100644 --- a/src/collector/custom_score_top_collector.rs +++ b/src/collector/custom_score_top_collector.rs @@ -90,9 +90,10 @@ where { type Fruit = Vec<(TScore, DocAddress)>; - fn collect(&mut self, doc: DocId, _score: Score) { + fn collect(&mut self, doc: DocId, _score: Score) -> crate::Result<()> { let score = self.segment_scorer.score(doc); self.segment_collector.collect(doc, score); + Ok(()) } fn harvest(self) -> Vec<(TScore, DocAddress)> { diff --git a/src/collector/docset_collector.rs b/src/collector/docset_collector.rs index a27a394189..9f6a5fd3bd 100644 --- a/src/collector/docset_collector.rs +++ b/src/collector/docset_collector.rs @@ -50,8 +50,9 @@ pub struct DocSetChildCollector { impl SegmentCollector for DocSetChildCollector { type Fruit = (u32, HashSet); - fn collect(&mut self, doc: crate::DocId, _score: Score) { + fn collect(&mut self, doc: crate::DocId, _score: Score) -> crate::Result<()> { self.docs.insert(doc); + Ok(()) } fn harvest(self) -> (u32, HashSet) { diff --git a/src/collector/facet_collector.rs b/src/collector/facet_collector.rs index e2ef47f989..8ad3311e28 100644 --- a/src/collector/facet_collector.rs +++ b/src/collector/facet_collector.rs @@ -333,7 +333,7 @@ impl Collector for FacetCollector { impl SegmentCollector for FacetSegmentCollector { type Fruit = FacetCounts; - fn collect(&mut self, doc: DocId, _: Score) { + fn collect(&mut self, doc: DocId, _: Score) -> crate::Result<()> { self.reader.facet_ords(doc, &mut self.facet_ords_buf); let mut previous_collapsed_ord: usize = usize::MAX; for &facet_ord in &self.facet_ords_buf { @@ -345,6 +345,7 @@ impl SegmentCollector for FacetSegmentCollector { }; previous_collapsed_ord = collapsed_ord; } + Ok(()) } /// Returns the results of the collection. diff --git a/src/collector/filter_collector_wrapper.rs b/src/collector/filter_collector_wrapper.rs index b1dbaaa203..15e7f80212 100644 --- a/src/collector/filter_collector_wrapper.rs +++ b/src/collector/filter_collector_wrapper.rs @@ -173,11 +173,12 @@ where { type Fruit = TSegmentCollector::Fruit; - fn collect(&mut self, doc: u32, score: Score) { + fn collect(&mut self, doc: u32, score: Score) -> crate::Result<()> { let value = self.fast_field_reader.get(doc); if (self.predicate)(value) { - self.segment_collector.collect(doc, score) + self.segment_collector.collect(doc, score)?; } + Ok(()) } fn harvest(self) -> ::Fruit { diff --git a/src/collector/histogram_collector.rs b/src/collector/histogram_collector.rs index 22956a86a2..fbf398627a 100644 --- a/src/collector/histogram_collector.rs +++ b/src/collector/histogram_collector.rs @@ -91,9 +91,10 @@ pub struct SegmentHistogramCollector { impl SegmentCollector for SegmentHistogramCollector { type Fruit = Vec; - fn collect(&mut self, doc: DocId, _score: Score) { + fn collect(&mut self, doc: DocId, _score: Score) -> crate::Result<()> { let value = self.ff_reader.get(doc); self.histogram_computer.add_value(value); + Ok(()) } fn harvest(self) -> Self::Fruit { diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 1597d7fe45..97b6020340 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -175,12 +175,12 @@ pub trait Collector: Sync + Send { if let Some(alive_bitset) = reader.alive_bitset() { weight.for_each(reader, &mut |doc, score| { if alive_bitset.is_alive(doc) { - segment_collector.collect(doc, score); + segment_collector.collect(doc, score).unwrap(); // TODO } })?; } else { weight.for_each(reader, &mut |doc, score| { - segment_collector.collect(doc, score); + segment_collector.collect(doc, score).unwrap(); // TODO })?; } Ok(segment_collector.harvest()) @@ -190,10 +190,11 @@ pub trait Collector: Sync + Send { impl SegmentCollector for Option { type Fruit = Option; - fn collect(&mut self, doc: DocId, score: Score) { + fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { if let Some(segment_collector) = self { - segment_collector.collect(doc, score); + segment_collector.collect(doc, score)?; } + Ok(()) } fn harvest(self) -> Self::Fruit { @@ -253,7 +254,7 @@ pub trait SegmentCollector: 'static { type Fruit: Fruit; /// The query pushes the scored document to the collector via this method. - fn collect(&mut self, doc: DocId, score: Score); + fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()>; /// Extract the fruit of the collection from the `SegmentCollector`. fn harvest(self) -> Self::Fruit; @@ -308,9 +309,10 @@ where { type Fruit = (Left::Fruit, Right::Fruit); - fn collect(&mut self, doc: DocId, score: Score) { - self.0.collect(doc, score); - self.1.collect(doc, score); + fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { + self.0.collect(doc, score)?; + self.1.collect(doc, score)?; + Ok(()) } fn harvest(self) -> ::Fruit { @@ -372,10 +374,11 @@ where { type Fruit = (One::Fruit, Two::Fruit, Three::Fruit); - fn collect(&mut self, doc: DocId, score: Score) { - self.0.collect(doc, score); - self.1.collect(doc, score); - self.2.collect(doc, score); + fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { + self.0.collect(doc, score)?; + self.1.collect(doc, score)?; + self.2.collect(doc, score)?; + Ok(()) } fn harvest(self) -> ::Fruit { @@ -446,11 +449,12 @@ where { type Fruit = (One::Fruit, Two::Fruit, Three::Fruit, Four::Fruit); - fn collect(&mut self, doc: DocId, score: Score) { - self.0.collect(doc, score); - self.1.collect(doc, score); - self.2.collect(doc, score); - self.3.collect(doc, score); + fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { + self.0.collect(doc, score)?; + self.1.collect(doc, score)?; + self.2.collect(doc, score)?; + self.3.collect(doc, score)?; + Ok(()) } fn harvest(self) -> ::Fruit { diff --git a/src/collector/multi_collector.rs b/src/collector/multi_collector.rs index 039902ff4f..7b119ad868 100644 --- a/src/collector/multi_collector.rs +++ b/src/collector/multi_collector.rs @@ -52,8 +52,9 @@ impl Collector for CollectorWrapper { impl SegmentCollector for Box { type Fruit = Box; - fn collect(&mut self, doc: u32, score: Score) { - self.as_mut().collect(doc, score); + fn collect(&mut self, doc: u32, score: Score) -> crate::Result<()> { + self.as_mut().collect(doc, score)?; + Ok(()) } fn harvest(self) -> Box { @@ -62,7 +63,7 @@ impl SegmentCollector for Box { } pub trait BoxableSegmentCollector { - fn collect(&mut self, doc: u32, score: Score); + fn collect(&mut self, doc: u32, score: Score) -> crate::Result<()>; fn harvest_from_box(self: Box) -> Box; } @@ -71,8 +72,8 @@ pub struct SegmentCollectorWrapper(TSegment impl BoxableSegmentCollector for SegmentCollectorWrapper { - fn collect(&mut self, doc: u32, score: Score) { - self.0.collect(doc, score); + fn collect(&mut self, doc: u32, score: Score) -> crate::Result<()> { + self.0.collect(doc, score) } fn harvest_from_box(self: Box) -> Box { @@ -228,10 +229,11 @@ pub struct MultiCollectorChild { impl SegmentCollector for MultiCollectorChild { type Fruit = MultiFruit; - fn collect(&mut self, doc: DocId, score: Score) { + fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { for child in &mut self.children { - child.collect(doc, score); + child.collect(doc, score)?; } + Ok(()) } fn harvest(self) -> MultiFruit { diff --git a/src/collector/tests.rs b/src/collector/tests.rs index 3bda822a10..5e0a0cfb2d 100644 --- a/src/collector/tests.rs +++ b/src/collector/tests.rs @@ -138,9 +138,10 @@ impl Collector for TestCollector { impl SegmentCollector for TestSegmentCollector { type Fruit = TestFruit; - fn collect(&mut self, doc: DocId, score: Score) { + fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { self.fruit.docs.push(DocAddress::new(self.segment_id, doc)); self.fruit.scores.push(score); + Ok(()) } fn harvest(self) -> ::Fruit { @@ -198,9 +199,10 @@ impl Collector for FastFieldTestCollector { impl SegmentCollector for FastFieldSegmentCollector { type Fruit = Vec; - fn collect(&mut self, doc: DocId, _score: Score) { + fn collect(&mut self, doc: DocId, _score: Score) -> crate::Result<()> { let val = self.reader.get(doc); self.vals.push(val); + Ok(()) } fn harvest(self) -> Vec { @@ -255,9 +257,10 @@ impl Collector for BytesFastFieldTestCollector { impl SegmentCollector for BytesFastFieldSegmentCollector { type Fruit = Vec; - fn collect(&mut self, doc: u32, _score: Score) { + fn collect(&mut self, doc: u32, _score: Score) -> crate::Result<()> { let data = self.reader.get_bytes(doc); self.vals.extend(data); + Ok(()) } fn harvest(self) -> ::Fruit { diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 516dedcb58..e0e3aeb9dc 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -699,8 +699,9 @@ pub struct TopScoreSegmentCollector(TopSegmentCollector); impl SegmentCollector for TopScoreSegmentCollector { type Fruit = Vec<(Score, DocAddress)>; - fn collect(&mut self, doc: DocId, score: Score) { + fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { self.0.collect(doc, score); + Ok(()) } fn harvest(self) -> Vec<(Score, DocAddress)> { diff --git a/src/collector/tweak_score_top_collector.rs b/src/collector/tweak_score_top_collector.rs index 1a81e73616..cb8d358abe 100644 --- a/src/collector/tweak_score_top_collector.rs +++ b/src/collector/tweak_score_top_collector.rs @@ -93,9 +93,10 @@ where { type Fruit = Vec<(TScore, DocAddress)>; - fn collect(&mut self, doc: DocId, score: Score) { + fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { let score = self.segment_scorer.score(doc, score); self.segment_collector.collect(doc, score); + Ok(()) } fn harvest(self) -> Vec<(TScore, DocAddress)> { From 6a4632211a3f23b2051c42157959eaa3255cd0eb Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Wed, 11 May 2022 18:03:29 +0800 Subject: [PATCH 03/10] forward error in aggregation collect --- src/aggregation/agg_req_with_accessor.rs | 2 ++ src/aggregation/bucket/histogram/histogram.rs | 25 ++++++++++-------- src/aggregation/bucket/range.rs | 20 +++++++------- src/aggregation/bucket/term_agg.rs | 26 ++++++++++--------- src/aggregation/collector.rs | 4 +-- src/aggregation/segment_agg_result.rs | 21 ++++++++------- 6 files changed, 55 insertions(+), 43 deletions(-) diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 8ed82ac5c6..ed44ab6ef9 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -1,5 +1,7 @@ //! This will enhance the request tree with access to the fastfield and metadata. +use std::rc::Rc; +use std::sync::atomic::AtomicU32; use std::sync::Arc; use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation}; diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 0d5f5574c1..c2a7e3472d 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -311,7 +311,7 @@ impl SegmentHistogramCollector { doc: &[DocId], bucket_with_accessor: &BucketAggregationWithAccessor, force_flush: bool, - ) { + ) -> crate::Result<()> { let bounds = self.bounds; let interval = self.interval; let offset = self.offset; @@ -341,28 +341,28 @@ impl SegmentHistogramCollector { bucket_pos0, docs[0], &bucket_with_accessor.sub_aggregation, - ); + )?; self.increment_bucket_if_in_bounds( val1, &bounds, bucket_pos1, docs[1], &bucket_with_accessor.sub_aggregation, - ); + )?; self.increment_bucket_if_in_bounds( val2, &bounds, bucket_pos2, docs[2], &bucket_with_accessor.sub_aggregation, - ); + )?; self.increment_bucket_if_in_bounds( val3, &bounds, bucket_pos3, docs[3], &bucket_with_accessor.sub_aggregation, - ); + )?; } for doc in iter.remainder() { let val = f64_from_fastfield_u64(accessor.get(*doc), &self.field_type); @@ -376,16 +376,17 @@ impl SegmentHistogramCollector { self.buckets[bucket_pos].key, get_bucket_val(val, self.interval, self.offset) as f64 ); - self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation); + self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation)?; } if force_flush { if let Some(sub_aggregations) = self.sub_aggregations.as_mut() { for sub_aggregation in sub_aggregations { sub_aggregation - .flush_staged_docs(&bucket_with_accessor.sub_aggregation, force_flush); + .flush_staged_docs(&bucket_with_accessor.sub_aggregation, force_flush)?; } } } + Ok(()) } #[inline] @@ -396,15 +397,16 @@ impl SegmentHistogramCollector { bucket_pos: usize, doc: DocId, bucket_with_accessor: &AggregationsWithAccessor, - ) { + ) -> crate::Result<()> { if bounds.contains(val) { debug_assert_eq!( self.buckets[bucket_pos].key, get_bucket_val(val, self.interval, self.offset) as f64 ); - self.increment_bucket(bucket_pos, doc, bucket_with_accessor); + self.increment_bucket(bucket_pos, doc, bucket_with_accessor)?; } + Ok(()) } #[inline] @@ -413,12 +415,13 @@ impl SegmentHistogramCollector { bucket_pos: usize, doc: DocId, bucket_with_accessor: &AggregationsWithAccessor, - ) { + ) -> crate::Result<()> { let bucket = &mut self.buckets[bucket_pos]; bucket.doc_count += 1; if let Some(sub_aggregation) = self.sub_aggregations.as_mut() { - (&mut sub_aggregation[bucket_pos]).collect(doc, bucket_with_accessor); + (&mut sub_aggregation[bucket_pos]).collect(doc, bucket_with_accessor)?; } + Ok(()) } fn f64_from_fastfield_u64(&self, val: u64) -> f64 { diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 69206b1100..610453e036 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -224,7 +224,7 @@ impl SegmentRangeCollector { doc: &[DocId], bucket_with_accessor: &BucketAggregationWithAccessor, force_flush: bool, - ) { + ) -> crate::Result<()> { let mut iter = doc.chunks_exact(4); let accessor = bucket_with_accessor .accessor @@ -240,24 +240,25 @@ impl SegmentRangeCollector { let bucket_pos3 = self.get_bucket_pos(val3); let bucket_pos4 = self.get_bucket_pos(val4); - self.increment_bucket(bucket_pos1, docs[0], &bucket_with_accessor.sub_aggregation); - self.increment_bucket(bucket_pos2, docs[1], &bucket_with_accessor.sub_aggregation); - self.increment_bucket(bucket_pos3, docs[2], &bucket_with_accessor.sub_aggregation); - self.increment_bucket(bucket_pos4, docs[3], &bucket_with_accessor.sub_aggregation); + self.increment_bucket(bucket_pos1, docs[0], &bucket_with_accessor.sub_aggregation)?; + self.increment_bucket(bucket_pos2, docs[1], &bucket_with_accessor.sub_aggregation)?; + self.increment_bucket(bucket_pos3, docs[2], &bucket_with_accessor.sub_aggregation)?; + self.increment_bucket(bucket_pos4, docs[3], &bucket_with_accessor.sub_aggregation)?; } for doc in iter.remainder() { let val = accessor.get(*doc); let bucket_pos = self.get_bucket_pos(val); - self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation); + self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation)?; } if force_flush { for bucket in &mut self.buckets { if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation { sub_aggregation - .flush_staged_docs(&bucket_with_accessor.sub_aggregation, force_flush); + .flush_staged_docs(&bucket_with_accessor.sub_aggregation, force_flush)?; } } } + Ok(()) } #[inline] @@ -266,13 +267,14 @@ impl SegmentRangeCollector { bucket_pos: usize, doc: DocId, bucket_with_accessor: &AggregationsWithAccessor, - ) { + ) -> crate::Result<()> { let bucket = &mut self.buckets[bucket_pos]; bucket.bucket.doc_count += 1; if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation { - sub_aggregation.collect(doc, bucket_with_accessor); + sub_aggregation.collect(doc, bucket_with_accessor)?; } + Ok(()) } #[inline] diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index c9833c8853..1ed63d3e65 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -246,8 +246,7 @@ impl TermBuckets { doc: DocId, bucket_with_accessor: &AggregationsWithAccessor, blueprint: &Option, - ) { - // self.ensure_vec_exists(term_ids); + ) -> crate::Result<()> { for &term_id in term_ids { let entry = self .entries @@ -255,17 +254,19 @@ impl TermBuckets { .or_insert_with(|| TermBucketEntry::from_blueprint(blueprint)); entry.doc_count += 1; if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { - sub_aggregations.collect(doc, bucket_with_accessor); + sub_aggregations.collect(doc, bucket_with_accessor)?; } } + Ok(()) } - fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) { + fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> { for entry in &mut self.entries.values_mut() { if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { - sub_aggregations.flush_staged_docs(agg_with_accessor, false); + sub_aggregations.flush_staged_docs(agg_with_accessor, false)?; } } + Ok(()) } } @@ -421,7 +422,7 @@ impl SegmentTermCollector { doc: &[DocId], bucket_with_accessor: &BucketAggregationWithAccessor, force_flush: bool, - ) { + ) -> crate::Result<()> { let accessor = bucket_with_accessor .accessor .as_multi() @@ -442,25 +443,25 @@ impl SegmentTermCollector { docs[0], &bucket_with_accessor.sub_aggregation, &self.blueprint, - ); + )?; self.term_buckets.increment_bucket( &vals2, docs[1], &bucket_with_accessor.sub_aggregation, &self.blueprint, - ); + )?; self.term_buckets.increment_bucket( &vals3, docs[2], &bucket_with_accessor.sub_aggregation, &self.blueprint, - ); + )?; self.term_buckets.increment_bucket( &vals4, docs[3], &bucket_with_accessor.sub_aggregation, &self.blueprint, - ); + )?; } for &doc in iter.remainder() { accessor.get_vals(doc, &mut vals1); @@ -470,12 +471,13 @@ impl SegmentTermCollector { doc, &bucket_with_accessor.sub_aggregation, &self.blueprint, - ); + )?; } if force_flush { self.term_buckets - .force_flush(&bucket_with_accessor.sub_aggregation); + .force_flush(&bucket_with_accessor.sub_aggregation)?; } + Ok(()) } } diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 3cbbbcdc4e..69913931cd 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -133,13 +133,13 @@ impl SegmentCollector for AggregationSegmentCollector { #[inline] fn collect(&mut self, doc: crate::DocId, _score: crate::Score) -> crate::Result<()> { - self.result.collect(doc, &self.aggs_with_accessor); + self.result.collect(doc, &self.aggs_with_accessor)?; Ok(()) } fn harvest(mut self) -> Self::Fruit { self.result - .flush_staged_docs(&self.aggs_with_accessor, true); + .flush_staged_docs(&self.aggs_with_accessor, true)?; self.result .into_intermediate_aggregations_result(&self.aggs_with_accessor) } diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 81f2b85de9..121fb4cf3f 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -115,21 +115,22 @@ impl SegmentAggregationResultsCollector { &mut self, doc: crate::DocId, agg_with_accessor: &AggregationsWithAccessor, - ) { + ) -> crate::Result<()> { self.staged_docs[self.num_staged_docs] = doc; self.num_staged_docs += 1; if self.num_staged_docs == self.staged_docs.len() { - self.flush_staged_docs(agg_with_accessor, false); + self.flush_staged_docs(agg_with_accessor, false)?; } + Ok(()) } pub(crate) fn flush_staged_docs( &mut self, agg_with_accessor: &AggregationsWithAccessor, force_flush: bool, - ) { + ) -> crate::Result<()> { if self.num_staged_docs == 0 { - return; + return Ok(()); } if let Some(metrics) = &mut self.metrics { for (collector, agg_with_accessor) in @@ -148,11 +149,12 @@ impl SegmentAggregationResultsCollector { &self.staged_docs[..self.num_staged_docs], agg_with_accessor, force_flush, - ); + )?; } } self.num_staged_docs = 0; + Ok(()) } } @@ -256,17 +258,18 @@ impl SegmentBucketResultCollector { doc: &[DocId], bucket_with_accessor: &BucketAggregationWithAccessor, force_flush: bool, - ) { + ) -> crate::Result<()> { match self { SegmentBucketResultCollector::Range(range) => { - range.collect_block(doc, bucket_with_accessor, force_flush); + range.collect_block(doc, bucket_with_accessor, force_flush)?; } SegmentBucketResultCollector::Histogram(histogram) => { - histogram.collect_block(doc, bucket_with_accessor, force_flush) + histogram.collect_block(doc, bucket_with_accessor, force_flush)?; } SegmentBucketResultCollector::Terms(terms) => { - terms.collect_block(doc, bucket_with_accessor, force_flush) + terms.collect_block(doc, bucket_with_accessor, force_flush)?; } } + Ok(()) } } From 11ac4512504926e8b301ba5fc593c46f066f0430 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Thu, 12 May 2022 12:18:29 +0800 Subject: [PATCH 04/10] abort aggregation when too many buckets are created Validation happens on different phases depending on the aggregation Term: During segment collection Histogram: At the end when converting in intermediate buckets (we preallocate empty buckets for the range) Revisit after #1370 Range: When validating the request update CHANGELOG --- CHANGELOG.md | 1 + examples/custom_collector.rs | 2 +- src/aggregation/agg_req_with_accessor.rs | 5 ++ src/aggregation/bucket/histogram/histogram.rs | 9 ++- src/aggregation/bucket/range.rs | 26 +++++++-- src/aggregation/bucket/term_agg.rs | 57 +++++++++++++++---- src/aggregation/mod.rs | 15 +++-- src/aggregation/segment_agg_result.rs | 14 ++++- src/collector/mod.rs | 6 +- src/query/boolean_query/boolean_weight.rs | 6 +- src/query/term_query/term_weight.rs | 4 +- src/query/weight.rs | 11 ++-- 12 files changed, 118 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 00d88d2938..4282b830a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ Unreleased - Add [histogram](https://github.com/quickwit-oss/tantivy/pull/1306) aggregation (@PSeitz) - Add support for fastfield on text fields (@PSeitz) - Add terms aggregation (@PSeitz) +- API Change: `SegmentCollector.collect` changed to return a `Result`. Tantivy 0.17 ================================ diff --git a/examples/custom_collector.rs b/examples/custom_collector.rs index 12f846a430..14a9d20358 100644 --- a/examples/custom_collector.rs +++ b/examples/custom_collector.rs @@ -102,7 +102,7 @@ struct StatsSegmentCollector { impl SegmentCollector for StatsSegmentCollector { type Fruit = Option; - fn collect(&mut self, doc: u32, _score: Score) -> crate::Result<()> { + fn collect(&mut self, doc: u32, _score: Score) -> tantivy::Result<()> { let value = self.fast_field_reader.get(doc) as f64; self.stats.count += 1; self.stats.sum += value; diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index ed44ab6ef9..10597b3de0 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -62,6 +62,7 @@ pub struct BucketAggregationWithAccessor { pub(crate) field_type: Type, pub(crate) bucket_agg: BucketAggregationType, pub(crate) sub_aggregation: AggregationsWithAccessor, + pub(crate) bucket_count: Rc, } impl BucketAggregationWithAccessor { @@ -69,6 +70,7 @@ impl BucketAggregationWithAccessor { bucket: &BucketAggregationType, sub_aggregation: &Aggregations, reader: &SegmentReader, + bucket_count: Rc, ) -> crate::Result { let mut inverted_index = None; let (accessor, field_type) = match &bucket { @@ -97,6 +99,7 @@ impl BucketAggregationWithAccessor { sub_aggregation: get_aggs_with_accessor_and_validate(&sub_aggregation, reader)?, bucket_agg: bucket.clone(), inverted_index, + bucket_count, }) } } @@ -137,6 +140,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate( aggs: &Aggregations, reader: &SegmentReader, ) -> crate::Result { + let bucket_count: Rc = Default::default(); let mut metrics = vec![]; let mut buckets = vec![]; for (key, agg) in aggs.iter() { @@ -147,6 +151,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate( &bucket.bucket_agg, &bucket.sub_aggregation, reader, + Rc::clone(&bucket_count), )?, )), Aggregation::Metric(metric) => metrics.push(( diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index c2a7e3472d..69111c71fc 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -13,7 +13,9 @@ use crate::aggregation::f64_from_fastfield_u64; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, }; -use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; +use crate::aggregation::segment_agg_result::{ + validate_bucket_count, SegmentAggregationResultsCollector, +}; use crate::fastfield::{DynamicFastFieldReader, FastFieldReader}; use crate::schema::Type; use crate::{DocId, TantivyError}; @@ -250,6 +252,11 @@ impl SegmentHistogramCollector { ); }; + agg_with_accessor + .bucket_count + .fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed); + validate_bucket_count(&agg_with_accessor.bucket_count)?; + Ok(IntermediateBucketResult::Histogram { buckets }) } diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 610453e036..590165c158 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -1,6 +1,9 @@ use std::fmt::Debug; use std::ops::Range; +use std::rc::Rc; +use std::sync::atomic::AtomicU32; +use fnv::FnvHashMap; use serde::{Deserialize, Serialize}; use crate::aggregation::agg_req_with_accessor::{ @@ -9,8 +12,10 @@ use crate::aggregation::agg_req_with_accessor::{ use crate::aggregation::intermediate_agg_result::{ IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; -use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key}; +use crate::aggregation::segment_agg_result::{ + validate_bucket_count, SegmentAggregationResultsCollector, +}; +use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key, SerializedKey}; use crate::fastfield::FastFieldReader; use crate::schema::Type; use crate::{DocId, TantivyError}; @@ -153,7 +158,7 @@ impl SegmentRangeCollector { ) -> crate::Result { let field_type = self.field_type; - let buckets = self + let buckets: FnvHashMap = self .buckets .into_iter() .map(move |range_bucket| { @@ -174,12 +179,13 @@ impl SegmentRangeCollector { pub(crate) fn from_req_and_validate( req: &RangeAggregation, sub_aggregation: &AggregationsWithAccessor, + bucket_count: &Rc, field_type: Type, ) -> crate::Result { // The range input on the request is f64. // We need to convert to u64 ranges, because we read the values as u64. // The mapping from the conversion is monotonic so ordering is preserved. - let buckets = extend_validate_ranges(&req.ranges, &field_type)? + let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)? .iter() .map(|range| { let to = if range.end == u64::MAX { @@ -212,6 +218,9 @@ impl SegmentRangeCollector { }) .collect::>()?; + bucket_count.fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed); + validate_bucket_count(bucket_count)?; + Ok(SegmentRangeCollector { buckets, field_type, @@ -403,8 +412,13 @@ mod tests { ranges, }; - SegmentRangeCollector::from_req_and_validate(&req, &Default::default(), field_type) - .expect("unexpected error") + SegmentRangeCollector::from_req_and_validate( + &req, + &Default::default(), + &Default::default(), + field_type, + ) + .expect("unexpected error") } #[test] diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 1ed63d3e65..312522017e 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -11,7 +11,9 @@ use crate::aggregation::agg_req_with_accessor::{ use crate::aggregation::intermediate_agg_result::{ IntermediateBucketResult, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; +use crate::aggregation::segment_agg_result::{ + validate_bucket_count, SegmentAggregationResultsCollector, +}; use crate::error::DataCorruption; use crate::fastfield::MultiValuedFastFieldReader; use crate::schema::Type; @@ -244,19 +246,23 @@ impl TermBuckets { &mut self, term_ids: &[u64], doc: DocId, - bucket_with_accessor: &AggregationsWithAccessor, + bucket_with_accessor: &BucketAggregationWithAccessor, blueprint: &Option, ) -> crate::Result<()> { for &term_id in term_ids { - let entry = self - .entries - .entry(term_id as u32) - .or_insert_with(|| TermBucketEntry::from_blueprint(blueprint)); + let entry = self.entries.entry(term_id as u32).or_insert_with(|| { + bucket_with_accessor + .bucket_count + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + TermBucketEntry::from_blueprint(blueprint) + }); entry.doc_count += 1; if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { - sub_aggregations.collect(doc, bucket_with_accessor)?; + sub_aggregations.collect(doc, &bucket_with_accessor.sub_aggregation)?; } } + validate_bucket_count(&bucket_with_accessor.bucket_count)?; Ok(()) } @@ -441,25 +447,25 @@ impl SegmentTermCollector { self.term_buckets.increment_bucket( &vals1, docs[0], - &bucket_with_accessor.sub_aggregation, + bucket_with_accessor, &self.blueprint, )?; self.term_buckets.increment_bucket( &vals2, docs[1], - &bucket_with_accessor.sub_aggregation, + bucket_with_accessor, &self.blueprint, )?; self.term_buckets.increment_bucket( &vals3, docs[2], - &bucket_with_accessor.sub_aggregation, + bucket_with_accessor, &self.blueprint, )?; self.term_buckets.increment_bucket( &vals4, docs[3], - &bucket_with_accessor.sub_aggregation, + bucket_with_accessor, &self.blueprint, )?; } @@ -469,7 +475,7 @@ impl SegmentTermCollector { self.term_buckets.increment_bucket( &vals1, doc, - &bucket_with_accessor.sub_aggregation, + bucket_with_accessor, &self.blueprint, )?; } @@ -1175,6 +1181,33 @@ mod tests { Ok(()) } + #[test] + fn terms_aggregation_term_bucket_limit() -> crate::Result<()> { + let terms: Vec = (0..100_000).map(|el| el.to_string()).collect(); + let terms_per_segment = vec![terms.iter().map(|el| el.as_str()).collect()]; + + let index = get_test_index_from_terms(true, &terms_per_segment)?; + + let agg_req: Aggregations = vec![( + "my_texts".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "string_id".to_string(), + min_doc_count: Some(0), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request_with_query(agg_req, &index, None); + assert!(res.is_err()); + + Ok(()) + } + #[test] fn test_json_format() -> crate::Result<()> { let agg_req: Aggregations = vec![( diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index dfaaf3265a..7fe2d82847 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -417,7 +417,9 @@ mod tests { let mut schema_builder = Schema::builder(); let text_fieldtype = crate::schema::TextOptions::default() .set_indexing_options( - TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs), + TextFieldIndexing::default() + .set_index_option(IndexRecordOption::Basic) + .set_fieldnorms(false), ) .set_fast() .set_stored(); @@ -435,7 +437,8 @@ mod tests { ); let index = Index::create_in_ram(schema_builder.build()); { - let mut index_writer = index.writer_for_tests()?; + // let mut index_writer = index.writer_for_tests()?; + let mut index_writer = index.writer_with_num_threads(1, 30_000_000)?; for values in segment_and_values { for (i, term) in values { let i = *i; @@ -457,9 +460,11 @@ mod tests { let segment_ids = index .searchable_segment_ids() .expect("Searchable segments failed."); - let mut index_writer = index.writer_for_tests()?; - index_writer.merge(&segment_ids).wait()?; - index_writer.wait_merging_threads()?; + if segment_ids.len() > 1 { + let mut index_writer = index.writer_for_tests()?; + index_writer.merge(&segment_ids).wait()?; + index_writer.wait_merging_threads()?; + } } Ok(index) diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 121fb4cf3f..57c545f7d6 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -4,6 +4,8 @@ //! merging. use std::fmt::Debug; +use std::rc::Rc; +use std::sync::atomic::AtomicU32; use super::agg_req::MetricAggregation; use super::agg_req_with_accessor::{ @@ -16,7 +18,7 @@ use super::metric::{ }; use super::VecWithNames; use crate::aggregation::agg_req::BucketAggregationType; -use crate::DocId; +use crate::{DocId, TantivyError}; pub(crate) const DOC_BLOCK_SIZE: usize = 64; pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE]; @@ -236,6 +238,7 @@ impl SegmentBucketResultCollector { Ok(Self::Range(SegmentRangeCollector::from_req_and_validate( range_req, &req.sub_aggregation, + &req.bucket_count, req.field_type, )?)) } @@ -273,3 +276,12 @@ impl SegmentBucketResultCollector { Ok(()) } } + +pub(crate) fn validate_bucket_count(bucket_count: &Rc) -> crate::Result<()> { + if bucket_count.load(std::sync::atomic::Ordering::Relaxed) > 65000 { + return Err(TantivyError::InvalidArgument( + "Aborting aggregation because too many buckets were created".to_string(), + )); + } + Ok(()) +} diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 97b6020340..6d600bb488 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -175,12 +175,14 @@ pub trait Collector: Sync + Send { if let Some(alive_bitset) = reader.alive_bitset() { weight.for_each(reader, &mut |doc, score| { if alive_bitset.is_alive(doc) { - segment_collector.collect(doc, score).unwrap(); // TODO + segment_collector.collect(doc, score)?; } + Ok(()) })?; } else { weight.for_each(reader, &mut |doc, score| { - segment_collector.collect(doc, score).unwrap(); // TODO + segment_collector.collect(doc, score)?; + Ok(()) })?; } Ok(segment_collector.harvest()) diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index f2a5fd376a..9024a6abb8 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -186,17 +186,17 @@ impl Weight for BooleanWeight { fn for_each( &self, reader: &SegmentReader, - callback: &mut dyn FnMut(DocId, Score), + callback: &mut dyn FnMut(DocId, Score) -> crate::Result<()>, ) -> crate::Result<()> { let scorer = self.complex_scorer::(reader, 1.0)?; match scorer { SpecializedScorer::TermUnion(term_scorers) => { let mut union_scorer = Union::::from(term_scorers); - for_each_scorer(&mut union_scorer, callback); + for_each_scorer(&mut union_scorer, callback)?; } SpecializedScorer::Other(mut scorer) => { - for_each_scorer(scorer.as_mut(), callback); + for_each_scorer(scorer.as_mut(), callback)?; } } Ok(()) diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index 4e742bc444..e1529027b2 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -49,10 +49,10 @@ impl Weight for TermWeight { fn for_each( &self, reader: &SegmentReader, - callback: &mut dyn FnMut(DocId, Score), + callback: &mut dyn FnMut(DocId, Score) -> crate::Result<()>, ) -> crate::Result<()> { let mut scorer = self.specialized_scorer(reader, 1.0)?; - for_each_scorer(&mut scorer, callback); + for_each_scorer(&mut scorer, callback)?; Ok(()) } diff --git a/src/query/weight.rs b/src/query/weight.rs index 3a2ff3d33c..26bdfb0edb 100644 --- a/src/query/weight.rs +++ b/src/query/weight.rs @@ -7,13 +7,14 @@ use crate::{DocId, Score, TERMINATED}; /// `DocSet` and push the scored documents to the collector. pub(crate) fn for_each_scorer( scorer: &mut TScorer, - callback: &mut dyn FnMut(DocId, Score), -) { + callback: &mut dyn FnMut(DocId, Score) -> crate::Result<()>, +) -> crate::Result<()> { let mut doc = scorer.doc(); while doc != TERMINATED { - callback(doc, scorer.score()); + callback(doc, scorer.score())?; doc = scorer.advance(); } + Ok(()) } /// Calls `callback` with all of the `(doc, score)` for which score @@ -71,10 +72,10 @@ pub trait Weight: Send + Sync + 'static { fn for_each( &self, reader: &SegmentReader, - callback: &mut dyn FnMut(DocId, Score), + callback: &mut dyn FnMut(DocId, Score) -> crate::Result<()>, ) -> crate::Result<()> { let mut scorer = self.scorer(reader, 1.0)?; - for_each_scorer(scorer.as_mut(), callback); + for_each_scorer(scorer.as_mut(), callback)?; Ok(()) } From 44ea7313ca85265945090b18588ce863beb54818 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Thu, 12 May 2022 21:04:25 +0800 Subject: [PATCH 05/10] set max bucket size as parameter --- examples/aggregation.rs | 2 +- src/aggregation/agg_req_with_accessor.rs | 20 +++++++-- src/aggregation/bucket/histogram/histogram.rs | 8 ++-- src/aggregation/bucket/range.rs | 14 +++---- src/aggregation/bucket/term_agg.rs | 39 ++++++++++------- src/aggregation/collector.rs | 42 +++++++++++++++---- src/aggregation/intermediate_agg_result.rs | 4 +- src/aggregation/metric/stats.rs | 4 +- src/aggregation/mod.rs | 42 +++++++++---------- src/aggregation/segment_agg_result.rs | 38 ++++++++++++++--- 10 files changed, 140 insertions(+), 73 deletions(-) diff --git a/examples/aggregation.rs b/examples/aggregation.rs index 82cc0fccd3..ae11dc5a5a 100644 --- a/examples/aggregation.rs +++ b/examples/aggregation.rs @@ -117,7 +117,7 @@ fn main() -> tantivy::Result<()> { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 10597b3de0..491faf2137 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation}; use super::bucket::{HistogramAggregation, RangeAggregation, TermsAggregation}; use super::metric::{AverageAggregation, StatsAggregation}; +use super::segment_agg_result::BucketCount; use super::VecWithNames; use crate::fastfield::{ type_and_cardinality, DynamicFastFieldReader, FastType, MultiValuedFastFieldReader, @@ -62,7 +63,7 @@ pub struct BucketAggregationWithAccessor { pub(crate) field_type: Type, pub(crate) bucket_agg: BucketAggregationType, pub(crate) sub_aggregation: AggregationsWithAccessor, - pub(crate) bucket_count: Rc, + pub(crate) bucket_count: BucketCount, } impl BucketAggregationWithAccessor { @@ -71,6 +72,7 @@ impl BucketAggregationWithAccessor { sub_aggregation: &Aggregations, reader: &SegmentReader, bucket_count: Rc, + max_bucket_count: u32, ) -> crate::Result { let mut inverted_index = None; let (accessor, field_type) = match &bucket { @@ -96,10 +98,18 @@ impl BucketAggregationWithAccessor { Ok(BucketAggregationWithAccessor { accessor, field_type, - sub_aggregation: get_aggs_with_accessor_and_validate(&sub_aggregation, reader)?, + sub_aggregation: get_aggs_with_accessor_and_validate( + &sub_aggregation, + reader, + bucket_count.clone(), + max_bucket_count, + )?, bucket_agg: bucket.clone(), inverted_index, - bucket_count, + bucket_count: BucketCount { + bucket_count, + max_bucket_count, + }, }) } } @@ -139,8 +149,9 @@ impl MetricAggregationWithAccessor { pub(crate) fn get_aggs_with_accessor_and_validate( aggs: &Aggregations, reader: &SegmentReader, + bucket_count: Rc, + max_bucket_count: u32, ) -> crate::Result { - let bucket_count: Rc = Default::default(); let mut metrics = vec![]; let mut buckets = vec![]; for (key, agg) in aggs.iter() { @@ -152,6 +163,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate( &bucket.sub_aggregation, reader, Rc::clone(&bucket_count), + max_bucket_count, )?, )), Aggregation::Metric(metric) => metrics.push(( diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 69111c71fc..70acf0f117 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -13,9 +13,7 @@ use crate::aggregation::f64_from_fastfield_u64; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, }; -use crate::aggregation::segment_agg_result::{ - validate_bucket_count, SegmentAggregationResultsCollector, -}; +use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; use crate::fastfield::{DynamicFastFieldReader, FastFieldReader}; use crate::schema::Type; use crate::{DocId, TantivyError}; @@ -254,8 +252,8 @@ impl SegmentHistogramCollector { agg_with_accessor .bucket_count - .fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed); - validate_bucket_count(&agg_with_accessor.bucket_count)?; + .add_count(buckets.len() as u32); + agg_with_accessor.bucket_count.validate_bucket_count()?; Ok(IntermediateBucketResult::Histogram { buckets }) } diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 590165c158..7faa500e7c 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -1,7 +1,5 @@ use std::fmt::Debug; use std::ops::Range; -use std::rc::Rc; -use std::sync::atomic::AtomicU32; use fnv::FnvHashMap; use serde::{Deserialize, Serialize}; @@ -12,9 +10,7 @@ use crate::aggregation::agg_req_with_accessor::{ use crate::aggregation::intermediate_agg_result::{ IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; -use crate::aggregation::segment_agg_result::{ - validate_bucket_count, SegmentAggregationResultsCollector, -}; +use crate::aggregation::segment_agg_result::{BucketCount, SegmentAggregationResultsCollector}; use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key, SerializedKey}; use crate::fastfield::FastFieldReader; use crate::schema::Type; @@ -179,7 +175,7 @@ impl SegmentRangeCollector { pub(crate) fn from_req_and_validate( req: &RangeAggregation, sub_aggregation: &AggregationsWithAccessor, - bucket_count: &Rc, + bucket_count: &BucketCount, field_type: Type, ) -> crate::Result { // The range input on the request is f64. @@ -218,8 +214,8 @@ impl SegmentRangeCollector { }) .collect::>()?; - bucket_count.fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed); - validate_bucket_count(bucket_count)?; + bucket_count.add_count(buckets.len() as u32); + bucket_count.validate_bucket_count()?; Ok(SegmentRangeCollector { buckets, @@ -438,7 +434,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let reader = index.reader()?; let searcher = reader.searcher(); diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 312522017e..52e120cc96 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -11,9 +11,7 @@ use crate::aggregation::agg_req_with_accessor::{ use crate::aggregation::intermediate_agg_result::{ IntermediateBucketResult, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::{ - validate_bucket_count, SegmentAggregationResultsCollector, -}; +use crate::aggregation::segment_agg_result::{BucketCount, SegmentAggregationResultsCollector}; use crate::error::DataCorruption; use crate::fastfield::MultiValuedFastFieldReader; use crate::schema::Type; @@ -246,23 +244,23 @@ impl TermBuckets { &mut self, term_ids: &[u64], doc: DocId, - bucket_with_accessor: &BucketAggregationWithAccessor, + sub_aggregation: &AggregationsWithAccessor, + bucket_count: &BucketCount, blueprint: &Option, ) -> crate::Result<()> { for &term_id in term_ids { let entry = self.entries.entry(term_id as u32).or_insert_with(|| { - bucket_with_accessor - .bucket_count - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + bucket_count.add_count(1); TermBucketEntry::from_blueprint(blueprint) }); entry.doc_count += 1; if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { - sub_aggregations.collect(doc, &bucket_with_accessor.sub_aggregation)?; + sub_aggregations.collect(doc, &sub_aggregation)?; } } - validate_bucket_count(&bucket_with_accessor.bucket_count)?; + bucket_count.validate_bucket_count()?; + Ok(()) } @@ -447,25 +445,29 @@ impl SegmentTermCollector { self.term_buckets.increment_bucket( &vals1, docs[0], - bucket_with_accessor, + &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, )?; self.term_buckets.increment_bucket( &vals2, docs[1], - bucket_with_accessor, + &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, )?; self.term_buckets.increment_bucket( &vals3, docs[2], - bucket_with_accessor, + &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, )?; self.term_buckets.increment_bucket( &vals4, docs[3], - bucket_with_accessor, + &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, )?; } @@ -475,7 +477,8 @@ impl SegmentTermCollector { self.term_buckets.increment_bucket( &vals1, doc, - bucket_with_accessor, + &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, )?; } @@ -1326,9 +1329,15 @@ mod bench { let mut collector = get_collector_with_buckets(total_terms); let vals = get_rand_terms(total_terms, num_terms); let aggregations_with_accessor: AggregationsWithAccessor = Default::default(); + let bucket_count: BucketCount = BucketCount { + bucket_count: Default::default(), + max_bucket_count: 1_000_001u32, + }; b.iter(|| { for &val in &vals { - collector.increment_bucket(&[val], 0, &aggregations_with_accessor, &None); + collector + .increment_bucket(&[val], 0, &aggregations_with_accessor, &bucket_count, &None) + .unwrap(); } }) } diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 69913931cd..cf2848f383 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -1,3 +1,5 @@ +use std::rc::Rc; + use super::agg_req::Aggregations; use super::agg_req_with_accessor::AggregationsWithAccessor; use super::agg_result::AggregationResults; @@ -7,17 +9,25 @@ use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_valida use crate::collector::{Collector, SegmentCollector}; use crate::SegmentReader; +pub const MAX_BUCKET_COUNT: u32 = 65000; + /// Collector for aggregations. /// /// The collector collects all aggregations by the underlying aggregation request. pub struct AggregationCollector { agg: Aggregations, + max_bucket_count: u32, } impl AggregationCollector { /// Create collector from aggregation request. - pub fn from_aggs(agg: Aggregations) -> Self { - Self { agg } + /// + /// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset + pub fn from_aggs(agg: Aggregations, max_bucket_count: Option) -> Self { + Self { + agg, + max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT), + } } } @@ -28,15 +38,21 @@ impl AggregationCollector { /// # Purpose /// AggregationCollector returns `IntermediateAggregationResults` and not the final /// `AggregationResults`, so that results from differenct indices can be merged and then converted -/// into the final `AggregationResults` via the `into()` method. +/// into the final `AggregationResults` via the `into_final_result()` method. pub struct DistributedAggregationCollector { agg: Aggregations, + max_bucket_count: u32, } impl DistributedAggregationCollector { /// Create collector from aggregation request. - pub fn from_aggs(agg: Aggregations) -> Self { - Self { agg } + /// + /// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset + pub fn from_aggs(agg: Aggregations, max_bucket_count: Option) -> Self { + Self { + agg, + max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT), + } } } @@ -50,7 +66,11 @@ impl Collector for DistributedAggregationCollector { _segment_local_id: crate::SegmentOrdinal, reader: &crate::SegmentReader, ) -> crate::Result { - AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader) + AggregationSegmentCollector::from_agg_req_and_reader( + &self.agg, + reader, + self.max_bucket_count, + ) } fn requires_scoring(&self) -> bool { @@ -75,7 +95,11 @@ impl Collector for AggregationCollector { _segment_local_id: crate::SegmentOrdinal, reader: &crate::SegmentReader, ) -> crate::Result { - AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader) + AggregationSegmentCollector::from_agg_req_and_reader( + &self.agg, + reader, + self.max_bucket_count, + ) } fn requires_scoring(&self) -> bool { @@ -117,8 +141,10 @@ impl AggregationSegmentCollector { pub fn from_agg_req_and_reader( agg: &Aggregations, reader: &SegmentReader, + max_bucket_count: u32, ) -> crate::Result { - let aggs_with_accessor = get_aggs_with_accessor_and_validate(agg, reader)?; + let aggs_with_accessor = + get_aggs_with_accessor_and_validate(agg, reader, Rc::default(), max_bucket_count)?; let result = SegmentAggregationResultsCollector::from_req_and_validate(&aggs_with_accessor)?; Ok(AggregationSegmentCollector { diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index cb2f9f416c..20eef59c07 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -35,7 +35,7 @@ pub struct IntermediateAggregationResults { } impl IntermediateAggregationResults { - /// Convert and intermediate result and its aggregation request to the final result + /// Convert intermediate result and its aggregation request to the final result. pub(crate) fn into_final_bucket_result( self, req: Aggregations, @@ -43,7 +43,7 @@ impl IntermediateAggregationResults { self.into_final_bucket_result_internal(&(req.into())) } - /// Convert and intermediate result and its aggregation request to the final result + /// Convert intermediate result and its aggregation request to the final result. /// /// Internal function, AggregationsInternal is used instead Aggregations, which is optimized /// for internal processing, by splitting metric and buckets into seperate groups. diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 0498ffbe80..2f704b17d0 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -222,7 +222,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let reader = index.reader()?; let searcher = reader.searcher(); @@ -299,7 +299,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 7fe2d82847..7f6f8378ce 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -28,7 +28,7 @@ //! //! ```verbatim //! let agg_req: Aggregations = serde_json::from_str(json_request_string).unwrap(); -//! let collector = AggregationCollector::from_aggs(agg_req); +//! let collector = AggregationCollector::from_aggs(agg_req, None); //! let searcher = reader.searcher(); //! let agg_res = searcher.search(&term_query, &collector).unwrap_err(); //! let json_response_string: String = &serde_json::to_string(&agg_res)?; @@ -68,7 +68,7 @@ //! .into_iter() //! .collect(); //! -//! let collector = AggregationCollector::from_aggs(agg_req); +//! let collector = AggregationCollector::from_aggs(agg_req, None); //! //! let searcher = reader.searcher(); //! let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); @@ -358,7 +358,7 @@ mod tests { index: &Index, query: Option<(&str, &str)>, ) -> crate::Result { - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let reader = index.reader()?; let searcher = reader.searcher(); @@ -547,7 +547,7 @@ mod tests { .unwrap(); let agg_res: AggregationResults = if use_distributed_collector { - let collector = DistributedAggregationCollector::from_aggs(agg_req.clone()); + let collector = DistributedAggregationCollector::from_aggs(agg_req.clone(), None); let searcher = reader.searcher(); let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap(); @@ -555,7 +555,7 @@ mod tests { .into_final_bucket_result(agg_req) .unwrap() } else { - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -792,7 +792,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); @@ -982,7 +982,7 @@ mod tests { assert_eq!(field_names, vec!["text".to_string()].into_iter().collect()); let agg_res: AggregationResults = if use_distributed_collector { - let collector = DistributedAggregationCollector::from_aggs(agg_req.clone()); + let collector = DistributedAggregationCollector::from_aggs(agg_req.clone(), None); let searcher = reader.searcher(); let res = searcher.search(&term_query, &collector).unwrap(); @@ -991,7 +991,7 @@ mod tests { serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap(); res.into_final_bucket_result(agg_req.clone()).unwrap() } else { - let collector = AggregationCollector::from_aggs(agg_req.clone()); + let collector = AggregationCollector::from_aggs(agg_req.clone(), None); let searcher = reader.searcher(); searcher.search(&term_query, &collector).unwrap() @@ -1049,7 +1049,7 @@ mod tests { ); // Test empty result set - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let searcher = reader.searcher(); searcher.search(&query_with_no_hits, &collector).unwrap(); @@ -1114,7 +1114,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); @@ -1227,7 +1227,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1258,7 +1258,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1289,7 +1289,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1328,7 +1328,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1357,7 +1357,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1386,7 +1386,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1422,7 +1422,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1457,7 +1457,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1496,7 +1496,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1526,7 +1526,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1582,7 +1582,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 57c545f7d6..fe07400897 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -12,6 +12,7 @@ use super::agg_req_with_accessor::{ AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor, }; use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector}; +use super::collector::MAX_BUCKET_COUNT; use super::intermediate_agg_result::{IntermediateAggregationResults, IntermediateBucketResult}; use super::metric::{ AverageAggregation, SegmentAverageCollector, SegmentStatsCollector, StatsAggregation, @@ -277,11 +278,36 @@ impl SegmentBucketResultCollector { } } -pub(crate) fn validate_bucket_count(bucket_count: &Rc) -> crate::Result<()> { - if bucket_count.load(std::sync::atomic::Ordering::Relaxed) > 65000 { - return Err(TantivyError::InvalidArgument( - "Aborting aggregation because too many buckets were created".to_string(), - )); +#[derive(Clone)] +pub(crate) struct BucketCount { + /// The counter which is shared between the aggregations for one request. + pub(crate) bucket_count: Rc, + pub(crate) max_bucket_count: u32, +} + +impl Default for BucketCount { + fn default() -> Self { + Self { + bucket_count: Default::default(), + max_bucket_count: MAX_BUCKET_COUNT, + } + } +} + +impl BucketCount { + pub(crate) fn validate_bucket_count(&self) -> crate::Result<()> { + if self.get_count() > self.max_bucket_count { + return Err(TantivyError::InvalidArgument( + "Aborting aggregation because too many buckets were created".to_string(), + )); + } + Ok(()) + } + pub(crate) fn add_count(&self, count: u32) { + self.bucket_count + .fetch_add(count as u32, std::sync::atomic::Ordering::Relaxed); + } + pub(crate) fn get_count(&self) -> u32 { + self.bucket_count.load(std::sync::atomic::Ordering::Relaxed) } - Ok(()) } From c5c2e59b2bbb810db2e55758f220384f9e2c2724 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Tue, 17 May 2022 11:35:39 +0800 Subject: [PATCH 06/10] introduce optional collect_block in segmentcollector add collect_block in segment_collector to handle groups of documents as performance optimization add collect_block for MultiCollector --- src/aggregation/bucket/term_agg.rs | 2 +- src/collector/count_collector.rs | 5 +++ src/collector/custom_score_top_collector.rs | 14 +++++- src/collector/mod.rs | 47 ++++++++++++++++++++- src/collector/multi_collector.rs | 18 ++++++++ src/collector/top_score_collector.rs | 8 ++++ 6 files changed, 89 insertions(+), 5 deletions(-) diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 52e120cc96..8a9970e0fd 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -256,7 +256,7 @@ impl TermBuckets { }); entry.doc_count += 1; if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { - sub_aggregations.collect(doc, &sub_aggregation)?; + sub_aggregations.collect(doc, sub_aggregation)?; } } bucket_count.validate_bucket_count()?; diff --git a/src/collector/count_collector.rs b/src/collector/count_collector.rs index 02f30f85c1..3ff1368f83 100644 --- a/src/collector/count_collector.rs +++ b/src/collector/count_collector.rs @@ -70,6 +70,11 @@ impl SegmentCollector for SegmentCountCollector { Ok(()) } + fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { + self.count += docs.len(); + Ok(()) + } + fn harvest(self) -> usize { self.count } diff --git a/src/collector/custom_score_top_collector.rs b/src/collector/custom_score_top_collector.rs index d597727ed1..0f981c20a7 100644 --- a/src/collector/custom_score_top_collector.rs +++ b/src/collector/custom_score_top_collector.rs @@ -8,7 +8,8 @@ pub(crate) struct CustomScoreTopCollector { } impl CustomScoreTopCollector -where TScore: Clone + PartialOrd +where + TScore: Clone + PartialOrd, { pub(crate) fn new( custom_scorer: TCustomScorer, @@ -96,6 +97,14 @@ where Ok(()) } + fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { + for (doc, _score) in docs { + let score = self.segment_scorer.score(*doc); + self.segment_collector.collect(*doc, score); + } + Ok(()) + } + fn harvest(self) -> Vec<(TScore, DocAddress)> { self.segment_collector.harvest() } @@ -114,7 +123,8 @@ where } impl CustomSegmentScorer for F -where F: 'static + FnMut(DocId) -> TScore +where + F: 'static + FnMut(DocId) -> TScore, { fn score(&mut self, doc: DocId) -> TScore { (self)(doc) diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 6d600bb488..7e6d43f7b7 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -172,19 +172,33 @@ pub trait Collector: Sync + Send { ) -> crate::Result<::Fruit> { let mut segment_collector = self.for_segment(segment_ord as u32, reader)?; + let mut cache_pos = 0; + let mut cache = [(0, 0.0); 64]; + if let Some(alive_bitset) = reader.alive_bitset() { weight.for_each(reader, &mut |doc, score| { if alive_bitset.is_alive(doc) { - segment_collector.collect(doc, score)?; + cache[cache_pos] = (doc, score); + cache_pos += 1; + if cache_pos == 64 { + segment_collector.collect_block(&cache)?; + cache_pos = 0; + } } Ok(()) })?; } else { weight.for_each(reader, &mut |doc, score| { - segment_collector.collect(doc, score)?; + cache[cache_pos] = (doc, score); + cache_pos += 1; + if cache_pos == 64 { + segment_collector.collect_block(&cache)?; + cache_pos = 0; + } Ok(()) })?; } + segment_collector.collect_block(&cache[..cache_pos])?; Ok(segment_collector.harvest()) } } @@ -258,6 +272,14 @@ pub trait SegmentCollector: 'static { /// The query pushes the scored document to the collector via this method. fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()>; + /// The query pushes the scored document to the collector via this method. + fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { + for (doc, score) in docs { + self.collect(*doc, *score)?; + } + Ok(()) + } + /// Extract the fruit of the collection from the `SegmentCollector`. fn harvest(self) -> Self::Fruit; } @@ -317,6 +339,12 @@ where Ok(()) } + fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { + self.0.collect_block(docs)?; + self.1.collect_block(docs)?; + Ok(()) + } + fn harvest(self) -> ::Fruit { (self.0.harvest(), self.1.harvest()) } @@ -383,6 +411,13 @@ where Ok(()) } + fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { + self.0.collect_block(docs)?; + self.1.collect_block(docs)?; + self.2.collect_block(docs)?; + Ok(()) + } + fn harvest(self) -> ::Fruit { (self.0.harvest(), self.1.harvest(), self.2.harvest()) } @@ -459,6 +494,14 @@ where Ok(()) } + fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { + self.0.collect_block(docs)?; + self.1.collect_block(docs)?; + self.2.collect_block(docs)?; + self.3.collect_block(docs)?; + Ok(()) + } + fn harvest(self) -> ::Fruit { ( self.0.harvest(), diff --git a/src/collector/multi_collector.rs b/src/collector/multi_collector.rs index 7b119ad868..a7a8ed1692 100644 --- a/src/collector/multi_collector.rs +++ b/src/collector/multi_collector.rs @@ -57,6 +57,11 @@ impl SegmentCollector for Box { Ok(()) } + fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { + self.as_mut().collect_block(docs)?; + Ok(()) + } + fn harvest(self) -> Box { BoxableSegmentCollector::harvest_from_box(self) } @@ -64,6 +69,7 @@ impl SegmentCollector for Box { pub trait BoxableSegmentCollector { fn collect(&mut self, doc: u32, score: Score) -> crate::Result<()>; + fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()>; fn harvest_from_box(self: Box) -> Box; } @@ -76,6 +82,11 @@ impl BoxableSegmentCollector self.0.collect(doc, score) } + fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { + self.0.collect_block(docs)?; + Ok(()) + } + fn harvest_from_box(self: Box) -> Box { Box::new(self.0.harvest()) } @@ -236,6 +247,13 @@ impl SegmentCollector for MultiCollectorChild { Ok(()) } + fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { + for child in &mut self.children { + child.collect_block(docs)?; + } + Ok(()) + } + fn harvest(self) -> MultiFruit { MultiFruit { sub_fruits: self diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index e0e3aeb9dc..6074d088a3 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -704,6 +704,14 @@ impl SegmentCollector for TopScoreSegmentCollector { Ok(()) } + #[inline] + fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { + for (doc, score) in docs { + self.0.collect(*doc, *score); + } + Ok(()) + } + fn harvest(self) -> Vec<(Score, DocAddress)> { self.0.harvest() } From 17dcc99e43f203da354def2af22e33184030e739 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Thu, 19 May 2022 16:25:21 +0800 Subject: [PATCH 07/10] Revert "introduce optional collect_block in segmentcollector" This reverts commit c5c2e59b2bbb810db2e55758f220384f9e2c2724. --- src/aggregation/bucket/term_agg.rs | 2 +- src/collector/count_collector.rs | 5 --- src/collector/custom_score_top_collector.rs | 14 +----- src/collector/mod.rs | 47 +-------------------- src/collector/multi_collector.rs | 18 -------- src/collector/top_score_collector.rs | 8 ---- 6 files changed, 5 insertions(+), 89 deletions(-) diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 8a9970e0fd..52e120cc96 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -256,7 +256,7 @@ impl TermBuckets { }); entry.doc_count += 1; if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { - sub_aggregations.collect(doc, sub_aggregation)?; + sub_aggregations.collect(doc, &sub_aggregation)?; } } bucket_count.validate_bucket_count()?; diff --git a/src/collector/count_collector.rs b/src/collector/count_collector.rs index 3ff1368f83..02f30f85c1 100644 --- a/src/collector/count_collector.rs +++ b/src/collector/count_collector.rs @@ -70,11 +70,6 @@ impl SegmentCollector for SegmentCountCollector { Ok(()) } - fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { - self.count += docs.len(); - Ok(()) - } - fn harvest(self) -> usize { self.count } diff --git a/src/collector/custom_score_top_collector.rs b/src/collector/custom_score_top_collector.rs index 0f981c20a7..d597727ed1 100644 --- a/src/collector/custom_score_top_collector.rs +++ b/src/collector/custom_score_top_collector.rs @@ -8,8 +8,7 @@ pub(crate) struct CustomScoreTopCollector { } impl CustomScoreTopCollector -where - TScore: Clone + PartialOrd, +where TScore: Clone + PartialOrd { pub(crate) fn new( custom_scorer: TCustomScorer, @@ -97,14 +96,6 @@ where Ok(()) } - fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { - for (doc, _score) in docs { - let score = self.segment_scorer.score(*doc); - self.segment_collector.collect(*doc, score); - } - Ok(()) - } - fn harvest(self) -> Vec<(TScore, DocAddress)> { self.segment_collector.harvest() } @@ -123,8 +114,7 @@ where } impl CustomSegmentScorer for F -where - F: 'static + FnMut(DocId) -> TScore, +where F: 'static + FnMut(DocId) -> TScore { fn score(&mut self, doc: DocId) -> TScore { (self)(doc) diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 7e6d43f7b7..6d600bb488 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -172,33 +172,19 @@ pub trait Collector: Sync + Send { ) -> crate::Result<::Fruit> { let mut segment_collector = self.for_segment(segment_ord as u32, reader)?; - let mut cache_pos = 0; - let mut cache = [(0, 0.0); 64]; - if let Some(alive_bitset) = reader.alive_bitset() { weight.for_each(reader, &mut |doc, score| { if alive_bitset.is_alive(doc) { - cache[cache_pos] = (doc, score); - cache_pos += 1; - if cache_pos == 64 { - segment_collector.collect_block(&cache)?; - cache_pos = 0; - } + segment_collector.collect(doc, score)?; } Ok(()) })?; } else { weight.for_each(reader, &mut |doc, score| { - cache[cache_pos] = (doc, score); - cache_pos += 1; - if cache_pos == 64 { - segment_collector.collect_block(&cache)?; - cache_pos = 0; - } + segment_collector.collect(doc, score)?; Ok(()) })?; } - segment_collector.collect_block(&cache[..cache_pos])?; Ok(segment_collector.harvest()) } } @@ -272,14 +258,6 @@ pub trait SegmentCollector: 'static { /// The query pushes the scored document to the collector via this method. fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()>; - /// The query pushes the scored document to the collector via this method. - fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { - for (doc, score) in docs { - self.collect(*doc, *score)?; - } - Ok(()) - } - /// Extract the fruit of the collection from the `SegmentCollector`. fn harvest(self) -> Self::Fruit; } @@ -339,12 +317,6 @@ where Ok(()) } - fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { - self.0.collect_block(docs)?; - self.1.collect_block(docs)?; - Ok(()) - } - fn harvest(self) -> ::Fruit { (self.0.harvest(), self.1.harvest()) } @@ -411,13 +383,6 @@ where Ok(()) } - fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { - self.0.collect_block(docs)?; - self.1.collect_block(docs)?; - self.2.collect_block(docs)?; - Ok(()) - } - fn harvest(self) -> ::Fruit { (self.0.harvest(), self.1.harvest(), self.2.harvest()) } @@ -494,14 +459,6 @@ where Ok(()) } - fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { - self.0.collect_block(docs)?; - self.1.collect_block(docs)?; - self.2.collect_block(docs)?; - self.3.collect_block(docs)?; - Ok(()) - } - fn harvest(self) -> ::Fruit { ( self.0.harvest(), diff --git a/src/collector/multi_collector.rs b/src/collector/multi_collector.rs index a7a8ed1692..7b119ad868 100644 --- a/src/collector/multi_collector.rs +++ b/src/collector/multi_collector.rs @@ -57,11 +57,6 @@ impl SegmentCollector for Box { Ok(()) } - fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { - self.as_mut().collect_block(docs)?; - Ok(()) - } - fn harvest(self) -> Box { BoxableSegmentCollector::harvest_from_box(self) } @@ -69,7 +64,6 @@ impl SegmentCollector for Box { pub trait BoxableSegmentCollector { fn collect(&mut self, doc: u32, score: Score) -> crate::Result<()>; - fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()>; fn harvest_from_box(self: Box) -> Box; } @@ -82,11 +76,6 @@ impl BoxableSegmentCollector self.0.collect(doc, score) } - fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { - self.0.collect_block(docs)?; - Ok(()) - } - fn harvest_from_box(self: Box) -> Box { Box::new(self.0.harvest()) } @@ -247,13 +236,6 @@ impl SegmentCollector for MultiCollectorChild { Ok(()) } - fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { - for child in &mut self.children { - child.collect_block(docs)?; - } - Ok(()) - } - fn harvest(self) -> MultiFruit { MultiFruit { sub_fruits: self diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 6074d088a3..e0e3aeb9dc 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -704,14 +704,6 @@ impl SegmentCollector for TopScoreSegmentCollector { Ok(()) } - #[inline] - fn collect_block(&mut self, docs: &[(DocId, Score)]) -> crate::Result<()> { - for (doc, score) in docs { - self.0.collect(*doc, *score); - } - Ok(()) - } - fn harvest(self) -> Vec<(Score, DocAddress)> { self.0.harvest() } From b114e553cd2c373e19307ab3e922abe45fa49f42 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Thu, 19 May 2022 16:43:55 +0800 Subject: [PATCH 08/10] Revert "return result from segment collector" This reverts commit a99e5459e3e58064ab8917b81e21e924e5ca0548. --- examples/custom_collector.rs | 3 +- src/aggregation/collector.rs | 3 +- src/collector/count_collector.rs | 11 +++--- src/collector/custom_score_top_collector.rs | 3 +- src/collector/docset_collector.rs | 3 +- src/collector/facet_collector.rs | 3 +- src/collector/filter_collector_wrapper.rs | 5 ++- src/collector/histogram_collector.rs | 3 +- src/collector/mod.rs | 40 +++++++++------------ src/collector/multi_collector.rs | 16 ++++----- src/collector/tests.rs | 9 ++--- src/collector/top_score_collector.rs | 3 +- src/collector/tweak_score_top_collector.rs | 3 +- src/query/boolean_query/boolean_weight.rs | 6 ++-- src/query/term_query/term_weight.rs | 4 +-- src/query/weight.rs | 11 +++--- 16 files changed, 52 insertions(+), 74 deletions(-) diff --git a/examples/custom_collector.rs b/examples/custom_collector.rs index 14a9d20358..7bdc9d06b4 100644 --- a/examples/custom_collector.rs +++ b/examples/custom_collector.rs @@ -102,12 +102,11 @@ struct StatsSegmentCollector { impl SegmentCollector for StatsSegmentCollector { type Fruit = Option; - fn collect(&mut self, doc: u32, _score: Score) -> tantivy::Result<()> { + fn collect(&mut self, doc: u32, _score: Score) { let value = self.fast_field_reader.get(doc) as f64; self.stats.count += 1; self.stats.sum += value; self.stats.squared_sum += value * value; - Ok(()) } fn harvest(self) -> ::Fruit { diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index cf2848f383..09c8ecbe1d 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -158,9 +158,8 @@ impl SegmentCollector for AggregationSegmentCollector { type Fruit = crate::Result; #[inline] - fn collect(&mut self, doc: crate::DocId, _score: crate::Score) -> crate::Result<()> { + fn collect(&mut self, doc: crate::DocId, _score: crate::Score) { self.result.collect(doc, &self.aggs_with_accessor)?; - Ok(()) } fn harvest(mut self) -> Self::Fruit { diff --git a/src/collector/count_collector.rs b/src/collector/count_collector.rs index 02f30f85c1..075a4f36b4 100644 --- a/src/collector/count_collector.rs +++ b/src/collector/count_collector.rs @@ -65,9 +65,8 @@ pub struct SegmentCountCollector { impl SegmentCollector for SegmentCountCollector { type Fruit = usize; - fn collect(&mut self, _: DocId, _: Score) -> crate::Result<()> { + fn collect(&mut self, _: DocId, _: Score) { self.count += 1; - Ok(()) } fn harvest(self) -> usize { @@ -93,18 +92,18 @@ mod tests { } { let mut count_collector = SegmentCountCollector::default(); - count_collector.collect(0u32, 1.0).unwrap(); + count_collector.collect(0u32, 1.0); assert_eq!(count_collector.harvest(), 1); } { let mut count_collector = SegmentCountCollector::default(); - count_collector.collect(0u32, 1.0).unwrap(); + count_collector.collect(0u32, 1.0); assert_eq!(count_collector.harvest(), 1); } { let mut count_collector = SegmentCountCollector::default(); - count_collector.collect(0u32, 1.0).unwrap(); - count_collector.collect(1u32, 1.0).unwrap(); + count_collector.collect(0u32, 1.0); + count_collector.collect(1u32, 1.0); assert_eq!(count_collector.harvest(), 2); } } diff --git a/src/collector/custom_score_top_collector.rs b/src/collector/custom_score_top_collector.rs index d597727ed1..d645004ade 100644 --- a/src/collector/custom_score_top_collector.rs +++ b/src/collector/custom_score_top_collector.rs @@ -90,10 +90,9 @@ where { type Fruit = Vec<(TScore, DocAddress)>; - fn collect(&mut self, doc: DocId, _score: Score) -> crate::Result<()> { + fn collect(&mut self, doc: DocId, _score: Score) { let score = self.segment_scorer.score(doc); self.segment_collector.collect(doc, score); - Ok(()) } fn harvest(self) -> Vec<(TScore, DocAddress)> { diff --git a/src/collector/docset_collector.rs b/src/collector/docset_collector.rs index 9f6a5fd3bd..a27a394189 100644 --- a/src/collector/docset_collector.rs +++ b/src/collector/docset_collector.rs @@ -50,9 +50,8 @@ pub struct DocSetChildCollector { impl SegmentCollector for DocSetChildCollector { type Fruit = (u32, HashSet); - fn collect(&mut self, doc: crate::DocId, _score: Score) -> crate::Result<()> { + fn collect(&mut self, doc: crate::DocId, _score: Score) { self.docs.insert(doc); - Ok(()) } fn harvest(self) -> (u32, HashSet) { diff --git a/src/collector/facet_collector.rs b/src/collector/facet_collector.rs index 8ad3311e28..e2ef47f989 100644 --- a/src/collector/facet_collector.rs +++ b/src/collector/facet_collector.rs @@ -333,7 +333,7 @@ impl Collector for FacetCollector { impl SegmentCollector for FacetSegmentCollector { type Fruit = FacetCounts; - fn collect(&mut self, doc: DocId, _: Score) -> crate::Result<()> { + fn collect(&mut self, doc: DocId, _: Score) { self.reader.facet_ords(doc, &mut self.facet_ords_buf); let mut previous_collapsed_ord: usize = usize::MAX; for &facet_ord in &self.facet_ords_buf { @@ -345,7 +345,6 @@ impl SegmentCollector for FacetSegmentCollector { }; previous_collapsed_ord = collapsed_ord; } - Ok(()) } /// Returns the results of the collection. diff --git a/src/collector/filter_collector_wrapper.rs b/src/collector/filter_collector_wrapper.rs index 15e7f80212..b1dbaaa203 100644 --- a/src/collector/filter_collector_wrapper.rs +++ b/src/collector/filter_collector_wrapper.rs @@ -173,12 +173,11 @@ where { type Fruit = TSegmentCollector::Fruit; - fn collect(&mut self, doc: u32, score: Score) -> crate::Result<()> { + fn collect(&mut self, doc: u32, score: Score) { let value = self.fast_field_reader.get(doc); if (self.predicate)(value) { - self.segment_collector.collect(doc, score)?; + self.segment_collector.collect(doc, score) } - Ok(()) } fn harvest(self) -> ::Fruit { diff --git a/src/collector/histogram_collector.rs b/src/collector/histogram_collector.rs index fbf398627a..22956a86a2 100644 --- a/src/collector/histogram_collector.rs +++ b/src/collector/histogram_collector.rs @@ -91,10 +91,9 @@ pub struct SegmentHistogramCollector { impl SegmentCollector for SegmentHistogramCollector { type Fruit = Vec; - fn collect(&mut self, doc: DocId, _score: Score) -> crate::Result<()> { + fn collect(&mut self, doc: DocId, _score: Score) { let value = self.ff_reader.get(doc); self.histogram_computer.add_value(value); - Ok(()) } fn harvest(self) -> Self::Fruit { diff --git a/src/collector/mod.rs b/src/collector/mod.rs index 6d600bb488..1597d7fe45 100644 --- a/src/collector/mod.rs +++ b/src/collector/mod.rs @@ -175,14 +175,12 @@ pub trait Collector: Sync + Send { if let Some(alive_bitset) = reader.alive_bitset() { weight.for_each(reader, &mut |doc, score| { if alive_bitset.is_alive(doc) { - segment_collector.collect(doc, score)?; + segment_collector.collect(doc, score); } - Ok(()) })?; } else { weight.for_each(reader, &mut |doc, score| { - segment_collector.collect(doc, score)?; - Ok(()) + segment_collector.collect(doc, score); })?; } Ok(segment_collector.harvest()) @@ -192,11 +190,10 @@ pub trait Collector: Sync + Send { impl SegmentCollector for Option { type Fruit = Option; - fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { + fn collect(&mut self, doc: DocId, score: Score) { if let Some(segment_collector) = self { - segment_collector.collect(doc, score)?; + segment_collector.collect(doc, score); } - Ok(()) } fn harvest(self) -> Self::Fruit { @@ -256,7 +253,7 @@ pub trait SegmentCollector: 'static { type Fruit: Fruit; /// The query pushes the scored document to the collector via this method. - fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()>; + fn collect(&mut self, doc: DocId, score: Score); /// Extract the fruit of the collection from the `SegmentCollector`. fn harvest(self) -> Self::Fruit; @@ -311,10 +308,9 @@ where { type Fruit = (Left::Fruit, Right::Fruit); - fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { - self.0.collect(doc, score)?; - self.1.collect(doc, score)?; - Ok(()) + fn collect(&mut self, doc: DocId, score: Score) { + self.0.collect(doc, score); + self.1.collect(doc, score); } fn harvest(self) -> ::Fruit { @@ -376,11 +372,10 @@ where { type Fruit = (One::Fruit, Two::Fruit, Three::Fruit); - fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { - self.0.collect(doc, score)?; - self.1.collect(doc, score)?; - self.2.collect(doc, score)?; - Ok(()) + fn collect(&mut self, doc: DocId, score: Score) { + self.0.collect(doc, score); + self.1.collect(doc, score); + self.2.collect(doc, score); } fn harvest(self) -> ::Fruit { @@ -451,12 +446,11 @@ where { type Fruit = (One::Fruit, Two::Fruit, Three::Fruit, Four::Fruit); - fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { - self.0.collect(doc, score)?; - self.1.collect(doc, score)?; - self.2.collect(doc, score)?; - self.3.collect(doc, score)?; - Ok(()) + fn collect(&mut self, doc: DocId, score: Score) { + self.0.collect(doc, score); + self.1.collect(doc, score); + self.2.collect(doc, score); + self.3.collect(doc, score); } fn harvest(self) -> ::Fruit { diff --git a/src/collector/multi_collector.rs b/src/collector/multi_collector.rs index 7b119ad868..039902ff4f 100644 --- a/src/collector/multi_collector.rs +++ b/src/collector/multi_collector.rs @@ -52,9 +52,8 @@ impl Collector for CollectorWrapper { impl SegmentCollector for Box { type Fruit = Box; - fn collect(&mut self, doc: u32, score: Score) -> crate::Result<()> { - self.as_mut().collect(doc, score)?; - Ok(()) + fn collect(&mut self, doc: u32, score: Score) { + self.as_mut().collect(doc, score); } fn harvest(self) -> Box { @@ -63,7 +62,7 @@ impl SegmentCollector for Box { } pub trait BoxableSegmentCollector { - fn collect(&mut self, doc: u32, score: Score) -> crate::Result<()>; + fn collect(&mut self, doc: u32, score: Score); fn harvest_from_box(self: Box) -> Box; } @@ -72,8 +71,8 @@ pub struct SegmentCollectorWrapper(TSegment impl BoxableSegmentCollector for SegmentCollectorWrapper { - fn collect(&mut self, doc: u32, score: Score) -> crate::Result<()> { - self.0.collect(doc, score) + fn collect(&mut self, doc: u32, score: Score) { + self.0.collect(doc, score); } fn harvest_from_box(self: Box) -> Box { @@ -229,11 +228,10 @@ pub struct MultiCollectorChild { impl SegmentCollector for MultiCollectorChild { type Fruit = MultiFruit; - fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { + fn collect(&mut self, doc: DocId, score: Score) { for child in &mut self.children { - child.collect(doc, score)?; + child.collect(doc, score); } - Ok(()) } fn harvest(self) -> MultiFruit { diff --git a/src/collector/tests.rs b/src/collector/tests.rs index 5e0a0cfb2d..3bda822a10 100644 --- a/src/collector/tests.rs +++ b/src/collector/tests.rs @@ -138,10 +138,9 @@ impl Collector for TestCollector { impl SegmentCollector for TestSegmentCollector { type Fruit = TestFruit; - fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { + fn collect(&mut self, doc: DocId, score: Score) { self.fruit.docs.push(DocAddress::new(self.segment_id, doc)); self.fruit.scores.push(score); - Ok(()) } fn harvest(self) -> ::Fruit { @@ -199,10 +198,9 @@ impl Collector for FastFieldTestCollector { impl SegmentCollector for FastFieldSegmentCollector { type Fruit = Vec; - fn collect(&mut self, doc: DocId, _score: Score) -> crate::Result<()> { + fn collect(&mut self, doc: DocId, _score: Score) { let val = self.reader.get(doc); self.vals.push(val); - Ok(()) } fn harvest(self) -> Vec { @@ -257,10 +255,9 @@ impl Collector for BytesFastFieldTestCollector { impl SegmentCollector for BytesFastFieldSegmentCollector { type Fruit = Vec; - fn collect(&mut self, doc: u32, _score: Score) -> crate::Result<()> { + fn collect(&mut self, doc: u32, _score: Score) { let data = self.reader.get_bytes(doc); self.vals.extend(data); - Ok(()) } fn harvest(self) -> ::Fruit { diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index e0e3aeb9dc..516dedcb58 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -699,9 +699,8 @@ pub struct TopScoreSegmentCollector(TopSegmentCollector); impl SegmentCollector for TopScoreSegmentCollector { type Fruit = Vec<(Score, DocAddress)>; - fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { + fn collect(&mut self, doc: DocId, score: Score) { self.0.collect(doc, score); - Ok(()) } fn harvest(self) -> Vec<(Score, DocAddress)> { diff --git a/src/collector/tweak_score_top_collector.rs b/src/collector/tweak_score_top_collector.rs index cb8d358abe..1a81e73616 100644 --- a/src/collector/tweak_score_top_collector.rs +++ b/src/collector/tweak_score_top_collector.rs @@ -93,10 +93,9 @@ where { type Fruit = Vec<(TScore, DocAddress)>; - fn collect(&mut self, doc: DocId, score: Score) -> crate::Result<()> { + fn collect(&mut self, doc: DocId, score: Score) { let score = self.segment_scorer.score(doc, score); self.segment_collector.collect(doc, score); - Ok(()) } fn harvest(self) -> Vec<(TScore, DocAddress)> { diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 9024a6abb8..f2a5fd376a 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -186,17 +186,17 @@ impl Weight for BooleanWeight { fn for_each( &self, reader: &SegmentReader, - callback: &mut dyn FnMut(DocId, Score) -> crate::Result<()>, + callback: &mut dyn FnMut(DocId, Score), ) -> crate::Result<()> { let scorer = self.complex_scorer::(reader, 1.0)?; match scorer { SpecializedScorer::TermUnion(term_scorers) => { let mut union_scorer = Union::::from(term_scorers); - for_each_scorer(&mut union_scorer, callback)?; + for_each_scorer(&mut union_scorer, callback); } SpecializedScorer::Other(mut scorer) => { - for_each_scorer(scorer.as_mut(), callback)?; + for_each_scorer(scorer.as_mut(), callback); } } Ok(()) diff --git a/src/query/term_query/term_weight.rs b/src/query/term_query/term_weight.rs index e1529027b2..4e742bc444 100644 --- a/src/query/term_query/term_weight.rs +++ b/src/query/term_query/term_weight.rs @@ -49,10 +49,10 @@ impl Weight for TermWeight { fn for_each( &self, reader: &SegmentReader, - callback: &mut dyn FnMut(DocId, Score) -> crate::Result<()>, + callback: &mut dyn FnMut(DocId, Score), ) -> crate::Result<()> { let mut scorer = self.specialized_scorer(reader, 1.0)?; - for_each_scorer(&mut scorer, callback)?; + for_each_scorer(&mut scorer, callback); Ok(()) } diff --git a/src/query/weight.rs b/src/query/weight.rs index 26bdfb0edb..3a2ff3d33c 100644 --- a/src/query/weight.rs +++ b/src/query/weight.rs @@ -7,14 +7,13 @@ use crate::{DocId, Score, TERMINATED}; /// `DocSet` and push the scored documents to the collector. pub(crate) fn for_each_scorer( scorer: &mut TScorer, - callback: &mut dyn FnMut(DocId, Score) -> crate::Result<()>, -) -> crate::Result<()> { + callback: &mut dyn FnMut(DocId, Score), +) { let mut doc = scorer.doc(); while doc != TERMINATED { - callback(doc, scorer.score())?; + callback(doc, scorer.score()); doc = scorer.advance(); } - Ok(()) } /// Calls `callback` with all of the `(doc, score)` for which score @@ -72,10 +71,10 @@ pub trait Weight: Send + Sync + 'static { fn for_each( &self, reader: &SegmentReader, - callback: &mut dyn FnMut(DocId, Score) -> crate::Result<()>, + callback: &mut dyn FnMut(DocId, Score), ) -> crate::Result<()> { let mut scorer = self.scorer(reader, 1.0)?; - for_each_scorer(scorer.as_mut(), callback)?; + for_each_scorer(scorer.as_mut(), callback); Ok(()) } From 71f75071d2ab4d9380a2254cf2197bdd93688fc2 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Thu, 19 May 2022 16:58:56 +0800 Subject: [PATCH 09/10] cache and return error in aggregations --- src/aggregation/bucket/term_agg.rs | 2 +- src/aggregation/collector.rs | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 52e120cc96..8a9970e0fd 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -256,7 +256,7 @@ impl TermBuckets { }); entry.doc_count += 1; if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { - sub_aggregations.collect(doc, &sub_aggregation)?; + sub_aggregations.collect(doc, sub_aggregation)?; } } bucket_count.validate_bucket_count()?; diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 09c8ecbe1d..c9510d9263 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -7,7 +7,7 @@ use super::intermediate_agg_result::IntermediateAggregationResults; use super::segment_agg_result::SegmentAggregationResultsCollector; use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate; use crate::collector::{Collector, SegmentCollector}; -use crate::SegmentReader; +use crate::{SegmentReader, TantivyError}; pub const MAX_BUCKET_COUNT: u32 = 65000; @@ -133,6 +133,7 @@ fn merge_fruits( pub struct AggregationSegmentCollector { aggs_with_accessor: AggregationsWithAccessor, result: SegmentAggregationResultsCollector, + error: Option, } impl AggregationSegmentCollector { @@ -150,6 +151,7 @@ impl AggregationSegmentCollector { Ok(AggregationSegmentCollector { aggs_with_accessor, result, + error: None, }) } } @@ -159,10 +161,18 @@ impl SegmentCollector for AggregationSegmentCollector { #[inline] fn collect(&mut self, doc: crate::DocId, _score: crate::Score) { - self.result.collect(doc, &self.aggs_with_accessor)?; + if self.error.is_some() { + return; + } + if let Err(err) = self.result.collect(doc, &self.aggs_with_accessor) { + self.error = Some(err); + } } fn harvest(mut self) -> Self::Fruit { + if let Some(err) = self.error { + return Err(err); + } self.result .flush_staged_docs(&self.aggs_with_accessor, true)?; self.result From 2e2822f89d885391dd9fe7897e9409bf2e4f1fc3 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Thu, 23 Jun 2022 09:48:28 +0900 Subject: [PATCH 10/10] Apply suggestions from code review --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4282b830a7..00d88d2938 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,6 @@ Unreleased - Add [histogram](https://github.com/quickwit-oss/tantivy/pull/1306) aggregation (@PSeitz) - Add support for fastfield on text fields (@PSeitz) - Add terms aggregation (@PSeitz) -- API Change: `SegmentCollector.collect` changed to return a `Result`. Tantivy 0.17 ================================