From a89c1cf58161473b162fe1251d1da4cd09b25e36 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Jun 2022 09:05:40 -0700 Subject: [PATCH] Support dictionary array for subtract and multiply kernel (#1971) * Support dictionary array for subtract kernel * Support dictionary array in multiply kernel --- arrow/src/compute/kernels/arithmetic.rs | 110 ++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 1f812b67e98..248e8df2770 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -720,6 +720,17 @@ where math_op(left, right, |a, b| a - b) } +/// Perform `left - right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn subtract_dyn(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!(left, right, |a, b| a - b) + } + _ => typed_math_op!(left, right, |a, b| a - b), + } +} + /// Subtract every value in an array by a scalar. If any value in the array is null then the /// result is also null. pub fn subtract_scalar( @@ -771,6 +782,17 @@ where math_op(left, right, |a, b| a * b) } +/// Perform `left * right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn multiply_dyn(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!(left, right, |a, b| a * b) + } + _ => typed_math_op!(left, right, |a, b| a * b), + } +} + /// Multiply every value in an array by a scalar. If any value in the array is null then the /// result is also null. pub fn multiply_scalar( @@ -936,6 +958,94 @@ mod tests { assert_eq!(19, c.value(4)); } + #[test] + fn test_primitive_array_subtract_dyn() { + let a = Int32Array::from(vec![Some(51), Some(6), Some(15), Some(8), Some(9)]); + let b = Int32Array::from(vec![Some(6), Some(7), Some(8), None, Some(8)]); + let c = subtract_dyn(&a, &b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(45, c.value(0)); + assert_eq!(-1, c.value(1)); + assert_eq!(7, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(1, c.value(4)); + } + + #[test] + fn test_primitive_array_subtract_dyn_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(15).unwrap(); + builder.append(8).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append(20).unwrap(); + let a = builder.finish(); + + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append_null().unwrap(); + builder.append(10).unwrap(); + let b = builder.finish(); + + let c = subtract_dyn(&a, &b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(9, c.value(0)); + assert_eq!(1, c.value(1)); + assert_eq!(-1, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(10, c.value(4)); + } + + #[test] + fn test_primitive_array_multiply_dyn() { + let a = Int32Array::from(vec![Some(5), Some(6), Some(7), Some(8), Some(9)]); + let b = Int32Array::from(vec![Some(6), Some(7), Some(8), None, Some(8)]); + let c = multiply_dyn(&a, &b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(30, c.value(0)); + assert_eq!(42, c.value(1)); + assert_eq!(56, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(72, c.value(4)); + } + + #[test] + fn test_primitive_array_multiply_dyn_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(5).unwrap(); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append(9).unwrap(); + let a = builder.finish(); + + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append_null().unwrap(); + builder.append(10).unwrap(); + let b = builder.finish(); + + let c = multiply_dyn(&a, &b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(30, c.value(0)); + assert_eq!(42, c.value(1)); + assert_eq!(56, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(90, c.value(4)); + } + #[test] fn test_primitive_array_add_sliced() { let a = Int32Array::from(vec![0, 0, 0, 5, 6, 7, 8, 9, 0]);