From ed60bbc421d3c0a193c2b361ad52c73dd4c8f8a5 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Sun, 14 Aug 2022 20:30:05 +0800 Subject: [PATCH] collation the validate precision code for decimal array --- arrow/src/array/array_decimal.rs | 167 +++++++++++++------------------ 1 file changed, 72 insertions(+), 95 deletions(-) diff --git a/arrow/src/array/array_decimal.rs b/arrow/src/array/array_decimal.rs index ed1d3102a13..5844913b6ce 100644 --- a/arrow/src/array/array_decimal.rs +++ b/arrow/src/array/array_decimal.rs @@ -251,39 +251,6 @@ impl BasicDecimalArray { fn raw_value_data_ptr(&self) -> *const u8 { self.value_data.as_ptr() } -} - -impl Decimal128Array { - /// Creates a [Decimal128Array] with default precision and scale, - /// based on an iterator of `i128` values without nulls - pub fn from_iter_values>(iter: I) -> Self { - let val_buf: Buffer = iter.into_iter().collect(); - let data = unsafe { - ArrayData::new_unchecked( - Self::default_type(), - val_buf.len() / std::mem::size_of::(), - None, - None, - 0, - vec![val_buf], - vec![], - ) - }; - Decimal128Array::from(data) - } - - // Validates decimal values in this array can be properly interpreted - // with the specified precision. - fn validate_decimal_precision(&self, precision: usize) -> Result<()> { - (0..self.len()).try_for_each(|idx| { - if self.is_valid(idx) { - let decimal = unsafe { self.value_unchecked(idx) }; - validate_decimal_precision(decimal.as_i128(), precision) - } else { - Ok(()) - } - }) - } /// Returns a Decimal array with the same data as self, with the /// specified precision. @@ -296,6 +263,23 @@ impl Decimal128Array { where Self: Sized, { + // validate precision and scale + self.validate_precision_scale(precision, scale)?; + + // Ensure that all values are within the requested + // precision. For performance, only check if the precision is + // decreased + if precision < self.precision { + self.validate_data(precision)?; + } + + // safety: self.data is valid DataType::Decimal as checked above + let new_data_type = Self::TYPE_CONSTRUCTOR(precision, scale); + Ok(self.data().clone().with_data_type(new_data_type).into()) + } + + // validate that the new precision and scale are valid or not + fn validate_precision_scale(&self, precision: usize, scale: usize) -> Result<()> { if precision > Self::MAX_PRECISION { return Err(ArrowError::InvalidArgumentError(format!( "precision {} is greater than max {}", @@ -316,26 +300,67 @@ impl Decimal128Array { scale, precision ))); } - - // Ensure that all values are within the requested - // precision. For performance, only check if the precision is - // decreased - if precision < self.precision { - self.validate_decimal_precision(precision)?; - } - let data_type = Self::TYPE_CONSTRUCTOR(self.precision, self.scale); assert_eq!(self.data().data_type(), &data_type); - // safety: self.data is valid DataType::Decimal as checked above - let new_data_type = Self::TYPE_CONSTRUCTOR(precision, scale); + Ok(()) + } + + // validate all the data in the array are valid within the new precision or not + fn validate_data(&self, precision: usize) -> Result<()> { + match BYTE_WIDTH { + 16 => self + .as_any() + .downcast_ref::() + .unwrap() + .validate_decimal_precision(precision), + 32 => self + .as_any() + .downcast_ref::() + .unwrap() + .validate_decimal_precision(precision), + other_width => { + panic!("invalid byte width {}", other_width); + } + } + } +} - Ok(self.data().clone().with_data_type(new_data_type).into()) +impl Decimal128Array { + /// Creates a [Decimal128Array] with default precision and scale, + /// based on an iterator of `i128` values without nulls + pub fn from_iter_values>(iter: I) -> Self { + let val_buf: Buffer = iter.into_iter().collect(); + let data = unsafe { + ArrayData::new_unchecked( + Self::default_type(), + val_buf.len() / std::mem::size_of::(), + None, + None, + 0, + vec![val_buf], + vec![], + ) + }; + Decimal128Array::from(data) + } + + // Validates decimal128 values in this array can be properly interpreted + // with the specified precision. + fn validate_decimal_precision(&self, precision: usize) -> Result<()> { + (0..self.len()).try_for_each(|idx| { + if self.is_valid(idx) { + let decimal = unsafe { self.value_unchecked(idx) }; + validate_decimal_precision(decimal.as_i128(), precision) + } else { + Ok(()) + } + }) } } impl Decimal256Array { - // Validates decimal values in this array can be properly interpreted + // Validates decimal256 values in this array can be properly interpreted // with the specified precision. fn validate_decimal_precision(&self, precision: usize) -> Result<()> { (0..self.len()).try_for_each(|idx| { @@ -353,54 +378,6 @@ impl Decimal256Array { } }) } - - /// Returns a Decimal array with the same data as self, with the - /// specified precision. - /// - /// Returns an Error if: - /// 1. `precision` is larger than [`Self::MAX_PRECISION`] - /// 2. `scale` is larger than [`Self::MAX_SCALE`]; - /// 3. `scale` is > `precision` - pub fn with_precision_and_scale(self, precision: usize, scale: usize) -> Result - where - Self: Sized, - { - if precision > Self::MAX_PRECISION { - return Err(ArrowError::InvalidArgumentError(format!( - "precision {} is greater than max {}", - precision, - Self::MAX_PRECISION - ))); - } - if scale > Self::MAX_SCALE { - return Err(ArrowError::InvalidArgumentError(format!( - "scale {} is greater than max {}", - scale, - Self::MAX_SCALE - ))); - } - if scale > precision { - return Err(ArrowError::InvalidArgumentError(format!( - "scale {} is greater than precision {}", - scale, precision - ))); - } - - // Ensure that all values are within the requested - // precision. For performance, only check if the precision is - // decreased - if precision < self.precision { - self.validate_decimal_precision(precision)?; - } - - let data_type = Self::TYPE_CONSTRUCTOR(self.precision, self.scale); - assert_eq!(self.data().data_type(), &data_type); - - // safety: self.data is valid DataType::Decimal as checked above - let new_data_type = Self::TYPE_CONSTRUCTOR(precision, scale); - - Ok(self.data().clone().with_data_type(new_data_type).into()) - } } impl From for BasicDecimalArray { @@ -950,7 +927,7 @@ mod tests { Decimal256::from_big_int( &value1, DECIMAL256_MAX_PRECISION, - DECIMAL_DEFAULT_SCALE + DECIMAL_DEFAULT_SCALE, ) .unwrap(), array.value(0) @@ -961,7 +938,7 @@ mod tests { Decimal256::from_big_int( &value2, DECIMAL256_MAX_PRECISION, - DECIMAL_DEFAULT_SCALE + DECIMAL_DEFAULT_SCALE, ) .unwrap(), array.value(2)