Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add divide dyn kernel which produces null for division by zero #2764

Merged
merged 4 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
102 changes: 102 additions & 0 deletions arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,39 @@ where
)
}

#[cfg(feature = "dyn_arith_dict")]
fn math_divide_safe_op_dict<K, T, F>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
op: F,
) -> Result<ArrayRef>
where
K: ArrowNumericType,
T: ArrowNumericType,
T::Native: One + Zero,
F: Fn(T::Native, T::Native) -> Option<T::Native>,
{
let left = left.downcast_dict::<PrimitiveArray<T>>().unwrap();
let right = right.downcast_dict::<PrimitiveArray<T>>().unwrap();
let array: PrimitiveArray<T> = binary_opt::<_, _, _, T>(left, right, op)?;
Ok(Arc::new(array) as ArrayRef)
}

fn math_safe_divide_op<LT, RT, F>(
left: &PrimitiveArray<LT>,
right: &PrimitiveArray<RT>,
op: F,
) -> Result<ArrayRef>
where
LT: ArrowNumericType,
RT: ArrowNumericType,
RT::Native: One + Zero,
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
{
let array: PrimitiveArray<LT> = binary_opt::<_, _, _, LT>(left, right, op)?;
Ok(Arc::new(array) as ArrayRef)
}

/// Perform `left + right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
Expand Down Expand Up @@ -1406,6 +1439,51 @@ pub fn divide_dyn_checked(left: &dyn Array, right: &dyn Array) -> Result<ArrayRe
}
}

/// Perform `left / right` operation on two arrays. If either left or right value is null
/// then the result is also null.
///
/// If any right hand value is zero, the operation value will be replaced with null in the
/// result.
///
/// Unlike `divide_dyn` or `divide_dyn_checked`, division by zero will get a null value instead
/// returning an `Err`, this also doesn't check overflowing, overflowing will just wrap
/// the result around.
pub fn divide_dyn_opt(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_math_op!(
left,
right,
|a, b| {
if b.is_zero() {
None
} else {
Some(a.div_wrapping(b))
}
},
math_divide_safe_op_dict
)
}
_ => {
downcast_primitive_array!(
(left, right) => {
math_safe_divide_op(left, right, |a, b| {
if b.is_zero() {
None
} else {
Some(a.div_wrapping(b))
}
})
}
_ => Err(ArrowError::CastError(format!(
"Unsupported data type {}, {}",
left.data_type(), right.data_type()
)))
)
}
}
}

/// Perform `left / right` operation on two arrays without checking for division by zero.
/// For floating point types, the result of dividing by zero follows normal floating point
/// rules. For other numeric types, dividing by zero will panic,
Expand Down Expand Up @@ -2752,4 +2830,28 @@ mod tests {
let overflow = divide_dyn_checked(&a, &b);
overflow.expect_err("overflow should be detected");
}

#[test]
#[cfg(feature = "dyn_arith_dict")]
fn test_div_dyn_opt_overflow_division_by_zero() {
let a = Int32Array::from(vec![i32::MIN]);
let b = Int32Array::from(vec![0]);

let division_by_zero = divide_dyn_opt(&a, &b);
let expected = Arc::new(Int32Array::from(vec![None])) as ArrayRef;
assert_eq!(&expected, &division_by_zero.unwrap());

let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::with_capacity(1, 1);
builder.append(i32::MIN).unwrap();
let a = builder.finish();

let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::with_capacity(1, 1);
builder.append(0).unwrap();
let b = builder.finish();

let division_by_zero = divide_dyn_opt(&a, &b);
assert_eq!(&expected, &division_by_zero.unwrap());
}
}
69 changes: 41 additions & 28 deletions arrow/src/compute/kernels/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,26 @@ where
Ok(unsafe { build_primitive_array(len, buffer.into(), 0, None) })
}

#[inline(never)]
fn try_binary_opt_no_nulls<A: ArrayAccessor, B: ArrayAccessor, F, O>(
len: usize,
a: A,
b: B,
op: F,
) -> Result<PrimitiveArray<O>>
where
O: ArrowPrimitiveType,
F: Fn(A::Item, B::Item) -> Option<O::Native>,
{
let mut buffer = Vec::with_capacity(10);
for idx in 0..len {
unsafe {
buffer.push(op(a.value_unchecked(idx), b.value_unchecked(idx)));
};
}
Ok(buffer.iter().collect())
}

/// Applies the provided binary operation across `a` and `b`, collecting the optional results
/// into a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the corresponding
/// index in the result will also be null. The binary operation could return `None` which
Expand All @@ -367,16 +387,14 @@ where
/// # Error
///
/// This function gives error if the arrays have different lengths
pub(crate) fn binary_opt<A, B, F, O>(
a: &PrimitiveArray<A>,
b: &PrimitiveArray<B>,
pub(crate) fn binary_opt<A: ArrayAccessor + Array, B: ArrayAccessor + Array, F, O>(
a: A,
b: B,
op: F,
) -> Result<PrimitiveArray<O>>
where
A: ArrowPrimitiveType,
B: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(A::Native, B::Native) -> Option<O::Native>,
F: Fn(A::Item, B::Item) -> Option<O::Native>,
{
if a.len() != b.len() {
return Err(ArrowError::ComputeError(
Expand All @@ -389,29 +407,24 @@ where
}

if a.null_count() == 0 && b.null_count() == 0 {
Ok(a.values()
.iter()
.zip(b.values().iter())
.map(|(a, b)| op(*a, *b))
.collect())
} else {
let iter_a = ArrayIter::new(a);
let iter_b = ArrayIter::new(b);

let values =
iter_a
.into_iter()
.zip(iter_b.into_iter())
.map(|(item_a, item_b)| {
if let (Some(a), Some(b)) = (item_a, item_b) {
op(a, b)
} else {
None
}
});

Ok(values.collect())
return try_binary_opt_no_nulls(a.len(), a, b, op);
}

let iter_a = ArrayIter::new(a);
let iter_b = ArrayIter::new(b);

let values = iter_a
.into_iter()
.zip(iter_b.into_iter())
.map(|(item_a, item_b)| {
if let (Some(a), Some(b)) = (item_a, item_b) {
op(a, b)
} else {
None
}
});

Ok(values.collect())
}

#[cfg(test)]
Expand Down