From 0e7a5cd9bbf18535b899d18ef40d52be0dce74f9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 22 Jul 2022 13:42:14 -0700 Subject: [PATCH] Add precision/scale check. Add decimal128 and decimal256 iters. --- arrow/src/array/array_binary.rs | 1 - arrow/src/array/array_decimal.rs | 127 ++++++++++++++++++++++-------- arrow/src/array/iterator.rs | 14 +++- arrow/src/compute/kernels/cast.rs | 4 +- arrow/src/datatypes/datatype.rs | 4 +- arrow/src/util/decimal.rs | 76 ++++++++++++++++-- 6 files changed, 182 insertions(+), 44 deletions(-) diff --git a/arrow/src/array/array_binary.rs b/arrow/src/array/array_binary.rs index b01696b0334..4848a25a058 100644 --- a/arrow/src/array/array_binary.rs +++ b/arrow/src/array/array_binary.rs @@ -24,7 +24,6 @@ use super::{ FixedSizeListArray, GenericBinaryIter, GenericListArray, OffsetSizeTrait, }; use crate::array::array::ArrayAccessor; -pub use crate::array::DecimalIter; use crate::buffer::Buffer; use crate::error::{ArrowError, Result}; use crate::util::bit_util; diff --git a/arrow/src/array/array_decimal.rs b/arrow/src/array/array_decimal.rs index 494b4e9d1a9..47316085820 100644 --- a/arrow/src/array/array_decimal.rs +++ b/arrow/src/array/array_decimal.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::array::{ArrayAccessor, Decimal128Iter, Decimal256Iter}; use std::borrow::Borrow; use std::convert::From; use std::fmt; @@ -24,13 +25,11 @@ use super::{ array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, FixedSizeListArray, }; use super::{BooleanBufferBuilder, FixedSizeBinaryArray}; +#[allow(deprecated)] pub use crate::array::DecimalIter; use crate::buffer::Buffer; -use crate::datatypes::DataType; -use crate::datatypes::{ - validate_decimal_precision, DECIMAL_DEFAULT_SCALE, DECIMAL_MAX_PRECISION, - DECIMAL_MAX_SCALE, -}; +use crate::datatypes::{validate_decimal_precision, DECIMAL_DEFAULT_SCALE}; +use crate::datatypes::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE}; use crate::error::{ArrowError, Result}; use crate::util::decimal::{BasicDecimal, Decimal128, Decimal256}; @@ -103,11 +102,18 @@ pub trait BasicDecimalArray>: /// Returns the element at index `i`. fn value(&self, i: usize) -> T { - let data = self.data(); - assert!(i < data.len(), "Out of bounds access"); + assert!(i < self.data().len(), "Out of bounds access"); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i`. + /// # Safety + /// Caller is responsible for ensuring that the index is within the bounds of the array + unsafe fn value_unchecked(&self, i: usize) -> T { + let data = self.data(); let offset = i + data.offset(); - let raw_val = unsafe { + let raw_val = { let pos = self.value_offset_at(offset); std::slice::from_raw_parts( self.raw_value_data_ptr().offset(pos as isize), @@ -270,24 +276,24 @@ impl Decimal128Array { /// specified precision. /// /// Returns an Error if: - /// 1. `precision` is larger than [`DECIMAL_MAX_PRECISION`] - /// 2. `scale` is larger than [`DECIMAL_MAX_SCALE`]; + /// 1. `precision` is larger than [`DECIMAL128_MAX_PRECISION`] + /// 2. `scale` is larger than [`DECIMAL128_MAX_SCALE`]; /// 3. `scale` is > `precision` pub fn with_precision_and_scale( mut self, precision: usize, scale: usize, ) -> Result { - if precision > DECIMAL_MAX_PRECISION { + if precision > DECIMAL128_MAX_PRECISION { return Err(ArrowError::InvalidArgumentError(format!( "precision {} is greater than max {}", - precision, DECIMAL_MAX_PRECISION + precision, DECIMAL128_MAX_PRECISION ))); } - if scale > DECIMAL_MAX_SCALE { + if scale > DECIMAL128_MAX_SCALE { return Err(ArrowError::InvalidArgumentError(format!( "scale {} is greater than max {}", - scale, DECIMAL_MAX_SCALE + scale, DECIMAL128_MAX_SCALE ))); } if scale > precision { @@ -302,7 +308,7 @@ impl Decimal128Array { // decreased if precision < self.precision { for v in self.iter().flatten() { - validate_decimal_precision(v, precision)?; + validate_decimal_precision(v.as_i128(), precision)?; } } @@ -322,7 +328,7 @@ impl Decimal128Array { /// The default precision and scale used when not specified. pub fn default_type() -> DataType { // Keep maximum precision - DataType::Decimal(DECIMAL_MAX_PRECISION, DECIMAL_DEFAULT_SCALE) + DataType::Decimal(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE) } } @@ -368,19 +374,13 @@ impl From for Decimal256Array { } } -impl<'a> IntoIterator for &'a Decimal128Array { - type Item = Option; - type IntoIter = DecimalIter<'a>; - - fn into_iter(self) -> Self::IntoIter { - DecimalIter::<'a>::new(self) - } -} - impl<'a> Decimal128Array { - /// constructs a new iterator - pub fn iter(&'a self) -> DecimalIter<'a> { - DecimalIter::new(self) + /// Constructs a new iterator that iterates `Decimal128` values as i128 values. + /// This is kept mostly for back-compatibility purpose. + /// Suggests to use `iter()` that returns `Decimal128Iter`. + #[allow(deprecated)] + pub fn i128_iter(&'a self) -> DecimalIter<'a> { + DecimalIter::<'a>::new(self) } } @@ -421,7 +421,7 @@ impl>> FromIterator for Decimal128Array { } macro_rules! def_decimal_array { - ($ty:ident, $array_name:expr) => { + ($ty:ident, $array_name:expr, $decimal_ty:ident, $iter_ty:ident) => { impl private_decimal::DecimalArrayPrivate for $ty { fn raw_value_data_ptr(&self) -> *const u8 { self.value_data.as_ptr() @@ -463,15 +463,55 @@ macro_rules! def_decimal_array { write!(f, "]") } } + + impl<'a> ArrayAccessor for &'a $ty { + type Item = $decimal_ty; + + fn value(&self, index: usize) -> Self::Item { + $ty::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + $ty::value_unchecked(self, index) + } + } + + impl<'a> IntoIterator for &'a $ty { + type Item = Option<$decimal_ty>; + type IntoIter = $iter_ty<'a>; + + fn into_iter(self) -> Self::IntoIter { + $iter_ty::<'a>::new(self) + } + } + + impl<'a> $ty { + /// constructs a new iterator + pub fn iter(&'a self) -> $iter_ty<'a> { + $iter_ty::<'a>::new(self) + } + } }; } -def_decimal_array!(Decimal128Array, "Decimal128Array"); -def_decimal_array!(Decimal256Array, "Decimal256Array"); +def_decimal_array!( + Decimal128Array, + "Decimal128Array", + Decimal128, + Decimal128Iter +); +def_decimal_array!( + Decimal256Array, + "Decimal256Array", + Decimal256, + Decimal256Iter +); #[cfg(test)] mod tests { + use crate::array::Decimal256Builder; use crate::{array::Decimal128Builder, datatypes::Field}; + use num::{BigInt, Num}; use super::*; @@ -567,7 +607,7 @@ mod tests { let data = vec![Some(-100), None, Some(101)]; let array: Decimal128Array = data.clone().into_iter().collect(); - let collected: Vec<_> = array.iter().collect(); + let collected: Vec<_> = array.iter().map(|d| d.map(|v| v.as_i128())).collect(); assert_eq!(data, collected); } @@ -576,7 +616,8 @@ mod tests { let data = vec![Some(-100), None, Some(101)]; let array: Decimal128Array = data.clone().into_iter().collect(); - let collected: Vec<_> = array.into_iter().collect(); + let collected: Vec<_> = + array.into_iter().map(|d| d.map(|v| v.as_i128())).collect(); assert_eq!(data, collected); } @@ -750,4 +791,24 @@ mod tests { assert!(decimal.is_null(0)); assert_eq!(decimal.value_as_string(1), "56".to_string()); } + + #[test] + fn test_decimal256_iter() { + // TODO: Impl FromIterator for Decimal256Array + let mut builder = Decimal256Builder::new(30, 76, 6); + let value = BigInt::from_str_radix("12345", 10).unwrap(); + let decimal1 = Decimal256::from_big_int(&value, 76, 6).unwrap(); + builder.append_value(&decimal1).unwrap(); + + builder.append_null(); + + let value = BigInt::from_str_radix("56789", 10).unwrap(); + let decimal2 = Decimal256::from_big_int(&value, 76, 6).unwrap(); + builder.append_value(&decimal2).unwrap(); + + let array: Decimal256Array = builder.finish(); + + let collected: Vec<_> = array.iter().collect(); + assert_eq!(vec![Some(decimal1), None, Some(decimal2)], collected); + } } diff --git a/arrow/src/array/iterator.rs b/arrow/src/array/iterator.rs index a4853d7d73b..e277ec82ad9 100644 --- a/arrow/src/array/iterator.rs +++ b/arrow/src/array/iterator.rs @@ -16,7 +16,7 @@ // under the License. use crate::array::array::ArrayAccessor; -use crate::array::BasicDecimalArray; +use crate::array::{BasicDecimalArray, Decimal256Array}; use super::{ Array, BooleanArray, Decimal128Array, GenericBinaryArray, GenericListArray, @@ -104,15 +104,25 @@ pub type GenericStringIter<'a, T> = ArrayIter<&'a GenericStringArray>; pub type GenericBinaryIter<'a, T> = ArrayIter<&'a GenericBinaryArray>; pub type GenericListArrayIter<'a, O> = ArrayIter<&'a GenericListArray>; +/// an iterator that returns `Some(Decimal128)` or `None`, that can be used on a +/// [`Decimal128Array`] +pub type Decimal128Iter<'a> = ArrayIter<&'a Decimal128Array>; + +/// an iterator that returns `Some(Decimal256)` or `None`, that can be used on a +/// [`Decimal256Array`] +pub type Decimal256Iter<'a> = ArrayIter<&'a Decimal256Array>; + /// an iterator that returns `Some(i128)` or `None`, that can be used on a /// [`Decimal128Array`] #[derive(Debug)] +#[deprecated(note = "Please use `Decimal128Iter` instead")] pub struct DecimalIter<'a> { array: &'a Decimal128Array, current: usize, current_end: usize, } +#[allow(deprecated)] impl<'a> DecimalIter<'a> { pub fn new(array: &'a Decimal128Array) -> Self { Self { @@ -123,6 +133,7 @@ impl<'a> DecimalIter<'a> { } } +#[allow(deprecated)] impl<'a> std::iter::Iterator for DecimalIter<'a> { type Item = Option; @@ -150,6 +161,7 @@ impl<'a> std::iter::Iterator for DecimalIter<'a> { } /// iterator has known size. +#[allow(deprecated)] impl<'a> std::iter::ExactSizeIterator for DecimalIter<'a> {} #[cfg(test)] diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 954acef763c..781f199a691 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -1205,7 +1205,7 @@ fn cast_decimal_to_decimal( let div = 10_i128.pow((input_scale - output_scale) as u32); array .iter() - .map(|v| v.map(|v| v / div)) + .map(|v| v.map(|v| v.as_i128() / div)) .collect::() } else { // For example, input_scale is 3 and output_scale is 4; @@ -1213,7 +1213,7 @@ fn cast_decimal_to_decimal( let mul = 10_i128.pow((output_scale - input_scale) as u32); array .iter() - .map(|v| v.map(|v| v * mul)) + .map(|v| v.map(|v| v.as_i128() * mul)) .collect::() } .with_precision_and_scale(*output_precision, *output_scale)?; diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs index 74a2ab45080..429a94f24b9 100644 --- a/arrow/src/datatypes/datatype.rs +++ b/arrow/src/datatypes/datatype.rs @@ -431,10 +431,10 @@ pub const MIN_DECIMAL_FOR_LARGER_PRECISION: [&str; 38] = [ ]; /// The maximum precision for [DataType::Decimal] values -pub const DECIMAL_MAX_PRECISION: usize = 38; +pub const DECIMAL128_MAX_PRECISION: usize = 38; /// The maximum scale for [DataType::Decimal] values -pub const DECIMAL_MAX_SCALE: usize = 38; +pub const DECIMAL128_MAX_SCALE: usize = 38; /// The maximum precision for [DataType::Decimal256] values pub const DECIMAL256_MAX_PRECISION: usize = 76; diff --git a/arrow/src/util/decimal.rs b/arrow/src/util/decimal.rs index 4d67245647d..8f9d394efd9 100644 --- a/arrow/src/util/decimal.rs +++ b/arrow/src/util/decimal.rs @@ -17,13 +17,22 @@ //! Decimal related utils +use crate::datatypes::{ + DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, +}; use crate::error::{ArrowError, Result}; use num::bigint::BigInt; +use num::Signed; use std::cmp::{min, Ordering}; pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { /// The bit-width of the internal representation. const BIT_WIDTH: usize; + /// The maximum precision. + const MAX_PRECISION: usize; + /// The maximum scale. + const MAX_SCALE: usize; /// Tries to create a decimal value from precision, scale and bytes. /// If the length of bytes isn't same as the bit width of this decimal, @@ -36,6 +45,21 @@ pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { where Self: Sized, { + if precision > Self::MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "precision {} is greater than max {}", + precision, + Self::MAX_PRECISION + ))); + } + if scale > Self::MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is greater than max {}", + scale, + Self::MAX_SCALE + ))); + } + if precision < scale { return Err(ArrowError::InvalidArgumentError(format!( "Precision {} is less than scale {}", @@ -138,10 +162,30 @@ pub struct Decimal256 { value: [u8; 32], } +impl Decimal256 { + /// Constructs a `Decimal256` value from a `BigInt`. + pub fn from_big_int( + num: &BigInt, + precision: usize, + scale: usize, + ) -> Result { + let mut bytes = if num.is_negative() { + vec![255; 32] + } else { + vec![0; 32] + }; + let num_bytes = &num.to_signed_bytes_le(); + bytes[0..num_bytes.len()].clone_from_slice(num_bytes); + Decimal256::try_new_from_bytes(precision, scale, &bytes) + } +} + macro_rules! def_decimal { - ($ty:ident, $bit:expr) => { + ($ty:ident, $bit:expr, $max_p:expr, $max_s:expr) => { impl BasicDecimal for $ty { const BIT_WIDTH: usize = $bit; + const MAX_PRECISION: usize = $max_p; + const MAX_SCALE: usize = $max_s; fn new(precision: usize, scale: usize, bytes: &[u8]) -> Self { $ty { @@ -201,12 +245,23 @@ macro_rules! def_decimal { }; } -def_decimal!(Decimal128, 128); -def_decimal!(Decimal256, 256); +def_decimal!( + Decimal128, + 128, + DECIMAL128_MAX_PRECISION, + DECIMAL128_MAX_SCALE +); +def_decimal!( + Decimal256, + 256, + DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE +); #[cfg(test)] mod tests { use crate::util::decimal::{BasicDecimal, Decimal128, Decimal256}; + use num::{BigInt, Num}; #[test] fn decimal_128_to_string() { @@ -281,10 +336,10 @@ mod tests { // smaller than i128 minimum bytes = vec![255; 32]; bytes[31] = 128; - let value = Decimal256::try_new_from_bytes(79, 4, &bytes).unwrap(); + let value = Decimal256::try_new_from_bytes(76, 4, &bytes).unwrap(); assert_eq!( value.to_string(), - "-5744373177007483132341216834415376678658315645522012356644966081642565415.7313" + "-574437317700748313234121683441537667865831564552201235664496608164256541.5731" ); bytes = vec![255; 32]; @@ -302,4 +357,15 @@ mod tests { let integer = i128_func(value); assert_eq!(integer, 100); } + + #[test] + fn bigint_to_decimal256() { + let num = BigInt::from_str_radix("123456789", 10).unwrap(); + let value = Decimal256::from_big_int(&num, 30, 2).unwrap(); + assert_eq!(value.to_string(), "1234567.89"); + + let num = BigInt::from_str_radix("-5744373177007483132341216834415376678658315645522012356644966081642565415731", 10).unwrap(); + let value = Decimal256::from_big_int(&num, 76, 4).unwrap(); + assert_eq!(value.to_string(), "-574437317700748313234121683441537667865831564552201235664496608164256541.5731"); + } }