Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get the round result for decimal to a decimal with smaller scale #3224

Merged
merged 7 commits into from Dec 3, 2022
77 changes: 73 additions & 4 deletions arrow-cast/src/cast.rs
Expand Up @@ -2131,6 +2131,7 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
if BYTE_WIDTH1 == 16 {
let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
if BYTE_WIDTH2 == 16 {
// the div must be greater or equal than 10
let div = 10_i128
.pow_checked((input_scale - output_scale) as u32)
.map_err(|_| {
Expand All @@ -2139,10 +2140,23 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
*output_scale,
))
})?;
let half = div / 2;
let neg_half = -half;

array
.try_unary::<_, Decimal128Type, _>(|v| {
v.checked_div(div).ok_or_else(|| {
// cast to smaller scale, need to round the result
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let d = v / div;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make consistent with the usage and clear compilation result?

let r = v % div;
if v >= 0 && r >= half {
d.checked_add(1)
} else if v < 0 && r <= neg_half {
d.checked_sub(1)
} else {
Some(d)
}
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). Overflowing on {:?}",
Decimal128Type::PREFIX,
Expand All @@ -2166,9 +2180,23 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
))
})?;

let half = div / i256::from_i128(2_i128);
let neg_half = -half;

array
.try_unary::<_, Decimal256Type, _>(|v| {
i256::from_i128(v).checked_div(div).ok_or_else(|| {
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let v = i256::from_i128(v);
let d = v / div;
let r = v % div;
if v >= i256::ZERO && r >= half {
d.checked_add(i256::ONE)
} else if v < i256::ZERO && r <= neg_half {
d.checked_sub(i256::ONE)
} else {
Some(d)
}
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). Overflowing on {:?}",
Decimal256Type::PREFIX,
Expand All @@ -2193,10 +2221,21 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
*output_scale,
))
})?;
let half = div / i256::from_i128(2_i128);
let neg_half = -half;
if BYTE_WIDTH2 == 16 {
array
.try_unary::<_, Decimal128Type, _>(|v| {
v.checked_div(div).ok_or_else(|| {
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let d = v / div;
let r = v % div;
if v >= i256::ZERO && r >= half {
d.checked_add(i256::ONE)
} else if v < i256::ZERO && r <= neg_half {
d.checked_sub(i256::ONE)
} else {
Some(d)
}.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). Overflowing on {:?}",
Decimal128Type::PREFIX,
Expand All @@ -2217,7 +2256,17 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
} else {
array
.try_unary::<_, Decimal256Type, _>(|v| {
v.checked_div(div).ok_or_else(|| {
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let d = v / div;
let r = v % div;
if v >= i256::ZERO && r >= half {
d.checked_add(i256::ONE)
} else if v < i256::ZERO && r <= neg_half {
d.checked_sub(i256::ONE)
} else {
Some(d)
}
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). Overflowing on {:?}",
Decimal256Type::PREFIX,
Expand Down Expand Up @@ -3588,6 +3637,26 @@ mod tests {
}
}
}

let cast_option = CastOptions { safe: false };
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add cast with safe is false

let casted_array_with_option =
cast_with_options($INPUT_ARRAY, $OUTPUT_TYPE, &cast_option).unwrap();
let result_array = casted_array_with_option
.as_any()
.downcast_ref::<$OUTPUT_TYPE_ARRAY>()
.unwrap();
assert_eq!($OUTPUT_TYPE, result_array.data_type());
assert_eq!(result_array.len(), $OUTPUT_VALUES.len());
for (i, x) in $OUTPUT_VALUES.iter().enumerate() {
match x {
Some(x) => {
assert_eq!(result_array.value(i), *x);
}
None => {
assert!(result_array.is_null(i));
}
}
}
};
}

Expand Down