Skip to content

Commit

Permalink
Support DictionaryArray in add_scalar kernel (#2018)
Browse files Browse the repository at this point in the history
* Add add_scalar_dyn

* Trigger Build

* Trigger Build
  • Loading branch information
viirya committed Jul 7, 2022
1 parent 62053a8 commit f8fa984
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -30,6 +30,7 @@ use crate::buffer::Buffer;
#[cfg(feature = "simd")]
use crate::buffer::MutableBuffer;
use crate::compute::kernels::arity::unary;
use crate::compute::unary_dyn;
use crate::compute::util::combine_option_bitmap;
use crate::datatypes;
use crate::datatypes::{ArrowNumericType, DataType};
Expand Down Expand Up @@ -707,6 +708,17 @@ where
Ok(unary(array, |value| value + scalar))
}

/// Add every value in an array by a scalar. If any value in the array is null then the
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
pub fn add_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
where
T: datatypes::ArrowNumericType,
T::Native: Add<Output = T::Native>,
{
unary_dyn::<_, T>(array, |value| value + scalar)
}

/// Perform `left - right` operation on two arrays. If either left or right value is null
/// then the result is also null.
pub fn subtract<T>(
Expand Down Expand Up @@ -958,6 +970,46 @@ mod tests {
assert_eq!(19, c.value(4));
}

#[test]
fn test_primitive_array_add_scalar_dyn() {
let a = Int32Array::from(vec![Some(5), Some(6), Some(7), None, Some(9)]);
let b = 1_i32;
let c = add_scalar_dyn::<Int32Type>(&a, b).unwrap();
let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(6, c.value(0));
assert_eq!(7, c.value(1));
assert_eq!(8, c.value(2));
assert!(c.is_null(3));
assert_eq!(10, c.value(4));

let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
builder.append(5).unwrap();
builder.append_null().unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append(9).unwrap();
let a = builder.finish();
let b = -1_i32;

let c = add_scalar_dyn::<Int32Type>(&a, b).unwrap();
let c = c
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap();
let values = c
.values()
.as_any()
.downcast_ref::<PrimitiveArray<Int32Type>>()
.unwrap();
assert_eq!(4, values.value(c.key(0).unwrap()));
assert!(c.is_null(1));
assert_eq!(6, values.value(c.key(2).unwrap()));
assert_eq!(7, values.value(c.key(3).unwrap()));
assert_eq!(8, values.value(c.key(4).unwrap()));
}

#[test]
fn test_primitive_array_subtract_dyn() {
let a = Int32Array::from(vec![Some(51), Some(6), Some(15), Some(8), Some(9)]);
Expand Down

0 comments on commit f8fa984

Please sign in to comment.