diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index ad9f0838832..30ae278ebe0 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -2133,6 +2133,7 @@ fn cast_decimal_to_decimal( if BYTE_WIDTH1 == 16 { let array = array.as_any().downcast_ref::().unwrap(); if BYTE_WIDTH2 == 16 { + // the div must be greater or equal than 10 let div = 10_i128 .pow_checked((input_scale - output_scale) as u32) .map_err(|_| { @@ -2141,10 +2142,23 @@ fn cast_decimal_to_decimal( *output_scale, )) })?; + let half = div / 2; + let neg_half = -half; array .try_unary::<_, Decimal128Type, _>(|v| { - v.checked_div(div).ok_or_else(|| { + // cast to smaller scale, need to round the result + // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation + let d = v.wrapping_div(div); + let r = v.wrapping_rem(div); + if v >= 0 && r >= half { + d.checked_add(1) + } else if v < 0 && r <= neg_half { + d.checked_sub(1) + } else { + Some(d) + } + .ok_or_else(|| { ArrowError::CastError(format!( "Cannot cast to {:?}({}, {}). Overflowing on {:?}", Decimal128Type::PREFIX, @@ -2168,9 +2182,23 @@ fn cast_decimal_to_decimal( )) })?; + let half = div / i256::from_i128(2_i128); + let neg_half = -half; + array .try_unary::<_, Decimal256Type, _>(|v| { - i256::from_i128(v).checked_div(div).ok_or_else(|| { + // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation + let v = i256::from_i128(v); + let d = v.wrapping_div(div); + let r = v.wrapping_rem(div); + if v >= i256::ZERO && r >= half { + d.checked_add(i256::ONE) + } else if v < i256::ZERO && r <= neg_half { + d.checked_sub(i256::ONE) + } else { + Some(d) + } + .ok_or_else(|| { ArrowError::CastError(format!( "Cannot cast to {:?}({}, {}). Overflowing on {:?}", Decimal256Type::PREFIX, @@ -2195,10 +2223,21 @@ fn cast_decimal_to_decimal( *output_scale, )) })?; + let half = div / i256::from_i128(2_i128); + let neg_half = -half; if BYTE_WIDTH2 == 16 { array .try_unary::<_, Decimal128Type, _>(|v| { - v.checked_div(div).ok_or_else(|| { + // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation + let d = v.wrapping_div(div); + let r = v.wrapping_rem(div); + if v >= i256::ZERO && r >= half { + d.checked_add(i256::ONE) + } else if v < i256::ZERO && r <= neg_half { + d.checked_sub(i256::ONE) + } else { + Some(d) + }.ok_or_else(|| { ArrowError::CastError(format!( "Cannot cast to {:?}({}, {}). Overflowing on {:?}", Decimal128Type::PREFIX, @@ -2219,7 +2258,17 @@ fn cast_decimal_to_decimal( } else { array .try_unary::<_, Decimal256Type, _>(|v| { - v.checked_div(div).ok_or_else(|| { + // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation + let d = v.wrapping_div(div); + let r = v.wrapping_rem(div); + if v >= i256::ZERO && r >= half { + d.checked_add(i256::ONE) + } else if v < i256::ZERO && r <= neg_half { + d.checked_sub(i256::ONE) + } else { + Some(d) + } + .ok_or_else(|| { ArrowError::CastError(format!( "Cannot cast to {:?}({}, {}). Overflowing on {:?}", Decimal256Type::PREFIX, @@ -3590,6 +3639,26 @@ mod tests { } } } + + let cast_option = CastOptions { safe: false }; + let casted_array_with_option = + cast_with_options($INPUT_ARRAY, $OUTPUT_TYPE, &cast_option).unwrap(); + let result_array = casted_array_with_option + .as_any() + .downcast_ref::<$OUTPUT_TYPE_ARRAY>() + .unwrap(); + assert_eq!($OUTPUT_TYPE, result_array.data_type()); + assert_eq!(result_array.len(), $OUTPUT_VALUES.len()); + for (i, x) in $OUTPUT_VALUES.iter().enumerate() { + match x { + Some(x) => { + assert_eq!(result_array.value(i), *x); + } + None => { + assert!(result_array.is_null(i)); + } + } + } }; } @@ -3616,6 +3685,44 @@ mod tests { } #[test] + #[cfg(not(feature = "force_validate"))] + #[should_panic( + expected = "5789604461865809771178549250434395392663499233282028201972879200395656481997 cannot be casted to 128-bit integer for Decimal128" + )] + fn test_cast_decimal_to_decimal_round_with_error() { + // decimal256 to decimal128 overflow + let array = vec![ + Some(i256::from_i128(1123454)), + Some(i256::from_i128(2123456)), + Some(i256::from_i128(-3123453)), + Some(i256::from_i128(-3123456)), + None, + Some(i256::MAX), + Some(i256::MIN), + ]; + let input_decimal_array = create_decimal256_array(array, 76, 4).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + let input_type = DataType::Decimal256(76, 4); + let output_type = DataType::Decimal128(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(-312345_i128), + Some(-312346_i128), + None, + None, + None, + ] + ); + } + + #[test] + #[cfg(not(feature = "force_validate"))] fn test_cast_decimal_to_decimal_round() { let array = vec![ Some(1123454), @@ -3703,34 +3810,6 @@ mod tests { None ] ); - - // decimal256 to decimal128 overflow - let array = vec![ - Some(i256::from_i128(1123454)), - Some(i256::from_i128(2123456)), - Some(i256::from_i128(-3123453)), - Some(i256::from_i128(-3123456)), - None, - Some(i256::MAX), - Some(i256::MIN), - ]; - let input_decimal_array = create_decimal256_array(array, 76, 4).unwrap(); - let array = Arc::new(input_decimal_array) as ArrayRef; - assert!(can_cast_types(&input_type, &output_type)); - generate_cast_test_case!( - &array, - Decimal128Array, - &output_type, - vec![ - Some(112345_i128), - Some(212346_i128), - Some(-312345_i128), - Some(-312346_i128), - None, - None, - None - ] - ); } #[test]