Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast: should get the round result for decimal to a decimal with smaller scale #3139

Merged
merged 2 commits into from Nov 25, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
192 changes: 179 additions & 13 deletions arrow-cast/src/cast.rs
Expand Up @@ -1967,12 +1967,26 @@ fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usi
// For example, input_scale is 4 and output_scale is 3;
// Original value is 11234_i128, and will be cast to 1123_i128.
let div = 10_i128.pow((input_scale - output_scale) as u32);
let half = div / 2;
let neg_half = half.wrapping_neg();
if BYTE_WIDTH1 == 16 {
let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
if BYTE_WIDTH2 == 16 {
let iter = array
.iter()
.map(|v| v.and_then(|v| v.div_checked(div).ok()));
// rounding the result
let iter = array.iter().map(|v| {
v.map(|v| {
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let d = v.wrapping_div(div);
let r = v.wrapping_rem(div);
if v >= 0 && r >= half {
d.wrapping_add(1)
} else if v < 0 && r <= neg_half {
d.wrapping_sub(1)
} else {
d
}
})
});
let casted_array = unsafe {
PrimitiveArray::<Decimal128Type>::from_trusted_len_iter(iter)
};
Expand All @@ -1981,7 +1995,17 @@ fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usi
.map(|a| Arc::new(a) as ArrayRef)
} else {
let iter = array.iter().map(|v| {
v.and_then(|v| v.div_checked(div).ok().map(i256::from_i128))
v.map(|v| {
let d = v.wrapping_div(div);
let r = v.wrapping_rem(div);
i256::from_i128(if v >= 0 && r >= half {
d.wrapping_add(1)
} else if v < 0 && r <= neg_half {
d.wrapping_sub(1)
} else {
d
})
})
});
let casted_array = unsafe {
PrimitiveArray::<Decimal256Type>::from_trusted_len_iter(iter)
Expand All @@ -1993,9 +2017,22 @@ fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usi
} else {
let array = array.as_any().downcast_ref::<Decimal256Array>().unwrap();
let div = i256::from_i128(div);
let half = div / i256::from_i128(2);
let neg_half = half.wrapping_neg();
if BYTE_WIDTH2 == 16 {
let iter = array.iter().map(|v| {
v.and_then(|v| v.div_checked(div).ok().and_then(|v| v.to_i128()))
v.and_then(|v| {
let d = v.wrapping_div(div);
let r = v.wrapping_rem(div);
if v >= i256::ZERO && r >= half {
d.wrapping_add(i256::ONE)
} else if v < i256::ZERO && r <= neg_half {
d.wrapping_sub(i256::ONE)
} else {
d
}
.to_i128()
})
});
let casted_array = unsafe {
PrimitiveArray::<Decimal128Type>::from_trusted_len_iter(iter)
Expand All @@ -2004,9 +2041,19 @@ fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usi
.with_precision_and_scale(*output_precision, *output_scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
let iter = array
.iter()
.map(|v| v.and_then(|v| v.div_checked(div).ok()));
let iter = array.iter().map(|v| {
v.map(|v| {
let d = v.wrapping_div(div);
let r = v.wrapping_rem(div);
if v >= i256::ZERO && r >= half {
d.wrapping_add(i256::ONE)
} else if v < i256::ZERO && r <= neg_half {
d.wrapping_sub(i256::ONE)
} else {
d
}
})
});
let casted_array = unsafe {
PrimitiveArray::<Decimal256Type>::from_trusted_len_iter(iter)
};
Expand Down Expand Up @@ -3566,6 +3613,125 @@ mod tests {
.with_precision_and_scale(precision, scale)
}

#[test]
#[cfg(not(feature = "force_validate"))]
fn test_cast_decimal_to_decimal_round() {
let array = vec![
Some(1123454),
Some(2123456),
Some(-3123453),
Some(-3123456),
None,
];
let input_decimal_array = create_decimal_array(array, 20, 4).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;
// decimal128 to decimal128
let input_type = DataType::Decimal128(20, 4);
let output_type = DataType::Decimal128(20, 3);
assert!(can_cast_types(&input_type, &output_type));
generate_cast_test_case!(
&array,
Decimal128Array,
&output_type,
vec![
Some(112345_i128),
Some(212346_i128),
Some(-312345_i128),
Some(-312346_i128),
None
]
);

// decimal128 to decimal256
let input_type = DataType::Decimal128(20, 4);
let output_type = DataType::Decimal256(20, 3);
assert!(can_cast_types(&input_type, &output_type));
generate_cast_test_case!(
&array,
Decimal256Array,
&output_type,
vec![
Some(i256::from_i128(112345_i128)),
Some(i256::from_i128(212346_i128)),
Some(i256::from_i128(-312345_i128)),
Some(i256::from_i128(-312346_i128)),
None
]
);

// decimal256
let array = vec![
Some(i256::from_i128(1123454)),
Some(i256::from_i128(2123456)),
Some(i256::from_i128(-3123453)),
Some(i256::from_i128(-3123456)),
None,
];
let input_decimal_array = create_decimal256_array(array, 20, 4).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;

// decimal256 to decimal256
let input_type = DataType::Decimal256(20, 4);
let output_type = DataType::Decimal256(20, 3);
assert!(can_cast_types(&input_type, &output_type));
generate_cast_test_case!(
&array,
Decimal256Array,
&output_type,
vec![
Some(i256::from_i128(112345_i128)),
Some(i256::from_i128(212346_i128)),
Some(i256::from_i128(-312345_i128)),
Some(i256::from_i128(-312346_i128)),
None
]
);
// decimal256 to decimal128
let input_type = DataType::Decimal256(20, 4);
let output_type = DataType::Decimal128(20, 3);
assert!(can_cast_types(&input_type, &output_type));
generate_cast_test_case!(
&array,
Decimal128Array,
&output_type,
vec![
Some(112345_i128),
Some(212346_i128),
Some(-312345_i128),
Some(-312346_i128),
None
]
);

// decimal256 to decimal128 overflow
let array = vec![
Some(i256::from_i128(1123454)),
Some(i256::from_i128(2123456)),
Some(i256::from_i128(-3123453)),
Some(i256::from_i128(-3123456)),
None,
Some(i256::MAX),
Some(i256::MIN),
];
let input_decimal_array = create_decimal256_array(array, 76, 4).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;
assert!(can_cast_types(&input_type, &output_type));
generate_cast_test_case!(
&array,
Decimal128Array,
&output_type,
vec![
Some(112345_i128),
Some(212346_i128),
Some(-312345_i128),
Some(-312346_i128),
None,
None,
None
]
);
}

#[test]
#[cfg(not(feature = "force_validate"))]
fn test_cast_decimal128_to_decimal128() {
Expand Down Expand Up @@ -7219,7 +7385,7 @@ mod tests {
let input_type = DataType::Decimal128(20, 0);
let output_type = DataType::Decimal128(20, -1);
assert!(can_cast_types(&input_type, &output_type));
let array = vec![Some(1123456), Some(2123456), Some(3123456), None];
let array = vec![Some(1123450), Some(2123455), Some(3123456), None];
let input_decimal_array = create_decimal_array(array, 20, 0).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;
generate_cast_test_case!(
Expand All @@ -7228,8 +7394,8 @@ mod tests {
&output_type,
vec![
Some(112345_i128),
Some(212345_i128),
Some(312345_i128),
Some(212346_i128),
Some(312346_i128),
None
]
);
Expand All @@ -7238,8 +7404,8 @@ mod tests {
let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);

assert_eq!("1123450", decimal_arr.value_as_string(0));
assert_eq!("2123450", decimal_arr.value_as_string(1));
assert_eq!("3123450", decimal_arr.value_as_string(2));
assert_eq!("2123460", decimal_arr.value_as_string(1));
assert_eq!("3123460", decimal_arr.value_as_string(2));
}

#[test]
Expand Down