diff --git a/examples/aggregation.rs b/examples/aggregation.rs index 82cc0fccd3..ae11dc5a5a 100644 --- a/examples/aggregation.rs +++ b/examples/aggregation.rs @@ -117,7 +117,7 @@ fn main() -> tantivy::Result<()> { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 8ed82ac5c6..491faf2137 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -1,10 +1,13 @@ //! This will enhance the request tree with access to the fastfield and metadata. +use std::rc::Rc; +use std::sync::atomic::AtomicU32; use std::sync::Arc; use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation}; use super::bucket::{HistogramAggregation, RangeAggregation, TermsAggregation}; use super::metric::{AverageAggregation, StatsAggregation}; +use super::segment_agg_result::BucketCount; use super::VecWithNames; use crate::fastfield::{ type_and_cardinality, DynamicFastFieldReader, FastType, MultiValuedFastFieldReader, @@ -60,6 +63,7 @@ pub struct BucketAggregationWithAccessor { pub(crate) field_type: Type, pub(crate) bucket_agg: BucketAggregationType, pub(crate) sub_aggregation: AggregationsWithAccessor, + pub(crate) bucket_count: BucketCount, } impl BucketAggregationWithAccessor { @@ -67,6 +71,8 @@ impl BucketAggregationWithAccessor { bucket: &BucketAggregationType, sub_aggregation: &Aggregations, reader: &SegmentReader, + bucket_count: Rc, + max_bucket_count: u32, ) -> crate::Result { let mut inverted_index = None; let (accessor, field_type) = match &bucket { @@ -92,9 +98,18 @@ impl BucketAggregationWithAccessor { Ok(BucketAggregationWithAccessor { accessor, field_type, - sub_aggregation: get_aggs_with_accessor_and_validate(&sub_aggregation, reader)?, + sub_aggregation: get_aggs_with_accessor_and_validate( + &sub_aggregation, + reader, + bucket_count.clone(), + max_bucket_count, + )?, bucket_agg: bucket.clone(), inverted_index, + bucket_count: BucketCount { + bucket_count, + max_bucket_count, + }, }) } } @@ -134,6 +149,8 @@ impl MetricAggregationWithAccessor { pub(crate) fn get_aggs_with_accessor_and_validate( aggs: &Aggregations, reader: &SegmentReader, + bucket_count: Rc, + max_bucket_count: u32, ) -> crate::Result { let mut metrics = vec![]; let mut buckets = vec![]; @@ -145,6 +162,8 @@ pub(crate) fn get_aggs_with_accessor_and_validate( &bucket.bucket_agg, &bucket.sub_aggregation, reader, + Rc::clone(&bucket_count), + max_bucket_count, )?, )), Aggregation::Metric(metric) => metrics.push(( diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index 3dba73d1a6..fc614990b4 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -4,21 +4,15 @@ //! intermediate average results, which is the sum and the number of values. The actual average is //! calculated on the step from intermediate to final aggregation result tree. -use std::cmp::Ordering; use std::collections::HashMap; use serde::{Deserialize, Serialize}; -use super::agg_req::{ - Aggregations, AggregationsInternal, BucketAggregationInternal, MetricAggregation, -}; -use super::bucket::{intermediate_buckets_to_final_buckets, GetDocCount}; -use super::intermediate_agg_result::{ - IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, - IntermediateMetricResult, IntermediateRangeBucketEntry, -}; +use super::agg_req::BucketAggregationInternal; +use super::bucket::GetDocCount; +use super::intermediate_agg_result::{IntermediateBucketResult, IntermediateMetricResult}; use super::metric::{SingleMetricResult, Stats}; -use super::{Key, VecWithNames}; +use super::Key; use crate::TantivyError; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] @@ -41,98 +35,6 @@ impl AggregationResults { ))) } } - - /// Convert and intermediate result and its aggregation request to the final result - pub fn from_intermediate_and_req( - results: IntermediateAggregationResults, - agg: Aggregations, - ) -> crate::Result { - AggregationResults::from_intermediate_and_req_internal(results, &(agg.into())) - } - - /// Convert and intermediate result and its aggregation request to the final result - /// - /// Internal function, CollectorAggregations is used instead Aggregations, which is optimized - /// for internal processing, by splitting metric and buckets into seperate groups. - pub(crate) fn from_intermediate_and_req_internal( - intermediate_results: IntermediateAggregationResults, - req: &AggregationsInternal, - ) -> crate::Result { - // Important assumption: - // When the tree contains buckets/metric, we expect it to have all buckets/metrics from the - // request - let mut results: HashMap = HashMap::new(); - - if let Some(buckets) = intermediate_results.buckets { - add_coverted_final_buckets_to_result(&mut results, buckets, &req.buckets)? - } else { - // When there are no buckets, we create empty buckets, so that the serialized json - // format is constant - add_empty_final_buckets_to_result(&mut results, &req.buckets)? - }; - - if let Some(metrics) = intermediate_results.metrics { - add_converted_final_metrics_to_result(&mut results, metrics); - } else { - // When there are no metrics, we create empty metric results, so that the serialized - // json format is constant - add_empty_final_metrics_to_result(&mut results, &req.metrics)?; - } - Ok(Self(results)) - } -} - -fn add_converted_final_metrics_to_result( - results: &mut HashMap, - metrics: VecWithNames, -) { - results.extend( - metrics - .into_iter() - .map(|(key, metric)| (key, AggregationResult::MetricResult(metric.into()))), - ); -} - -fn add_empty_final_metrics_to_result( - results: &mut HashMap, - req_metrics: &VecWithNames, -) -> crate::Result<()> { - results.extend(req_metrics.iter().map(|(key, req)| { - let empty_bucket = IntermediateMetricResult::empty_from_req(req); - ( - key.to_string(), - AggregationResult::MetricResult(empty_bucket.into()), - ) - })); - Ok(()) -} - -fn add_empty_final_buckets_to_result( - results: &mut HashMap, - req_buckets: &VecWithNames, -) -> crate::Result<()> { - let requested_buckets = req_buckets.iter(); - for (key, req) in requested_buckets { - let empty_bucket = AggregationResult::BucketResult(BucketResult::empty_from_req(req)?); - results.insert(key.to_string(), empty_bucket); - } - Ok(()) -} - -fn add_coverted_final_buckets_to_result( - results: &mut HashMap, - buckets: VecWithNames, - req_buckets: &VecWithNames, -) -> crate::Result<()> { - assert_eq!(buckets.len(), req_buckets.len()); - - let buckets_with_request = buckets.into_iter().zip(req_buckets.values()); - for ((key, bucket), req) in buckets_with_request { - let result = - AggregationResult::BucketResult(BucketResult::from_intermediate_and_req(bucket, req)?); - results.insert(key, result); - } - Ok(()) } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -154,7 +56,8 @@ impl AggregationResult { match self { AggregationResult::BucketResult(_bucket) => Err(TantivyError::InternalError( "Tried to retrieve value from bucket aggregation. This is not supported and \ - should not happen during collection, but should be catched during validation" + should not happen during collection phase, but should be catched during \ + validation" .to_string(), )), AggregationResult::MetricResult(metric) => metric.get_value(agg_property), @@ -230,48 +133,7 @@ pub enum BucketResult { impl BucketResult { pub(crate) fn empty_from_req(req: &BucketAggregationInternal) -> crate::Result { let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg); - BucketResult::from_intermediate_and_req(empty_bucket, req) - } - - fn from_intermediate_and_req( - bucket_result: IntermediateBucketResult, - req: &BucketAggregationInternal, - ) -> crate::Result { - match bucket_result { - IntermediateBucketResult::Range(range_res) => { - let mut buckets: Vec = range_res - .buckets - .into_iter() - .map(|(_, bucket)| { - RangeBucketEntry::from_intermediate_and_req(bucket, &req.sub_aggregation) - }) - .collect::>>()?; - - buckets.sort_by(|left, right| { - // TODO use total_cmp next stable rust release - left.from - .unwrap_or(f64::MIN) - .partial_cmp(&right.from.unwrap_or(f64::MIN)) - .unwrap_or(Ordering::Equal) - }); - Ok(BucketResult::Range { buckets }) - } - IntermediateBucketResult::Histogram { buckets } => { - let buckets = intermediate_buckets_to_final_buckets( - buckets, - req.as_histogram() - .expect("unexpected aggregation, expected histogram aggregation"), - &req.sub_aggregation, - )?; - - Ok(BucketResult::Histogram { buckets }) - } - IntermediateBucketResult::Terms(terms) => terms.into_final_result( - req.as_term() - .expect("unexpected aggregation, expected term aggregation"), - &req.sub_aggregation, - ), - } + empty_bucket.into_final_bucket_result(req) } } @@ -311,22 +173,6 @@ pub struct BucketEntry { /// Sub-aggregations in this bucket. pub sub_aggregation: AggregationResults, } - -impl BucketEntry { - pub(crate) fn from_intermediate_and_req( - entry: IntermediateHistogramBucketEntry, - req: &AggregationsInternal, - ) -> crate::Result { - Ok(BucketEntry { - key: Key::F64(entry.key), - doc_count: entry.doc_count, - sub_aggregation: AggregationResults::from_intermediate_and_req_internal( - entry.sub_aggregation, - req, - )?, - }) - } -} impl GetDocCount for &BucketEntry { fn doc_count(&self) -> u64 { self.doc_count @@ -384,21 +230,3 @@ pub struct RangeBucketEntry { #[serde(skip_serializing_if = "Option::is_none")] pub to: Option, } - -impl RangeBucketEntry { - fn from_intermediate_and_req( - entry: IntermediateRangeBucketEntry, - req: &AggregationsInternal, - ) -> crate::Result { - Ok(RangeBucketEntry { - key: entry.key, - doc_count: entry.doc_count, - sub_aggregation: AggregationResults::from_intermediate_and_req_internal( - entry.sub_aggregation, - req, - )?, - to: entry.to, - from: entry.from, - }) - } -} diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index a2a4a87e59..70acf0f117 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -250,6 +250,11 @@ impl SegmentHistogramCollector { ); }; + agg_with_accessor + .bucket_count + .add_count(buckets.len() as u32); + agg_with_accessor.bucket_count.validate_bucket_count()?; + Ok(IntermediateBucketResult::Histogram { buckets }) } @@ -311,7 +316,7 @@ impl SegmentHistogramCollector { doc: &[DocId], bucket_with_accessor: &BucketAggregationWithAccessor, force_flush: bool, - ) { + ) -> crate::Result<()> { let bounds = self.bounds; let interval = self.interval; let offset = self.offset; @@ -341,28 +346,28 @@ impl SegmentHistogramCollector { bucket_pos0, docs[0], &bucket_with_accessor.sub_aggregation, - ); + )?; self.increment_bucket_if_in_bounds( val1, &bounds, bucket_pos1, docs[1], &bucket_with_accessor.sub_aggregation, - ); + )?; self.increment_bucket_if_in_bounds( val2, &bounds, bucket_pos2, docs[2], &bucket_with_accessor.sub_aggregation, - ); + )?; self.increment_bucket_if_in_bounds( val3, &bounds, bucket_pos3, docs[3], &bucket_with_accessor.sub_aggregation, - ); + )?; } for doc in iter.remainder() { let val = f64_from_fastfield_u64(accessor.get(*doc), &self.field_type); @@ -376,16 +381,17 @@ impl SegmentHistogramCollector { self.buckets[bucket_pos].key, get_bucket_val(val, self.interval, self.offset) as f64 ); - self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation); + self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation)?; } if force_flush { if let Some(sub_aggregations) = self.sub_aggregations.as_mut() { for sub_aggregation in sub_aggregations { sub_aggregation - .flush_staged_docs(&bucket_with_accessor.sub_aggregation, force_flush); + .flush_staged_docs(&bucket_with_accessor.sub_aggregation, force_flush)?; } } } + Ok(()) } #[inline] @@ -396,15 +402,16 @@ impl SegmentHistogramCollector { bucket_pos: usize, doc: DocId, bucket_with_accessor: &AggregationsWithAccessor, - ) { + ) -> crate::Result<()> { if bounds.contains(val) { debug_assert_eq!( self.buckets[bucket_pos].key, get_bucket_val(val, self.interval, self.offset) as f64 ); - self.increment_bucket(bucket_pos, doc, bucket_with_accessor); + self.increment_bucket(bucket_pos, doc, bucket_with_accessor)?; } + Ok(()) } #[inline] @@ -413,12 +420,13 @@ impl SegmentHistogramCollector { bucket_pos: usize, doc: DocId, bucket_with_accessor: &AggregationsWithAccessor, - ) { + ) -> crate::Result<()> { let bucket = &mut self.buckets[bucket_pos]; bucket.doc_count += 1; if let Some(sub_aggregation) = self.sub_aggregations.as_mut() { - (&mut sub_aggregation[bucket_pos]).collect(doc, bucket_with_accessor); + (&mut sub_aggregation[bucket_pos]).collect(doc, bucket_with_accessor)?; } + Ok(()) } fn f64_from_fastfield_u64(&self, val: u64) -> f64 { @@ -482,14 +490,12 @@ fn intermediate_buckets_to_final_buckets_fill_gaps( sub_aggregation: empty_sub_aggregation.clone(), }, }) - .map(|intermediate_bucket| { - BucketEntry::from_intermediate_and_req(intermediate_bucket, sub_aggregation) - }) + .map(|intermediate_bucket| intermediate_bucket.into_final_bucket_entry(sub_aggregation)) .collect::>>() } // Convert to BucketEntry -pub(crate) fn intermediate_buckets_to_final_buckets( +pub(crate) fn intermediate_histogram_buckets_to_final_buckets( buckets: Vec, histogram_req: &HistogramAggregation, sub_aggregation: &AggregationsInternal, @@ -503,8 +509,8 @@ pub(crate) fn intermediate_buckets_to_final_buckets( } else { buckets .into_iter() - .filter(|bucket| bucket.doc_count >= histogram_req.min_doc_count()) - .map(|bucket| BucketEntry::from_intermediate_and_req(bucket, sub_aggregation)) + .filter(|histogram_bucket| histogram_bucket.doc_count >= histogram_req.min_doc_count()) + .map(|histogram_bucket| histogram_bucket.into_final_bucket_entry(sub_aggregation)) .collect::>>() } } @@ -546,7 +552,7 @@ pub(crate) fn generate_buckets_with_opt_minmax( let offset = req.offset.unwrap_or(0.0); let first_bucket_num = get_bucket_num_f64(min, req.interval, offset) as i64; let last_bucket_num = get_bucket_num_f64(max, req.interval, offset) as i64; - let mut buckets = vec![]; + let mut buckets = Vec::with_capacity((first_bucket_num..=last_bucket_num).count()); for bucket_pos in first_bucket_num..=last_bucket_num { let bucket_key = bucket_pos as f64 * req.interval + offset; buckets.push(bucket_key); diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index d570e96a37..7faa500e7c 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -1,6 +1,7 @@ use std::fmt::Debug; use std::ops::Range; +use fnv::FnvHashMap; use serde::{Deserialize, Serialize}; use crate::aggregation::agg_req_with_accessor::{ @@ -9,8 +10,8 @@ use crate::aggregation::agg_req_with_accessor::{ use crate::aggregation::intermediate_agg_result::{ IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; -use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key}; +use crate::aggregation::segment_agg_result::{BucketCount, SegmentAggregationResultsCollector}; +use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key, SerializedKey}; use crate::fastfield::FastFieldReader; use crate::schema::Type; use crate::{DocId, TantivyError}; @@ -153,7 +154,7 @@ impl SegmentRangeCollector { ) -> crate::Result { let field_type = self.field_type; - let buckets = self + let buckets: FnvHashMap = self .buckets .into_iter() .map(move |range_bucket| { @@ -174,12 +175,13 @@ impl SegmentRangeCollector { pub(crate) fn from_req_and_validate( req: &RangeAggregation, sub_aggregation: &AggregationsWithAccessor, + bucket_count: &BucketCount, field_type: Type, ) -> crate::Result { // The range input on the request is f64. // We need to convert to u64 ranges, because we read the values as u64. // The mapping from the conversion is monotonic so ordering is preserved. - let buckets = extend_validate_ranges(&req.ranges, &field_type)? + let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)? .iter() .map(|range| { let to = if range.end == u64::MAX { @@ -212,6 +214,9 @@ impl SegmentRangeCollector { }) .collect::>()?; + bucket_count.add_count(buckets.len() as u32); + bucket_count.validate_bucket_count()?; + Ok(SegmentRangeCollector { buckets, field_type, @@ -224,7 +229,7 @@ impl SegmentRangeCollector { doc: &[DocId], bucket_with_accessor: &BucketAggregationWithAccessor, force_flush: bool, - ) { + ) -> crate::Result<()> { let mut iter = doc.chunks_exact(4); let accessor = bucket_with_accessor .accessor @@ -240,24 +245,25 @@ impl SegmentRangeCollector { let bucket_pos3 = self.get_bucket_pos(val3); let bucket_pos4 = self.get_bucket_pos(val4); - self.increment_bucket(bucket_pos1, docs[0], &bucket_with_accessor.sub_aggregation); - self.increment_bucket(bucket_pos2, docs[1], &bucket_with_accessor.sub_aggregation); - self.increment_bucket(bucket_pos3, docs[2], &bucket_with_accessor.sub_aggregation); - self.increment_bucket(bucket_pos4, docs[3], &bucket_with_accessor.sub_aggregation); + self.increment_bucket(bucket_pos1, docs[0], &bucket_with_accessor.sub_aggregation)?; + self.increment_bucket(bucket_pos2, docs[1], &bucket_with_accessor.sub_aggregation)?; + self.increment_bucket(bucket_pos3, docs[2], &bucket_with_accessor.sub_aggregation)?; + self.increment_bucket(bucket_pos4, docs[3], &bucket_with_accessor.sub_aggregation)?; } for doc in iter.remainder() { let val = accessor.get(*doc); let bucket_pos = self.get_bucket_pos(val); - self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation); + self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation)?; } if force_flush { for bucket in &mut self.buckets { if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation { sub_aggregation - .flush_staged_docs(&bucket_with_accessor.sub_aggregation, force_flush); + .flush_staged_docs(&bucket_with_accessor.sub_aggregation, force_flush)?; } } } + Ok(()) } #[inline] @@ -266,13 +272,14 @@ impl SegmentRangeCollector { bucket_pos: usize, doc: DocId, bucket_with_accessor: &AggregationsWithAccessor, - ) { + ) -> crate::Result<()> { let bucket = &mut self.buckets[bucket_pos]; bucket.bucket.doc_count += 1; if let Some(sub_aggregation) = &mut bucket.bucket.sub_aggregation { - sub_aggregation.collect(doc, bucket_with_accessor); + sub_aggregation.collect(doc, bucket_with_accessor)?; } + Ok(()) } #[inline] @@ -317,7 +324,7 @@ fn to_u64_range(range: &RangeAggregationRange, field_type: &Type) -> crate::Resu } /// Extends the provided buckets to contain the whole value range, by inserting buckets at the -/// beginning and end. +/// beginning and end and filling gaps. fn extend_validate_ranges( buckets: &[RangeAggregationRange], field_type: &Type, @@ -401,8 +408,13 @@ mod tests { ranges, }; - SegmentRangeCollector::from_req_and_validate(&req, &Default::default(), field_type) - .expect("unexpected error") + SegmentRangeCollector::from_req_and_validate( + &req, + &Default::default(), + &Default::default(), + field_type, + ) + .expect("unexpected error") } #[test] @@ -422,7 +434,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let reader = index.reader()?; let searcher = reader.searcher(); diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index c9833c8853..8a9970e0fd 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -11,7 +11,7 @@ use crate::aggregation::agg_req_with_accessor::{ use crate::aggregation::intermediate_agg_result::{ IntermediateBucketResult, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; +use crate::aggregation::segment_agg_result::{BucketCount, SegmentAggregationResultsCollector}; use crate::error::DataCorruption; use crate::fastfield::MultiValuedFastFieldReader; use crate::schema::Type; @@ -244,28 +244,33 @@ impl TermBuckets { &mut self, term_ids: &[u64], doc: DocId, - bucket_with_accessor: &AggregationsWithAccessor, + sub_aggregation: &AggregationsWithAccessor, + bucket_count: &BucketCount, blueprint: &Option, - ) { - // self.ensure_vec_exists(term_ids); + ) -> crate::Result<()> { for &term_id in term_ids { - let entry = self - .entries - .entry(term_id as u32) - .or_insert_with(|| TermBucketEntry::from_blueprint(blueprint)); + let entry = self.entries.entry(term_id as u32).or_insert_with(|| { + bucket_count.add_count(1); + + TermBucketEntry::from_blueprint(blueprint) + }); entry.doc_count += 1; if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { - sub_aggregations.collect(doc, bucket_with_accessor); + sub_aggregations.collect(doc, sub_aggregation)?; } } + bucket_count.validate_bucket_count()?; + + Ok(()) } - fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) { + fn force_flush(&mut self, agg_with_accessor: &AggregationsWithAccessor) -> crate::Result<()> { for entry in &mut self.entries.values_mut() { if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { - sub_aggregations.flush_staged_docs(agg_with_accessor, false); + sub_aggregations.flush_staged_docs(agg_with_accessor, false)?; } } + Ok(()) } } @@ -421,7 +426,7 @@ impl SegmentTermCollector { doc: &[DocId], bucket_with_accessor: &BucketAggregationWithAccessor, force_flush: bool, - ) { + ) -> crate::Result<()> { let accessor = bucket_with_accessor .accessor .as_multi() @@ -441,26 +446,30 @@ impl SegmentTermCollector { &vals1, docs[0], &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, - ); + )?; self.term_buckets.increment_bucket( &vals2, docs[1], &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, - ); + )?; self.term_buckets.increment_bucket( &vals3, docs[2], &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, - ); + )?; self.term_buckets.increment_bucket( &vals4, docs[3], &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, - ); + )?; } for &doc in iter.remainder() { accessor.get_vals(doc, &mut vals1); @@ -469,13 +478,15 @@ impl SegmentTermCollector { &vals1, doc, &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, - ); + )?; } if force_flush { self.term_buckets - .force_flush(&bucket_with_accessor.sub_aggregation); + .force_flush(&bucket_with_accessor.sub_aggregation)?; } + Ok(()) } } @@ -1173,6 +1184,33 @@ mod tests { Ok(()) } + #[test] + fn terms_aggregation_term_bucket_limit() -> crate::Result<()> { + let terms: Vec = (0..100_000).map(|el| el.to_string()).collect(); + let terms_per_segment = vec![terms.iter().map(|el| el.as_str()).collect()]; + + let index = get_test_index_from_terms(true, &terms_per_segment)?; + + let agg_req: Aggregations = vec![( + "my_texts".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Terms(TermsAggregation { + field: "string_id".to_string(), + min_doc_count: Some(0), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request_with_query(agg_req, &index, None); + assert!(res.is_err()); + + Ok(()) + } + #[test] fn test_json_format() -> crate::Result<()> { let agg_req: Aggregations = vec![( @@ -1291,9 +1329,15 @@ mod bench { let mut collector = get_collector_with_buckets(total_terms); let vals = get_rand_terms(total_terms, num_terms); let aggregations_with_accessor: AggregationsWithAccessor = Default::default(); + let bucket_count: BucketCount = BucketCount { + bucket_count: Default::default(), + max_bucket_count: 1_000_001u32, + }; b.iter(|| { for &val in &vals { - collector.increment_bucket(&[val], 0, &aggregations_with_accessor, &None); + collector + .increment_bucket(&[val], 0, &aggregations_with_accessor, &bucket_count, &None) + .unwrap(); } }) } diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index f35a5e3e17..c9510d9263 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -1,3 +1,5 @@ +use std::rc::Rc; + use super::agg_req::Aggregations; use super::agg_req_with_accessor::AggregationsWithAccessor; use super::agg_result::AggregationResults; @@ -5,19 +7,27 @@ use super::intermediate_agg_result::IntermediateAggregationResults; use super::segment_agg_result::SegmentAggregationResultsCollector; use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate; use crate::collector::{Collector, SegmentCollector}; -use crate::SegmentReader; +use crate::{SegmentReader, TantivyError}; + +pub const MAX_BUCKET_COUNT: u32 = 65000; /// Collector for aggregations. /// /// The collector collects all aggregations by the underlying aggregation request. pub struct AggregationCollector { agg: Aggregations, + max_bucket_count: u32, } impl AggregationCollector { /// Create collector from aggregation request. - pub fn from_aggs(agg: Aggregations) -> Self { - Self { agg } + /// + /// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset + pub fn from_aggs(agg: Aggregations, max_bucket_count: Option) -> Self { + Self { + agg, + max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT), + } } } @@ -28,15 +38,21 @@ impl AggregationCollector { /// # Purpose /// AggregationCollector returns `IntermediateAggregationResults` and not the final /// `AggregationResults`, so that results from differenct indices can be merged and then converted -/// into the final `AggregationResults` via the `into()` method. +/// into the final `AggregationResults` via the `into_final_result()` method. pub struct DistributedAggregationCollector { agg: Aggregations, + max_bucket_count: u32, } impl DistributedAggregationCollector { /// Create collector from aggregation request. - pub fn from_aggs(agg: Aggregations) -> Self { - Self { agg } + /// + /// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset + pub fn from_aggs(agg: Aggregations, max_bucket_count: Option) -> Self { + Self { + agg, + max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT), + } } } @@ -50,7 +66,11 @@ impl Collector for DistributedAggregationCollector { _segment_local_id: crate::SegmentOrdinal, reader: &crate::SegmentReader, ) -> crate::Result { - AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader) + AggregationSegmentCollector::from_agg_req_and_reader( + &self.agg, + reader, + self.max_bucket_count, + ) } fn requires_scoring(&self) -> bool { @@ -75,7 +95,11 @@ impl Collector for AggregationCollector { _segment_local_id: crate::SegmentOrdinal, reader: &crate::SegmentReader, ) -> crate::Result { - AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader) + AggregationSegmentCollector::from_agg_req_and_reader( + &self.agg, + reader, + self.max_bucket_count, + ) } fn requires_scoring(&self) -> bool { @@ -87,7 +111,7 @@ impl Collector for AggregationCollector { segment_fruits: Vec<::Fruit>, ) -> crate::Result { let res = merge_fruits(segment_fruits)?; - AggregationResults::from_intermediate_and_req(res, self.agg.clone()) + res.into_final_bucket_result(self.agg.clone()) } } @@ -109,6 +133,7 @@ fn merge_fruits( pub struct AggregationSegmentCollector { aggs_with_accessor: AggregationsWithAccessor, result: SegmentAggregationResultsCollector, + error: Option, } impl AggregationSegmentCollector { @@ -117,13 +142,16 @@ impl AggregationSegmentCollector { pub fn from_agg_req_and_reader( agg: &Aggregations, reader: &SegmentReader, + max_bucket_count: u32, ) -> crate::Result { - let aggs_with_accessor = get_aggs_with_accessor_and_validate(agg, reader)?; + let aggs_with_accessor = + get_aggs_with_accessor_and_validate(agg, reader, Rc::default(), max_bucket_count)?; let result = SegmentAggregationResultsCollector::from_req_and_validate(&aggs_with_accessor)?; Ok(AggregationSegmentCollector { aggs_with_accessor, result, + error: None, }) } } @@ -133,12 +161,20 @@ impl SegmentCollector for AggregationSegmentCollector { #[inline] fn collect(&mut self, doc: crate::DocId, _score: crate::Score) { - self.result.collect(doc, &self.aggs_with_accessor); + if self.error.is_some() { + return; + } + if let Err(err) = self.result.collect(doc, &self.aggs_with_accessor) { + self.error = Some(err); + } } fn harvest(mut self) -> Self::Fruit { + if let Some(err) = self.error { + return Err(err); + } self.result - .flush_staged_docs(&self.aggs_with_accessor, true); + .flush_staged_docs(&self.aggs_with_accessor, true)?; self.result .into_intermediate_aggregations_result(&self.aggs_with_accessor) } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 9bde00707e..20eef59c07 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -3,16 +3,20 @@ //! indices. use std::cmp::Ordering; +use std::collections::HashMap; use fnv::FnvHashMap; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use super::agg_req::{AggregationsInternal, BucketAggregationType, MetricAggregation}; -use super::agg_result::BucketResult; +use super::agg_req::{ + Aggregations, AggregationsInternal, BucketAggregationInternal, BucketAggregationType, + MetricAggregation, +}; +use super::agg_result::{AggregationResult, BucketResult, RangeBucketEntry}; use super::bucket::{ - cut_off_buckets, get_agg_name_and_property, GetDocCount, Order, OrderTarget, - SegmentHistogramBucketEntry, TermsAggregation, + cut_off_buckets, get_agg_name_and_property, intermediate_histogram_buckets_to_final_buckets, + GetDocCount, Order, OrderTarget, SegmentHistogramBucketEntry, TermsAggregation, }; use super::metric::{IntermediateAverage, IntermediateStats}; use super::segment_agg_result::SegmentMetricResultCollector; @@ -31,6 +35,46 @@ pub struct IntermediateAggregationResults { } impl IntermediateAggregationResults { + /// Convert intermediate result and its aggregation request to the final result. + pub(crate) fn into_final_bucket_result( + self, + req: Aggregations, + ) -> crate::Result { + self.into_final_bucket_result_internal(&(req.into())) + } + + /// Convert intermediate result and its aggregation request to the final result. + /// + /// Internal function, AggregationsInternal is used instead Aggregations, which is optimized + /// for internal processing, by splitting metric and buckets into seperate groups. + pub(crate) fn into_final_bucket_result_internal( + self, + req: &AggregationsInternal, + ) -> crate::Result { + // Important assumption: + // When the tree contains buckets/metric, we expect it to have all buckets/metrics from the + // request + let mut results: HashMap = HashMap::new(); + + if let Some(buckets) = self.buckets { + convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets)? + } else { + // When there are no buckets, we create empty buckets, so that the serialized json + // format is constant + add_empty_final_buckets_to_result(&mut results, &req.buckets)? + }; + + if let Some(metrics) = self.metrics { + convert_and_add_final_metrics_to_result(&mut results, metrics); + } else { + // When there are no metrics, we create empty metric results, so that the serialized + // json format is constant + add_empty_final_metrics_to_result(&mut results, &req.metrics)?; + } + + Ok(AggregationResults(results)) + } + pub(crate) fn empty_from_req(req: &AggregationsInternal) -> Self { let metrics = if req.metrics.is_empty() { None @@ -90,6 +134,58 @@ impl IntermediateAggregationResults { } } +fn convert_and_add_final_metrics_to_result( + results: &mut HashMap, + metrics: VecWithNames, +) { + results.extend( + metrics + .into_iter() + .map(|(key, metric)| (key, AggregationResult::MetricResult(metric.into()))), + ); +} + +fn add_empty_final_metrics_to_result( + results: &mut HashMap, + req_metrics: &VecWithNames, +) -> crate::Result<()> { + results.extend(req_metrics.iter().map(|(key, req)| { + let empty_bucket = IntermediateMetricResult::empty_from_req(req); + ( + key.to_string(), + AggregationResult::MetricResult(empty_bucket.into()), + ) + })); + Ok(()) +} + +fn add_empty_final_buckets_to_result( + results: &mut HashMap, + req_buckets: &VecWithNames, +) -> crate::Result<()> { + let requested_buckets = req_buckets.iter(); + for (key, req) in requested_buckets { + let empty_bucket = AggregationResult::BucketResult(BucketResult::empty_from_req(req)?); + results.insert(key.to_string(), empty_bucket); + } + Ok(()) +} + +fn convert_and_add_final_buckets_to_result( + results: &mut HashMap, + buckets: VecWithNames, + req_buckets: &VecWithNames, +) -> crate::Result<()> { + assert_eq!(buckets.len(), req_buckets.len()); + + let buckets_with_request = buckets.into_iter().zip(req_buckets.values()); + for ((key, bucket), req) in buckets_with_request { + let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req)?); + results.insert(key, result); + } + Ok(()) +} + /// An aggregation is either a bucket or a metric. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum IntermediateAggregationResult { @@ -171,6 +267,45 @@ pub enum IntermediateBucketResult { } impl IntermediateBucketResult { + pub(crate) fn into_final_bucket_result( + self, + req: &BucketAggregationInternal, + ) -> crate::Result { + match self { + IntermediateBucketResult::Range(range_res) => { + let mut buckets: Vec = range_res + .buckets + .into_iter() + .map(|(_, bucket)| bucket.into_final_bucket_entry(&req.sub_aggregation)) + .collect::>>()?; + + buckets.sort_by(|left, right| { + // TODO use total_cmp next stable rust release + left.from + .unwrap_or(f64::MIN) + .partial_cmp(&right.from.unwrap_or(f64::MIN)) + .unwrap_or(Ordering::Equal) + }); + Ok(BucketResult::Range { buckets }) + } + IntermediateBucketResult::Histogram { buckets } => { + let buckets = intermediate_histogram_buckets_to_final_buckets( + buckets, + req.as_histogram() + .expect("unexpected aggregation, expected histogram aggregation"), + &req.sub_aggregation, + )?; + + Ok(BucketResult::Histogram { buckets }) + } + IntermediateBucketResult::Terms(terms) => terms.into_final_result( + req.as_term() + .expect("unexpected aggregation, expected term aggregation"), + &req.sub_aggregation, + ), + } + } + pub(crate) fn empty_from_req(req: &BucketAggregationType) -> Self { match req { BucketAggregationType::Terms(_) => IntermediateBucketResult::Terms(Default::default()), @@ -267,10 +402,9 @@ impl IntermediateTermBucketResult { Ok(BucketEntry { key: Key::Str(key), doc_count: entry.doc_count, - sub_aggregation: AggregationResults::from_intermediate_and_req_internal( - entry.sub_aggregation, - sub_aggregation_req, - )?, + sub_aggregation: entry + .sub_aggregation + .into_final_bucket_result_internal(sub_aggregation_req)?, }) }) .collect::>()?; @@ -374,6 +508,21 @@ pub struct IntermediateHistogramBucketEntry { pub sub_aggregation: IntermediateAggregationResults, } +impl IntermediateHistogramBucketEntry { + pub(crate) fn into_final_bucket_entry( + self, + req: &AggregationsInternal, + ) -> crate::Result { + Ok(BucketEntry { + key: Key::F64(self.key), + doc_count: self.doc_count, + sub_aggregation: self + .sub_aggregation + .into_final_bucket_result_internal(req)?, + }) + } +} + impl From for IntermediateHistogramBucketEntry { fn from(entry: SegmentHistogramBucketEntry) -> Self { IntermediateHistogramBucketEntry { @@ -402,6 +551,23 @@ pub struct IntermediateRangeBucketEntry { pub to: Option, } +impl IntermediateRangeBucketEntry { + pub(crate) fn into_final_bucket_entry( + self, + req: &AggregationsInternal, + ) -> crate::Result { + Ok(RangeBucketEntry { + key: self.key, + doc_count: self.doc_count, + sub_aggregation: self + .sub_aggregation + .into_final_bucket_result_internal(req)?, + to: self.to, + from: self.from, + }) + } +} + /// This is the term entry for a bucket, which contains a count, and optionally /// sub_aggregations. #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 0498ffbe80..2f704b17d0 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -222,7 +222,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let reader = index.reader()?; let searcher = reader.searcher(); @@ -299,7 +299,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 37fa05c0ff..7f6f8378ce 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -28,7 +28,7 @@ //! //! ```verbatim //! let agg_req: Aggregations = serde_json::from_str(json_request_string).unwrap(); -//! let collector = AggregationCollector::from_aggs(agg_req); +//! let collector = AggregationCollector::from_aggs(agg_req, None); //! let searcher = reader.searcher(); //! let agg_res = searcher.search(&term_query, &collector).unwrap_err(); //! let json_response_string: String = &serde_json::to_string(&agg_res)?; @@ -68,7 +68,7 @@ //! .into_iter() //! .collect(); //! -//! let collector = AggregationCollector::from_aggs(agg_req); +//! let collector = AggregationCollector::from_aggs(agg_req, None); //! //! let searcher = reader.searcher(); //! let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); @@ -358,7 +358,7 @@ mod tests { index: &Index, query: Option<(&str, &str)>, ) -> crate::Result { - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let reader = index.reader()?; let searcher = reader.searcher(); @@ -417,7 +417,9 @@ mod tests { let mut schema_builder = Schema::builder(); let text_fieldtype = crate::schema::TextOptions::default() .set_indexing_options( - TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs), + TextFieldIndexing::default() + .set_index_option(IndexRecordOption::Basic) + .set_fieldnorms(false), ) .set_fast() .set_stored(); @@ -435,7 +437,8 @@ mod tests { ); let index = Index::create_in_ram(schema_builder.build()); { - let mut index_writer = index.writer_for_tests()?; + // let mut index_writer = index.writer_for_tests()?; + let mut index_writer = index.writer_with_num_threads(1, 30_000_000)?; for values in segment_and_values { for (i, term) in values { let i = *i; @@ -457,9 +460,11 @@ mod tests { let segment_ids = index .searchable_segment_ids() .expect("Searchable segments failed."); - let mut index_writer = index.writer_for_tests()?; - index_writer.merge(&segment_ids).wait()?; - index_writer.wait_merging_threads()?; + if segment_ids.len() > 1 { + let mut index_writer = index.writer_for_tests()?; + index_writer.merge(&segment_ids).wait()?; + index_writer.wait_merging_threads()?; + } } Ok(index) @@ -542,16 +547,15 @@ mod tests { .unwrap(); let agg_res: AggregationResults = if use_distributed_collector { - let collector = DistributedAggregationCollector::from_aggs(agg_req.clone()); + let collector = DistributedAggregationCollector::from_aggs(agg_req.clone(), None); let searcher = reader.searcher(); - AggregationResults::from_intermediate_and_req( - searcher.search(&AllQuery, &collector).unwrap(), - agg_req, - ) - .unwrap() + let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap(); + intermediate_agg_result + .into_final_bucket_result(agg_req) + .unwrap() } else { - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -788,7 +792,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); @@ -978,16 +982,16 @@ mod tests { assert_eq!(field_names, vec!["text".to_string()].into_iter().collect()); let agg_res: AggregationResults = if use_distributed_collector { - let collector = DistributedAggregationCollector::from_aggs(agg_req.clone()); + let collector = DistributedAggregationCollector::from_aggs(agg_req.clone(), None); let searcher = reader.searcher(); let res = searcher.search(&term_query, &collector).unwrap(); // Test de/serialization roundtrip on intermediate_agg_result let res: IntermediateAggregationResults = serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap(); - AggregationResults::from_intermediate_and_req(res, agg_req.clone()).unwrap() + res.into_final_bucket_result(agg_req.clone()).unwrap() } else { - let collector = AggregationCollector::from_aggs(agg_req.clone()); + let collector = AggregationCollector::from_aggs(agg_req.clone(), None); let searcher = reader.searcher(); searcher.search(&term_query, &collector).unwrap() @@ -1045,7 +1049,7 @@ mod tests { ); // Test empty result set - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let searcher = reader.searcher(); searcher.search(&query_with_no_hits, &collector).unwrap(); @@ -1110,7 +1114,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); @@ -1223,7 +1227,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1254,7 +1258,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1285,7 +1289,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1324,7 +1328,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1353,7 +1357,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1382,7 +1386,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1418,7 +1422,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1453,7 +1457,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1492,7 +1496,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1522,7 +1526,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1578,7 +1582,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 81f2b85de9..fe07400897 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -4,19 +4,22 @@ //! merging. use std::fmt::Debug; +use std::rc::Rc; +use std::sync::atomic::AtomicU32; use super::agg_req::MetricAggregation; use super::agg_req_with_accessor::{ AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor, }; use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector}; +use super::collector::MAX_BUCKET_COUNT; use super::intermediate_agg_result::{IntermediateAggregationResults, IntermediateBucketResult}; use super::metric::{ AverageAggregation, SegmentAverageCollector, SegmentStatsCollector, StatsAggregation, }; use super::VecWithNames; use crate::aggregation::agg_req::BucketAggregationType; -use crate::DocId; +use crate::{DocId, TantivyError}; pub(crate) const DOC_BLOCK_SIZE: usize = 64; pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE]; @@ -115,21 +118,22 @@ impl SegmentAggregationResultsCollector { &mut self, doc: crate::DocId, agg_with_accessor: &AggregationsWithAccessor, - ) { + ) -> crate::Result<()> { self.staged_docs[self.num_staged_docs] = doc; self.num_staged_docs += 1; if self.num_staged_docs == self.staged_docs.len() { - self.flush_staged_docs(agg_with_accessor, false); + self.flush_staged_docs(agg_with_accessor, false)?; } + Ok(()) } pub(crate) fn flush_staged_docs( &mut self, agg_with_accessor: &AggregationsWithAccessor, force_flush: bool, - ) { + ) -> crate::Result<()> { if self.num_staged_docs == 0 { - return; + return Ok(()); } if let Some(metrics) = &mut self.metrics { for (collector, agg_with_accessor) in @@ -148,11 +152,12 @@ impl SegmentAggregationResultsCollector { &self.staged_docs[..self.num_staged_docs], agg_with_accessor, force_flush, - ); + )?; } } self.num_staged_docs = 0; + Ok(()) } } @@ -234,6 +239,7 @@ impl SegmentBucketResultCollector { Ok(Self::Range(SegmentRangeCollector::from_req_and_validate( range_req, &req.sub_aggregation, + &req.bucket_count, req.field_type, )?)) } @@ -256,17 +262,52 @@ impl SegmentBucketResultCollector { doc: &[DocId], bucket_with_accessor: &BucketAggregationWithAccessor, force_flush: bool, - ) { + ) -> crate::Result<()> { match self { SegmentBucketResultCollector::Range(range) => { - range.collect_block(doc, bucket_with_accessor, force_flush); + range.collect_block(doc, bucket_with_accessor, force_flush)?; } SegmentBucketResultCollector::Histogram(histogram) => { - histogram.collect_block(doc, bucket_with_accessor, force_flush) + histogram.collect_block(doc, bucket_with_accessor, force_flush)?; } SegmentBucketResultCollector::Terms(terms) => { - terms.collect_block(doc, bucket_with_accessor, force_flush) + terms.collect_block(doc, bucket_with_accessor, force_flush)?; } } + Ok(()) + } +} + +#[derive(Clone)] +pub(crate) struct BucketCount { + /// The counter which is shared between the aggregations for one request. + pub(crate) bucket_count: Rc, + pub(crate) max_bucket_count: u32, +} + +impl Default for BucketCount { + fn default() -> Self { + Self { + bucket_count: Default::default(), + max_bucket_count: MAX_BUCKET_COUNT, + } + } +} + +impl BucketCount { + pub(crate) fn validate_bucket_count(&self) -> crate::Result<()> { + if self.get_count() > self.max_bucket_count { + return Err(TantivyError::InvalidArgument( + "Aborting aggregation because too many buckets were created".to_string(), + )); + } + Ok(()) + } + pub(crate) fn add_count(&self, count: u32) { + self.bucket_count + .fetch_add(count as u32, std::sync::atomic::Ordering::Relaxed); + } + pub(crate) fn get_count(&self) -> u32 { + self.bucket_count.load(std::sync::atomic::Ordering::Relaxed) } }