diff --git a/arrow-buffer/src/bigint.rs b/arrow-buffer/src/bigint.rs index 3518b85e4eb..51965977ced 100644 --- a/arrow-buffer/src/bigint.rs +++ b/arrow-buffer/src/bigint.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use num::cast::AsPrimitive; use num::BigInt; use std::cmp::Ordering; @@ -346,6 +347,21 @@ fn mulx(a: u128, b: u128) -> (u128, u128) { (low, high) } +macro_rules! define_as_primitive { + ($native_ty:ty) => { + impl AsPrimitive for $native_ty { + fn as_(self) -> i256 { + i256::from_i128(self as i128) + } + } + }; +} + +define_as_primitive!(i8); +define_as_primitive!(i16); +define_as_primitive!(i32); +define_as_primitive!(i64); + #[cfg(test)] mod tests { use super::*; diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index c0b08ecc57d..62cffe83cb5 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -88,6 +88,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Decimal256(_, _), Decimal128(_, _)) => true, // signed numeric to decimal (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal128(_, _)) | + (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal256(_, _)) | // decimal to signed numeric (Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) | ( @@ -305,8 +306,8 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS) } -/// Cast the primitive array to defined decimal data type array -fn cast_primitive_to_decimal( +/// Cast the primitive array to defined decimal128 data type array +fn cast_primitive_to_decimal128( array: T, op: F, precision: u8, @@ -324,7 +325,26 @@ where Ok(Arc::new(decimal_array)) } -fn cast_integer_to_decimal( +/// Cast the primitive array to defined decimal256 data type array +fn cast_primitive_to_decimal256( + array: T, + op: F, + precision: u8, + scale: u8, +) -> Result +where + F: Fn(T::Item) -> i256, +{ + #[allow(clippy::redundant_closure)] + let decimal_array = ArrayIter::new(array) + .map(|v| v.map(|v| op(v))) + .collect::() + .with_precision_and_scale(precision, scale)?; + + Ok(Arc::new(decimal_array)) +} + +fn cast_integer_to_decimal128( array: &PrimitiveArray, precision: u8, scale: u8, @@ -336,10 +356,25 @@ where // with_precision_and_scale validates the // value is within range for the output precision - cast_primitive_to_decimal(array, |v| v.as_() * mul, precision, scale) + cast_primitive_to_decimal128(array, |v| v.as_() * mul, precision, scale) +} + +fn cast_integer_to_decimal256( + array: &PrimitiveArray, + precision: u8, + scale: u8, +) -> Result +where + ::Native: AsPrimitive, +{ + let mul: i256 = i256::from_i128(10_i128.pow(scale as u32)); + + // with_precision_and_scale validates the + // value is within range for the output precision + cast_primitive_to_decimal256(array, |v| v.as_().wrapping_mul(mul), precision, scale) } -fn cast_floating_point_to_decimal( +fn cast_floating_point_to_decimal128( array: &PrimitiveArray, precision: u8, scale: u8, @@ -349,7 +384,7 @@ where { let mul = 10_f64.powi(scale as i32); - cast_primitive_to_decimal( + cast_primitive_to_decimal128( array, |v| { // with_precision_and_scale validates the @@ -361,6 +396,28 @@ where ) } +fn cast_floating_point_to_decimal256( + array: &PrimitiveArray, + precision: u8, + scale: u8, +) -> Result +where + ::Native: AsPrimitive, +{ + let mul = 10_f64.powi(scale as i32); + + cast_primitive_to_decimal256( + array, + |v| { + // with_precision_and_scale validates the + // value is within range for the output precision + i256::from_i128((v.as_() * mul) as i128) + }, + precision, + scale, + ) +} + /// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`] fn cast_reinterpret_arrays< I: ArrowPrimitiveType, @@ -545,32 +602,73 @@ pub fn cast_with_options( // cast data to decimal match from_type { // TODO now just support signed numeric to decimal, support decimal to numeric later - Int8 => cast_integer_to_decimal( + Int8 => cast_integer_to_decimal128( as_primitive_array::(array), *precision, *scale, ), - Int16 => cast_integer_to_decimal( + Int16 => cast_integer_to_decimal128( as_primitive_array::(array), *precision, *scale, ), - Int32 => cast_integer_to_decimal( + Int32 => cast_integer_to_decimal128( as_primitive_array::(array), *precision, *scale, ), - Int64 => cast_integer_to_decimal( + Int64 => cast_integer_to_decimal128( as_primitive_array::(array), *precision, *scale, ), - Float32 => cast_floating_point_to_decimal( + Float32 => cast_floating_point_to_decimal128( as_primitive_array::(array), *precision, *scale, ), - Float64 => cast_floating_point_to_decimal( + Float64 => cast_floating_point_to_decimal128( + as_primitive_array::(array), + *precision, + *scale, + ), + Null => Ok(new_null_array(to_type, array.len())), + _ => Err(ArrowError::CastError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type + ))), + } + } + (_, Decimal256(precision, scale)) => { + // cast data to decimal + match from_type { + // TODO now just support signed numeric to decimal, support decimal to numeric later + Int8 => cast_integer_to_decimal256( + as_primitive_array::(array), + *precision, + *scale, + ), + Int16 => cast_integer_to_decimal256( + as_primitive_array::(array), + *precision, + *scale, + ), + Int32 => cast_integer_to_decimal256( + as_primitive_array::(array), + *precision, + *scale, + ), + Int64 => cast_integer_to_decimal256( + as_primitive_array::(array), + *precision, + *scale, + ), + Float32 => cast_floating_point_to_decimal256( + as_primitive_array::(array), + *precision, + *scale, + ), + Float64 => cast_floating_point_to_decimal256( as_primitive_array::(array), *precision, *scale, @@ -3071,7 +3169,7 @@ mod tests { #[test] #[cfg(not(feature = "force_validate"))] - fn test_cast_numeric_to_decimal() { + fn test_cast_numeric_to_decimal128() { // test negative cast type let decimal_type = DataType::Decimal128(38, 6); assert!(!can_cast_types(&DataType::UInt64, &decimal_type)); @@ -3184,6 +3282,121 @@ mod tests { ); } + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_cast_numeric_to_decimal256() { + // test negative cast type + let decimal_type = DataType::Decimal256(58, 6); + assert!(!can_cast_types(&DataType::UInt64, &decimal_type)); + + // i8, i16, i32, i64 + let input_datas = vec![ + Arc::new(Int8Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i8 + Arc::new(Int16Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i16 + Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i32 + Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i64 + ]; + for array in input_datas { + generate_cast_test_case!( + &array, + Decimal256Array, + &decimal_type, + vec![ + Some(i256::from_i128(1000000_i128)), + Some(i256::from_i128(2000000_i128)), + Some(i256::from_i128(3000000_i128)), + None, + Some(i256::from_i128(5000000_i128)) + ] + ); + } + + // test i8 to decimal type with overflow the result type + // the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3. + let array = Int8Array::from(vec![1, 2, 3, 4, 100]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast(&array, &DataType::Decimal256(3, 1)); + assert!(casted_array.is_ok()); + let array = casted_array.unwrap(); + let array: &Decimal256Array = as_primitive_array(&array); + let err = array.validate_decimal_precision(3); + assert_eq!("Invalid argument error: 1000 is too large to store in a Decimal256 of precision 3. Max is 999", err.unwrap_err().to_string()); + + // test f32 to decimal type + let array = Float32Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_7), + Some(1.123_456_7), + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal256Array, + &decimal_type, + vec![ + Some(i256::from_i128(1100000_i128)), + Some(i256::from_i128(2200000_i128)), + Some(i256::from_i128(4400000_i128)), + None, + Some(i256::from_i128(1123456_i128)), + Some(i256::from_i128(1123456_i128)), + ] + ); + + // test f64 to decimal type + let array = Float64Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_789_123_4), + Some(1.123_456_789_012_345_6), + Some(1.123_456_789_012_345_6), + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal256Array, + &decimal_type, + vec![ + Some(i256::from_i128(1100000_i128)), + Some(i256::from_i128(2200000_i128)), + Some(i256::from_i128(4400000_i128)), + None, + Some(i256::from_i128(1123456_i128)), + Some(i256::from_i128(1123456_i128)), + Some(i256::from_i128(1123456_i128)), + ] + ); + } + #[test] fn test_cast_i32_to_f64() { let a = Int32Array::from(vec![5, 6, 7, 8, 9]);