diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 781f199a691..25aa525b452 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -38,7 +38,6 @@ use std::str; use std::sync::Arc; -use crate::array::BasicDecimalArray; use crate::buffer::MutableBuffer; use crate::compute::divide_scalar; use crate::compute::kernels::arithmetic::{divide, multiply}; @@ -48,6 +47,7 @@ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::{array::*, compute::take}; use crate::{buffer::Buffer, util::serialization::lexical_to_string}; +use num::cast::AsPrimitive; use num::{NumCast, ToPrimitive}; /// CastOptions provides a way to override the default cast behaviors @@ -270,45 +270,60 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS) } -// cast the integer array to defined decimal data type array -macro_rules! cast_integer_to_decimal { - ($ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => {{ - let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); - let mul: i128 = 10_i128.pow(*$SCALE as u32); - let decimal_array = array - .iter() - .map(|v| { - v.map(|v| { - let v = v as i128; - // with_precision_and_scale validates the - // value is within range for the output precision - mul * v - }) - }) - .collect::() - .with_precision_and_scale(*$PRECISION, *$SCALE)?; - Ok(Arc::new(decimal_array)) - }}; +/// Cast the primitive array to defined decimal data type array +fn cast_primitive_to_decimal( + array: T, + op: F, + precision: usize, + scale: usize, +) -> Result> +where + F: Fn(T::Item) -> i128, +{ + #[allow(clippy::redundant_closure)] + let decimal_array = ArrayIter::new(array) + .map(|v| v.map(|v| op(v))) + .collect::() + .with_precision_and_scale(precision, scale)?; + + Ok(Arc::new(decimal_array)) } -// cast the floating-point array to defined decimal data type array -macro_rules! cast_floating_point_to_decimal { - ($ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => {{ - let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); - let mul = 10_f64.powi(*$SCALE as i32); - let decimal_array = array - .iter() - .map(|v| { - v.map(|v| { - // with_precision_and_scale validates the - // value is within range for the output precision - ((v as f64) * mul) as i128 - }) - }) - .collect::() - .with_precision_and_scale(*$PRECISION, *$SCALE)?; - Ok(Arc::new(decimal_array)) - }}; +fn cast_integer_to_decimal( + array: &PrimitiveArray, + precision: usize, + scale: usize, +) -> Result> +where + ::Native: AsPrimitive, +{ + let mul: i128 = 10_i128.pow(scale as u32); + + // with_precision_and_scale validates the + // value is within range for the output precision + cast_primitive_to_decimal(array, |v| v.as_() * mul, precision, scale) +} + +fn cast_floating_point_to_decimal( + array: &PrimitiveArray, + precision: usize, + scale: usize, +) -> Result> +where + ::Native: AsPrimitive, +{ + let mul = 10_f64.powi(scale as i32); + + cast_primitive_to_decimal( + array, + |v| { + // with_precision_and_scale validates the + // value is within range for the output precision + (v.as_() * mul) as i128 + }, + precision, + scale, + ) } // cast the decimal array to integer array @@ -428,24 +443,36 @@ pub fn cast_with_options( // cast data to decimal match from_type { // TODO now just support signed numeric to decimal, support decimal to numeric later - Int8 => { - cast_integer_to_decimal!(array, Int8Array, precision, scale) - } - Int16 => { - cast_integer_to_decimal!(array, Int16Array, precision, scale) - } - Int32 => { - cast_integer_to_decimal!(array, Int32Array, precision, scale) - } - Int64 => { - cast_integer_to_decimal!(array, Int64Array, precision, scale) - } - Float32 => { - cast_floating_point_to_decimal!(array, Float32Array, precision, scale) - } - Float64 => { - cast_floating_point_to_decimal!(array, Float64Array, precision, scale) - } + Int8 => cast_integer_to_decimal( + as_primitive_array::(array), + *precision, + *scale, + ), + Int16 => cast_integer_to_decimal( + as_primitive_array::(array), + *precision, + *scale, + ), + Int32 => cast_integer_to_decimal( + as_primitive_array::(array), + *precision, + *scale, + ), + Int64 => cast_integer_to_decimal( + as_primitive_array::(array), + *precision, + *scale, + ), + Float32 => cast_floating_point_to_decimal( + as_primitive_array::(array), + *precision, + *scale, + ), + Float64 => cast_floating_point_to_decimal( + as_primitive_array::(array), + *precision, + *scale, + ), Null => Ok(new_null_array(to_type, array.len())), _ => Err(ArrowError::CastError(format!( "Casting from {:?} to {:?} not supported",