Skip to content

Commit

Permalink
Update try_binary and checked_ops, and remove math_checked_op (#…
Browse files Browse the repository at this point in the history
…2717)

* update try_binary
delete math_checked_op
update the return type of checked ops

Signed-off-by: remzi <13716567376yh@gmail.com>

* float div not panic on zero

Signed-off-by: remzi <13716567376yh@gmail.com>

* fix nan test

Signed-off-by: remzi <13716567376yh@gmail.com>

* add float divide by zero

Signed-off-by: remzi <13716567376yh@gmail.com>

* add float tests

Signed-off-by: remzi <13716567376yh@gmail.com>

* fix compile error

Signed-off-by: remzi <13716567376yh@gmail.com>

Signed-off-by: remzi <13716567376yh@gmail.com>
  • Loading branch information
HaoYang670 committed Sep 16, 2022
1 parent 43d912c commit f572ec1
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 149 deletions.
2 changes: 1 addition & 1 deletion arrow/Cargo.toml
Expand Up @@ -51,7 +51,7 @@ serde_json = { version = "1.0", default-features = false, features = ["std"], op
indexmap = { version = "1.9", default-features = false, features = ["std"] }
rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true }
num = { version = "0.4", default-features = false, features = ["std"] }
half = { version = "2.0", default-features = false }
half = { version = "2.0", default-features = false, features = ["num-traits"]}
hashbrown = { version = "0.12", default-features = false }
csv_crate = { version = "1.1", default-features = false, optional = true, package = "csv" }
regex = { version = "1.5.6", default-features = false, features = ["std", "unicode"] }
Expand Down
220 changes: 93 additions & 127 deletions arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -78,32 +78,6 @@ where
Ok(binary(left, right, op))
}

/// This is similar to `math_op` as it performs given operation between two input primitive arrays.
/// But the given operation can return `None` if overflow is detected. For the case, this function
/// returns an `Err`.
fn math_checked_op<LT, RT, F>(
left: &PrimitiveArray<LT>,
right: &PrimitiveArray<RT>,
op: F,
) -> Result<PrimitiveArray<LT>>
where
LT: ArrowNumericType,
RT: ArrowNumericType,
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform math operation on arrays of different length".to_string(),
));
}

try_binary(left, right, |a, b| {
op(a, b).ok_or_else(|| {
ArrowError::ComputeError(format!("Overflow happened on: {:?}, {:?}", a, b))
})
})
}

/// Helper function for operations where a valid `0` on the right array should
/// result in an [ArrowError::DivideByZero], namely the division and modulo operations
///
Expand All @@ -121,26 +95,9 @@ where
LT: ArrowNumericType,
RT: ArrowNumericType,
RT::Native: One + Zero,
F: Fn(LT::Native, RT::Native) -> Option<LT::Native>,
F: Fn(LT::Native, RT::Native) -> Result<LT::Native>,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(
"Cannot perform math operation on arrays of different length".to_string(),
));
}

try_binary(left, right, |l, r| {
if r.is_zero() {
Err(ArrowError::DivideByZero)
} else {
op(l, r).ok_or_else(|| {
ArrowError::ComputeError(format!(
"Overflow happened on: {:?}, {:?}",
l, r
))
})
}
})
try_binary(left, right, op)
}

/// Helper function for operations where a valid `0` on the right array should
Expand All @@ -161,16 +118,12 @@ fn math_checked_divide_op_on_iters<T, F>(
where
T: ArrowNumericType,
T::Native: One + Zero,
F: Fn(T::Native, T::Native) -> T::Native,
F: Fn(T::Native, T::Native) -> Result<T::Native>,
{
let buffer = if null_bit_buffer.is_some() {
let values = left.zip(right).map(|(left, right)| {
if let (Some(l), Some(r)) = (left, right) {
if r.is_zero() {
Err(ArrowError::DivideByZero)
} else {
Ok(op(l, r))
}
op(l, r)
} else {
Ok(T::default_value())
}
Expand All @@ -179,15 +132,10 @@ where
unsafe { Buffer::try_from_trusted_len_iter(values) }
} else {
// no value is null
let values = left.map(|l| l.unwrap()).zip(right.map(|r| r.unwrap())).map(
|(left, right)| {
if right.is_zero() {
Err(ArrowError::DivideByZero)
} else {
Ok(op(left, right))
}
},
);
let values = left
.map(|l| l.unwrap())
.zip(right.map(|r| r.unwrap()))
.map(|(left, right)| op(left, right));
// Safety: Iterator comes from a PrimitiveArray which reports its size correctly
unsafe { Buffer::try_from_trusted_len_iter(values) }
}?;
Expand Down Expand Up @@ -654,7 +602,7 @@ where
K: ArrowNumericType,
T: ArrowNumericType,
T::Native: One + Zero,
F: Fn(T::Native, T::Native) -> T::Native,
F: Fn(T::Native, T::Native) -> Result<T::Native>,
{
if left.len() != right.len() {
return Err(ArrowError::ComputeError(format!(
Expand Down Expand Up @@ -725,7 +673,7 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
math_checked_op(left, right, |a, b| a.add_checked(b))
try_binary(left, right, |a, b| a.add_checked(b))
}

/// Perform `left + right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -826,11 +774,7 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_unary(array, |value| {
value.add_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", scalar, value))
})
})
try_unary(array, |value| value.add_checked(scalar))
}

/// Add every value in an array by a scalar. If any value in the array is null then the
Expand Down Expand Up @@ -863,12 +807,8 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_unary_dyn::<_, T>(array, |value| {
value.add_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", scalar, value))
})
})
.map(|a| Arc::new(a) as ArrayRef)
try_unary_dyn::<_, T>(array, |value| value.add_checked(scalar))
.map(|a| Arc::new(a) as ArrayRef)
}

/// Perform `left - right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -900,7 +840,7 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
math_checked_op(left, right, |a, b| a.sub_checked(b))
try_binary(left, right, |a, b| a.sub_checked(b))
}

/// Perform `left - right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -953,14 +893,7 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero,
{
try_unary(array, |value| {
value.sub_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: subtracting {:?} from {:?}",
scalar, value
))
})
})
try_unary(array, |value| value.sub_checked(scalar))
}

/// Subtract every value in an array by a scalar. If any value in the array is null then the
Expand Down Expand Up @@ -991,15 +924,8 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_unary_dyn::<_, T>(array, |value| {
value.sub_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: subtracting {:?} from {:?}",
scalar, value
))
})
})
.map(|a| Arc::new(a) as ArrayRef)
try_unary_dyn::<_, T>(array, |value| value.sub_checked(scalar))
.map(|a| Arc::new(a) as ArrayRef)
}

/// Perform `-` operation on an array. If value is null then the result is also null.
Expand Down Expand Up @@ -1052,7 +978,7 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
math_checked_op(left, right, |a, b| a.mul_checked(b))
try_binary(left, right, |a, b| a.mul_checked(b))
}

/// Perform `left * right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -1105,14 +1031,7 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp + Zero + One,
{
try_unary(array, |value| {
value.mul_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: multiplying {:?} by {:?}",
value, scalar,
))
})
})
try_unary(array, |value| value.mul_checked(scalar))
}

/// Multiply every value in an array by a scalar. If any value in the array is null then the
Expand Down Expand Up @@ -1143,15 +1062,8 @@ where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
try_unary_dyn::<_, T>(array, |value| {
value.mul_checked(scalar).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: multiplying {:?} by {:?}",
value, scalar
))
})
})
.map(|a| Arc::new(a) as ArrayRef)
try_unary_dyn::<_, T>(array, |value| value.mul_checked(scalar))
.map(|a| Arc::new(a) as ArrayRef)
}

/// Perform `left % right` operation on two arrays. If either left or right value is null
Expand All @@ -1170,7 +1082,13 @@ where
a % b
});
#[cfg(not(feature = "simd"))]
return math_checked_divide_op(left, right, |a, b| Some(a % b));
return try_binary(left, right, |a, b| {
if b.is_zero() {
Err(ArrowError::DivideByZero)
} else {
Ok(a % b)
}
});
}

/// Perform `left / right` operation on two arrays. If either left or right value is null
Expand Down Expand Up @@ -1225,12 +1143,17 @@ where
pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result<ArrayRef> {
match left.data_type() {
DataType::Dictionary(_, _) => {
typed_dict_math_op!(left, right, |a, b| a / b, math_divide_checked_op_dict)
typed_dict_math_op!(
left,
right,
|a, b| a.div_checked(b),
math_divide_checked_op_dict
)
}
_ => {
downcast_primitive_array!(
(left, right) => {
math_checked_divide_op(left, right, |a, b| Some(a / b)).map(|a| Arc::new(a) as ArrayRef)
math_checked_divide_op(left, right, |a, b| a.div_checked(b)).map(|a| Arc::new(a) as ArrayRef)
}
_ => Err(ArrowError::CastError(format!(
"Unsupported data type {}, {}",
Expand Down Expand Up @@ -1331,15 +1254,8 @@ where
return Err(ArrowError::DivideByZero);
}

try_unary_dyn::<_, T>(array, |value| {
value.div_checked(divisor).ok_or_else(|| {
ArrowError::CastError(format!(
"Overflow: dividing {:?} by {:?}",
value, divisor
))
})
})
.map(|a| Arc::new(a) as ArrayRef)
try_unary_dyn::<_, T>(array, |value| value.div_checked(divisor))
.map(|a| Arc::new(a) as ArrayRef)
}

#[cfg(test)]
Expand Down Expand Up @@ -2134,31 +2050,57 @@ mod tests {

#[test]
#[should_panic(expected = "DivideByZero")]
fn test_primitive_array_divide_by_zero_with_checked() {
fn test_int_array_divide_by_zero_with_checked() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
divide_checked(&a, &b).unwrap();
}

#[test]
#[should_panic(expected = "DivideByZero")]
fn test_f32_array_divide_by_zero_with_checked() {
let a = Float32Array::from(vec![15.0]);
let b = Float32Array::from(vec![0.0]);
divide_checked(&a, &b).unwrap();
}

#[test]
#[should_panic(expected = "attempt to divide by zero")]
fn test_primitive_array_divide_by_zero() {
fn test_int_array_divide_by_zero() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
divide(&a, &b).unwrap();
}

#[test]
fn test_f32_array_divide_by_zero() {
let a = Float32Array::from(vec![1.5, 0.0, -1.5]);
let b = Float32Array::from(vec![0.0, 0.0, 0.0]);
let result = divide(&a, &b).unwrap();
assert_eq!(result.value(0), f32::INFINITY);
assert!(result.value(1).is_nan());
assert_eq!(result.value(2), f32::NEG_INFINITY);
}

#[test]
#[should_panic(expected = "DivideByZero")]
fn test_primitive_array_divide_dyn_by_zero() {
fn test_int_array_divide_dyn_by_zero() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
divide_dyn(&a, &b).unwrap();
}

#[test]
#[should_panic(expected = "DivideByZero")]
fn test_primitive_array_divide_dyn_by_zero_dict() {
fn test_f32_array_divide_dyn_by_zero() {
let a = Float32Array::from(vec![1.5]);
let b = Float32Array::from(vec![0.0]);
divide_dyn(&a, &b).unwrap();
}

#[test]
#[should_panic(expected = "DivideByZero")]
fn test_int_array_divide_dyn_by_zero_dict() {
let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::with_capacity(1, 1);
builder.append(15).unwrap();
Expand All @@ -2174,14 +2116,38 @@ mod tests {

#[test]
#[should_panic(expected = "DivideByZero")]
fn test_primitive_array_modulus_by_zero() {
fn test_f32_dict_array_divide_dyn_by_zero() {
let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Float32Type>::with_capacity(1, 1);
builder.append(1.5).unwrap();
let a = builder.finish();

let mut builder =
PrimitiveDictionaryBuilder::<Int8Type, Float32Type>::with_capacity(1, 1);
builder.append(0.0).unwrap();
let b = builder.finish();

divide_dyn(&a, &b).unwrap();
}

#[test]
#[should_panic(expected = "DivideByZero")]
fn test_i32_array_modulus_by_zero() {
let a = Int32Array::from(vec![15]);
let b = Int32Array::from(vec![0]);
modulus(&a, &b).unwrap();
}

#[test]
fn test_primitive_array_divide_f64() {
#[should_panic(expected = "DivideByZero")]
fn test_f32_array_modulus_by_zero() {
let a = Float32Array::from(vec![1.5]);
let b = Float32Array::from(vec![0.0]);
modulus(&a, &b).unwrap();
}

#[test]
fn test_f64_array_divide() {
let a = Float64Array::from(vec![15.0, 15.0, 8.0]);
let b = Float64Array::from(vec![5.0, 6.0, 8.0]);
let c = divide(&a, &b).unwrap();
Expand Down

0 comments on commit f572ec1

Please sign in to comment.