From dc2071e712ec110c27172184f4af11261c87932a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 29 Jun 2022 23:30:57 -0700 Subject: [PATCH 1/2] Support dictionary array for subtract kernel --- arrow/src/compute/kernels/arithmetic.rs | 55 +++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 1f812b67e98..72fe807f3d9 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -720,6 +720,17 @@ where math_op(left, right, |a, b| a - b) } +/// Perform `left - right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn subtract_dyn(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!(left, right, |a, b| a - b) + } + _ => typed_math_op!(left, right, |a, b| a - b), + } +} + /// Subtract every value in an array by a scalar. If any value in the array is null then the /// result is also null. pub fn subtract_scalar( @@ -936,6 +947,50 @@ mod tests { assert_eq!(19, c.value(4)); } + #[test] + fn test_primitive_array_subtract_dyn() { + let a = Int32Array::from(vec![Some(51), Some(6), Some(15), Some(8), Some(9)]); + let b = Int32Array::from(vec![Some(6), Some(7), Some(8), None, Some(8)]); + let c = subtract_dyn(&a, &b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(45, c.value(0)); + assert_eq!(-1, c.value(1)); + assert_eq!(7, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(1, c.value(4)); + } + + #[test] + fn test_primitive_array_subtract_dyn_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(15).unwrap(); + builder.append(8).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append(20).unwrap(); + let a = builder.finish(); + + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append_null().unwrap(); + builder.append(10).unwrap(); + let b = builder.finish(); + + let c = subtract_dyn(&a, &b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(9, c.value(0)); + assert_eq!(1, c.value(1)); + assert_eq!(-1, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(10, c.value(4)); + } + #[test] fn test_primitive_array_add_sliced() { let a = Int32Array::from(vec![0, 0, 0, 5, 6, 7, 8, 9, 0]); From 4d280747d7029e9551da55e6b4d2214887df0248 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 29 Jun 2022 23:39:37 -0700 Subject: [PATCH 2/2] Support dictionary array in multiply kernel --- arrow/src/compute/kernels/arithmetic.rs | 55 +++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 72fe807f3d9..248e8df2770 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -782,6 +782,17 @@ where math_op(left, right, |a, b| a * b) } +/// Perform `left * right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn multiply_dyn(left: &dyn Array, right: &dyn Array) -> Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!(left, right, |a, b| a * b) + } + _ => typed_math_op!(left, right, |a, b| a * b), + } +} + /// Multiply every value in an array by a scalar. If any value in the array is null then the /// result is also null. pub fn multiply_scalar( @@ -991,6 +1002,50 @@ mod tests { assert_eq!(10, c.value(4)); } + #[test] + fn test_primitive_array_multiply_dyn() { + let a = Int32Array::from(vec![Some(5), Some(6), Some(7), Some(8), Some(9)]); + let b = Int32Array::from(vec![Some(6), Some(7), Some(8), None, Some(8)]); + let c = multiply_dyn(&a, &b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(30, c.value(0)); + assert_eq!(42, c.value(1)); + assert_eq!(56, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(72, c.value(4)); + } + + #[test] + fn test_primitive_array_multiply_dyn_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(5).unwrap(); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append(9).unwrap(); + let a = builder.finish(); + + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(6).unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append_null().unwrap(); + builder.append(10).unwrap(); + let b = builder.finish(); + + let c = multiply_dyn(&a, &b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(30, c.value(0)); + assert_eq!(42, c.value(1)); + assert_eq!(56, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(90, c.value(4)); + } + #[test] fn test_primitive_array_add_sliced() { let a = Int32Array::from(vec![0, 0, 0, 5, 6, 7, 8, 9, 0]);