Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DictionaryArray in unary kernel #1990

Merged
merged 7 commits into from
Jul 6, 2022
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
154 changes: 150 additions & 4 deletions arrow/src/compute/kernels/arity.rs
Expand Up @@ -17,9 +17,13 @@

//! Defines kernels suitable to perform operations to primitive arrays.

use crate::array::{Array, ArrayData, PrimitiveArray};
use crate::array::{Array, ArrayData, DictionaryArray, PrimitiveArray};
use crate::buffer::Buffer;
use crate::datatypes::ArrowPrimitiveType;
use crate::datatypes::{
ArrowNumericType, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type,
Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use crate::error::{ArrowError, Result};

#[inline]
fn into_primitive_array_data<I: ArrowPrimitiveType, O: ArrowPrimitiveType>(
Expand Down Expand Up @@ -78,10 +82,120 @@ where
PrimitiveArray::<O>::from(data)
}

macro_rules! unary_dict_op {
($array: expr, $op: expr, $value_ty: ty) => {{
// Safety justification: Since the inputs are valid Arrow arrays, all values are
// valid indexes into the dictionary (which is verified during construction)

let array_iter = unsafe {
$array
.values()
.as_any()
.downcast_ref::<$value_ty>()
.unwrap()
.take_iter_unchecked($array.keys_iter())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, is it possible to directly apply the op on dictionary values? if values are large strings, the current approach will need to first decode the dictionary and convert it to a "plain" array, and then apply the op to each value in there, which is expensive.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it too, but didn't try. Let me try if it is feasible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I remember it. It is feasible, but it has one con. Because unary_dict and unary_dyn must return ArrayRef, the compiler cannot infer T so the caller must specify it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the change in latest commit. You can see if this is okay.

};

let values = array_iter.map(|v| v.map(|value| $op(value))).collect();

Ok(values)
}};
}

/// A helper function that applies an unary function to a dictionary array with primitive value type.
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<PrimitiveArray<T>>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be public so it can be used by other mods like arithmetic.rs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They will use unary_dyn directly.

where
K: ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
unary_dict_op!(array, op, PrimitiveArray<T>)
viirya marked this conversation as resolved.
Show resolved Hide resolved
}

/// Applies an unary function to an array with primitive values.
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<PrimitiveArray<T>>
where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
match array.data_type() {
DataType::Dictionary(key_type, _) => match key_type.as_ref() {
DataType::Int8 => unary_dict(
array
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap(),
op,
),
DataType::Int16 => unary_dict(
array
.as_any()
.downcast_ref::<DictionaryArray<Int16Type>>()
.unwrap(),
op,
),
DataType::Int32 => unary_dict(
array
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.unwrap(),
op,
),
DataType::Int64 => unary_dict(
array
.as_any()
.downcast_ref::<DictionaryArray<Int64Type>>()
.unwrap(),
op,
),
DataType::UInt8 => unary_dict(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt8Type>>()
.unwrap(),
op,
),
DataType::UInt16 => unary_dict(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt16Type>>()
.unwrap(),
op,
),
DataType::UInt32 => unary_dict(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt32Type>>()
.unwrap(),
op,
),
DataType::UInt64 => unary_dict(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt64Type>>()
.unwrap(),
op,
),
t => Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on dictionary array of key type {}.",
t
))),
},
_ => Ok(unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
)),
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::array::{as_primitive_array, Float64Array};
use crate::array::{
as_primitive_array, Float64Array, Int32Array, PrimitiveBuilder,
PrimitiveDictionaryBuilder,
};
use crate::datatypes::{Int32Type, Int8Type};

#[test]
fn test_unary_f64_slice() {
Expand All @@ -93,6 +207,38 @@ mod tests {
assert_eq!(
result,
Float64Array::from(vec![None, Some(7.0), None, Some(7.0)])
)
);

let result = unary_dyn(input_slice, |n| n + 1.0).unwrap();
assert_eq!(
result,
Float64Array::from(vec![None, Some(7.8), None, Some(8.2)])
);
}

#[test]
fn test_unary_dict_and_unary_dyn() {
let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
builder.append(5).unwrap();
builder.append(6).unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append_null().unwrap();
builder.append(9).unwrap();
let dictionary_array = builder.finish();

let result = unary_dict(&dictionary_array, |n| n + 1).unwrap();
assert_eq!(
result,
Int32Array::from(vec![Some(6), Some(7), Some(8), Some(9), None, Some(10)])
);

let result = unary_dyn(&dictionary_array, |n| n + 1).unwrap();
assert_eq!(
result,
Int32Array::from(vec![Some(6), Some(7), Some(8), Some(9), None, Some(10)])
);
}
}