Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for keyed parameter in range and histgram aggregations #1424

Merged
1 change: 1 addition & 0 deletions examples/aggregation.rs
Expand Up @@ -110,6 +110,7 @@ fn main() -> tantivy::Result<()> {
(9f64..14f64).into(),
(14f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: sub_agg_req_1.clone(),
}),
Expand Down
16 changes: 14 additions & 2 deletions src/aggregation/agg_req.rs
Expand Up @@ -20,6 +20,7 @@
//! bucket_agg: BucketAggregationType::Range(RangeAggregation{
//! field: "score".to_string(),
//! ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
//! keyed: false,
//! }),
//! sub_aggregation: Default::default(),
//! }),
Expand All @@ -36,7 +37,8 @@
//! "ranges": [
//! { "from": 3.0, "to": 7.0 },
//! { "from": 7.0, "to": 20.0 }
//! ]
//! ],
//! "keyed": false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the default, I think we can just omit it

Copy link
Contributor Author

@k-yomo k-yomo Jul 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It didn't pass the tests...
https://github.com/quickwit-oss/tantivy/runs/7517376126?check_suite_focus=true

Ah I see, I removed serde(default) by mistake, let me fix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 6444516

//! }
//! }
//! }"#;
Expand Down Expand Up @@ -100,6 +102,12 @@ pub(crate) struct BucketAggregationInternal {
}

impl BucketAggregationInternal {
pub(crate) fn as_range(&self) -> Option<&RangeAggregation> {
match &self.bucket_agg {
BucketAggregationType::Range(range) => Some(range),
_ => None,
}
}
pub(crate) fn as_histogram(&self) -> Option<&HistogramAggregation> {
match &self.bucket_agg {
BucketAggregationType::Histogram(histogram) => Some(histogram),
Expand Down Expand Up @@ -264,6 +272,7 @@ mod tests {
(7f64..20f64).into(),
(20f64..f64::MAX).into(),
],
keyed: true,
}),
sub_aggregation: Default::default(),
}),
Expand All @@ -290,7 +299,8 @@ mod tests {
{
"from": 20.0
}
]
],
"keyed": true
}
}
}"#;
Expand All @@ -312,6 +322,7 @@ mod tests {
(7f64..20f64).into(),
(20f64..f64::MAX).into(),
],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Expand All @@ -337,6 +348,7 @@ mod tests {
(7f64..20f64).into(),
(20f64..f64::MAX).into(),
],
..Default::default()
}),
sub_aggregation: agg_req2,
}),
Expand Down
3 changes: 1 addition & 2 deletions src/aggregation/agg_req_with_accessor.rs
Expand Up @@ -77,8 +77,7 @@ impl BucketAggregationWithAccessor {
let mut inverted_index = None;
let (accessor, field_type) = match &bucket {
BucketAggregationType::Range(RangeAggregation {
field: field_name,
ranges: _,
field: field_name, ..
}) => get_ff_reader_and_validate(reader, field_name, Cardinality::SingleValue)?,
BucketAggregationType::Histogram(HistogramAggregation {
field: field_name, ..
Expand Down
16 changes: 14 additions & 2 deletions src/aggregation/agg_result.rs
Expand Up @@ -6,6 +6,7 @@

use std::collections::HashMap;

use fnv::FnvHashMap;
use serde::{Deserialize, Serialize};

use super::agg_req::BucketAggregationInternal;
Expand Down Expand Up @@ -104,7 +105,7 @@ pub enum BucketResult {
/// sub_aggregations.
Range {
/// The range buckets sorted by range.
buckets: Vec<RangeBucketEntry>,
buckets: BucketEntries<RangeBucketEntry>,
},
/// This is the histogram entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
Expand All @@ -114,7 +115,7 @@ pub enum BucketResult {
/// If there are holes depends on the request, if min_doc_count is 0, then there are no
/// holes between the first and last bucket.
/// See [HistogramAggregation](super::bucket::HistogramAggregation)
buckets: Vec<BucketEntry>,
buckets: BucketEntries<BucketEntry>,
},
/// This is the term result
Terms {
Expand All @@ -137,6 +138,17 @@ impl BucketResult {
}
}

/// This is the wrapper of buckets entries, which can be vector or hashmap
/// depending on if it's keyed or not.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum BucketEntries<T> {
/// Vector format bucket entries
Vec(Vec<T>),
/// HashMap format bucket entries
HashMap(FnvHashMap<String, T>),
}

/// This is the default entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
///
Expand Down
46 changes: 44 additions & 2 deletions src/aggregation/bucket/histogram/histogram.rs
Expand Up @@ -48,8 +48,6 @@ use crate::{DocId, TantivyError};
///
/// # Limitations/Compatibility
///
/// The keyed parameter (elasticsearch) is not yet supported.
///
/// # JSON Format
/// ```json
/// {
Expand Down Expand Up @@ -117,6 +115,8 @@ pub struct HistogramAggregation {
/// Cannot be set in conjunction with min_doc_count > 0, since the empty buckets from extended
/// bounds would not be returned.
pub extended_bounds: Option<HistogramBounds>,
/// Whether to return the buckets as a hash map
pub keyed: bool,
}

impl HistogramAggregation {
Expand Down Expand Up @@ -1395,4 +1395,46 @@ mod tests {

Ok(())
}

#[test]
fn histogram_keyed_buckets_test() -> crate::Result<()> {
let index = get_test_index_with_num_docs(false, 100)?;

let agg_req: Aggregations = vec![(
"histogram".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Histogram(HistogramAggregation {
field: "score_f64".to_string(),
interval: 50.0,
keyed: true,
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();

let res = exec_request(agg_req, &index)?;

assert_eq!(
res,
json!({
"histogram": {
"buckets": {
"0": {
"key": 0.0,
"doc_count": 50
},
"50": {
"key": 50.0,
"doc_count": 50
}
}
}
})
);

Ok(())
}
}
51 changes: 48 additions & 3 deletions src/aggregation/bucket/range.rs
Expand Up @@ -35,8 +35,6 @@ use crate::{DocId, TantivyError};
/// # Limitations/Compatibility
/// Overlapping ranges are not yet supported.
///
/// The keyed parameter (elasticsearch) is not yet supported.
///
/// # Request JSON Format
/// ```json
/// {
Expand All @@ -51,13 +49,15 @@ use crate::{DocId, TantivyError};
/// }
/// }
/// ```
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct RangeAggregation {
/// The field to aggregate on.
pub field: String,
/// Note that this aggregation includes the from value and excludes the to value for each
/// range. Extra buckets will be created until the first to, and last from, if necessary.
pub ranges: Vec<RangeAggregationRange>,
/// Whether to return the buckets as a hash map
pub keyed: bool,
}

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
Expand Down Expand Up @@ -406,6 +406,7 @@ mod tests {
let req = RangeAggregation {
field: "dummy".to_string(),
ranges,
..Default::default()
};

SegmentRangeCollector::from_req_and_validate(
Expand All @@ -427,6 +428,7 @@ mod tests {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
..Default::default()
}),
sub_aggregation: Default::default(),
}),
Expand Down Expand Up @@ -454,6 +456,49 @@ mod tests {
Ok(())
}

#[test]
fn range_keyed_buckets_test() -> crate::Result<()> {
let index = get_test_index_with_num_docs(false, 100)?;

let agg_req: Aggregations = vec![(
"range".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Range(RangeAggregation {
field: "fraction_f64".to_string(),
ranges: vec![(0f64..0.1f64).into(), (0.1f64..0.2f64).into()],
keyed: true,
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();

let collector = AggregationCollector::from_aggs(agg_req, None);

let reader = index.reader()?;
let searcher = reader.searcher();
let agg_res = searcher.search(&AllQuery, &collector).unwrap();

let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?;

assert_eq!(
res,
json!({
"range": {
"buckets": {
"*-0": { "key": "*-0", "doc_count": 0, "to": 0.0},
"0-0.1": {"key": "0-0.1", "doc_count": 10, "from": 0.0, "to": 0.1},
"0.1-0.2": {"key": "0.1-0.2", "doc_count": 10, "from": 0.1, "to": 0.2},
"0.2-*": {"key": "0.2-*", "doc_count": 80, "from": 0.2},
}
}
})
);

Ok(())
}

#[test]
fn bucket_test_extend_range_hole() {
let buckets = vec![(10f64..20f64).into(), (30f64..40f64).into()];
Expand Down
27 changes: 26 additions & 1 deletion src/aggregation/intermediate_agg_result.rs
Expand Up @@ -21,7 +21,7 @@ use super::bucket::{
use super::metric::{IntermediateAverage, IntermediateStats};
use super::segment_agg_result::SegmentMetricResultCollector;
use super::{Key, SerializedKey, VecWithNames};
use crate::aggregation::agg_result::{AggregationResults, BucketEntry};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
use crate::aggregation::bucket::TermsAggregationInternal;

/// Contains the intermediate aggregation result, which is optimized to be merged with other
Expand Down Expand Up @@ -281,6 +281,21 @@ impl IntermediateBucketResult {
.unwrap_or(f64::MIN)
.total_cmp(&right.from.unwrap_or(f64::MIN))
});

let is_keyed = req
.as_range()
.expect("unexpected aggregation, expected range aggregation")
.keyed;
let buckets = if is_keyed {
let mut bucket_map =
FnvHashMap::with_capacity_and_hasher(buckets.len(), Default::default());
for bucket in buckets {
bucket_map.insert(bucket.key.to_string(), bucket);
}
BucketEntries::HashMap(bucket_map)
} else {
BucketEntries::Vec(buckets)
};
Ok(BucketResult::Range { buckets })
}
IntermediateBucketResult::Histogram { buckets } => {
Expand All @@ -291,6 +306,16 @@ impl IntermediateBucketResult {
&req.sub_aggregation,
)?;

let buckets = if req.as_histogram().unwrap().keyed {
let mut bucket_map =
FnvHashMap::with_capacity_and_hasher(buckets.len(), Default::default());
for bucket in buckets {
bucket_map.insert(bucket.key.to_string(), bucket);
}
BucketEntries::HashMap(bucket_map)
} else {
BucketEntries::Vec(buckets)
};
Ok(BucketResult::Histogram { buckets })
}
IntermediateBucketResult::Terms(terms) => terms.into_final_result(
Expand Down
1 change: 1 addition & 0 deletions src/aggregation/metric/stats.rs
Expand Up @@ -285,6 +285,7 @@ mod tests {
(7f64..19f64).into(),
(19f64..20f64).into(),
],
..Default::default()
}),
sub_aggregation: iter::once((
"stats".to_string(),
Expand Down