diff --git a/examples/aggregation.rs b/examples/aggregation.rs index fb0d131c17..fbe412e8e1 100644 --- a/examples/aggregation.rs +++ b/examples/aggregation.rs @@ -118,7 +118,7 @@ fn main() -> tantivy::Result<()> { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); diff --git a/src/aggregation/agg_result.rs b/src/aggregation/agg_result.rs index 30a884be37..4f71e9d7a5 100644 --- a/src/aggregation/agg_result.rs +++ b/src/aggregation/agg_result.rs @@ -12,6 +12,7 @@ use super::bucket::GetDocCount; use super::intermediate_agg_result::{IntermediateBucketResult, IntermediateMetricResult}; use super::metric::{SingleMetricResult, Stats}; use super::Key; +use crate::schema::Schema; use crate::TantivyError; #[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] @@ -129,9 +130,12 @@ pub enum BucketResult { } impl BucketResult { - pub(crate) fn empty_from_req(req: &BucketAggregationInternal) -> crate::Result { + pub(crate) fn empty_from_req( + req: &BucketAggregationInternal, + schema: &Schema, + ) -> crate::Result { let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg); - empty_bucket.into_final_bucket_result(req) + empty_bucket.into_final_bucket_result(req, schema) } } diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 329bea6d4b..c37cc65162 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -4,6 +4,8 @@ use std::fmt::Display; use fastfield_codecs::Column; use itertools::Itertools; use serde::{Deserialize, Serialize}; +use time::format_description::well_known::Rfc3339; +use time::OffsetDateTime; use crate::aggregation::agg_req::AggregationsInternal; use crate::aggregation::agg_req_with_accessor::{ @@ -15,7 +17,7 @@ use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, }; use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; -use crate::schema::Type; +use crate::schema::{Schema, Type}; use crate::{DocId, TantivyError}; /// Histogram is a bucket aggregation, where buckets are created dynamically for given `interval`. @@ -451,6 +453,7 @@ fn intermediate_buckets_to_final_buckets_fill_gaps( buckets: Vec, histogram_req: &HistogramAggregation, sub_aggregation: &AggregationsInternal, + schema: &Schema, ) -> crate::Result> { // Generate the full list of buckets without gaps. // @@ -491,7 +494,9 @@ fn intermediate_buckets_to_final_buckets_fill_gaps( sub_aggregation: empty_sub_aggregation.clone(), }, }) - .map(|intermediate_bucket| intermediate_bucket.into_final_bucket_entry(sub_aggregation)) + .map(|intermediate_bucket| { + intermediate_bucket.into_final_bucket_entry(sub_aggregation, schema) + }) .collect::>>() } @@ -500,20 +505,56 @@ pub(crate) fn intermediate_histogram_buckets_to_final_buckets( buckets: Vec, histogram_req: &HistogramAggregation, sub_aggregation: &AggregationsInternal, + schema: &Schema, ) -> crate::Result> { - if histogram_req.min_doc_count() == 0 { + let mut buckets = if histogram_req.min_doc_count() == 0 { // With min_doc_count != 0, we may need to add buckets, so that there are no // gaps, since intermediate result does not contain empty buckets (filtered to // reduce serialization size). - intermediate_buckets_to_final_buckets_fill_gaps(buckets, histogram_req, sub_aggregation) + intermediate_buckets_to_final_buckets_fill_gaps( + buckets, + histogram_req, + sub_aggregation, + schema, + )? } else { buckets .into_iter() .filter(|histogram_bucket| histogram_bucket.doc_count >= histogram_req.min_doc_count()) - .map(|histogram_bucket| histogram_bucket.into_final_bucket_entry(sub_aggregation)) - .collect::>>() + .map(|histogram_bucket| { + histogram_bucket.into_final_bucket_entry(sub_aggregation, schema) + }) + .collect::>>()? + }; + + // If we have a date type on the histogram buckets, we add the `key_as_string` field as rfc339 + let field = schema + .get_field(&histogram_req.field) + .ok_or_else(|| TantivyError::FieldNotFound(histogram_req.field.to_string()))?; + if schema.get_field_entry(field).field_type().is_date() { + for bucket in buckets.iter_mut() { + match bucket.key { + crate::aggregation::Key::F64(val) => { + let datetime = OffsetDateTime::from_unix_timestamp_nanos(1_000 * (val as i128)) + .map_err(|err| { + TantivyError::InvalidArgument(format!( + "Could not convert {:?} to OffsetDateTime, err {:?}", + val, err + )) + })?; + let key_as_string = datetime.format(&Rfc3339).map_err(|_err| { + TantivyError::InvalidArgument("Could not serialize date".to_string()) + })?; + + bucket.key_as_string = Some(key_as_string); + } + _ => {} + } + } } + + Ok(buckets) } /// Applies req extended_bounds/hard_bounds on the min_max value @@ -1404,13 +1445,25 @@ mod tests { let res: Value = serde_json::from_str(&serde_json::to_string(&agg_res)?)?; assert_eq!(res["histogram"]["buckets"][0]["key"], 1546300800000000.0); + assert_eq!( + res["histogram"]["buckets"][0]["key_as_string"], + "2019-01-01T00:00:00Z" + ); assert_eq!(res["histogram"]["buckets"][0]["doc_count"], 1); assert_eq!(res["histogram"]["buckets"][1]["key"], 1546387200000000.0); + assert_eq!( + res["histogram"]["buckets"][1]["key_as_string"], + "2019-01-02T00:00:00Z" + ); + assert_eq!(res["histogram"]["buckets"][1]["doc_count"], 5); assert_eq!(res["histogram"]["buckets"][2]["key"], 1546473600000000.0); - assert_eq!(res["histogram"]["buckets"][2]["key"], 1546473600000000.0); + assert_eq!( + res["histogram"]["buckets"][2]["key_as_string"], + "2019-01-03T00:00:00Z" + ); assert_eq!(res["histogram"]["buckets"][3], Value::Null); diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index cd91ec20c8..a53ac268ae 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -7,6 +7,7 @@ use super::intermediate_agg_result::IntermediateAggregationResults; use super::segment_agg_result::SegmentAggregationResultsCollector; use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_validate; use crate::collector::{Collector, SegmentCollector}; +use crate::schema::Schema; use crate::{SegmentReader, TantivyError}; /// The default max bucket count, before the aggregation fails. @@ -16,6 +17,7 @@ pub const MAX_BUCKET_COUNT: u32 = 65000; /// /// The collector collects all aggregations by the underlying aggregation request. pub struct AggregationCollector { + schema: Schema, agg: Aggregations, max_bucket_count: u32, } @@ -25,8 +27,9 @@ impl AggregationCollector { /// /// Aggregation fails when the total bucket count is higher than max_bucket_count. /// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset - pub fn from_aggs(agg: Aggregations, max_bucket_count: Option) -> Self { + pub fn from_aggs(agg: Aggregations, max_bucket_count: Option, schema: Schema) -> Self { Self { + schema, agg, max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT), } @@ -113,7 +116,7 @@ impl Collector for AggregationCollector { segment_fruits: Vec<::Fruit>, ) -> crate::Result { let res = merge_fruits(segment_fruits)?; - res.into_final_bucket_result(self.agg.clone()) + res.into_final_bucket_result(self.agg.clone(), &self.schema) } } diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index 5a9613b3c5..e0117228ac 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -22,6 +22,7 @@ use super::segment_agg_result::SegmentMetricResultCollector; use super::{Key, SerializedKey, VecWithNames}; use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry}; use crate::aggregation::bucket::TermsAggregationInternal; +use crate::schema::Schema; /// Contains the intermediate aggregation result, which is optimized to be merged with other /// intermediate results. @@ -35,8 +36,12 @@ pub struct IntermediateAggregationResults { impl IntermediateAggregationResults { /// Convert intermediate result and its aggregation request to the final result. - pub fn into_final_bucket_result(self, req: Aggregations) -> crate::Result { - self.into_final_bucket_result_internal(&(req.into())) + pub fn into_final_bucket_result( + self, + req: Aggregations, + schema: &Schema, + ) -> crate::Result { + self.into_final_bucket_result_internal(&(req.into()), schema) } /// Convert intermediate result and its aggregation request to the final result. @@ -46,6 +51,7 @@ impl IntermediateAggregationResults { pub(crate) fn into_final_bucket_result_internal( self, req: &AggregationsInternal, + schema: &Schema, ) -> crate::Result { // Important assumption: // When the tree contains buckets/metric, we expect it to have all buckets/metrics from the @@ -53,11 +59,11 @@ impl IntermediateAggregationResults { let mut results: FxHashMap = FxHashMap::default(); if let Some(buckets) = self.buckets { - convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets)? + convert_and_add_final_buckets_to_result(&mut results, buckets, &req.buckets, schema)? } else { // When there are no buckets, we create empty buckets, so that the serialized json // format is constant - add_empty_final_buckets_to_result(&mut results, &req.buckets)? + add_empty_final_buckets_to_result(&mut results, &req.buckets, schema)? }; if let Some(metrics) = self.metrics { @@ -158,10 +164,12 @@ fn add_empty_final_metrics_to_result( fn add_empty_final_buckets_to_result( results: &mut FxHashMap, req_buckets: &VecWithNames, + schema: &Schema, ) -> crate::Result<()> { let requested_buckets = req_buckets.iter(); for (key, req) in requested_buckets { - let empty_bucket = AggregationResult::BucketResult(BucketResult::empty_from_req(req)?); + let empty_bucket = + AggregationResult::BucketResult(BucketResult::empty_from_req(req, schema)?); results.insert(key.to_string(), empty_bucket); } Ok(()) @@ -171,12 +179,13 @@ fn convert_and_add_final_buckets_to_result( results: &mut FxHashMap, buckets: VecWithNames, req_buckets: &VecWithNames, + schema: &Schema, ) -> crate::Result<()> { assert_eq!(buckets.len(), req_buckets.len()); let buckets_with_request = buckets.into_iter().zip(req_buckets.values()); for ((key, bucket), req) in buckets_with_request { - let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req)?); + let result = AggregationResult::BucketResult(bucket.into_final_bucket_result(req, schema)?); results.insert(key, result); } Ok(()) @@ -266,13 +275,14 @@ impl IntermediateBucketResult { pub(crate) fn into_final_bucket_result( self, req: &BucketAggregationInternal, + schema: &Schema, ) -> crate::Result { match self { IntermediateBucketResult::Range(range_res) => { let mut buckets: Vec = range_res .buckets .into_iter() - .map(|(_, bucket)| bucket.into_final_bucket_entry(&req.sub_aggregation)) + .map(|(_, bucket)| bucket.into_final_bucket_entry(&req.sub_aggregation, schema)) .collect::>>()?; buckets.sort_by(|left, right| { @@ -303,6 +313,7 @@ impl IntermediateBucketResult { req.as_histogram() .expect("unexpected aggregation, expected histogram aggregation"), &req.sub_aggregation, + schema, )?; let buckets = if req.as_histogram().unwrap().keyed { @@ -321,6 +332,7 @@ impl IntermediateBucketResult { req.as_term() .expect("unexpected aggregation, expected term aggregation"), &req.sub_aggregation, + schema, ), } } @@ -411,6 +423,7 @@ impl IntermediateTermBucketResult { self, req: &TermsAggregation, sub_aggregation_req: &AggregationsInternal, + schema: &Schema, ) -> crate::Result { let req = TermsAggregationInternal::from_req(req); let mut buckets: Vec = self @@ -424,7 +437,7 @@ impl IntermediateTermBucketResult { doc_count: entry.doc_count, sub_aggregation: entry .sub_aggregation - .into_final_bucket_result_internal(sub_aggregation_req)?, + .into_final_bucket_result_internal(sub_aggregation_req, schema)?, }) }) .collect::>()?; @@ -529,6 +542,7 @@ impl IntermediateHistogramBucketEntry { pub(crate) fn into_final_bucket_entry( self, req: &AggregationsInternal, + schema: &Schema, ) -> crate::Result { Ok(BucketEntry { key_as_string: None, @@ -536,7 +550,7 @@ impl IntermediateHistogramBucketEntry { doc_count: self.doc_count, sub_aggregation: self .sub_aggregation - .into_final_bucket_result_internal(req)?, + .into_final_bucket_result_internal(req, schema)?, }) } } @@ -573,13 +587,14 @@ impl IntermediateRangeBucketEntry { pub(crate) fn into_final_bucket_entry( self, req: &AggregationsInternal, + schema: &Schema, ) -> crate::Result { Ok(RangeBucketEntry { key: self.key, doc_count: self.doc_count, sub_aggregation: self .sub_aggregation - .into_final_bucket_result_internal(req)?, + .into_final_bucket_result_internal(req, schema)?, to: self.to, from: self.from, }) diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index f84944c261..dec50bdf0d 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -222,7 +222,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let reader = index.reader()?; let searcher = reader.searcher(); @@ -300,7 +300,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 4469bc6610..a9dc69c775 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -53,9 +53,10 @@ //! use tantivy::query::AllQuery; //! use tantivy::aggregation::agg_result::AggregationResults; //! use tantivy::IndexReader; +//! use tantivy::schema::Schema; //! //! # #[allow(dead_code)] -//! fn aggregate_on_index(reader: &IndexReader) { +//! fn aggregate_on_index(reader: &IndexReader, schema: Schema) { //! let agg_req: Aggregations = vec![ //! ( //! "average".to_string(), @@ -67,7 +68,7 @@ //! .into_iter() //! .collect(); //! -//! let collector = AggregationCollector::from_aggs(agg_req, None); +//! let collector = AggregationCollector::from_aggs(agg_req, None, schema); //! //! let searcher = reader.searcher(); //! let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); @@ -360,7 +361,7 @@ mod tests { index: &Index, query: Option<(&str, &str)>, ) -> crate::Result { - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = AggregationCollector::from_aggs(agg_req, None, index.schema()); let reader = index.reader()?; let searcher = reader.searcher(); @@ -554,10 +555,10 @@ mod tests { let searcher = reader.searcher(); let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap(); intermediate_agg_result - .into_final_bucket_result(agg_req) + .into_final_bucket_result(agg_req, &index.schema()) .unwrap() } else { - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = AggregationCollector::from_aggs(agg_req, None, index.schema()); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -807,7 +808,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); @@ -1007,9 +1008,10 @@ mod tests { // Test de/serialization roundtrip on intermediate_agg_result let res: IntermediateAggregationResults = serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap(); - res.into_final_bucket_result(agg_req.clone()).unwrap() + res.into_final_bucket_result(agg_req.clone(), &index.schema()) + .unwrap() } else { - let collector = AggregationCollector::from_aggs(agg_req.clone(), None); + let collector = AggregationCollector::from_aggs(agg_req.clone(), None, index.schema()); let searcher = reader.searcher(); searcher.search(&term_query, &collector).unwrap() @@ -1067,7 +1069,7 @@ mod tests { ); // Test empty result set - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = AggregationCollector::from_aggs(agg_req, None, index.schema()); let searcher = reader.searcher(); searcher.search(&query_with_no_hits, &collector).unwrap(); @@ -1132,7 +1134,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); @@ -1245,7 +1247,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1276,7 +1278,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1307,7 +1309,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1346,7 +1348,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1375,7 +1377,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = AggregationCollector::from_aggs(agg_req, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1404,7 +1406,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req, None); + let collector = AggregationCollector::from_aggs(agg_req, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1441,7 +1443,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1476,7 +1478,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1515,7 +1517,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1545,7 +1547,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1602,7 +1604,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1, None); + let collector = AggregationCollector::from_aggs(agg_req_1, None, index.schema()); let searcher = reader.searcher(); let agg_res: AggregationResults = diff --git a/src/schema/field_type.rs b/src/schema/field_type.rs index c2cae2f1e1..b92f4448d7 100644 --- a/src/schema/field_type.rs +++ b/src/schema/field_type.rs @@ -181,6 +181,11 @@ impl FieldType { matches!(self, FieldType::IpAddr(_)) } + /// returns true if this is an date field + pub fn is_date(&self) -> bool { + matches!(self, FieldType::Date(_)) + } + /// returns true if the field is indexed. pub fn is_indexed(&self) -> bool { match *self {