diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index e52940b4fc4..1580856dfc0 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -51,7 +51,7 @@ serde_json = { version = "1.0", default-features = false, features = ["std"], op indexmap = { version = "1.9", default-features = false, features = ["std"] } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } num = { version = "0.4", default-features = false, features = ["std"] } -half = { version = "2.0", default-features = false } +half = { version = "2.0", default-features = false, features = ["num-traits"]} hashbrown = { version = "0.12", default-features = false } csv_crate = { version = "1.1", default-features = false, optional = true, package = "csv" } regex = { version = "1.5.6", default-features = false, features = ["std", "unicode"] } diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 04fe2393ec4..7b91a261c7e 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -78,32 +78,6 @@ where Ok(binary(left, right, op)) } -/// This is similar to `math_op` as it performs given operation between two input primitive arrays. -/// But the given operation can return `None` if overflow is detected. For the case, this function -/// returns an `Err`. -fn math_checked_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - op: F, -) -> Result> -where - LT: ArrowNumericType, - RT: ArrowNumericType, - F: Fn(LT::Native, RT::Native) -> Option, -{ - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform math operation on arrays of different length".to_string(), - )); - } - - try_binary(left, right, |a, b| { - op(a, b).ok_or_else(|| { - ArrowError::ComputeError(format!("Overflow happened on: {:?}, {:?}", a, b)) - }) - }) -} - /// Helper function for operations where a valid `0` on the right array should /// result in an [ArrowError::DivideByZero], namely the division and modulo operations /// @@ -121,26 +95,9 @@ where LT: ArrowNumericType, RT: ArrowNumericType, RT::Native: One + Zero, - F: Fn(LT::Native, RT::Native) -> Option, + F: Fn(LT::Native, RT::Native) -> Result, { - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform math operation on arrays of different length".to_string(), - )); - } - - try_binary(left, right, |l, r| { - if r.is_zero() { - Err(ArrowError::DivideByZero) - } else { - op(l, r).ok_or_else(|| { - ArrowError::ComputeError(format!( - "Overflow happened on: {:?}, {:?}", - l, r - )) - }) - } - }) + try_binary(left, right, op) } /// Helper function for operations where a valid `0` on the right array should @@ -161,16 +118,12 @@ fn math_checked_divide_op_on_iters( where T: ArrowNumericType, T::Native: One + Zero, - F: Fn(T::Native, T::Native) -> T::Native, + F: Fn(T::Native, T::Native) -> Result, { let buffer = if null_bit_buffer.is_some() { let values = left.zip(right).map(|(left, right)| { if let (Some(l), Some(r)) = (left, right) { - if r.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(op(l, r)) - } + op(l, r) } else { Ok(T::default_value()) } @@ -179,15 +132,10 @@ where unsafe { Buffer::try_from_trusted_len_iter(values) } } else { // no value is null - let values = left.map(|l| l.unwrap()).zip(right.map(|r| r.unwrap())).map( - |(left, right)| { - if right.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(op(left, right)) - } - }, - ); + let values = left + .map(|l| l.unwrap()) + .zip(right.map(|r| r.unwrap())) + .map(|(left, right)| op(left, right)); // Safety: Iterator comes from a PrimitiveArray which reports its size correctly unsafe { Buffer::try_from_trusted_len_iter(values) } }?; @@ -654,7 +602,7 @@ where K: ArrowNumericType, T: ArrowNumericType, T::Native: One + Zero, - F: Fn(T::Native, T::Native) -> T::Native, + F: Fn(T::Native, T::Native) -> Result, { if left.len() != right.len() { return Err(ArrowError::ComputeError(format!( @@ -725,7 +673,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - math_checked_op(left, right, |a, b| a.add_checked(b)) + try_binary(left, right, |a, b| a.add_checked(b)) } /// Perform `left + right` operation on two arrays. If either left or right value is null @@ -826,11 +774,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - try_unary(array, |value| { - value.add_checked(scalar).ok_or_else(|| { - ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", scalar, value)) - }) - }) + try_unary(array, |value| value.add_checked(scalar)) } /// Add every value in an array by a scalar. If any value in the array is null then the @@ -863,12 +807,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - try_unary_dyn::<_, T>(array, |value| { - value.add_checked(scalar).ok_or_else(|| { - ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", scalar, value)) - }) - }) - .map(|a| Arc::new(a) as ArrayRef) + try_unary_dyn::<_, T>(array, |value| value.add_checked(scalar)) + .map(|a| Arc::new(a) as ArrayRef) } /// Perform `left - right` operation on two arrays. If either left or right value is null @@ -900,7 +840,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - math_checked_op(left, right, |a, b| a.sub_checked(b)) + try_binary(left, right, |a, b| a.sub_checked(b)) } /// Perform `left - right` operation on two arrays. If either left or right value is null @@ -953,14 +893,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp + Zero, { - try_unary(array, |value| { - value.sub_checked(scalar).ok_or_else(|| { - ArrowError::CastError(format!( - "Overflow: subtracting {:?} from {:?}", - scalar, value - )) - }) - }) + try_unary(array, |value| value.sub_checked(scalar)) } /// Subtract every value in an array by a scalar. If any value in the array is null then the @@ -991,15 +924,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - try_unary_dyn::<_, T>(array, |value| { - value.sub_checked(scalar).ok_or_else(|| { - ArrowError::CastError(format!( - "Overflow: subtracting {:?} from {:?}", - scalar, value - )) - }) - }) - .map(|a| Arc::new(a) as ArrayRef) + try_unary_dyn::<_, T>(array, |value| value.sub_checked(scalar)) + .map(|a| Arc::new(a) as ArrayRef) } /// Perform `-` operation on an array. If value is null then the result is also null. @@ -1052,7 +978,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - math_checked_op(left, right, |a, b| a.mul_checked(b)) + try_binary(left, right, |a, b| a.mul_checked(b)) } /// Perform `left * right` operation on two arrays. If either left or right value is null @@ -1105,14 +1031,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp + Zero + One, { - try_unary(array, |value| { - value.mul_checked(scalar).ok_or_else(|| { - ArrowError::CastError(format!( - "Overflow: multiplying {:?} by {:?}", - value, scalar, - )) - }) - }) + try_unary(array, |value| value.mul_checked(scalar)) } /// Multiply every value in an array by a scalar. If any value in the array is null then the @@ -1143,15 +1062,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - try_unary_dyn::<_, T>(array, |value| { - value.mul_checked(scalar).ok_or_else(|| { - ArrowError::CastError(format!( - "Overflow: multiplying {:?} by {:?}", - value, scalar - )) - }) - }) - .map(|a| Arc::new(a) as ArrayRef) + try_unary_dyn::<_, T>(array, |value| value.mul_checked(scalar)) + .map(|a| Arc::new(a) as ArrayRef) } /// Perform `left % right` operation on two arrays. If either left or right value is null @@ -1170,7 +1082,13 @@ where a % b }); #[cfg(not(feature = "simd"))] - return math_checked_divide_op(left, right, |a, b| Some(a % b)); + return try_binary(left, right, |a, b| { + if b.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(a % b) + } + }); } /// Perform `left / right` operation on two arrays. If either left or right value is null @@ -1225,12 +1143,17 @@ where pub fn divide_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { DataType::Dictionary(_, _) => { - typed_dict_math_op!(left, right, |a, b| a / b, math_divide_checked_op_dict) + typed_dict_math_op!( + left, + right, + |a, b| a.div_checked(b), + math_divide_checked_op_dict + ) } _ => { downcast_primitive_array!( (left, right) => { - math_checked_divide_op(left, right, |a, b| Some(a / b)).map(|a| Arc::new(a) as ArrayRef) + math_checked_divide_op(left, right, |a, b| a.div_checked(b)).map(|a| Arc::new(a) as ArrayRef) } _ => Err(ArrowError::CastError(format!( "Unsupported data type {}, {}", @@ -1331,15 +1254,8 @@ where return Err(ArrowError::DivideByZero); } - try_unary_dyn::<_, T>(array, |value| { - value.div_checked(divisor).ok_or_else(|| { - ArrowError::CastError(format!( - "Overflow: dividing {:?} by {:?}", - value, divisor - )) - }) - }) - .map(|a| Arc::new(a) as ArrayRef) + try_unary_dyn::<_, T>(array, |value| value.div_checked(divisor)) + .map(|a| Arc::new(a) as ArrayRef) } #[cfg(test)] @@ -2134,23 +2050,41 @@ mod tests { #[test] #[should_panic(expected = "DivideByZero")] - fn test_primitive_array_divide_by_zero_with_checked() { + fn test_int_array_divide_by_zero_with_checked() { let a = Int32Array::from(vec![15]); let b = Int32Array::from(vec![0]); divide_checked(&a, &b).unwrap(); } + #[test] + #[should_panic(expected = "DivideByZero")] + fn test_f32_array_divide_by_zero_with_checked() { + let a = Float32Array::from(vec![15.0]); + let b = Float32Array::from(vec![0.0]); + divide_checked(&a, &b).unwrap(); + } + #[test] #[should_panic(expected = "attempt to divide by zero")] - fn test_primitive_array_divide_by_zero() { + fn test_int_array_divide_by_zero() { let a = Int32Array::from(vec![15]); let b = Int32Array::from(vec![0]); divide(&a, &b).unwrap(); } + #[test] + fn test_f32_array_divide_by_zero() { + let a = Float32Array::from(vec![1.5, 0.0, -1.5]); + let b = Float32Array::from(vec![0.0, 0.0, 0.0]); + let result = divide(&a, &b).unwrap(); + assert_eq!(result.value(0), f32::INFINITY); + assert!(result.value(1).is_nan()); + assert_eq!(result.value(2), f32::NEG_INFINITY); + } + #[test] #[should_panic(expected = "DivideByZero")] - fn test_primitive_array_divide_dyn_by_zero() { + fn test_int_array_divide_dyn_by_zero() { let a = Int32Array::from(vec![15]); let b = Int32Array::from(vec![0]); divide_dyn(&a, &b).unwrap(); @@ -2158,7 +2092,15 @@ mod tests { #[test] #[should_panic(expected = "DivideByZero")] - fn test_primitive_array_divide_dyn_by_zero_dict() { + fn test_f32_array_divide_dyn_by_zero() { + let a = Float32Array::from(vec![1.5]); + let b = Float32Array::from(vec![0.0]); + divide_dyn(&a, &b).unwrap(); + } + + #[test] + #[should_panic(expected = "DivideByZero")] + fn test_int_array_divide_dyn_by_zero_dict() { let mut builder = PrimitiveDictionaryBuilder::::with_capacity(1, 1); builder.append(15).unwrap(); @@ -2174,14 +2116,38 @@ mod tests { #[test] #[should_panic(expected = "DivideByZero")] - fn test_primitive_array_modulus_by_zero() { + fn test_f32_dict_array_divide_dyn_by_zero() { + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(1.5).unwrap(); + let a = builder.finish(); + + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(0.0).unwrap(); + let b = builder.finish(); + + divide_dyn(&a, &b).unwrap(); + } + + #[test] + #[should_panic(expected = "DivideByZero")] + fn test_i32_array_modulus_by_zero() { let a = Int32Array::from(vec![15]); let b = Int32Array::from(vec![0]); modulus(&a, &b).unwrap(); } #[test] - fn test_primitive_array_divide_f64() { + #[should_panic(expected = "DivideByZero")] + fn test_f32_array_modulus_by_zero() { + let a = Float32Array::from(vec![1.5]); + let b = Float32Array::from(vec![0.0]); + modulus(&a, &b).unwrap(); + } + + #[test] + fn test_f64_array_divide() { let a = Float64Array::from(vec![15.0, 15.0, 8.0]); let b = Float64Array::from(vec![5.0, 6.0, 8.0]); let c = divide(&a, &b).unwrap(); diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 21c633116ee..5060234c71b 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -261,9 +261,10 @@ where /// /// Like [`try_unary`] the function is only evaluated for non-null indices /// -/// # Panic +/// # Error /// -/// Panics if the arrays have different lengths +/// Return an error if the arrays have different lengths or +/// the operation is under erroneous pub fn try_binary( a: &PrimitiveArray, b: &PrimitiveArray, @@ -275,13 +276,16 @@ where O: ArrowPrimitiveType, F: Fn(A::Native, B::Native) -> Result, { - assert_eq!(a.len(), b.len()); - let len = a.len(); - + if a.len() != b.len() { + return Err(ArrowError::ComputeError( + "Cannot perform a binary operation on arrays of different length".to_string(), + )); + } if a.is_empty() { return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE))); } + let len = a.len(); let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap(); let null_count = null_buffer .as_ref() diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs index de35c4804fa..dec0cc4b53b 100644 --- a/arrow/src/datatypes/native.rs +++ b/arrow/src/datatypes/native.rs @@ -16,8 +16,10 @@ // under the License. use super::DataType; +use crate::error::{ArrowError, Result}; pub use arrow_buffer::{ArrowNativeType, ToByteSlice}; use half::f16; +use num::Zero; /// Trait bridging the dynamic-typed nature of Arrow (via [`DataType`]) with the /// static-typed nature of rust types ([`ArrowNativeType`]) for all types that implement [`ArrowNativeType`]. @@ -43,6 +45,8 @@ pub trait ArrowPrimitiveType: 'static { pub(crate) mod native_op { use super::ArrowNativeType; + use crate::error::{ArrowError, Result}; + use num::Zero; use std::ops::{Add, Div, Mul, Sub}; /// Trait for ArrowNativeType to provide overflow-checking and non-overflow-checking @@ -61,33 +65,38 @@ pub(crate) mod native_op { + Sub + Mul + Div + + Zero { - fn add_checked(self, rhs: Self) -> Option { - Some(self + rhs) + fn add_checked(self, rhs: Self) -> Result { + Ok(self + rhs) } fn add_wrapping(self, rhs: Self) -> Self { self + rhs } - fn sub_checked(self, rhs: Self) -> Option { - Some(self - rhs) + fn sub_checked(self, rhs: Self) -> Result { + Ok(self - rhs) } fn sub_wrapping(self, rhs: Self) -> Self { self - rhs } - fn mul_checked(self, rhs: Self) -> Option { - Some(self * rhs) + fn mul_checked(self, rhs: Self) -> Result { + Ok(self * rhs) } fn mul_wrapping(self, rhs: Self) -> Self { self * rhs } - fn div_checked(self, rhs: Self) -> Option { - Some(self / rhs) + fn div_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + Ok(self / rhs) + } } fn div_wrapping(self, rhs: Self) -> Self { @@ -99,32 +108,56 @@ pub(crate) mod native_op { macro_rules! native_type_op { ($t:tt) => { impl native_op::ArrowNativeTypeOp for $t { - fn add_checked(self, rhs: Self) -> Option { - self.checked_add(rhs) + fn add_checked(self, rhs: Self) -> Result { + self.checked_add(rhs).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?} + {:?}", + self, rhs + )) + }) } fn add_wrapping(self, rhs: Self) -> Self { self.wrapping_add(rhs) } - fn sub_checked(self, rhs: Self) -> Option { - self.checked_sub(rhs) + fn sub_checked(self, rhs: Self) -> Result { + self.checked_sub(rhs).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?} - {:?}", + self, rhs + )) + }) } fn sub_wrapping(self, rhs: Self) -> Self { self.wrapping_sub(rhs) } - fn mul_checked(self, rhs: Self) -> Option { - self.checked_mul(rhs) + fn mul_checked(self, rhs: Self) -> Result { + self.checked_mul(rhs).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?} * {:?}", + self, rhs + )) + }) } fn mul_wrapping(self, rhs: Self) -> Self { self.wrapping_mul(rhs) } - fn div_checked(self, rhs: Self) -> Option { - self.checked_div(rhs) + fn div_checked(self, rhs: Self) -> Result { + if rhs.is_zero() { + Err(ArrowError::DivideByZero) + } else { + self.checked_div(rhs).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Overflow happened on: {:?} / {:?}", + self, rhs + )) + }) + } } fn div_wrapping(self, rhs: Self) -> Self { @@ -138,6 +171,7 @@ native_type_op!(i8); native_type_op!(i16); native_type_op!(i32); native_type_op!(i64); +native_type_op!(i128); native_type_op!(u8); native_type_op!(u16); native_type_op!(u32);