Skip to content

Commit

Permalink
collation the validate precision code for decimal array
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 committed Aug 15, 2022
1 parent 5e27d93 commit ed60bbc
Showing 1 changed file with 72 additions and 95 deletions.
167 changes: 72 additions & 95 deletions arrow/src/array/array_decimal.rs
Expand Up @@ -251,39 +251,6 @@ impl<const BYTE_WIDTH: usize> BasicDecimalArray<BYTE_WIDTH> {
fn raw_value_data_ptr(&self) -> *const u8 {
self.value_data.as_ptr()
}
}

impl Decimal128Array {
/// Creates a [Decimal128Array] with default precision and scale,
/// based on an iterator of `i128` values without nulls
pub fn from_iter_values<I: IntoIterator<Item = i128>>(iter: I) -> Self {
let val_buf: Buffer = iter.into_iter().collect();
let data = unsafe {
ArrayData::new_unchecked(
Self::default_type(),
val_buf.len() / std::mem::size_of::<i128>(),
None,
None,
0,
vec![val_buf],
vec![],
)
};
Decimal128Array::from(data)
}

// Validates decimal values in this array can be properly interpreted
// with the specified precision.
fn validate_decimal_precision(&self, precision: usize) -> Result<()> {
(0..self.len()).try_for_each(|idx| {
if self.is_valid(idx) {
let decimal = unsafe { self.value_unchecked(idx) };
validate_decimal_precision(decimal.as_i128(), precision)
} else {
Ok(())
}
})
}

/// Returns a Decimal array with the same data as self, with the
/// specified precision.
Expand All @@ -296,6 +263,23 @@ impl Decimal128Array {
where
Self: Sized,
{
// validate precision and scale
self.validate_precision_scale(precision, scale)?;

// Ensure that all values are within the requested
// precision. For performance, only check if the precision is
// decreased
if precision < self.precision {
self.validate_data(precision)?;
}

// safety: self.data is valid DataType::Decimal as checked above
let new_data_type = Self::TYPE_CONSTRUCTOR(precision, scale);
Ok(self.data().clone().with_data_type(new_data_type).into())
}

// validate that the new precision and scale are valid or not
fn validate_precision_scale(&self, precision: usize, scale: usize) -> Result<()> {
if precision > Self::MAX_PRECISION {
return Err(ArrowError::InvalidArgumentError(format!(
"precision {} is greater than max {}",
Expand All @@ -316,26 +300,67 @@ impl Decimal128Array {
scale, precision
)));
}

// Ensure that all values are within the requested
// precision. For performance, only check if the precision is
// decreased
if precision < self.precision {
self.validate_decimal_precision(precision)?;
}

let data_type = Self::TYPE_CONSTRUCTOR(self.precision, self.scale);
assert_eq!(self.data().data_type(), &data_type);

// safety: self.data is valid DataType::Decimal as checked above
let new_data_type = Self::TYPE_CONSTRUCTOR(precision, scale);
Ok(())
}

// validate all the data in the array are valid within the new precision or not
fn validate_data(&self, precision: usize) -> Result<()> {
match BYTE_WIDTH {
16 => self
.as_any()
.downcast_ref::<Decimal128Array>()
.unwrap()
.validate_decimal_precision(precision),
32 => self
.as_any()
.downcast_ref::<Decimal256Array>()
.unwrap()
.validate_decimal_precision(precision),
other_width => {
panic!("invalid byte width {}", other_width);
}
}
}
}

Ok(self.data().clone().with_data_type(new_data_type).into())
impl Decimal128Array {
/// Creates a [Decimal128Array] with default precision and scale,
/// based on an iterator of `i128` values without nulls
pub fn from_iter_values<I: IntoIterator<Item = i128>>(iter: I) -> Self {
let val_buf: Buffer = iter.into_iter().collect();
let data = unsafe {
ArrayData::new_unchecked(
Self::default_type(),
val_buf.len() / std::mem::size_of::<i128>(),
None,
None,
0,
vec![val_buf],
vec![],
)
};
Decimal128Array::from(data)
}

// Validates decimal128 values in this array can be properly interpreted
// with the specified precision.
fn validate_decimal_precision(&self, precision: usize) -> Result<()> {
(0..self.len()).try_for_each(|idx| {
if self.is_valid(idx) {
let decimal = unsafe { self.value_unchecked(idx) };
validate_decimal_precision(decimal.as_i128(), precision)
} else {
Ok(())
}
})
}
}

impl Decimal256Array {
// Validates decimal values in this array can be properly interpreted
// Validates decimal256 values in this array can be properly interpreted
// with the specified precision.
fn validate_decimal_precision(&self, precision: usize) -> Result<()> {
(0..self.len()).try_for_each(|idx| {
Expand All @@ -353,54 +378,6 @@ impl Decimal256Array {
}
})
}

/// Returns a Decimal array with the same data as self, with the
/// specified precision.
///
/// Returns an Error if:
/// 1. `precision` is larger than [`Self::MAX_PRECISION`]
/// 2. `scale` is larger than [`Self::MAX_SCALE`];
/// 3. `scale` is > `precision`
pub fn with_precision_and_scale(self, precision: usize, scale: usize) -> Result<Self>
where
Self: Sized,
{
if precision > Self::MAX_PRECISION {
return Err(ArrowError::InvalidArgumentError(format!(
"precision {} is greater than max {}",
precision,
Self::MAX_PRECISION
)));
}
if scale > Self::MAX_SCALE {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than max {}",
scale,
Self::MAX_SCALE
)));
}
if scale > precision {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than precision {}",
scale, precision
)));
}

// Ensure that all values are within the requested
// precision. For performance, only check if the precision is
// decreased
if precision < self.precision {
self.validate_decimal_precision(precision)?;
}

let data_type = Self::TYPE_CONSTRUCTOR(self.precision, self.scale);
assert_eq!(self.data().data_type(), &data_type);

// safety: self.data is valid DataType::Decimal as checked above
let new_data_type = Self::TYPE_CONSTRUCTOR(precision, scale);

Ok(self.data().clone().with_data_type(new_data_type).into())
}
}

impl<const BYTE_WIDTH: usize> From<ArrayData> for BasicDecimalArray<BYTE_WIDTH> {
Expand Down Expand Up @@ -950,7 +927,7 @@ mod tests {
Decimal256::from_big_int(
&value1,
DECIMAL256_MAX_PRECISION,
DECIMAL_DEFAULT_SCALE
DECIMAL_DEFAULT_SCALE,
)
.unwrap(),
array.value(0)
Expand All @@ -961,7 +938,7 @@ mod tests {
Decimal256::from_big_int(
&value2,
DECIMAL256_MAX_PRECISION,
DECIMAL_DEFAULT_SCALE
DECIMAL_DEFAULT_SCALE,
)
.unwrap(),
array.value(2)
Expand Down

0 comments on commit ed60bbc

Please sign in to comment.