Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aggregation bucket limit #1363

Merged
merged 10 commits into from Jun 23, 2022
2 changes: 1 addition & 1 deletion examples/aggregation.rs
Expand Up @@ -117,7 +117,7 @@ fn main() -> tantivy::Result<()> {
.into_iter()
.collect();

let collector = AggregationCollector::from_aggs(agg_req_1);
let collector = AggregationCollector::from_aggs(agg_req_1, None);

let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();
Expand Down
21 changes: 20 additions & 1 deletion src/aggregation/agg_req_with_accessor.rs
@@ -1,10 +1,13 @@
//! This will enhance the request tree with access to the fastfield and metadata.

use std::rc::Rc;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;

use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation};
use super::bucket::{HistogramAggregation, RangeAggregation, TermsAggregation};
use super::metric::{AverageAggregation, StatsAggregation};
use super::segment_agg_result::BucketCount;
use super::VecWithNames;
use crate::fastfield::{
type_and_cardinality, DynamicFastFieldReader, FastType, MultiValuedFastFieldReader,
Expand Down Expand Up @@ -60,13 +63,16 @@ pub struct BucketAggregationWithAccessor {
pub(crate) field_type: Type,
pub(crate) bucket_agg: BucketAggregationType,
pub(crate) sub_aggregation: AggregationsWithAccessor,
pub(crate) bucket_count: BucketCount,
}

impl BucketAggregationWithAccessor {
fn try_from_bucket(
bucket: &BucketAggregationType,
sub_aggregation: &Aggregations,
reader: &SegmentReader,
bucket_count: Rc<AtomicU32>,
max_bucket_count: u32,
) -> crate::Result<BucketAggregationWithAccessor> {
let mut inverted_index = None;
let (accessor, field_type) = match &bucket {
Expand All @@ -92,9 +98,18 @@ impl BucketAggregationWithAccessor {
Ok(BucketAggregationWithAccessor {
accessor,
field_type,
sub_aggregation: get_aggs_with_accessor_and_validate(&sub_aggregation, reader)?,
sub_aggregation: get_aggs_with_accessor_and_validate(
&sub_aggregation,
reader,
bucket_count.clone(),
max_bucket_count,
)?,
bucket_agg: bucket.clone(),
inverted_index,
bucket_count: BucketCount {
bucket_count,
max_bucket_count,
},
})
}
}
Expand Down Expand Up @@ -134,6 +149,8 @@ impl MetricAggregationWithAccessor {
pub(crate) fn get_aggs_with_accessor_and_validate(
aggs: &Aggregations,
reader: &SegmentReader,
bucket_count: Rc<AtomicU32>,
max_bucket_count: u32,
) -> crate::Result<AggregationsWithAccessor> {
let mut metrics = vec![];
let mut buckets = vec![];
Expand All @@ -145,6 +162,8 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
&bucket.bucket_agg,
&bucket.sub_aggregation,
reader,
Rc::clone(&bucket_count),
max_bucket_count,
)?,
)),
Aggregation::Metric(metric) => metrics.push((
Expand Down
186 changes: 7 additions & 179 deletions src/aggregation/agg_result.rs
Expand Up @@ -4,21 +4,15 @@
//! intermediate average results, which is the sum and the number of values. The actual average is
//! calculated on the step from intermediate to final aggregation result tree.

use std::cmp::Ordering;
use std::collections::HashMap;

use serde::{Deserialize, Serialize};

use super::agg_req::{
Aggregations, AggregationsInternal, BucketAggregationInternal, MetricAggregation,
};
use super::bucket::{intermediate_buckets_to_final_buckets, GetDocCount};
use super::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry,
IntermediateMetricResult, IntermediateRangeBucketEntry,
};
use super::agg_req::BucketAggregationInternal;
use super::bucket::GetDocCount;
use super::intermediate_agg_result::{IntermediateBucketResult, IntermediateMetricResult};
use super::metric::{SingleMetricResult, Stats};
use super::{Key, VecWithNames};
use super::Key;
use crate::TantivyError;

#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
Expand All @@ -41,98 +35,6 @@ impl AggregationResults {
)))
}
}

/// Convert and intermediate result and its aggregation request to the final result
pub fn from_intermediate_and_req(
results: IntermediateAggregationResults,
agg: Aggregations,
) -> crate::Result<Self> {
AggregationResults::from_intermediate_and_req_internal(results, &(agg.into()))
}

/// Convert and intermediate result and its aggregation request to the final result
///
/// Internal function, CollectorAggregations is used instead Aggregations, which is optimized
/// for internal processing, by splitting metric and buckets into seperate groups.
pub(crate) fn from_intermediate_and_req_internal(
intermediate_results: IntermediateAggregationResults,
req: &AggregationsInternal,
) -> crate::Result<Self> {
// Important assumption:
// When the tree contains buckets/metric, we expect it to have all buckets/metrics from the
// request
let mut results: HashMap<String, AggregationResult> = HashMap::new();

if let Some(buckets) = intermediate_results.buckets {
add_coverted_final_buckets_to_result(&mut results, buckets, &req.buckets)?
} else {
// When there are no buckets, we create empty buckets, so that the serialized json
// format is constant
add_empty_final_buckets_to_result(&mut results, &req.buckets)?
};

if let Some(metrics) = intermediate_results.metrics {
add_converted_final_metrics_to_result(&mut results, metrics);
} else {
// When there are no metrics, we create empty metric results, so that the serialized
// json format is constant
add_empty_final_metrics_to_result(&mut results, &req.metrics)?;
}
Ok(Self(results))
}
}

fn add_converted_final_metrics_to_result(
results: &mut HashMap<String, AggregationResult>,
metrics: VecWithNames<IntermediateMetricResult>,
) {
results.extend(
metrics
.into_iter()
.map(|(key, metric)| (key, AggregationResult::MetricResult(metric.into()))),
);
}

fn add_empty_final_metrics_to_result(
results: &mut HashMap<String, AggregationResult>,
req_metrics: &VecWithNames<MetricAggregation>,
) -> crate::Result<()> {
results.extend(req_metrics.iter().map(|(key, req)| {
let empty_bucket = IntermediateMetricResult::empty_from_req(req);
(
key.to_string(),
AggregationResult::MetricResult(empty_bucket.into()),
)
}));
Ok(())
}

fn add_empty_final_buckets_to_result(
results: &mut HashMap<String, AggregationResult>,
req_buckets: &VecWithNames<BucketAggregationInternal>,
) -> crate::Result<()> {
let requested_buckets = req_buckets.iter();
for (key, req) in requested_buckets {
let empty_bucket = AggregationResult::BucketResult(BucketResult::empty_from_req(req)?);
results.insert(key.to_string(), empty_bucket);
}
Ok(())
}

fn add_coverted_final_buckets_to_result(
results: &mut HashMap<String, AggregationResult>,
buckets: VecWithNames<IntermediateBucketResult>,
req_buckets: &VecWithNames<BucketAggregationInternal>,
) -> crate::Result<()> {
assert_eq!(buckets.len(), req_buckets.len());

let buckets_with_request = buckets.into_iter().zip(req_buckets.values());
for ((key, bucket), req) in buckets_with_request {
let result =
AggregationResult::BucketResult(BucketResult::from_intermediate_and_req(bucket, req)?);
results.insert(key, result);
}
Ok(())
}

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
Expand All @@ -154,7 +56,8 @@ impl AggregationResult {
match self {
AggregationResult::BucketResult(_bucket) => Err(TantivyError::InternalError(
"Tried to retrieve value from bucket aggregation. This is not supported and \
should not happen during collection, but should be catched during validation"
should not happen during collection phase, but should be catched during \
validation"
.to_string(),
)),
AggregationResult::MetricResult(metric) => metric.get_value(agg_property),
Expand Down Expand Up @@ -230,48 +133,7 @@ pub enum BucketResult {
impl BucketResult {
pub(crate) fn empty_from_req(req: &BucketAggregationInternal) -> crate::Result<Self> {
let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg);
BucketResult::from_intermediate_and_req(empty_bucket, req)
}

fn from_intermediate_and_req(
bucket_result: IntermediateBucketResult,
req: &BucketAggregationInternal,
) -> crate::Result<Self> {
match bucket_result {
IntermediateBucketResult::Range(range_res) => {
let mut buckets: Vec<RangeBucketEntry> = range_res
.buckets
.into_iter()
.map(|(_, bucket)| {
RangeBucketEntry::from_intermediate_and_req(bucket, &req.sub_aggregation)
})
.collect::<crate::Result<Vec<_>>>()?;

buckets.sort_by(|left, right| {
// TODO use total_cmp next stable rust release
left.from
.unwrap_or(f64::MIN)
.partial_cmp(&right.from.unwrap_or(f64::MIN))
.unwrap_or(Ordering::Equal)
});
Ok(BucketResult::Range { buckets })
}
IntermediateBucketResult::Histogram { buckets } => {
let buckets = intermediate_buckets_to_final_buckets(
buckets,
req.as_histogram()
.expect("unexpected aggregation, expected histogram aggregation"),
&req.sub_aggregation,
)?;

Ok(BucketResult::Histogram { buckets })
}
IntermediateBucketResult::Terms(terms) => terms.into_final_result(
req.as_term()
.expect("unexpected aggregation, expected term aggregation"),
&req.sub_aggregation,
),
}
empty_bucket.into_final_bucket_result(req)
}
}

Expand Down Expand Up @@ -311,22 +173,6 @@ pub struct BucketEntry {
/// Sub-aggregations in this bucket.
pub sub_aggregation: AggregationResults,
}

impl BucketEntry {
pub(crate) fn from_intermediate_and_req(
entry: IntermediateHistogramBucketEntry,
req: &AggregationsInternal,
) -> crate::Result<Self> {
Ok(BucketEntry {
key: Key::F64(entry.key),
doc_count: entry.doc_count,
sub_aggregation: AggregationResults::from_intermediate_and_req_internal(
entry.sub_aggregation,
req,
)?,
})
}
}
impl GetDocCount for &BucketEntry {
fn doc_count(&self) -> u64 {
self.doc_count
Expand Down Expand Up @@ -384,21 +230,3 @@ pub struct RangeBucketEntry {
#[serde(skip_serializing_if = "Option::is_none")]
pub to: Option<f64>,
}

impl RangeBucketEntry {
fn from_intermediate_and_req(
entry: IntermediateRangeBucketEntry,
req: &AggregationsInternal,
) -> crate::Result<Self> {
Ok(RangeBucketEntry {
key: entry.key,
doc_count: entry.doc_count,
sub_aggregation: AggregationResults::from_intermediate_and_req_internal(
entry.sub_aggregation,
req,
)?,
to: entry.to,
from: entry.from,
})
}
}