diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 0189dade2ef..f64038c19a0 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -30,6 +30,7 @@ use crate::buffer::Buffer; #[cfg(feature = "simd")] use crate::buffer::MutableBuffer; use crate::compute::kernels::arity::unary; +use crate::compute::unary_dyn; use crate::compute::util::combine_option_bitmap; use crate::datatypes; use crate::datatypes::{ArrowNumericType, DataType}; @@ -707,6 +708,17 @@ where Ok(unary(array, |value| value + scalar)) } +/// Add every value in an array by a scalar. If any value in the array is null then the +/// result is also null. The given array must be a `PrimitiveArray` of the type same as +/// the scalar, or a `DictionaryArray` of the value type same as the scalar. +pub fn add_scalar_dyn(array: &dyn Array, scalar: T::Native) -> Result +where + T: datatypes::ArrowNumericType, + T::Native: Add, +{ + unary_dyn::<_, T>(array, |value| value + scalar) +} + /// Perform `left - right` operation on two arrays. If either left or right value is null /// then the result is also null. pub fn subtract( @@ -958,6 +970,46 @@ mod tests { assert_eq!(19, c.value(4)); } + #[test] + fn test_primitive_array_add_scalar_dyn() { + let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]); + let b = 1_i32; + let c = add_scalar_dyn::(&a, b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(6, c.value(0)); + assert_eq!(7, c.value(1)); + assert_eq!(8, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(10, c.value(4)); + + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(5).unwrap(); + builder.append_null().unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append(9).unwrap(); + let a = builder.finish(); + let b = -1_i32; + + let c = add_scalar_dyn::(&a, b).unwrap(); + let c = c + .as_any() + .downcast_ref::>() + .unwrap(); + let values = c + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(4, values.value(c.key(0).unwrap())); + assert!(c.is_null(1)); + assert_eq!(6, values.value(c.key(2).unwrap())); + assert_eq!(7, values.value(c.key(3).unwrap())); + assert_eq!(8, values.value(c.key(4).unwrap())); + } + #[test] fn test_primitive_array_subtract_dyn() { let a = Int32Array::from(vec![Some(51), Some(6), Some(15), Some(8), Some(9)]);