Skip to content

Commit

Permalink
Support decimal negative scale
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Nov 21, 2022
1 parent 5bce104 commit a782f91
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 43 deletions.
15 changes: 11 additions & 4 deletions arrow-array/src/array/primitive_array.rs
Expand Up @@ -1002,7 +1002,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
pub fn with_precision_and_scale(
self,
precision: u8,
scale: u8,
scale: i8,
) -> Result<Self, ArrowError>
where
Self: Sized,
Expand All @@ -1023,7 +1023,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
fn validate_precision_scale(
&self,
precision: u8,
scale: u8,
scale: i8,
) -> Result<(), ArrowError> {
if precision > T::MAX_PRECISION {
return Err(ArrowError::InvalidArgumentError(format!(
Expand All @@ -1039,7 +1039,14 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
Decimal128Type::MAX_SCALE
)));
}
if scale > precision {
if scale < -T::MAX_SCALE {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is smaller than min {}",
scale,
-Decimal128Type::MAX_SCALE
)));
}
if scale > 0 && scale as u8 > precision {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than precision {}",
scale, precision
Expand Down Expand Up @@ -1095,7 +1102,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
}

/// Returns the decimal scale of this array
pub fn scale(&self) -> u8 {
pub fn scale(&self) -> i8 {
match T::BYTE_LENGTH {
16 => {
if let DataType::Decimal128(_, s) = self.data().data_type() {
Expand Down
33 changes: 18 additions & 15 deletions arrow-array/src/types.rs
Expand Up @@ -491,15 +491,15 @@ pub trait DecimalType:
{
const BYTE_LENGTH: usize;
const MAX_PRECISION: u8;
const MAX_SCALE: u8;
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType;
const MAX_SCALE: i8;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType;
const DEFAULT_TYPE: DataType;

/// "Decimal128" or "Decimal256", for use in error messages
const PREFIX: &'static str;

/// Formats the decimal value with the provided precision and scale
fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String;
fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String;

/// Validates that `value` contains no more than `precision` decimal digits
fn validate_decimal_precision(
Expand All @@ -515,14 +515,14 @@ pub struct Decimal128Type {}
impl DecimalType for Decimal128Type {
const BYTE_LENGTH: usize = 16;
const MAX_PRECISION: u8 = DECIMAL128_MAX_PRECISION;
const MAX_SCALE: u8 = DECIMAL128_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128;
const MAX_SCALE: i8 = DECIMAL128_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal128;
const DEFAULT_TYPE: DataType =
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal128";

fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale)
}

fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> {
Expand All @@ -543,14 +543,14 @@ pub struct Decimal256Type {}
impl DecimalType for Decimal256Type {
const BYTE_LENGTH: usize = 32;
const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION;
const MAX_SCALE: u8 = DECIMAL256_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256;
const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal256;
const DEFAULT_TYPE: DataType =
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal256";

fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale as usize)
fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String {
format_decimal_str(&value.to_string(), precision as usize, scale)
}

fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> {
Expand All @@ -564,7 +564,7 @@ impl ArrowPrimitiveType for Decimal256Type {
const DATA_TYPE: DataType = <Self as DecimalType>::DEFAULT_TYPE;
}

fn format_decimal_str(value_str: &str, precision: usize, scale: usize) -> String {
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
let (sign, rest) = match value_str.strip_prefix('-') {
Some(stripped) => ("-", stripped),
None => ("", value_str),
Expand All @@ -574,13 +574,16 @@ fn format_decimal_str(value_str: &str, precision: usize, scale: usize) -> String

if scale == 0 {
value_str.to_string()
} else if rest.len() > scale {
} else if scale < 0 {
let padding = value_str.len() + scale.unsigned_abs() as usize;
format!("{:0<width$}", value_str, width = padding)
} else if rest.len() > scale as usize {
// Decimal separator is in the middle of the string
let (whole, decimal) = value_str.split_at(value_str.len() - scale);
let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
format!("{}.{}", whole, decimal)
} else {
// String has to be padded
format!("{}0.{:0>width$}", sign, rest, width = scale)
format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
}
}

Expand Down
52 changes: 40 additions & 12 deletions arrow-cast/src/cast.rs
Expand Up @@ -319,7 +319,7 @@ fn cast_integer_to_decimal<
>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
scale: i8,
base: M,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
Expand Down Expand Up @@ -352,7 +352,7 @@ where
fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
Expand Down Expand Up @@ -391,7 +391,7 @@ where
fn cast_floating_point_to_decimal256<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
precision: u8,
scale: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
Expand Down Expand Up @@ -437,7 +437,7 @@ fn cast_reinterpret_arrays<
fn cast_decimal_to_integer<D, T>(
array: &ArrayRef,
base: D::Native,
scale: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
Expand Down Expand Up @@ -1921,9 +1921,9 @@ fn cast_decimal_to_decimal_with_option<
const BYTE_WIDTH2: usize,
>(
array: &ArrayRef,
input_scale: &u8,
input_scale: &i8,
output_precision: &u8,
output_scale: &u8,
output_scale: &i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
if cast_options.safe {
Expand All @@ -1947,9 +1947,9 @@ fn cast_decimal_to_decimal_with_option<
/// the array values when cast failures happen.
fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
array: &ArrayRef,
input_scale: &u8,
input_scale: &i8,
output_precision: &u8,
output_scale: &u8,
output_scale: &i8,
) -> Result<ArrayRef, ArrowError> {
if input_scale > output_scale {
// For example, input_scale is 4 and output_scale is 3;
Expand Down Expand Up @@ -2062,9 +2062,9 @@ fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usi
/// cast failure happens.
fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
array: &ArrayRef,
input_scale: &u8,
input_scale: &i8,
output_precision: &u8,
output_scale: &u8,
output_scale: &i8,
) -> Result<ArrayRef, ArrowError> {
if input_scale > output_scale {
// For example, input_scale is 4 and output_scale is 3;
Expand Down Expand Up @@ -3540,7 +3540,7 @@ mod tests {
fn create_decimal_array(
array: Vec<Option<i128>>,
precision: u8,
scale: u8,
scale: i8,
) -> Result<Decimal128Array, ArrowError> {
array
.into_iter()
Expand All @@ -3551,7 +3551,7 @@ mod tests {
fn create_decimal256_array(
array: Vec<Option<i256>>,
precision: u8,
scale: u8,
scale: i8,
) -> Result<Decimal256Array, ArrowError> {
array
.into_iter()
Expand Down Expand Up @@ -7206,4 +7206,32 @@ mod tests {
err
);
}

#[test]
fn test_cast_decimal128_to_decimal128_negative_scale() {
let input_type = DataType::Decimal128(20, 0);
let output_type = DataType::Decimal128(20, -1);
assert!(can_cast_types(&input_type, &output_type));
let array = vec![Some(1123456), Some(2123456), Some(3123456), None];
let input_decimal_array = create_decimal_array(array, 20, 0).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;
generate_cast_test_case!(
&array,
Decimal128Array,
&output_type,
vec![
Some(112345_i128),
Some(212345_i128),
Some(312345_i128),
None
]
);

let casted_array = cast(&array, &output_type).unwrap();
let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);

assert_eq!("1123450", decimal_arr.value_as_string(0));
assert_eq!("2123450", decimal_arr.value_as_string(1));
assert_eq!("3123450", decimal_arr.value_as_string(2));
}
}
6 changes: 3 additions & 3 deletions arrow-csv/src/reader.rs
Expand Up @@ -721,7 +721,7 @@ fn build_decimal_array(
rows: &[StringRecord],
col_idx: usize,
precision: u8,
scale: u8,
scale: i8,
) -> Result<ArrayRef, ArrowError> {
let mut decimal_builder = Decimal128Builder::with_capacity(rows.len());
for row in rows {
Expand Down Expand Up @@ -762,13 +762,13 @@ fn build_decimal_array(
fn parse_decimal_with_parameter(
s: &str,
precision: u8,
scale: u8,
scale: i8,
) -> Result<i128, ArrowError> {
if PARSE_DECIMAL_RE.is_match(s) {
let mut offset = s.len();
let len = s.len();
let mut base = 1;
let scale_usize = usize::from(scale);
let scale_usize = usize::from(scale as u8);

// handle the value after the '.' and meet the scale
let delimiter_position = s.find('.');
Expand Down
6 changes: 3 additions & 3 deletions arrow-data/src/decimal.rs
Expand Up @@ -728,17 +728,17 @@ pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
pub const DECIMAL128_MAX_PRECISION: u8 = 38;

/// The maximum scale for [arrow_schema::DataType::Decimal128] values
pub const DECIMAL128_MAX_SCALE: u8 = 38;
pub const DECIMAL128_MAX_SCALE: i8 = 38;

/// The maximum precision for [arrow_schema::DataType::Decimal256] values
pub const DECIMAL256_MAX_PRECISION: u8 = 76;

/// The maximum scale for [arrow_schema::DataType::Decimal256] values
pub const DECIMAL256_MAX_SCALE: u8 = 76;
pub const DECIMAL256_MAX_SCALE: i8 = 76;

/// The default scale for [arrow_schema::DataType::Decimal128] and
/// [arrow_schema::DataType::Decimal256] values
pub const DECIMAL_DEFAULT_SCALE: u8 = 10;
pub const DECIMAL_DEFAULT_SCALE: i8 = 10;

/// Validates that the specified `i128` value can be properly
/// interpreted as a Decimal number with precision `precision`
Expand Down
4 changes: 2 additions & 2 deletions arrow-schema/src/datatype.rs
Expand Up @@ -190,14 +190,14 @@ pub enum DataType {
/// * scale is the number of digits past the decimal
///
/// For example the number 123.45 has precision 5 and scale 2.
Decimal128(u8, u8),
Decimal128(u8, i8),
/// Exact 256-bit width decimal value with precision and scale
///
/// * precision is the total number of digits
/// * scale is the number of digits past the decimal
///
/// For example the number 123.45 has precision 5 and scale 2.
Decimal256(u8, u8),
Decimal256(u8, i8),
/// A Map is a logical nested type that is represented as
///
/// `List<entries: Struct<key: K, value: V>>`
Expand Down
6 changes: 3 additions & 3 deletions arrow-select/src/take.rs
Expand Up @@ -914,7 +914,7 @@ mod tests {
options: Option<TakeOptions>,
expected_data: Vec<Option<i128>>,
precision: &u8,
scale: &u8,
scale: &i8,
) -> Result<(), ArrowError> {
let output = data
.into_iter()
Expand Down Expand Up @@ -1032,7 +1032,7 @@ mod tests {
fn test_take_decimal128_non_null_indices() {
let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
let precision: u8 = 10;
let scale: u8 = 5;
let scale: i8 = 5;
test_take_decimal_arrays(
vec![None, Some(3), Some(5), Some(2), Some(3), None],
&index,
Expand All @@ -1048,7 +1048,7 @@ mod tests {
fn test_take_decimal128() {
let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
let precision: u8 = 10;
let scale: u8 = 5;
let scale: i8 = 5;
test_take_decimal_arrays(
vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
&index,
Expand Down
2 changes: 1 addition & 1 deletion arrow/tests/array_transform.rs
Expand Up @@ -31,7 +31,7 @@ use std::sync::Arc;
fn create_decimal_array(
array: Vec<Option<i128>>,
precision: u8,
scale: u8,
scale: i8,
) -> Decimal128Array {
array
.into_iter()
Expand Down

0 comments on commit a782f91

Please sign in to comment.