Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get the round result for decimal to a decimal with smaller scale #3224

Merged
merged 7 commits into from Dec 3, 2022
143 changes: 111 additions & 32 deletions arrow-cast/src/cast.rs
Expand Up @@ -2133,6 +2133,7 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
if BYTE_WIDTH1 == 16 {
let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
if BYTE_WIDTH2 == 16 {
// the div must be greater or equal than 10
let div = 10_i128
.pow_checked((input_scale - output_scale) as u32)
.map_err(|_| {
Expand All @@ -2141,10 +2142,23 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
*output_scale,
))
})?;
let half = div / 2;
let neg_half = -half;

array
.try_unary::<_, Decimal128Type, _>(|v| {
v.checked_div(div).ok_or_else(|| {
// cast to smaller scale, need to round the result
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let d = v.wrapping_div(div);
let r = v.wrapping_rem(div);
if v >= 0 && r >= half {
d.checked_add(1)
} else if v < 0 && r <= neg_half {
d.checked_sub(1)
} else {
Some(d)
}
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). Overflowing on {:?}",
Decimal128Type::PREFIX,
Expand All @@ -2168,9 +2182,23 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
))
})?;

let half = div / i256::from_i128(2_i128);
let neg_half = -half;

array
.try_unary::<_, Decimal256Type, _>(|v| {
i256::from_i128(v).checked_div(div).ok_or_else(|| {
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let v = i256::from_i128(v);
let d = v.wrapping_div(div);
let r = v.wrapping_rem(div);
if v >= i256::ZERO && r >= half {
d.checked_add(i256::ONE)
} else if v < i256::ZERO && r <= neg_half {
d.checked_sub(i256::ONE)
} else {
Some(d)
}
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). Overflowing on {:?}",
Decimal256Type::PREFIX,
Expand All @@ -2195,10 +2223,21 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
*output_scale,
))
})?;
let half = div / i256::from_i128(2_i128);
let neg_half = -half;
if BYTE_WIDTH2 == 16 {
array
.try_unary::<_, Decimal128Type, _>(|v| {
v.checked_div(div).ok_or_else(|| {
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let d = v.wrapping_div(div);
let r = v.wrapping_rem(div);
if v >= i256::ZERO && r >= half {
d.checked_add(i256::ONE)
} else if v < i256::ZERO && r <= neg_half {
d.checked_sub(i256::ONE)
} else {
Some(d)
}.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). Overflowing on {:?}",
Decimal128Type::PREFIX,
Expand All @@ -2219,7 +2258,17 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
} else {
array
.try_unary::<_, Decimal256Type, _>(|v| {
v.checked_div(div).ok_or_else(|| {
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let d = v.wrapping_div(div);
let r = v.wrapping_rem(div);
if v >= i256::ZERO && r >= half {
d.checked_add(i256::ONE)
} else if v < i256::ZERO && r <= neg_half {
d.checked_sub(i256::ONE)
} else {
Some(d)
}
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). Overflowing on {:?}",
Decimal256Type::PREFIX,
Expand Down Expand Up @@ -3590,6 +3639,26 @@ mod tests {
}
}
}

let cast_option = CastOptions { safe: false };
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add cast with safe is false

let casted_array_with_option =
cast_with_options($INPUT_ARRAY, $OUTPUT_TYPE, &cast_option).unwrap();
let result_array = casted_array_with_option
.as_any()
.downcast_ref::<$OUTPUT_TYPE_ARRAY>()
.unwrap();
assert_eq!($OUTPUT_TYPE, result_array.data_type());
assert_eq!(result_array.len(), $OUTPUT_VALUES.len());
for (i, x) in $OUTPUT_VALUES.iter().enumerate() {
match x {
Some(x) => {
assert_eq!(result_array.value(i), *x);
}
None => {
assert!(result_array.is_null(i));
}
}
}
};
}

Expand All @@ -3616,6 +3685,44 @@ mod tests {
}

#[test]
#[cfg(not(feature = "force_validate"))]
#[should_panic(
expected = "5789604461865809771178549250434395392663499233282028201972879200395656481997 cannot be casted to 128-bit integer for Decimal128"
)]
fn test_cast_decimal_to_decimal_round_with_error() {
// decimal256 to decimal128 overflow
let array = vec![
Some(i256::from_i128(1123454)),
Some(i256::from_i128(2123456)),
Some(i256::from_i128(-3123453)),
Some(i256::from_i128(-3123456)),
None,
Some(i256::MAX),
Some(i256::MIN),
];
let input_decimal_array = create_decimal256_array(array, 76, 4).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;
let input_type = DataType::Decimal256(76, 4);
let output_type = DataType::Decimal128(20, 3);
assert!(can_cast_types(&input_type, &output_type));
generate_cast_test_case!(
&array,
Decimal128Array,
&output_type,
vec![
Some(112345_i128),
Some(212346_i128),
Some(-312345_i128),
Some(-312346_i128),
None,
None,
None,
]
);
}

#[test]
#[cfg(not(feature = "force_validate"))]
fn test_cast_decimal_to_decimal_round() {
let array = vec![
Some(1123454),
Expand Down Expand Up @@ -3703,34 +3810,6 @@ mod tests {
None
]
);

// decimal256 to decimal128 overflow
let array = vec![
Some(i256::from_i128(1123454)),
Some(i256::from_i128(2123456)),
Some(i256::from_i128(-3123453)),
Some(i256::from_i128(-3123456)),
None,
Some(i256::MAX),
Some(i256::MIN),
];
let input_decimal_array = create_decimal256_array(array, 76, 4).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;
assert!(can_cast_types(&input_type, &output_type));
generate_cast_test_case!(
&array,
Decimal128Array,
&output_type,
vec![
Some(112345_i128),
Some(212346_i128),
Some(-312345_i128),
Some(-312346_i128),
None,
None,
None
]
);
}

#[test]
Expand Down