Skip to content

Commit

Permalink
Support decimal negative scale (#3152)
Browse files Browse the repository at this point in the history
* Support decimal negative scale

* Fix casting from numeric to negative scale decimal

* Fix clippy
  • Loading branch information
viirya committed Nov 23, 2022
1 parent 6c466af commit 78ab0ef
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 51 deletions.
15 changes: 11 additions & 4 deletions arrow-array/src/array/primitive_array.rs
Expand Up @@ -1003,7 +1003,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
pub fn with_precision_and_scale(
self,
precision: u8,
scale: u8,
scale: i8,
) -> Result<Self, ArrowError>
where
Self: Sized,
Expand All @@ -1024,7 +1024,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
fn validate_precision_scale(
&self,
precision: u8,
scale: u8,
scale: i8,
) -> Result<(), ArrowError> {
if precision == 0 {
return Err(ArrowError::InvalidArgumentError(format!(
Expand All @@ -1046,7 +1046,14 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
T::MAX_SCALE
)));
}
if scale > precision {
if scale < -T::MAX_SCALE {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is smaller than min {}",
scale,
-Decimal128Type::MAX_SCALE
)));
}
if scale > 0 && scale as u8 > precision {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than precision {}",
scale, precision
Expand Down Expand Up @@ -1102,7 +1109,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
}

/// Returns the decimal scale of this array
pub fn scale(&self) -> u8 {
pub fn scale(&self) -> i8 {
match T::BYTE_LENGTH {
16 => {
if let DataType::Decimal128(_, s) = self.data().data_type() {
Expand Down
33 changes: 18 additions & 15 deletions arrow-array/src/types.rs
Expand Up @@ -491,15 +491,15 @@ pub trait DecimalType:
{
const BYTE_LENGTH: usize;
const MAX_PRECISION: u8;
const MAX_SCALE: u8;
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType;
const MAX_SCALE: i8;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType;
const DEFAULT_TYPE: DataType;

/// "Decimal128" or "Decimal256", for use in error messages
const PREFIX: &'static str;

/// Formats the decimal value with the provided precision and scale
fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String;
fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String;

/// Validates that `value` contains no more than `precision` decimal digits
fn validate_decimal_precision(
Expand All @@ -515,14 +515,14 @@ pub struct Decimal128Type {}
impl DecimalType for Decimal128Type {
const BYTE_LENGTH: usize = 16;
const MAX_PRECISION: u8 = DECIMAL128_MAX_PRECISION;
const MAX_SCALE: u8 = DECIMAL128_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128;
const MAX_SCALE: i8 = DECIMAL128_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal128;
const DEFAULT_TYPE: DataType =
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal128";

fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale)
}

fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> {
Expand All @@ -543,14 +543,14 @@ pub struct Decimal256Type {}
impl DecimalType for Decimal256Type {
const BYTE_LENGTH: usize = 32;
const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION;
const MAX_SCALE: u8 = DECIMAL256_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256;
const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal256;
const DEFAULT_TYPE: DataType =
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal256";

fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale)
}

fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> {
Expand All @@ -564,7 +564,7 @@ impl ArrowPrimitiveType for Decimal256Type {
const DATA_TYPE: DataType = <Self as DecimalType>::DEFAULT_TYPE;
}

fn format_decimal_str(value_str: &str, precision: usize, scale: usize) -> String {
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
let (sign, rest) = match value_str.strip_prefix('-') {
Some(stripped) => ("-", stripped),
None => ("", value_str),
Expand All @@ -574,13 +574,16 @@ fn format_decimal_str(value_str: &str, precision: usize, scale: usize) -> String

if scale == 0 {
value_str.to_string()
} else if rest.len() > scale {
} else if scale < 0 {
let padding = value_str.len() + scale.unsigned_abs() as usize;
format!("{:0<width$}", value_str, width = padding)
} else if rest.len() > scale as usize {
// Decimal separator is in the middle of the string
let (whole, decimal) = value_str.split_at(value_str.len() - scale);
let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
format!("{}.{}", whole, decimal)
} else {
// String has to be padded
format!("{}0.{:0>width$}", sign, rest, width = scale)
format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
}
}

Expand Down
102 changes: 86 additions & 16 deletions arrow-cast/src/cast.rs
Expand Up @@ -319,15 +319,15 @@ fn cast_integer_to_decimal<
>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
scale: i8,
base: M,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<M>,
M: ArrowNativeTypeOp,
{
let mul: M = base.pow_checked(scale as u32).map_err(|_| {
let mul_or_div: M = base.pow_checked(scale.unsigned_abs() as u32).map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). The scale causes overflow.",
D::PREFIX,
Expand All @@ -336,14 +336,26 @@ where
))
})?;

if cast_options.safe {
if scale < 0 {
if cast_options.safe {
array
.unary_opt::<_, D>(|v| v.as_().div_checked(mul_or_div).ok())
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
array
.try_unary::<_, D, _>(|v| v.as_().div_checked(mul_or_div))
.and_then(|a| a.with_precision_and_scale(precision, scale))
.map(|a| Arc::new(a) as ArrayRef)
}
} else if cast_options.safe {
array
.unary_opt::<_, D>(|v| v.as_().mul_checked(mul).ok())
.unary_opt::<_, D>(|v| v.as_().mul_checked(mul_or_div).ok())
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
array
.try_unary::<_, D, _>(|v| v.as_().mul_checked(mul))
.try_unary::<_, D, _>(|v| v.as_().mul_checked(mul_or_div))
.and_then(|a| a.with_precision_and_scale(precision, scale))
.map(|a| Arc::new(a) as ArrayRef)
}
Expand All @@ -352,7 +364,7 @@ where
fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
Expand Down Expand Up @@ -391,7 +403,7 @@ where
fn cast_floating_point_to_decimal256<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
Expand Down Expand Up @@ -437,7 +449,7 @@ fn cast_reinterpret_arrays<
fn cast_decimal_to_integer<D, T>(
array: &ArrayRef,
base: D::Native,
scale: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
Expand Down Expand Up @@ -1921,9 +1933,9 @@ fn cast_decimal_to_decimal_with_option<
const BYTE_WIDTH2: usize,
>(
array: &ArrayRef,
input_scale: &u8,
input_scale: &i8,
output_precision: &u8,
output_scale: &u8,
output_scale: &i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
if cast_options.safe {
Expand All @@ -1947,9 +1959,9 @@ fn cast_decimal_to_decimal_with_option<
/// the array values when cast failures happen.
fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
array: &ArrayRef,
input_scale: &u8,
input_scale: &i8,
output_precision: &u8,
output_scale: &u8,
output_scale: &i8,
) -> Result<ArrayRef, ArrowError> {
if input_scale > output_scale {
// For example, input_scale is 4 and output_scale is 3;
Expand Down Expand Up @@ -2062,9 +2074,9 @@ fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usi
/// cast failure happens.
fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
array: &ArrayRef,
input_scale: &u8,
input_scale: &i8,
output_precision: &u8,
output_scale: &u8,
output_scale: &i8,
) -> Result<ArrayRef, ArrowError> {
if input_scale > output_scale {
// For example, input_scale is 4 and output_scale is 3;
Expand Down Expand Up @@ -3540,7 +3552,7 @@ mod tests {
fn create_decimal_array(
array: Vec<Option<i128>>,
precision: u8,
scale: u8,
scale: i8,
) -> Result<Decimal128Array, ArrowError> {
array
.into_iter()
Expand All @@ -3551,7 +3563,7 @@ mod tests {
fn create_decimal256_array(
array: Vec<Option<i256>>,
precision: u8,
scale: u8,
scale: i8,
) -> Result<Decimal256Array, ArrowError> {
array
.into_iter()
Expand Down Expand Up @@ -7206,4 +7218,62 @@ mod tests {
err
);
}

#[test]
fn test_cast_decimal128_to_decimal128_negative_scale() {
let input_type = DataType::Decimal128(20, 0);
let output_type = DataType::Decimal128(20, -1);
assert!(can_cast_types(&input_type, &output_type));
let array = vec![Some(1123456), Some(2123456), Some(3123456), None];
let input_decimal_array = create_decimal_array(array, 20, 0).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;
generate_cast_test_case!(
&array,
Decimal128Array,
&output_type,
vec![
Some(112345_i128),
Some(212345_i128),
Some(312345_i128),
None
]
);

let casted_array = cast(&array, &output_type).unwrap();
let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);

assert_eq!("1123450", decimal_arr.value_as_string(0));
assert_eq!("2123450", decimal_arr.value_as_string(1));
assert_eq!("3123450", decimal_arr.value_as_string(2));
}

#[test]
fn test_cast_numeric_to_decimal128_negative() {
let decimal_type = DataType::Decimal128(38, -1);
let array = Arc::new(Int32Array::from(vec![
Some(1123456),
Some(2123456),
Some(3123456),
])) as ArrayRef;

let casted_array = cast(&array, &decimal_type).unwrap();
let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);

assert_eq!("1123450", decimal_arr.value_as_string(0));
assert_eq!("2123450", decimal_arr.value_as_string(1));
assert_eq!("3123450", decimal_arr.value_as_string(2));

let array = Arc::new(Float32Array::from(vec![
Some(1123.456),
Some(2123.456),
Some(3123.456),
])) as ArrayRef;

let casted_array = cast(&array, &decimal_type).unwrap();
let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);

assert_eq!("1120", decimal_arr.value_as_string(0));
assert_eq!("2120", decimal_arr.value_as_string(1));
assert_eq!("3120", decimal_arr.value_as_string(2));
}
}
6 changes: 3 additions & 3 deletions arrow-csv/src/reader.rs
Expand Up @@ -721,7 +721,7 @@ fn build_decimal_array(
rows: &[StringRecord],
col_idx: usize,
precision: u8,
scale: u8,
scale: i8,
) -> Result<ArrayRef, ArrowError> {
let mut decimal_builder = Decimal128Builder::with_capacity(rows.len());
for row in rows {
Expand Down Expand Up @@ -762,13 +762,13 @@ fn build_decimal_array(
fn parse_decimal_with_parameter(
s: &str,
precision: u8,
scale: u8,
scale: i8,
) -> Result<i128, ArrowError> {
if PARSE_DECIMAL_RE.is_match(s) {
let mut offset = s.len();
let len = s.len();
let mut base = 1;
let scale_usize = usize::from(scale);
let scale_usize = usize::from(scale as u8);

// handle the value after the '.' and meet the scale
let delimiter_position = s.find('.');
Expand Down
6 changes: 3 additions & 3 deletions arrow-data/src/decimal.rs
Expand Up @@ -728,17 +728,17 @@ pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
pub const DECIMAL128_MAX_PRECISION: u8 = 38;

/// The maximum scale for [arrow_schema::DataType::Decimal128] values
pub const DECIMAL128_MAX_SCALE: u8 = 38;
pub const DECIMAL128_MAX_SCALE: i8 = 38;

/// The maximum precision for [arrow_schema::DataType::Decimal256] values
pub const DECIMAL256_MAX_PRECISION: u8 = 76;

/// The maximum scale for [arrow_schema::DataType::Decimal256] values
pub const DECIMAL256_MAX_SCALE: u8 = 76;
pub const DECIMAL256_MAX_SCALE: i8 = 76;

/// The default scale for [arrow_schema::DataType::Decimal128] and
/// [arrow_schema::DataType::Decimal256] values
pub const DECIMAL_DEFAULT_SCALE: u8 = 10;
pub const DECIMAL_DEFAULT_SCALE: i8 = 10;

/// Validates that the specified `i128` value can be properly
/// interpreted as a Decimal number with precision `precision`
Expand Down
4 changes: 2 additions & 2 deletions arrow-schema/src/datatype.rs
Expand Up @@ -190,14 +190,14 @@ pub enum DataType {
/// * scale is the number of digits past the decimal
///
/// For example the number 123.45 has precision 5 and scale 2.
Decimal128(u8, u8),
Decimal128(u8, i8),
/// Exact 256-bit width decimal value with precision and scale
///
/// * precision is the total number of digits
/// * scale is the number of digits past the decimal
///
/// For example the number 123.45 has precision 5 and scale 2.
Decimal256(u8, u8),
Decimal256(u8, i8),
/// A Map is a logical nested type that is represented as
///
/// `List<entries: Struct<key: K, value: V>>`
Expand Down

0 comments on commit 78ab0ef

Please sign in to comment.