Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support decimal negative scale #3152

Merged
merged 3 commits into from Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 11 additions & 4 deletions arrow-array/src/array/primitive_array.rs
Expand Up @@ -1002,7 +1002,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
pub fn with_precision_and_scale(
self,
precision: u8,
scale: u8,
scale: i8,
) -> Result<Self, ArrowError>
where
Self: Sized,
Expand All @@ -1023,7 +1023,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
fn validate_precision_scale(
&self,
precision: u8,
scale: u8,
scale: i8,
) -> Result<(), ArrowError> {
if precision > T::MAX_PRECISION {
return Err(ArrowError::InvalidArgumentError(format!(
Expand All @@ -1039,7 +1039,14 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
Decimal128Type::MAX_SCALE
)));
}
if scale > precision {
if scale < -T::MAX_SCALE {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure about this, as the rationale for enforcing the upper limit is if the positive scale exceeds the precision it would result in truncation. I'm not sure this applies to the negative direction. However, the C++ version has this same check, so 🤷 - https://github.com/apache/arrow/blob/91ee6dad722ee154d63eea86ce5644e1e658b53b/cpp/src/arrow/util/decimal.cc#L389

return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is smaller than min {}",
scale,
-Decimal128Type::MAX_SCALE
)));
}
if scale > 0 && scale as u8 > precision {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than precision {}",
scale, precision
Expand Down Expand Up @@ -1095,7 +1102,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
}

/// Returns the decimal scale of this array
pub fn scale(&self) -> u8 {
pub fn scale(&self) -> i8 {
match T::BYTE_LENGTH {
16 => {
if let DataType::Decimal128(_, s) = self.data().data_type() {
Expand Down
33 changes: 18 additions & 15 deletions arrow-array/src/types.rs
Expand Up @@ -491,15 +491,15 @@ pub trait DecimalType:
{
const BYTE_LENGTH: usize;
const MAX_PRECISION: u8;
const MAX_SCALE: u8;
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType;
const MAX_SCALE: i8;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType;
const DEFAULT_TYPE: DataType;

/// "Decimal128" or "Decimal256", for use in error messages
const PREFIX: &'static str;

/// Formats the decimal value with the provided precision and scale
fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String;
fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String;

/// Validates that `value` contains no more than `precision` decimal digits
fn validate_decimal_precision(
Expand All @@ -515,14 +515,14 @@ pub struct Decimal128Type {}
impl DecimalType for Decimal128Type {
const BYTE_LENGTH: usize = 16;
const MAX_PRECISION: u8 = DECIMAL128_MAX_PRECISION;
const MAX_SCALE: u8 = DECIMAL128_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128;
const MAX_SCALE: i8 = DECIMAL128_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal128;
const DEFAULT_TYPE: DataType =
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal128";

fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale)
}

fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> {
Expand All @@ -543,14 +543,14 @@ pub struct Decimal256Type {}
impl DecimalType for Decimal256Type {
const BYTE_LENGTH: usize = 32;
const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION;
const MAX_SCALE: u8 = DECIMAL256_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256;
const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal256;
const DEFAULT_TYPE: DataType =
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal256";

fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale)
}

fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> {
Expand All @@ -564,7 +564,7 @@ impl ArrowPrimitiveType for Decimal256Type {
const DATA_TYPE: DataType = <Self as DecimalType>::DEFAULT_TYPE;
}

fn format_decimal_str(value_str: &str, precision: usize, scale: usize) -> String {
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
let (sign, rest) = match value_str.strip_prefix('-') {
Some(stripped) => ("-", stripped),
None => ("", value_str),
Expand All @@ -574,13 +574,16 @@ fn format_decimal_str(value_str: &str, precision: usize, scale: usize) -> String

if scale == 0 {
value_str.to_string()
} else if rest.len() > scale {
} else if scale < 0 {
let padding = value_str.len() + scale.unsigned_abs() as usize;
format!("{:0<width$}", value_str, width = padding)
} else if rest.len() > scale as usize {
// Decimal separator is in the middle of the string
let (whole, decimal) = value_str.split_at(value_str.len() - scale);
let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
format!("{}.{}", whole, decimal)
} else {
// String has to be padded
format!("{}0.{:0>width$}", sign, rest, width = scale)
format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
}
}

Expand Down
102 changes: 86 additions & 16 deletions arrow-cast/src/cast.rs
Expand Up @@ -319,15 +319,15 @@ fn cast_integer_to_decimal<
>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
scale: i8,
base: M,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<M>,
M: ArrowNativeTypeOp,
{
let mul: M = base.pow_checked(scale as u32).map_err(|_| {
let mul_or_div: M = base.pow_checked(scale.unsigned_abs() as u32).map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). The scale causes overflow.",
D::PREFIX,
Expand All @@ -336,14 +336,26 @@ where
))
})?;

if cast_options.safe {
if scale < 0 {
if cast_options.safe {
array
.unary_opt::<_, D>(|v| v.as_().div_checked(mul_or_div).ok())
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
array
.try_unary::<_, D, _>(|v| v.as_().div_checked(mul_or_div))
.and_then(|a| a.with_precision_and_scale(precision, scale))
.map(|a| Arc::new(a) as ArrayRef)
}
} else if cast_options.safe {
array
.unary_opt::<_, D>(|v| v.as_().mul_checked(mul).ok())
.unary_opt::<_, D>(|v| v.as_().mul_checked(mul_or_div).ok())
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
array
.try_unary::<_, D, _>(|v| v.as_().mul_checked(mul))
.try_unary::<_, D, _>(|v| v.as_().mul_checked(mul_or_div))
.and_then(|a| a.with_precision_and_scale(precision, scale))
.map(|a| Arc::new(a) as ArrayRef)
}
Expand All @@ -352,7 +364,7 @@ where
fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
Expand Down Expand Up @@ -391,7 +403,7 @@ where
fn cast_floating_point_to_decimal256<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
Expand Down Expand Up @@ -437,7 +449,7 @@ fn cast_reinterpret_arrays<
fn cast_decimal_to_integer<D, T>(
array: &ArrayRef,
base: D::Native,
scale: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
Expand Down Expand Up @@ -1921,9 +1933,9 @@ fn cast_decimal_to_decimal_with_option<
const BYTE_WIDTH2: usize,
>(
array: &ArrayRef,
input_scale: &u8,
input_scale: &i8,
output_precision: &u8,
output_scale: &u8,
output_scale: &i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
if cast_options.safe {
Expand All @@ -1947,9 +1959,9 @@ fn cast_decimal_to_decimal_with_option<
/// the array values when cast failures happen.
fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
array: &ArrayRef,
input_scale: &u8,
input_scale: &i8,
output_precision: &u8,
output_scale: &u8,
output_scale: &i8,
) -> Result<ArrayRef, ArrowError> {
if input_scale > output_scale {
// For example, input_scale is 4 and output_scale is 3;
Expand Down Expand Up @@ -2062,9 +2074,9 @@ fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usi
/// cast failure happens.
fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
array: &ArrayRef,
input_scale: &u8,
input_scale: &i8,
output_precision: &u8,
output_scale: &u8,
output_scale: &i8,
) -> Result<ArrayRef, ArrowError> {
if input_scale > output_scale {
// For example, input_scale is 4 and output_scale is 3;
Expand Down Expand Up @@ -3540,7 +3552,7 @@ mod tests {
fn create_decimal_array(
array: Vec<Option<i128>>,
precision: u8,
scale: u8,
scale: i8,
) -> Result<Decimal128Array, ArrowError> {
array
.into_iter()
Expand All @@ -3551,7 +3563,7 @@ mod tests {
fn create_decimal256_array(
array: Vec<Option<i256>>,
precision: u8,
scale: u8,
scale: i8,
) -> Result<Decimal256Array, ArrowError> {
array
.into_iter()
Expand Down Expand Up @@ -7206,4 +7218,62 @@ mod tests {
err
);
}

#[test]
fn test_cast_decimal128_to_decimal128_negative_scale() {
let input_type = DataType::Decimal128(20, 0);
let output_type = DataType::Decimal128(20, -1);
assert!(can_cast_types(&input_type, &output_type));
let array = vec![Some(1123456), Some(2123456), Some(3123456), None];
let input_decimal_array = create_decimal_array(array, 20, 0).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;
generate_cast_test_case!(
&array,
Decimal128Array,
&output_type,
vec![
Some(112345_i128),
Some(212345_i128),
Some(312345_i128),
None
]
);

let casted_array = cast(&array, &output_type).unwrap();
let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);

assert_eq!("1123450", decimal_arr.value_as_string(0));
assert_eq!("2123450", decimal_arr.value_as_string(1));
assert_eq!("3123450", decimal_arr.value_as_string(2));
}

#[test]
fn test_cast_numeric_to_decimal128_negative() {
let decimal_type = DataType::Decimal128(38, -1);
let array = Arc::new(Int32Array::from(vec![
Some(1123456),
Some(2123456),
Some(3123456),
])) as ArrayRef;

let casted_array = cast(&array, &decimal_type).unwrap();
let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);

assert_eq!("1123450", decimal_arr.value_as_string(0));
assert_eq!("2123450", decimal_arr.value_as_string(1));
assert_eq!("3123450", decimal_arr.value_as_string(2));

let array = Arc::new(Float32Array::from(vec![
Some(1123.456),
Some(2123.456),
Some(3123.456),
])) as ArrayRef;

let casted_array = cast(&array, &decimal_type).unwrap();
let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);

assert_eq!("1120", decimal_arr.value_as_string(0));
assert_eq!("2120", decimal_arr.value_as_string(1));
assert_eq!("3120", decimal_arr.value_as_string(2));
}
}
6 changes: 3 additions & 3 deletions arrow-csv/src/reader.rs
Expand Up @@ -721,7 +721,7 @@ fn build_decimal_array(
rows: &[StringRecord],
col_idx: usize,
precision: u8,
scale: u8,
scale: i8,
) -> Result<ArrayRef, ArrowError> {
let mut decimal_builder = Decimal128Builder::with_capacity(rows.len());
for row in rows {
Expand Down Expand Up @@ -762,13 +762,13 @@ fn build_decimal_array(
fn parse_decimal_with_parameter(
s: &str,
precision: u8,
scale: u8,
scale: i8,
) -> Result<i128, ArrowError> {
if PARSE_DECIMAL_RE.is_match(s) {
let mut offset = s.len();
let len = s.len();
let mut base = 1;
let scale_usize = usize::from(scale);
let scale_usize = usize::from(scale as u8);

// handle the value after the '.' and meet the scale
let delimiter_position = s.find('.');
Expand Down
6 changes: 3 additions & 3 deletions arrow-data/src/decimal.rs
Expand Up @@ -728,17 +728,17 @@ pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
pub const DECIMAL128_MAX_PRECISION: u8 = 38;

/// The maximum scale for [arrow_schema::DataType::Decimal128] values
pub const DECIMAL128_MAX_SCALE: u8 = 38;
pub const DECIMAL128_MAX_SCALE: i8 = 38;

/// The maximum precision for [arrow_schema::DataType::Decimal256] values
pub const DECIMAL256_MAX_PRECISION: u8 = 76;

/// The maximum scale for [arrow_schema::DataType::Decimal256] values
pub const DECIMAL256_MAX_SCALE: u8 = 76;
pub const DECIMAL256_MAX_SCALE: i8 = 76;

/// The default scale for [arrow_schema::DataType::Decimal128] and
/// [arrow_schema::DataType::Decimal256] values
pub const DECIMAL_DEFAULT_SCALE: u8 = 10;
pub const DECIMAL_DEFAULT_SCALE: i8 = 10;

/// Validates that the specified `i128` value can be properly
/// interpreted as a Decimal number with precision `precision`
Expand Down
4 changes: 2 additions & 2 deletions arrow-schema/src/datatype.rs
Expand Up @@ -190,14 +190,14 @@ pub enum DataType {
/// * scale is the number of digits past the decimal
///
/// For example the number 123.45 has precision 5 and scale 2.
Decimal128(u8, u8),
Decimal128(u8, i8),
/// Exact 256-bit width decimal value with precision and scale
///
/// * precision is the total number of digits
/// * scale is the number of digits past the decimal
///
/// For example the number 123.45 has precision 5 and scale 2.
Decimal256(u8, u8),
Decimal256(u8, i8),
/// A Map is a logical nested type that is represented as
///
/// `List<entries: Struct<key: K, value: V>>`
Expand Down