diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 0189dade2ef..5cfb5a8711f 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -30,6 +30,7 @@ use crate::buffer::Buffer; #[cfg(feature = "simd")] use crate::buffer::MutableBuffer; use crate::compute::kernels::arity::unary; +use crate::compute::unary_dyn; use crate::compute::util::combine_option_bitmap; use crate::datatypes; use crate::datatypes::{ArrowNumericType, DataType}; @@ -748,6 +749,21 @@ where Ok(unary(array, |value| value - scalar)) } +/// Subtract every value in an array by a scalar. If any value in the array is null then the +/// result is also null. The given array must be a `PrimitiveArray` of the type same as +/// the scalar, or a `DictionaryArray` of the value type same as the scalar. +pub fn subtract_scalar_dyn(array: &dyn Array, scalar: T::Native) -> Result +where + T: datatypes::ArrowNumericType, + T::Native: Add + + Sub + + Mul + + Div + + Zero, +{ + unary_dyn::<_, T>(array, |value| value - scalar) +} + /// Perform `-` operation on an array. If value is null then the result is also null. pub fn negate(array: &PrimitiveArray) -> Result> where @@ -812,6 +828,23 @@ where Ok(unary(array, |value| value * scalar)) } +/// Multiply every value in an array by a scalar. If any value in the array is null then the +/// result is also null. The given array must be a `PrimitiveArray` of the type same as +/// the scalar, or a `DictionaryArray` of the value type same as the scalar. +pub fn multiply_scalar_dyn(array: &dyn Array, scalar: T::Native) -> Result +where + T: datatypes::ArrowNumericType, + T::Native: Add + + Sub + + Mul + + Div + + Rem + + Zero + + One, +{ + unary_dyn::<_, T>(array, |value| value * scalar) +} + /// Perform `left % right` operation on two arrays. If either left or right value is null /// then the result is also null. If any right hand value is zero then the result of this /// operation will be `Err(ArrowError::DivideByZero)`. @@ -897,6 +930,21 @@ where Ok(unary(array, |a| a / divisor)) } +/// Divide every value in an array by a scalar. If any value in the array is null then the +/// result is also null. If the scalar is zero then the result of this operation will be +/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type +/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar. +pub fn divide_scalar_dyn(array: &dyn Array, divisor: T::Native) -> Result +where + T: datatypes::ArrowNumericType, + T::Native: Div + Zero, +{ + if divisor.is_zero() { + return Err(ArrowError::DivideByZero); + } + unary_dyn::<_, T>(array, |value| value / divisor) +} + #[cfg(test)] mod tests { use super::*; @@ -1002,6 +1050,46 @@ mod tests { assert_eq!(10, c.value(4)); } + #[test] + fn test_primitive_array_subtract_scalar_dyn() { + let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]); + let b = 1_i32; + let c = subtract_scalar_dyn::(&a, b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(4, c.value(0)); + assert_eq!(5, c.value(1)); + assert_eq!(6, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(8, c.value(4)); + + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(5).unwrap(); + builder.append_null().unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append(9).unwrap(); + let a = builder.finish(); + let b = -1_i32; + + let c = subtract_scalar_dyn::(&a, b).unwrap(); + let c = c + .as_any() + .downcast_ref::>() + .unwrap(); + let values = c + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(6, values.value(c.key(0).unwrap())); + assert!(c.is_null(1)); + assert_eq!(8, values.value(c.key(2).unwrap())); + assert_eq!(9, values.value(c.key(3).unwrap())); + assert_eq!(10, values.value(c.key(4).unwrap())); + } + #[test] fn test_primitive_array_multiply_dyn() { let a = Int32Array::from(vec![Some(5), Some(6), Some(7), Some(8), Some(9)]); @@ -1046,6 +1134,46 @@ mod tests { assert_eq!(90, c.value(4)); } + #[test] + fn test_primitive_array_multiply_scalar_dyn() { + let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]); + let b = 2_i32; + let c = multiply_scalar_dyn::(&a, b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(10, c.value(0)); + assert_eq!(12, c.value(1)); + assert_eq!(14, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(18, c.value(4)); + + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(5).unwrap(); + builder.append_null().unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append(9).unwrap(); + let a = builder.finish(); + let b = -1_i32; + + let c = multiply_scalar_dyn::(&a, b).unwrap(); + let c = c + .as_any() + .downcast_ref::>() + .unwrap(); + let values = c + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(-5, values.value(c.key(0).unwrap())); + assert!(c.is_null(1)); + assert_eq!(-7, values.value(c.key(2).unwrap())); + assert_eq!(-8, values.value(c.key(3).unwrap())); + assert_eq!(-9, values.value(c.key(4).unwrap())); + } + #[test] fn test_primitive_array_add_sliced() { let a = Int32Array::from(vec![0, 0, 0, 5, 6, 7, 8, 9, 0]); @@ -1192,6 +1320,50 @@ mod tests { assert_eq!(c, expected); } + #[test] + fn test_primitive_array_divide_scalar_dyn() { + let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]); + let b = 2_i32; + let c = divide_scalar_dyn::(&a, b).unwrap(); + let c = c.as_any().downcast_ref::().unwrap(); + assert_eq!(2, c.value(0)); + assert_eq!(3, c.value(1)); + assert_eq!(3, c.value(2)); + assert!(c.is_null(3)); + assert_eq!(4, c.value(4)); + + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(5).unwrap(); + builder.append_null().unwrap(); + builder.append(7).unwrap(); + builder.append(8).unwrap(); + builder.append(9).unwrap(); + let a = builder.finish(); + let b = -2_i32; + + let c = divide_scalar_dyn::(&a, b).unwrap(); + let c = c + .as_any() + .downcast_ref::>() + .unwrap(); + let values = c + .values() + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(-2, values.value(c.key(0).unwrap())); + assert!(c.is_null(1)); + assert_eq!(-3, values.value(c.key(2).unwrap())); + assert_eq!(-4, values.value(c.key(3).unwrap())); + assert_eq!(-4, values.value(c.key(4).unwrap())); + + let e = divide_scalar_dyn::(&a, 0_i32) + .expect_err("should have failed due to divide by zero"); + assert_eq!("DivideByZero", format!("{:?}", e)); + } + #[test] fn test_primitive_array_divide_scalar_sliced() { let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]);