From a25af3dfe4d67456ef02089b75b81a85995431f0 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Fri, 14 Oct 2022 15:55:33 +1300 Subject: [PATCH] Take decimal as primitive (#2637) --- arrow/src/compute/kernels/take.rs | 52 ++++++++----------------------- 1 file changed, 13 insertions(+), 39 deletions(-) diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index a399f060200..0ef2025cf38 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -134,10 +134,19 @@ where let values = values.as_any().downcast_ref::().unwrap(); Ok(Arc::new(take_boolean(values, indices)?)) } - DataType::Decimal128(_, _) => { - let decimal_values = - values.as_any().downcast_ref::().unwrap(); - Ok(Arc::new(take_decimal128(decimal_values, indices)?)) + DataType::Decimal128(p, s) => { + let decimal_values = values.as_any().downcast_ref::().unwrap(); + let array = take_primitive(decimal_values, indices)? + .with_precision_and_scale(*p, *s) + .unwrap(); + Ok(Arc::new(array)) + } + DataType::Decimal256(p, s) => { + let decimal_values = values.as_any().downcast_ref::().unwrap(); + let array = take_primitive(decimal_values, indices)? + .with_precision_and_scale(*p, *s) + .unwrap(); + Ok(Arc::new(array)) } DataType::Utf8 => { let values = values @@ -429,41 +438,6 @@ where Ok((buffer, nulls)) } -/// `take` implementation for decimal arrays -fn take_decimal128( - decimal_values: &Decimal128Array, - indices: &PrimitiveArray, -) -> Result -where - IndexType: ArrowNumericType, - IndexType::Native: ToPrimitive, -{ - indices - .iter() - .map(|index| { - // Use type annotations below for readability (was blowing - // my mind otherwise) - let t: Option>> = index.map(|index| { - let index = ToPrimitive::to_usize(&index).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) - })?; - - if decimal_values.is_null(index) { - Ok(None) - } else { - Ok(Some(decimal_values.value(index))) - } - }); - let t: Result>> = t.transpose(); - let t: Result> = t.map(|t| t.flatten()); - t - }) - .collect::>()? - // PERF: we could avoid re-validating that the data in - // Decimal128Array was in range as we know it came from a valid Decimal128Array - .with_precision_and_scale(decimal_values.precision()?, decimal_values.scale()?) -} - /// `take` implementation for all primitive arrays /// /// This checks if an `indices` slot is populated, and gets the value from `values`