From 4cb8932ee18d2ad45c13f647f67070d7172560ae Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Wed, 9 Mar 2022 11:21:14 +0100 Subject: [PATCH] Add Histogram aggregation --- src/aggregation/agg_req.rs | 7 + src/aggregation/agg_req_with_accessor.rs | 5 +- src/aggregation/agg_result.rs | 105 ++- src/aggregation/bucket/histogram/histogram.rs | 753 ++++++++++++++++++ src/aggregation/bucket/histogram/mod.rs | 2 + src/aggregation/bucket/mod.rs | 3 + src/aggregation/bucket/range.rs | 2 +- src/aggregation/intermediate_agg_result.rs | 137 +++- src/aggregation/mod.rs | 121 ++- src/aggregation/segment_agg_result.rs | 21 +- 10 files changed, 1116 insertions(+), 40 deletions(-) create mode 100644 src/aggregation/bucket/histogram/histogram.rs create mode 100644 src/aggregation/bucket/histogram/mod.rs diff --git a/src/aggregation/agg_req.rs b/src/aggregation/agg_req.rs index 0ceb6cb1c4..aaf4ee7603 100644 --- a/src/aggregation/agg_req.rs +++ b/src/aggregation/agg_req.rs @@ -48,6 +48,7 @@ use std::collections::{HashMap, HashSet}; use serde::{Deserialize, Serialize}; +use super::bucket::HistogramAggregation; pub use super::bucket::RangeAggregation; use super::metric::{AverageAggregation, StatsAggregation}; @@ -123,12 +124,18 @@ pub enum BucketAggregationType { /// Put data into buckets of user-defined ranges. #[serde(rename = "range")] Range(RangeAggregation), + /// Put data into buckets of user-defined ranges. + #[serde(rename = "histogram")] + Histogram(HistogramAggregation), } impl BucketAggregationType { fn get_fast_field_names(&self, fast_field_names: &mut HashSet) { match self { BucketAggregationType::Range(range) => fast_field_names.insert(range.field.to_string()), + BucketAggregationType::Histogram(histogram) => { + fast_field_names.insert(histogram.field.to_string()) + } }; } } diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index c84f11cefd..bf87e51009 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -1,7 +1,7 @@ //! This will enhance the request tree with access to the fastfield and metadata. use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation}; -use super::bucket::RangeAggregation; +use super::bucket::{HistogramAggregation, RangeAggregation}; use super::metric::{AverageAggregation, StatsAggregation}; use super::VecWithNames; use crate::fastfield::{type_and_cardinality, DynamicFastFieldReader, FastType}; @@ -48,6 +48,9 @@ impl BucketAggregationWithAccessor { field: field_name, ranges: _, }) => get_ff_reader_and_validate(reader, field_name)?, + BucketAggregationType::Histogram(HistogramAggregation { + field: field_name, .. + }) => get_ff_reader_and_validate(reader, field_name)?, }; let sub_aggregation = sub_aggregation.clone(); Ok(BucketAggregationWithAccessor { diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index 3d0ed20b29..0dbb6c5fc5 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -10,14 +10,15 @@ use std::collections::HashMap; use itertools::Itertools; use serde::{Deserialize, Serialize}; +use super::bucket::generate_buckets; use super::intermediate_agg_result::{ IntermediateAggregationResult, IntermediateAggregationResults, IntermediateBucketResult, - IntermediateMetricResult, IntermediateRangeBucketEntry, + IntermediateHistogramBucketEntry, IntermediateMetricResult, IntermediateRangeBucketEntry, }; use super::metric::{SingleMetricResult, Stats}; use super::Key; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] /// The final aggegation result. pub struct AggregationResults(pub HashMap); @@ -81,12 +82,18 @@ impl From for MetricResult { #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[serde(untagged)] pub enum BucketResult { - /// This is the default entry for a bucket, which contains a key, count, and optionally + /// This is the range entry for a bucket, which contains a key, count, from, to, and optionally /// sub_aggregations. Range { /// The range buckets sorted by range. buckets: Vec, }, + /// This is the histogram entry for a bucket, which contains a key, count, and optionally + /// sub_aggregations. + Histogram { + /// The buckets. + buckets: Vec, + }, } impl From for BucketResult { @@ -106,6 +113,96 @@ impl From for BucketResult { }); BucketResult::Range { buckets } } + IntermediateBucketResult::Histogram { buckets, req } => { + let buckets = if req.min_doc_count() == 0 { + // We need to fill up the buckets for the total ranges, so that there are no + // gaps + let max = buckets + .iter() + .map(|bucket| bucket.key) + .fold(f64::NEG_INFINITY, f64::max); + let min = buckets + .iter() + .map(|bucket| bucket.key) + .fold(f64::INFINITY, f64::min); + let all_buckets = generate_buckets(&req, min, max); + + buckets + .into_iter() + .merge_join_by(all_buckets.into_iter(), |existing_bucket, all_bucket| { + existing_bucket + .key + .partial_cmp(all_bucket) + .unwrap_or(Ordering::Equal) + }) + .map(|either| match either { + itertools::EitherOrBoth::Both(existing, _) => existing.into(), + itertools::EitherOrBoth::Left(existing) => existing.into(), + // Add missing bucket + itertools::EitherOrBoth::Right(bucket) => BucketEntry { + key: Key::F64(bucket), + doc_count: 0, + sub_aggregation: Default::default(), + }, + }) + .collect_vec() + } else { + buckets + .into_iter() + .filter(|bucket| bucket.doc_count >= req.min_doc_count()) + .map(|bucket| bucket.into()) + .collect_vec() + }; + + BucketResult::Histogram { buckets } + } + } + } +} + +/// This is the default entry for a bucket, which contains a key, count, and optionally +/// sub_aggregations. +/// +/// # JSON Format +/// ```ignore +/// { +/// ... +/// "my_histogram": { +/// "buckets": [ +/// { +/// "key": "2.0", +/// "doc_count": 5 +/// }, +/// { +/// "key": "4.0", +/// "doc_count": 2 +/// }, +/// { +/// "key": "6.0", +/// "doc_count": 3 +/// } +/// ] +/// } +/// ... +/// } +/// ``` +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct BucketEntry { + /// The identifier of the bucket. + pub key: Key, + /// Number of documents in the bucket. + pub doc_count: u64, + #[serde(flatten)] + /// sub-aggregations in this bucket. + pub sub_aggregation: AggregationResults, +} + +impl From for BucketEntry { + fn from(entry: IntermediateHistogramBucketEntry) -> Self { + BucketEntry { + key: Key::F64(entry.key), + doc_count: entry.doc_count, + sub_aggregation: entry.sub_aggregation.into(), } } } @@ -114,7 +211,7 @@ impl From for BucketResult { /// sub_aggregations. /// /// # JSON Format -/// ```json +/// ```ignore /// { /// ... /// "my_ranges": { diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs new file mode 100644 index 0000000000..383021a852 --- /dev/null +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -0,0 +1,753 @@ +use serde::{Deserialize, Serialize}; + +use crate::aggregation::agg_req_with_accessor::{ + AggregationsWithAccessor, BucketAggregationWithAccessor, +}; +use crate::aggregation::f64_from_fastfield_u64; +use crate::aggregation::intermediate_agg_result::IntermediateBucketResult; +use crate::aggregation::segment_agg_result::{ + SegmentAggregationResultsCollector, SegmentHistogramBucketEntry, +}; +use crate::fastfield::{DynamicFastFieldReader, FastFieldReader}; +use crate::schema::Type; +use crate::{DocId, TantivyError}; + +/// Provide user-defined buckets to aggregate on. +/// Two special buckets will automatically be created to cover the whole range of values. +/// The provided buckets have to be continous. +/// During the aggregation, the values extracted from the fast_field `field` will be checked +/// against each bucket range. Note that this aggregation includes the from value and excludes the +/// to value for each range. +/// +/// Result type is [BucketResult](crate::aggregation::agg_result::BucketResult) with +/// [RangeBucketEntry](crate::aggregation::agg_result::RangeBucketEntry) on the +/// AggregationCollector. +/// +/// Result type is +/// [crate::aggregation::intermediate_agg_result::IntermediateBucketResult] with +/// [crate::aggregation::intermediate_agg_result::IntermediateRangeBucketEntry] on the +/// DistributedAggregationCollector. +/// +/// # Request JSON Format +/// ```ignore +/// { +/// "prices": { +/// "histogram": { +/// "field": "price", +/// "interval": 50, +/// "extended_bounds": { +/// "min": 0, +/// "max": 500 +/// } +/// } +/// } +/// } +/// ``` +#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] +pub struct HistogramAggregation { + /// The field to aggregate on. + pub field: String, + /// The interval must be a positive value. + pub interval: f64, + /// Intervals usually start at 0, offset can move the interval. + pub offset: Option, + /// The minimum number of documents in a bucket to be returned. + pub min_doc_count: Option, + /// hard bounds + pub hard_bounds: Option, + /// extended_bounds + pub extended_bounds: Option, +} + +impl HistogramAggregation { + fn validate(&self) -> crate::Result<()> { + if self.interval <= 0.0f64 { + return Err(TantivyError::InvalidArgument( + "interval must be a positive value".to_string(), + )); + } + + if self.min_doc_count.unwrap_or(0) > 0 && self.hard_bounds.is_some() { + return Err(TantivyError::InvalidArgument( + "Cannot set min_doc_count and hard_bounds at the same time".to_string(), + )); + } + if self.min_doc_count.unwrap_or(0) > 0 && self.extended_bounds.is_some() { + return Err(TantivyError::InvalidArgument( + "Cannot set min_doc_count and extended_bounds at the same time".to_string(), + )); + } + + Ok(()) + } + + /// Returns the minimum number of documents required for a bucket to be returned. + pub fn min_doc_count(&self) -> u64 { + self.min_doc_count.unwrap_or(0) + } +} + +/// Used to set extended or hard bounds on the histogram. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct HistogramBounds { + /// The lower bounds. + pub min: f64, + /// The upper bounds. + pub max: f64, +} + +/// The collector puts values from the fast field into the correct buckets and does a conversion to +/// the correct datatype. +#[derive(Clone, Debug, PartialEq)] +pub struct SegmentHistogramCollector { + /// The buckets containing the aggregation data. + buckets: Vec, + field_type: Type, + req: HistogramAggregation, + offset: f64, + first_bucket_num: i64, +} + +impl SegmentHistogramCollector { + pub fn into_intermediate_bucket_result(self) -> IntermediateBucketResult { + let buckets = self + .buckets + .into_iter() + .map(|bucket| bucket.into()) + .collect(); + + IntermediateBucketResult::Histogram { + buckets, + req: self.req, + } + } + + pub(crate) fn from_req_and_validate( + req: &HistogramAggregation, + sub_aggregation: &AggregationsWithAccessor, + field_type: Type, + accessor: &DynamicFastFieldReader, + ) -> crate::Result { + req.validate()?; + let min = f64_from_fastfield_u64(accessor.min_value(), &field_type); + let max = f64_from_fastfield_u64(accessor.max_value(), &field_type); + + let (min, max) = get_req_min_max(req, min, max); + + // We compute and generate the buckets range (min, max) based on the request and the min + // max in the fast field, but this is likely not ideal when this is a subbucket, where many + // unnecessary buckets may be generated. + let buckets = generate_buckets(req, min, max); + + let sub_aggregation = if sub_aggregation.is_empty() { + None + } else { + Some(SegmentAggregationResultsCollector::from_req_and_validate( + sub_aggregation, + )?) + }; + + let buckets = buckets + .iter() + .map(|bucket| SegmentHistogramBucketEntry { + key: *bucket, + doc_count: 0, + sub_aggregation: sub_aggregation.clone(), + }) + .collect(); + + let (min, _) = get_req_min_max(req, min, max); + + let first_bucket_num = + get_bucket_num_f64(min, req.interval, req.offset.unwrap_or(0.0)) as i64; + + Ok(Self { + buckets, + field_type, + req: req.clone(), + offset: req.offset.unwrap_or(0f64), + first_bucket_num, + }) + } + + #[inline] + pub(crate) fn collect_block( + &mut self, + doc: &[DocId], + bucket_with_accessor: &BucketAggregationWithAccessor, + force_flush: bool, + ) { + if let Some(bounds) = self.req.hard_bounds.clone() { + let interval = self.req.interval; + let offset = self.offset; + let first_bucket_num = self.first_bucket_num; + let is_in_bounds = move |val| val >= bounds.min && val <= bounds.max; + let get_bucket_num = |val| { + if is_in_bounds(val) { + (get_bucket_num_f64(val, interval, offset) as i64 - first_bucket_num) as usize + } else { + 0 + } + }; + let mut iter = doc.chunks_exact(4); + for docs in iter.by_ref() { + let val1 = f64_from_fastfield_u64( + bucket_with_accessor.accessor.get(docs[0]), + &self.field_type, + ); + let val2 = f64_from_fastfield_u64( + bucket_with_accessor.accessor.get(docs[1]), + &self.field_type, + ); + let val3 = f64_from_fastfield_u64( + bucket_with_accessor.accessor.get(docs[2]), + &self.field_type, + ); + let val4 = f64_from_fastfield_u64( + bucket_with_accessor.accessor.get(docs[3]), + &self.field_type, + ); + + let bucket_pos1 = get_bucket_num(val1); + let bucket_pos2 = get_bucket_num(val2); + let bucket_pos3 = get_bucket_num(val3); + let bucket_pos4 = get_bucket_num(val4); + + if is_in_bounds(val1) { + self.increment_bucket( + bucket_pos1, + docs[0], + &bucket_with_accessor.sub_aggregation, + ); + } + if is_in_bounds(val2) { + self.increment_bucket( + bucket_pos2, + docs[1], + &bucket_with_accessor.sub_aggregation, + ); + } + if is_in_bounds(val3) { + self.increment_bucket( + bucket_pos3, + docs[2], + &bucket_with_accessor.sub_aggregation, + ); + } + if is_in_bounds(val4) { + self.increment_bucket( + bucket_pos4, + docs[3], + &bucket_with_accessor.sub_aggregation, + ); + } + } + for doc in iter.remainder() { + let val = f64_from_fastfield_u64( + bucket_with_accessor.accessor.get(*doc), + &self.field_type, + ); + if !is_in_bounds(val) { + continue; + } + let bucket_pos = (get_bucket_num_f64(val, self.req.interval, self.offset) as i64 + - self.first_bucket_num) as usize; + + debug_assert_eq!( + self.buckets[bucket_pos].key, + get_bucket_val(val, self.req.interval, self.req.offset.unwrap_or(0.0)) as f64 + ); + self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation); + } + } else { + let mut iter = doc.chunks_exact(4); + for docs in iter.by_ref() { + let val1 = bucket_with_accessor.accessor.get(docs[0]); + let val2 = bucket_with_accessor.accessor.get(docs[1]); + let val3 = bucket_with_accessor.accessor.get(docs[2]); + let val4 = bucket_with_accessor.accessor.get(docs[3]); + let bucket_pos1 = + (self.get_bucket_num(val1) as i64 - self.first_bucket_num) as usize; + let bucket_pos2 = + (self.get_bucket_num(val2) as i64 - self.first_bucket_num) as usize; + let bucket_pos3 = + (self.get_bucket_num(val3) as i64 - self.first_bucket_num) as usize; + let bucket_pos4 = + (self.get_bucket_num(val4) as i64 - self.first_bucket_num) as usize; + + self.increment_bucket(bucket_pos1, docs[0], &bucket_with_accessor.sub_aggregation); + self.increment_bucket(bucket_pos2, docs[1], &bucket_with_accessor.sub_aggregation); + self.increment_bucket(bucket_pos3, docs[2], &bucket_with_accessor.sub_aggregation); + self.increment_bucket(bucket_pos4, docs[3], &bucket_with_accessor.sub_aggregation); + } + for doc in iter.remainder() { + let val = bucket_with_accessor.accessor.get(*doc); + let bucket_pos = (self.get_bucket_num(val) as i64 - self.first_bucket_num) as usize; + debug_assert_eq!( + self.buckets[bucket_pos].key, + get_bucket_val( + f64_from_fastfield_u64(val, &self.field_type), + self.req.interval, + self.req.offset.unwrap_or(0.0) + ) as f64 + ); + self.increment_bucket(bucket_pos, *doc, &bucket_with_accessor.sub_aggregation); + } + } + if force_flush { + for bucket in &mut self.buckets { + if let Some(sub_aggregation) = &mut bucket.sub_aggregation { + sub_aggregation + .flush_staged_docs(&bucket_with_accessor.sub_aggregation, force_flush); + } + } + } + } + + #[inline] + fn increment_bucket( + &mut self, + bucket_pos: usize, + doc: DocId, + bucket_with_accessor: &AggregationsWithAccessor, + ) { + let bucket = &mut self.buckets[bucket_pos]; + + bucket.doc_count += 1; + if let Some(sub_aggregation) = &mut bucket.sub_aggregation { + sub_aggregation.collect(doc, bucket_with_accessor); + } + } + + #[inline] + fn get_bucket_num(&self, val: u64) -> f64 { + let val = f64_from_fastfield_u64(val, &self.field_type); + get_bucket_num_f64(val, self.req.interval, self.offset) + } +} + +#[inline] +fn get_bucket_num_f64(val: f64, interval: f64, offset: f64) -> f64 { + ((val - offset) / interval).floor() +} + +#[inline] +fn get_bucket_val(val: f64, interval: f64, offset: f64) -> f64 { + let bucket_pos = get_bucket_num_f64(val, interval, offset); + bucket_pos * interval + offset +} + +fn get_req_min_max(req: &HistogramAggregation, mut min: f64, mut max: f64) -> (f64, f64) { + if let Some(extended_bounds) = &req.extended_bounds { + min = min.min(extended_bounds.min); + max = max.max(extended_bounds.max); + } + if let Some(hard_bounds) = &req.hard_bounds { + min = hard_bounds.min; + max = hard_bounds.max; + } + + (min, max) +} + +/// Generates buckets with req.interval, for given min, max +pub(crate) fn generate_buckets(req: &HistogramAggregation, min: f64, max: f64) -> Vec { + let (min, max) = get_req_min_max(req, min, max); + + let offset = req.offset.unwrap_or(0.0); + let first_bucket_num = get_bucket_num_f64(min, req.interval, offset) as i64; + let last_bucket_num = get_bucket_num_f64(max, req.interval, offset) as i64; + let mut buckets = vec![]; + for bucket_pos in first_bucket_num..=last_bucket_num { + let bucket_key = bucket_pos as f64 * req.interval + offset; + buckets.push(bucket_key); + } + + buckets +} + +#[test] +fn generate_buckets_test() { + let histogram_req = HistogramAggregation { + field: "dummy".to_string(), + interval: 2.0, + ..Default::default() + }; + + let buckets = generate_buckets(&histogram_req, 0.0, 10.0); + assert_eq!(buckets, vec![0.0, 2.0, 4.0, 6.0, 8.0, 10.0]); + + let buckets = generate_buckets(&histogram_req, 2.5, 5.5); + assert_eq!(buckets, vec![2.0, 4.0]); + + // Single bucket + let buckets = generate_buckets(&histogram_req, 0.5, 0.75); + assert_eq!(buckets, vec![0.0]); + + // With offset + let histogram_req = HistogramAggregation { + field: "dummy".to_string(), + interval: 2.0, + offset: Some(0.5), + ..Default::default() + }; + + let buckets = generate_buckets(&histogram_req, 0.0, 10.0); + assert_eq!(buckets, vec![-1.5, 0.5, 2.5, 4.5, 6.5, 8.5]); + + let buckets = generate_buckets(&histogram_req, 2.5, 5.5); + assert_eq!(buckets, vec![2.5, 4.5]); + + // Single bucket + let buckets = generate_buckets(&histogram_req, 0.5, 0.75); + assert_eq!(buckets, vec![0.5]); + + // With extended_bounds + let histogram_req = HistogramAggregation { + field: "dummy".to_string(), + interval: 2.0, + extended_bounds: Some(HistogramBounds { + min: 0.0, + max: 10.0, + }), + ..Default::default() + }; + + let buckets = generate_buckets(&histogram_req, 0.0, 10.0); + assert_eq!(buckets, vec![0.0, 2.0, 4.0, 6.0, 8.0, 10.0]); + + let buckets = generate_buckets(&histogram_req, 2.5, 5.5); + assert_eq!(buckets, vec![0.0, 2.0, 4.0, 6.0, 8.0, 10.0]); + + // Single bucket, but extended_bounds + let buckets = generate_buckets(&histogram_req, 0.5, 0.75); + assert_eq!(buckets, vec![0.0, 2.0, 4.0, 6.0, 8.0, 10.0]); + + // With invalid extended_bounds + let histogram_req = HistogramAggregation { + field: "dummy".to_string(), + interval: 2.0, + extended_bounds: Some(HistogramBounds { min: 3.0, max: 5.0 }), + ..Default::default() + }; + + let buckets = generate_buckets(&histogram_req, 0.0, 10.0); + assert_eq!(buckets, vec![0.0, 2.0, 4.0, 6.0, 8.0, 10.0]); + + // With hard_bounds reducing + let histogram_req = HistogramAggregation { + field: "dummy".to_string(), + interval: 2.0, + hard_bounds: Some(HistogramBounds { min: 3.0, max: 5.0 }), + ..Default::default() + }; + + let buckets = generate_buckets(&histogram_req, 0.0, 10.0); + assert_eq!(buckets, vec![2.0, 4.0]); + + // With hard_bounds extending + let histogram_req = HistogramAggregation { + field: "dummy".to_string(), + interval: 2.0, + hard_bounds: Some(HistogramBounds { + min: 0.0, + max: 10.0, + }), + ..Default::default() + }; + + let buckets = generate_buckets(&histogram_req, 2.5, 5.5); + assert_eq!(buckets, vec![0.0, 2.0, 4.0, 6.0, 8.0, 10.0]); + + // Blubber + let histogram_req = HistogramAggregation { + field: "dummy".to_string(), + interval: 2.0, + ..Default::default() + }; + + let buckets = generate_buckets(&histogram_req, 4.0, 10.0); + assert_eq!(buckets, vec![4.0, 6.0, 8.0, 10.0]); +} + +#[cfg(test)] +mod tests { + + use serde_json::Value; + + use super::*; + use crate::aggregation::agg_req::{ + Aggregation, Aggregations, BucketAggregation, BucketAggregationType, + }; + use crate::aggregation::tests::{get_test_index_from_values, get_test_index_with_num_docs}; + use crate::aggregation::AggregationCollector; + use crate::query::AllQuery; + use crate::Index; + + fn exec_request(agg_req: Aggregations, index: &Index) -> crate::Result { + let collector = AggregationCollector::from_aggs(agg_req); + + let reader = index.reader()?; + let searcher = reader.searcher(); + let agg_res = searcher.search(&AllQuery, &collector).unwrap(); + + let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; + Ok(res) + } + + #[test] + fn histogram_test_crooked_values() -> crate::Result<()> { + let values = vec![-12.0, 12.31, 14.33, 16.23]; + + let index = get_test_index_from_values(false, &values)?; + + let agg_req: Aggregations = vec![( + "my_interval".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 3.5, + offset: Some(0.0), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + + assert_eq!(res["my_interval"]["buckets"][0]["key"], -14.0); + assert_eq!(res["my_interval"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["my_interval"]["buckets"][7]["key"], 10.5); + assert_eq!(res["my_interval"]["buckets"][7]["doc_count"], 1); + assert_eq!(res["my_interval"]["buckets"][8]["key"], 14.0); + assert_eq!(res["my_interval"]["buckets"][8]["doc_count"], 2); + assert_eq!(res["my_interval"]["buckets"][9], Value::Null); + + // With offset + let agg_req: Aggregations = vec![( + "my_interval".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 3.5, + offset: Some(1.2), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + + assert_eq!(res["my_interval"]["buckets"][0]["key"], -12.8); + assert_eq!(res["my_interval"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["my_interval"]["buckets"][1]["key"], -9.3); + assert_eq!(res["my_interval"]["buckets"][1]["doc_count"], 0); + assert_eq!(res["my_interval"]["buckets"][2]["key"], -5.8); + assert_eq!(res["my_interval"]["buckets"][2]["doc_count"], 0); + assert_eq!(res["my_interval"]["buckets"][3]["key"], -2.3); + assert_eq!(res["my_interval"]["buckets"][3]["doc_count"], 0); + + assert_eq!(res["my_interval"]["buckets"][7]["key"], 11.7); + assert_eq!(res["my_interval"]["buckets"][7]["doc_count"], 2); + assert_eq!(res["my_interval"]["buckets"][8]["key"], 15.2); + assert_eq!(res["my_interval"]["buckets"][8]["doc_count"], 1); + assert_eq!(res["my_interval"]["buckets"][9], Value::Null); + + Ok(()) + } + + #[test] + fn histogram_test_min_value_positive_force_merge_segments() -> crate::Result<()> { + histogram_test_min_value_positive_merge_segments(true) + } + + #[test] + fn histogram_test_min_value_positive() -> crate::Result<()> { + histogram_test_min_value_positive_merge_segments(false) + } + fn histogram_test_min_value_positive_merge_segments(merge_segments: bool) -> crate::Result<()> { + let values = vec![10.0, 12.0, 14.0, 16.23]; + + let index = get_test_index_from_values(merge_segments, &values)?; + + let agg_req: Aggregations = vec![( + "my_interval".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 1.0, + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + + assert_eq!(res["my_interval"]["buckets"][0]["key"], 10.0); + assert_eq!(res["my_interval"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["my_interval"]["buckets"][1]["key"], 11.0); + assert_eq!(res["my_interval"]["buckets"][1]["doc_count"], 0); + assert_eq!(res["my_interval"]["buckets"][2]["key"], 12.0); + assert_eq!(res["my_interval"]["buckets"][2]["doc_count"], 1); + assert_eq!(res["my_interval"]["buckets"][3]["key"], 13.0); + assert_eq!(res["my_interval"]["buckets"][3]["doc_count"], 0); + assert_eq!(res["my_interval"]["buckets"][6]["key"], 16.0); + assert_eq!(res["my_interval"]["buckets"][6]["doc_count"], 1); + assert_eq!(res["my_interval"]["buckets"][7], Value::Null); + + Ok(()) + } + + #[test] + fn histogram_simple_test() -> crate::Result<()> { + let index = get_test_index_with_num_docs(false, 100)?; + + let agg_req: Aggregations = vec![( + "histogram".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 1.0, + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + + assert_eq!(res["histogram"]["buckets"][0]["key"], 0.0); + assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 1); + assert_eq!(res["histogram"]["buckets"][1]["key"], 1.0); + assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 1); + assert_eq!(res["histogram"]["buckets"][99]["key"], 99.0); + assert_eq!(res["histogram"]["buckets"][99]["doc_count"], 1); + assert_eq!(res["histogram"]["buckets"][100], Value::Null); + Ok(()) + } + + #[test] + fn histogram_merge_test() -> crate::Result<()> { + // Merge buckets counts from different segments + let values = vec![10.0, 12.0, 14.0, 16.23, 10.0, 13.0, 10.0, 12.0]; + + let index = get_test_index_from_values(false, &values)?; + + let agg_req: Aggregations = vec![( + "histogram".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 1.0, + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + + assert_eq!(res["histogram"]["buckets"][0]["key"], 10.0); + assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 3); + assert_eq!(res["histogram"]["buckets"][1]["key"], 11.0); + assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][2]["key"], 12.0); + assert_eq!(res["histogram"]["buckets"][2]["doc_count"], 2); + assert_eq!(res["histogram"]["buckets"][3]["key"], 13.0); + assert_eq!(res["histogram"]["buckets"][3]["doc_count"], 1); + + Ok(()) + } + #[test] + fn histogram_min_doc_test_multi_segments() -> crate::Result<()> { + histogram_min_doc_test_with_opt(false) + } + #[test] + fn histogram_min_doc_test_single_segments() -> crate::Result<()> { + histogram_min_doc_test_with_opt(true) + } + fn histogram_min_doc_test_with_opt(merge_segments: bool) -> crate::Result<()> { + let values = vec![10.0, 12.0, 14.0, 16.23, 10.0, 13.0, 10.0, 12.0]; + + let index = get_test_index_from_values(merge_segments, &values)?; + + let agg_req: Aggregations = vec![( + "histogram".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 1.0, + min_doc_count: Some(2), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + + assert_eq!(res["histogram"]["buckets"][0]["key"], 10.0); + assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 3); + assert_eq!(res["histogram"]["buckets"][1]["key"], 12.0); + assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 2); + assert_eq!(res["histogram"]["buckets"][2], Value::Null); + + Ok(()) + } + #[test] + fn histogram_hard_bounds_test_multi_segment() -> crate::Result<()> { + histogram_hard_bounds_test_with_opt(false) + } + #[test] + fn histogram_hard_bounds_test_single_segment() -> crate::Result<()> { + histogram_hard_bounds_test_with_opt(true) + } + fn histogram_hard_bounds_test_with_opt(merge_segments: bool) -> crate::Result<()> { + let values = vec![10.0, 12.0, 14.0, 16.23, 10.0, 13.0, 10.0, 12.0]; + + let index = get_test_index_from_values(merge_segments, &values)?; + + let agg_req: Aggregations = vec![( + "histogram".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 1.0, + hard_bounds: Some(HistogramBounds { + min: 2.0, + max: 12.0, + }), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let res = exec_request(agg_req, &index)?; + + assert_eq!(res["histogram"]["buckets"][0]["key"], 2.0); + assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 0); + assert_eq!(res["histogram"]["buckets"][10]["key"], 12.0); + assert_eq!(res["histogram"]["buckets"][10]["doc_count"], 2); + assert_eq!(res["histogram"]["buckets"][11], Value::Null); + + Ok(()) + } +} diff --git a/src/aggregation/bucket/histogram/mod.rs b/src/aggregation/bucket/histogram/mod.rs new file mode 100644 index 0000000000..77b526147c --- /dev/null +++ b/src/aggregation/bucket/histogram/mod.rs @@ -0,0 +1,2 @@ +mod histogram; +pub use histogram::*; diff --git a/src/aggregation/bucket/mod.rs b/src/aggregation/bucket/mod.rs index e69d95be9f..0a9991fce7 100644 --- a/src/aggregation/bucket/mod.rs +++ b/src/aggregation/bucket/mod.rs @@ -7,7 +7,10 @@ //! Results of intermediate buckets are //! [IntermediateBucketResult](super::intermediate_agg_result::IntermediateBucketResult) +mod histogram; mod range; +pub(crate) use histogram::SegmentHistogramCollector; +pub use histogram::*; pub(crate) use range::SegmentRangeCollector; pub use range::*; diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 4650221717..39631ffee2 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -31,7 +31,7 @@ use crate::{DocId, TantivyError}; /// DistributedAggregationCollector. /// /// # Request JSON Format -/// ```json +/// ```ignore /// { /// "range": { /// "field": "score", diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 7a205fb698..bc18216f77 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -2,16 +2,19 @@ //! Intermediate aggregation results can be used to merge results between segments or between //! indices. -use std::collections::HashMap; +use std::cmp::Ordering; +use fnv::FnvHashMap; +use itertools::Itertools; use serde::{Deserialize, Serialize}; +use super::bucket::HistogramAggregation; use super::metric::{IntermediateAverage, IntermediateStats}; use super::segment_agg_result::{ - SegmentAggregationResultsCollector, SegmentBucketResultCollector, SegmentMetricResultCollector, - SegmentRangeBucketEntry, + SegmentAggregationResultsCollector, SegmentBucketResultCollector, SegmentHistogramBucketEntry, + SegmentMetricResultCollector, SegmentRangeBucketEntry, }; -use super::{Key, SerializedKey, VecWithNames}; +use super::{Key, MergeFruits, SerializedKey, VecWithNames}; /// Contains the intermediate aggregation result, which is optimized to be merged with other /// intermediate results. @@ -124,13 +127,25 @@ impl IntermediateMetricResult { pub enum IntermediateBucketResult { /// This is the range entry for a bucket, which contains a key, count, from, to, and optionally /// sub_aggregations. - Range(HashMap), + Range(FnvHashMap), + /// This is the histogram entry for a bucket, which contains a key, count, and optionally + /// sub_aggregations. + Histogram { + /// The buckets + buckets: Vec, + /// The original request. It is used to compute the total range after merging segments and + /// get min_doc_count after merging all segment results. + req: HistogramAggregation, + }, } impl From for IntermediateBucketResult { fn from(collector: SegmentBucketResultCollector) -> Self { match collector { SegmentBucketResultCollector::Range(range) => range.into_intermediate_bucket_result(), + SegmentBucketResultCollector::Histogram(histogram) => { + histogram.into_intermediate_bucket_result() + } } } } @@ -142,22 +157,96 @@ impl IntermediateBucketResult { IntermediateBucketResult::Range(entries_left), IntermediateBucketResult::Range(entries_right), ) => { - for (name, entry_left) in entries_left.iter_mut() { - if let Some(entry_right) = entries_right.get(name) { - entry_left.merge_fruits(entry_right); - } - } - - for (key, res) in entries_right.iter() { - if !entries_left.contains_key(key) { - entries_left.insert(key.clone(), res.clone()); - } - } + merge_maps(entries_left, entries_right); + } + ( + IntermediateBucketResult::Histogram { + buckets: entries_left, + .. + }, + IntermediateBucketResult::Histogram { + buckets: entries_right, + .. + }, + ) => { + let mut buckets = entries_left + .drain(..) + .merge_join_by(entries_right.iter(), |left, right| { + left.key.partial_cmp(&right.key).unwrap_or(Ordering::Equal) + }) + .map(|either| match either { + itertools::EitherOrBoth::Both(mut left, right) => { + left.merge_fruits(right); + left + } + itertools::EitherOrBoth::Left(left) => left, + itertools::EitherOrBoth::Right(right) => right.clone(), + }) + .collect(); + + std::mem::swap(entries_left, &mut buckets); + } + (IntermediateBucketResult::Range(_), _) => { + panic!("try merge on different types") + } + (IntermediateBucketResult::Histogram { .. }, _) => { + panic!("try merge on different types") } } } } +// fn merge_sorted_vecs(entries_left: &mut Vec, entries_right: &Vec) { +// for el in entries_left +//.iter_mut() +//.merge_join_by(entries_right.iter(), |left, right| left.key.cmp(right.key)) +//{} +//} + +fn merge_maps( + entries_left: &mut FnvHashMap, + entries_right: &FnvHashMap, +) { + for (name, entry_left) in entries_left.iter_mut() { + if let Some(entry_right) = entries_right.get(name) { + entry_left.merge_fruits(entry_right); + } + } + + for (key, res) in entries_right.iter() { + if !entries_left.contains_key(key) { + entries_left.insert(key.clone(), res.clone()); + } + } +} + +/// This is the histogram entry for a bucket, which contains a key, count, and optionally +/// sub_aggregations. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct IntermediateHistogramBucketEntry { + /// The unique the bucket is identified. + pub key: f64, + /// The number of documents in the bucket. + pub doc_count: u64, + /// The sub_aggregation in this bucket. + pub sub_aggregation: IntermediateAggregationResults, +} + +impl From for IntermediateHistogramBucketEntry { + fn from(entry: SegmentHistogramBucketEntry) -> Self { + let sub_aggregation = if let Some(sub_aggregation) = entry.sub_aggregation { + sub_aggregation.into() + } else { + Default::default() + }; + IntermediateHistogramBucketEntry { + key: entry.key, + doc_count: entry.doc_count, + sub_aggregation, + } + } +} + /// This is the range entry for a bucket, which contains a key, count, and optionally /// sub_aggregations. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -184,7 +273,6 @@ impl From for IntermediateRangeBucketEntry { } else { Default::default() }; - // let sub_aggregation = entry.sub_aggregation.into(); IntermediateRangeBucketEntry { key: entry.key, @@ -197,22 +285,31 @@ impl From for IntermediateRangeBucketEntry { } } -impl IntermediateRangeBucketEntry { +impl MergeFruits for IntermediateRangeBucketEntry { fn merge_fruits(&mut self, other: &IntermediateRangeBucketEntry) { self.doc_count += other.doc_count; self.sub_aggregation.merge_fruits(&other.sub_aggregation); } } +impl MergeFruits for IntermediateHistogramBucketEntry { + fn merge_fruits(&mut self, other: &IntermediateHistogramBucketEntry) { + self.doc_count += other.doc_count; + self.sub_aggregation.merge_fruits(&other.sub_aggregation); + } +} + #[cfg(test)] mod tests { + use std::collections::HashMap; + use pretty_assertions::assert_eq; use super::*; fn get_sub_test_tree(data: &[(String, u64)]) -> IntermediateAggregationResults { let mut map = HashMap::new(); - let mut buckets = HashMap::new(); + let mut buckets = FnvHashMap::default(); for (key, doc_count) in data { buckets.insert( key.to_string(), @@ -235,7 +332,7 @@ mod tests { fn get_test_tree(data: &[(String, u64, String, u64)]) -> IntermediateAggregationResults { let mut map = HashMap::new(); - let mut buckets = HashMap::new(); + let mut buckets: FnvHashMap<_, _> = Default::default(); for (key, doc_count, sub_aggregation_key, sub_aggregation_count) in data { buckets.insert( key.to_string(), diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index d4f77fa946..d7d838ee32 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -239,6 +239,10 @@ pub enum Key { F64(f64), } +trait MergeFruits { + fn merge_fruits(&mut self, other: &Self); +} + impl Display for Key { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -310,6 +314,16 @@ mod tests { pub fn get_test_index_with_num_docs( merge_segments: bool, num_docs: usize, + ) -> crate::Result { + get_test_index_from_values( + merge_segments, + &(0..num_docs).map(|el| el as f64).collect::>(), + ) + } + + pub fn get_test_index_from_values( + merge_segments: bool, + values: &[f64], ) -> crate::Result { let mut schema_builder = Schema::builder(); let text_fieldtype = crate::schema::TextOptions::default() @@ -332,18 +346,17 @@ mod tests { let index = Index::create_in_ram(schema_builder.build()); { let mut index_writer = index.writer_for_tests()?; - for i in 0..num_docs { + for i in values { // writing the segment index_writer.add_document(doc!( text_field => "cool", - score_field => i as u64, - score_field_f64 => i as f64, - score_field_i64 => i as i64, - fraction_field => i as f64/100.0, + score_field => *i as u64, + score_field_f64 => *i as f64, + score_field_i64 => *i as i64, + fraction_field => *i as f64/100.0, ))?; + index_writer.commit()?; } - - index_writer.commit()?; } if merge_segments { let segment_ids = index @@ -385,27 +398,42 @@ mod tests { // A second bucket on the first level should have the cache unfilled // let elasticsearch_compatible_json_req = r#" - let elasticsearch_compatible_json_req = r#" + let elasticsearch_compatible_json = json!( { "bucketsL1": { "range": { "field": "score", - "ranges": [ { "to": 3.0 }, { "from": 3.0, "to": 266.0 }, { "from": 266.0 } ] + "ranges": [ { "to": 3.0f64 }, { "from": 3.0f64, "to": 266.0f64 }, { "from": 266.0f64 } ] }, "aggs": { "bucketsL2": { "range": { "field": "score", - "ranges": [ { "to": 100.0 }, { "from": 100.0, "to": 266.0 }, { "from": 266.0 } ] + "ranges": [ { "to": 100.0f64 }, { "from": 100.0f64, "to": 266.0f64 }, { "from": 266.0f64 } ] + } + } + } + }, + "histogram_test":{ + "histogram": { + "field": "score", + "interval": 263.0, + "offset": 3.0, + }, + "aggs": { + "bucketsL2": { + "histogram": { + "field": "score", + "interval": 263.0 } } } } - } - "#; + }); let agg_req: Aggregations = - serde_json::from_str(elasticsearch_compatible_json_req).unwrap(); + serde_json::from_str(&serde_json::to_string(&elasticsearch_compatible_json).unwrap()) + .unwrap(); let agg_res: AggregationResults = if use_distributed_collector { let collector = DistributedAggregationCollector::from_aggs(agg_req); @@ -950,6 +978,8 @@ mod tests { use test::{self, Bencher}; use super::*; + use crate::aggregation::bucket::HistogramAggregation; + use crate::aggregation::bucket::HistogramBounds; use crate::aggregation::metric::StatsAggregation; use crate::query::AllQuery; @@ -1165,6 +1195,71 @@ mod tests { }); } + // hard bounds has a different algorithm, because it actually limits collection range + #[bench] + fn bench_aggregation_histogram_only_hard_bounds(b: &mut Bencher) { + let index = get_test_index_bench(false).unwrap(); + let reader = index.reader().unwrap(); + + b.iter(|| { + let agg_req_1: Aggregations = vec![( + "rangef64".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 100f64, // 1000 buckets + hard_bounds: Some(HistogramBounds { + min: 1000.0, + max: 1500_000.0, + }), + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = + searcher.search(&AllQuery, &collector).unwrap().into(); + + agg_res + }); + } + + #[bench] + fn bench_aggregation_histogram_only(b: &mut Bencher) { + let index = get_test_index_bench(false).unwrap(); + let reader = index.reader().unwrap(); + + b.iter(|| { + let agg_req_1: Aggregations = vec![( + "rangef64".to_string(), + Aggregation::Bucket(BucketAggregation { + bucket_agg: BucketAggregationType::Histogram(HistogramAggregation { + field: "score_f64".to_string(), + interval: 100f64, // 1000 buckets + ..Default::default() + }), + sub_aggregation: Default::default(), + }), + )] + .into_iter() + .collect(); + + let collector = AggregationCollector::from_aggs(agg_req_1); + + let searcher = reader.searcher(); + let agg_res: AggregationResults = + searcher.search(&AllQuery, &collector).unwrap().into(); + + agg_res + }); + } + #[bench] fn bench_aggregation_sub_tree(b: &mut Bencher) { let index = get_test_index_bench(false).unwrap(); diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 3b17bfdbd9..56aafe3fc5 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -9,7 +9,7 @@ use super::agg_req::MetricAggregation; use super::agg_req_with_accessor::{ AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor, }; -use super::bucket::SegmentRangeCollector; +use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector}; use super::metric::{ AverageAggregation, SegmentAverageCollector, SegmentStatsCollector, StatsAggregation, }; @@ -151,6 +151,7 @@ impl SegmentMetricResultCollector { #[derive(Clone, Debug, PartialEq)] pub(crate) enum SegmentBucketResultCollector { Range(SegmentRangeCollector), + Histogram(SegmentHistogramCollector), } impl SegmentBucketResultCollector { @@ -163,6 +164,14 @@ impl SegmentBucketResultCollector { req.field_type, )?)) } + BucketAggregationType::Histogram(histogram) => Ok(Self::Histogram( + SegmentHistogramCollector::from_req_and_validate( + histogram, + &req.sub_aggregation, + req.field_type, + &req.accessor, + )?, + )), } } @@ -177,10 +186,20 @@ impl SegmentBucketResultCollector { SegmentBucketResultCollector::Range(range) => { range.collect_block(doc, bucket_with_accessor, force_flush); } + SegmentBucketResultCollector::Histogram(histogram) => { + histogram.collect_block(doc, bucket_with_accessor, force_flush) + } } } } +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct SegmentHistogramBucketEntry { + pub key: f64, + pub doc_count: u64, + pub sub_aggregation: Option, +} + #[derive(Clone, PartialEq)] pub(crate) struct SegmentRangeBucketEntry { pub key: Key,