Skip to content

Commit

Permalink
Add dictionary support to subtract_scalar, multiply_scalar, divide_sc…
Browse files Browse the repository at this point in the history
…alar (#2020)
  • Loading branch information
viirya committed Jul 7, 2022
1 parent f8fa984 commit da3879e
Showing 1 changed file with 171 additions and 0 deletions.
171 changes: 171 additions & 0 deletions arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -760,6 +760,21 @@ where
Ok(unary(array, |value| value - scalar))
}

/// Subtract every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
pub fn subtract_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: datatypes::ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Zero,
{
unary_dyn::<_, T>(array, |value| value - scalar)
}

/// Perform `-` operation on an array. If value is null then the result is also null.
pub fn negate<T>(array: &PrimitiveArray<T>) -> Result<PrimitiveArray<T>>
where
Expand Down Expand Up @@ -824,6 +839,23 @@ where
Ok(unary(array, |value| value * scalar))
}

/// Multiply every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
pub fn multiply_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: datatypes::ArrowNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Rem<Output = T::Native>
+ Zero
+ One,
{
unary_dyn::<_, T>(array, |value| value * scalar)
}

/// Perform `left % right` operation on two arrays. If either left or right value is null
/// then the result is also null. If any right hand value is zero then the result of this
/// operation will be `Err(ArrowError::DivideByZero)`.
Expand Down Expand Up @@ -909,6 +941,21 @@ where
Ok(unary(array, |a| a / divisor))
}

/// Divide every value in an array by a scalar. If any value in the array is null then the
/// result is also null. If the scalar is zero then the result of this operation will be
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar.
pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) -> Result<ArrayRef>
where
T: datatypes::ArrowNumericType,
T::Native: Div<Output = T::Native> + Zero,
{
if divisor.is_zero() {
return Err(ArrowError::DivideByZero);
}
unary_dyn::<_, T>(array, |value| value / divisor)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1054,6 +1101,46 @@ mod tests {
assert_eq!(10, c.value(4));
}

#[test]
fn test_primitive_array_subtract_scalar_dyn() {
let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]);
let b = 1_i32;
let c = subtract_scalar_dyn::<Int32Type>(&a, b).unwrap();
let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(4, c.value(0));
assert_eq!(5, c.value(1));
assert_eq!(6, c.value(2));
assert!(c.is_null(3));
assert_eq!(8, c.value(4));

let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
builder.append(5).unwrap();
builder.append_null().unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append(9).unwrap();
let a = builder.finish();
let b = -1_i32;

let c = subtract_scalar_dyn::<Int32Type>(&a, b).unwrap();
let c = c
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap();
let values = c
.values()
.as_any()
.downcast_ref::<PrimitiveArray<Int32Type>>()
.unwrap();
assert_eq!(6, values.value(c.key(0).unwrap()));
assert!(c.is_null(1));
assert_eq!(8, values.value(c.key(2).unwrap()));
assert_eq!(9, values.value(c.key(3).unwrap()));
assert_eq!(10, values.value(c.key(4).unwrap()));
}

#[test]
fn test_primitive_array_multiply_dyn() {
let a = Int32Array::from(vec![Some(5), Some(6), Some(7), Some(8), Some(9)]);
Expand Down Expand Up @@ -1098,6 +1185,46 @@ mod tests {
assert_eq!(90, c.value(4));
}

#[test]
fn test_primitive_array_multiply_scalar_dyn() {
let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]);
let b = 2_i32;
let c = multiply_scalar_dyn::<Int32Type>(&a, b).unwrap();
let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(10, c.value(0));
assert_eq!(12, c.value(1));
assert_eq!(14, c.value(2));
assert!(c.is_null(3));
assert_eq!(18, c.value(4));

let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
builder.append(5).unwrap();
builder.append_null().unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append(9).unwrap();
let a = builder.finish();
let b = -1_i32;

let c = multiply_scalar_dyn::<Int32Type>(&a, b).unwrap();
let c = c
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap();
let values = c
.values()
.as_any()
.downcast_ref::<PrimitiveArray<Int32Type>>()
.unwrap();
assert_eq!(-5, values.value(c.key(0).unwrap()));
assert!(c.is_null(1));
assert_eq!(-7, values.value(c.key(2).unwrap()));
assert_eq!(-8, values.value(c.key(3).unwrap()));
assert_eq!(-9, values.value(c.key(4).unwrap()));
}

#[test]
fn test_primitive_array_add_sliced() {
let a = Int32Array::from(vec![0, 0, 0, 5, 6, 7, 8, 9, 0]);
Expand Down Expand Up @@ -1244,6 +1371,50 @@ mod tests {
assert_eq!(c, expected);
}

#[test]
fn test_primitive_array_divide_scalar_dyn() {
let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]);
let b = 2_i32;
let c = divide_scalar_dyn::<Int32Type>(&a, b).unwrap();
let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(2, c.value(0));
assert_eq!(3, c.value(1));
assert_eq!(3, c.value(2));
assert!(c.is_null(3));
assert_eq!(4, c.value(4));

let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
builder.append(5).unwrap();
builder.append_null().unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append(9).unwrap();
let a = builder.finish();
let b = -2_i32;

let c = divide_scalar_dyn::<Int32Type>(&a, b).unwrap();
let c = c
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap();
let values = c
.values()
.as_any()
.downcast_ref::<PrimitiveArray<Int32Type>>()
.unwrap();
assert_eq!(-2, values.value(c.key(0).unwrap()));
assert!(c.is_null(1));
assert_eq!(-3, values.value(c.key(2).unwrap()));
assert_eq!(-4, values.value(c.key(3).unwrap()));
assert_eq!(-4, values.value(c.key(4).unwrap()));

let e = divide_scalar_dyn::<Int32Type>(&a, 0_i32)
.expect_err("should have failed due to divide by zero");
assert_eq!("DivideByZero", format!("{:?}", e));
}

#[test]
fn test_primitive_array_divide_scalar_sliced() {
let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]);
Expand Down

0 comments on commit da3879e

Please sign in to comment.