Skip to content

Commit

Permalink
Cast numeric to decimal256
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Oct 25, 2022
1 parent bca8445 commit bf7f475
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 13 deletions.
16 changes: 16 additions & 0 deletions arrow-buffer/src/bigint.rs
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use num::cast::AsPrimitive;
use num::BigInt;
use std::cmp::Ordering;

Expand Down Expand Up @@ -346,6 +347,21 @@ fn mulx(a: u128, b: u128) -> (u128, u128) {
(low, high)
}

macro_rules! define_as_primitive {
($native_ty:ty) => {
impl AsPrimitive<i256> for $native_ty {
fn as_(self) -> i256 {
i256::from_i128(self as i128)
}
}
};
}

define_as_primitive!(i8);
define_as_primitive!(i16);
define_as_primitive!(i32);
define_as_primitive!(i64);

#[cfg(test)]
mod tests {
use super::*;
Expand Down
239 changes: 226 additions & 13 deletions arrow/src/compute/kernels/cast.rs
Expand Up @@ -88,6 +88,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Decimal256(_, _), Decimal128(_, _)) => true,
// signed numeric to decimal
(Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal128(_, _)) |
(Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal256(_, _)) |
// decimal to signed numeric
(Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
| (
Expand Down Expand Up @@ -305,8 +306,8 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS)
}

/// Cast the primitive array to defined decimal data type array
fn cast_primitive_to_decimal<T: ArrayAccessor, F>(
/// Cast the primitive array to defined decimal128 data type array
fn cast_primitive_to_decimal128<T: ArrayAccessor, F>(
array: T,
op: F,
precision: u8,
Expand All @@ -324,7 +325,26 @@ where
Ok(Arc::new(decimal_array))
}

fn cast_integer_to_decimal<T: ArrowNumericType>(
/// Cast the primitive array to defined decimal256 data type array
fn cast_primitive_to_decimal256<T: ArrayAccessor, F>(
array: T,
op: F,
precision: u8,
scale: u8,
) -> Result<ArrayRef>
where
F: Fn(T::Item) -> i256,
{
#[allow(clippy::redundant_closure)]
let decimal_array = ArrayIter::new(array)
.map(|v| v.map(|v| op(v)))
.collect::<Decimal256Array>()
.with_precision_and_scale(precision, scale)?;

Ok(Arc::new(decimal_array))
}

fn cast_integer_to_decimal128<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
Expand All @@ -336,10 +356,25 @@ where

// with_precision_and_scale validates the
// value is within range for the output precision
cast_primitive_to_decimal(array, |v| v.as_() * mul, precision, scale)
cast_primitive_to_decimal128(array, |v| v.as_() * mul, precision, scale)
}

fn cast_integer_to_decimal256<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
) -> Result<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<i256>,
{
let mul: i256 = i256::from_i128(10_i128.pow(scale as u32));

// with_precision_and_scale validates the
// value is within range for the output precision
cast_primitive_to_decimal256(array, |v| v.as_().wrapping_mul(mul), precision, scale)
}

fn cast_floating_point_to_decimal<T: ArrowNumericType>(
fn cast_floating_point_to_decimal128<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
Expand All @@ -349,7 +384,7 @@ where
{
let mul = 10_f64.powi(scale as i32);

cast_primitive_to_decimal(
cast_primitive_to_decimal128(
array,
|v| {
// with_precision_and_scale validates the
Expand All @@ -361,6 +396,28 @@ where
)
}

fn cast_floating_point_to_decimal256<T: ArrowNumericType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
) -> Result<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
{
let mul = 10_f64.powi(scale as i32);

cast_primitive_to_decimal256(
array,
|v| {
// with_precision_and_scale validates the
// value is within range for the output precision
i256::from_i128((v.as_() * mul) as i128)
},
precision,
scale,
)
}

/// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`]
fn cast_reinterpret_arrays<
I: ArrowPrimitiveType,
Expand Down Expand Up @@ -545,32 +602,73 @@ pub fn cast_with_options(
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
Int8 => cast_integer_to_decimal(
Int8 => cast_integer_to_decimal128(
as_primitive_array::<Int8Type>(array),
*precision,
*scale,
),
Int16 => cast_integer_to_decimal(
Int16 => cast_integer_to_decimal128(
as_primitive_array::<Int16Type>(array),
*precision,
*scale,
),
Int32 => cast_integer_to_decimal(
Int32 => cast_integer_to_decimal128(
as_primitive_array::<Int32Type>(array),
*precision,
*scale,
),
Int64 => cast_integer_to_decimal(
Int64 => cast_integer_to_decimal128(
as_primitive_array::<Int64Type>(array),
*precision,
*scale,
),
Float32 => cast_floating_point_to_decimal(
Float32 => cast_floating_point_to_decimal128(
as_primitive_array::<Float32Type>(array),
*precision,
*scale,
),
Float64 => cast_floating_point_to_decimal(
Float64 => cast_floating_point_to_decimal128(
as_primitive_array::<Float64Type>(array),
*precision,
*scale,
),
Null => Ok(new_null_array(to_type, array.len())),
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type
))),
}
}
(_, Decimal256(precision, scale)) => {
// cast data to decimal
match from_type {
// TODO now just support signed numeric to decimal, support decimal to numeric later
Int8 => cast_integer_to_decimal256(
as_primitive_array::<Int8Type>(array),
*precision,
*scale,
),
Int16 => cast_integer_to_decimal256(
as_primitive_array::<Int16Type>(array),
*precision,
*scale,
),
Int32 => cast_integer_to_decimal256(
as_primitive_array::<Int32Type>(array),
*precision,
*scale,
),
Int64 => cast_integer_to_decimal256(
as_primitive_array::<Int64Type>(array),
*precision,
*scale,
),
Float32 => cast_floating_point_to_decimal256(
as_primitive_array::<Float32Type>(array),
*precision,
*scale,
),
Float64 => cast_floating_point_to_decimal256(
as_primitive_array::<Float64Type>(array),
*precision,
*scale,
Expand Down Expand Up @@ -3071,7 +3169,7 @@ mod tests {

#[test]
#[cfg(not(feature = "force_validate"))]
fn test_cast_numeric_to_decimal() {
fn test_cast_numeric_to_decimal128() {
// test negative cast type
let decimal_type = DataType::Decimal128(38, 6);
assert!(!can_cast_types(&DataType::UInt64, &decimal_type));
Expand Down Expand Up @@ -3184,6 +3282,121 @@ mod tests {
);
}

#[test]
#[cfg(not(feature = "force_validate"))]
fn test_cast_numeric_to_decimal256() {
// test negative cast type
let decimal_type = DataType::Decimal256(58, 6);
assert!(!can_cast_types(&DataType::UInt64, &decimal_type));

// i8, i16, i32, i64
let input_datas = vec![
Arc::new(Int8Array::from(vec![
Some(1),
Some(2),
Some(3),
None,
Some(5),
])) as ArrayRef, // i8
Arc::new(Int16Array::from(vec![
Some(1),
Some(2),
Some(3),
None,
Some(5),
])) as ArrayRef, // i16
Arc::new(Int32Array::from(vec![
Some(1),
Some(2),
Some(3),
None,
Some(5),
])) as ArrayRef, // i32
Arc::new(Int64Array::from(vec![
Some(1),
Some(2),
Some(3),
None,
Some(5),
])) as ArrayRef, // i64
];
for array in input_datas {
generate_cast_test_case!(
&array,
Decimal256Array,
&decimal_type,
vec![
Some(i256::from_i128(1000000_i128)),
Some(i256::from_i128(2000000_i128)),
Some(i256::from_i128(3000000_i128)),
None,
Some(i256::from_i128(5000000_i128))
]
);
}

// test i8 to decimal type with overflow the result type
// the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3.
let array = Int8Array::from(vec![1, 2, 3, 4, 100]);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast(&array, &DataType::Decimal256(3, 1));
assert!(casted_array.is_ok());
let array = casted_array.unwrap();
let array: &Decimal256Array = as_primitive_array(&array);
let err = array.validate_decimal_precision(3);
assert_eq!("Invalid argument error: 1000 is too large to store in a Decimal256 of precision 3. Max is 999", err.unwrap_err().to_string());

// test f32 to decimal type
let array = Float32Array::from(vec![
Some(1.1),
Some(2.2),
Some(4.4),
None,
Some(1.123_456_7),
Some(1.123_456_7),
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
&array,
Decimal256Array,
&decimal_type,
vec![
Some(i256::from_i128(1100000_i128)),
Some(i256::from_i128(2200000_i128)),
Some(i256::from_i128(4400000_i128)),
None,
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)),
]
);

// test f64 to decimal type
let array = Float64Array::from(vec![
Some(1.1),
Some(2.2),
Some(4.4),
None,
Some(1.123_456_789_123_4),
Some(1.123_456_789_012_345_6),
Some(1.123_456_789_012_345_6),
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
&array,
Decimal256Array,
&decimal_type,
vec![
Some(i256::from_i128(1100000_i128)),
Some(i256::from_i128(2200000_i128)),
Some(i256::from_i128(4400000_i128)),
None,
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)),
Some(i256::from_i128(1123456_i128)),
]
);
}

#[test]
fn test_cast_i32_to_f64() {
let a = Int32Array::from(vec![5, 6, 7, 8, 9]);
Expand Down

0 comments on commit bf7f475

Please sign in to comment.