Skip to content

Commit

Permalink
Make DecimalArray as PrimitiveArray
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Oct 11, 2022
1 parent 5cf46d4 commit de451bb
Show file tree
Hide file tree
Showing 27 changed files with 382 additions and 1,556 deletions.
980 changes: 0 additions & 980 deletions arrow-array/src/array/decimal_array.rs

This file was deleted.

9 changes: 0 additions & 9 deletions arrow-array/src/array/mod.rs
Expand Up @@ -31,9 +31,6 @@ pub use binary_array::*;
mod boolean_array;
pub use boolean_array::*;

mod decimal_array;
pub use decimal_array::*;

mod dictionary_array;
pub use dictionary_array::*;

Expand Down Expand Up @@ -449,12 +446,6 @@ impl PartialEq for FixedSizeBinaryArray {
}
}

impl PartialEq for Decimal128Array {
fn eq(&self, other: &Self) -> bool {
self.data().eq(other.data())
}
}

impl<OffsetSize: OffsetSizeTrait> PartialEq for GenericListArray<OffsetSize> {
fn eq(&self, other: &Self) -> bool {
self.data().eq(other.data())
Expand Down
174 changes: 173 additions & 1 deletion arrow-array/src/array/primitive_array.rs
Expand Up @@ -16,6 +16,7 @@
// under the License.

use crate::builder::{BooleanBufferBuilder, BufferBuilder, PrimitiveBuilder};
use crate::decimal::Decimal;
use crate::iterator::PrimitiveIter;
use crate::raw_pointer::RawPtrBox;
use crate::temporal_conversions::{as_date, as_datetime, as_duration, as_time};
Expand All @@ -25,7 +26,7 @@ use crate::{print_long_array, Array, ArrayAccessor};
use arrow_buffer::{bit_util, i256, ArrowNativeType, Buffer, MutableBuffer};
use arrow_data::bit_iterator::try_for_each_valid_idx;
use arrow_data::ArrayData;
use arrow_schema::DataType;
use arrow_schema::{ArrowError, DataType};
use chrono::{Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
use half::f16;
use std::any::Any;
Expand Down Expand Up @@ -176,6 +177,9 @@ pub type DurationMillisecondArray = PrimitiveArray<DurationMillisecondType>;
pub type DurationMicrosecondArray = PrimitiveArray<DurationMicrosecondType>;
pub type DurationNanosecondArray = PrimitiveArray<DurationNanosecondType>;

pub type Decimal128Array = PrimitiveArray<Decimal128Type>;
pub type Decimal256Array = PrimitiveArray<Decimal256Type>;

/// Trait bridging the dynamic-typed nature of Arrow (via [`DataType`]) with the
/// static-typed nature of rust types ([`ArrowNativeType`]) for all types that implement [`ArrowNativeType`].
pub trait ArrowPrimitiveType: 'static {
Expand Down Expand Up @@ -827,6 +831,174 @@ impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
}
}

impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
/// Returns a Decimal array with the same data as self, with the
/// specified precision.
///
/// Returns an Error if:
/// 1. `precision` is larger than `T:MAX_PRECISION`
/// 2. `scale` is larger than `T::MAX_SCALE`
/// 3. `scale` is > `precision`
pub fn with_precision_and_scale(
self,
precision: u8,
scale: u8,
) -> Result<Self, ArrowError>
where
Self: Sized,
{
// validate precision and scale
self.validate_precision_scale(precision, scale)?;

// Ensure that all values are within the requested
// precision. For performance, only check if the precision is
// decreased
let p = self.precision()?;
if precision < p {
self.validate_decimal_precision(precision)?;
}

// safety: self.data is valid DataType::Decimal as checked above
let new_data_type = T::TYPE_CONSTRUCTOR(precision, scale);
let data = self.data().clone().into_builder().data_type(new_data_type);

// SAFETY
// Validated data above
Ok(unsafe { data.build_unchecked().into() })
}

/// Returns a Decimal array with the same data as self, with the
/// specified precision.
///
/// # Safety
///
/// This doesn't validate decimal values with specified precision.
pub unsafe fn unchecked_with_precision_and_scale(
self,
precision: u8,
scale: u8,
) -> Result<Self, ArrowError>
where
Self: Sized,
{
// validate precision and scale
self.validate_precision_scale(precision, scale)?;

// safety: self.data is valid DataType::Decimal as checked above
let new_data_type = T::TYPE_CONSTRUCTOR(precision, scale);
let data = self.data().clone().into_builder().data_type(new_data_type);

Ok(data.build_unchecked().into())
}

// validate that the new precision and scale are valid or not
fn validate_precision_scale(
&self,
precision: u8,
scale: u8,
) -> Result<(), ArrowError> {
if precision > T::MAX_PRECISION {
return Err(ArrowError::InvalidArgumentError(format!(
"precision {} is greater than max {}",
precision,
Decimal128Type::MAX_PRECISION
)));
}
if scale > T::MAX_SCALE {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than max {}",
scale,
Decimal128Type::MAX_SCALE
)));
}
if scale > precision {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than precision {}",
scale, precision
)));
}

Ok(())
}

// Validates values in this array can be properly interpreted
// with the specified precision.
fn validate_decimal_precision(&self, precision: u8) -> Result<(), ArrowError> {
(0..self.len()).try_for_each(|idx| {
if self.is_valid(idx) {
let decimal = unsafe { self.value_unchecked(idx) };
T::validate_decimal_precision(decimal, precision)
} else {
Ok(())
}
})
}

pub fn value_as_string(&self, row: usize) -> Result<String, ArrowError> {
let p = self.precision()?;
let s = self.scale()?;
Ok(Decimal::<T>::new(p, s, &T::to_native(self.value(row))).to_string())
}

pub fn precision(&self) -> Result<u8, ArrowError> {
match T::BYTE_LENGTH {
16 => {
if let DataType::Decimal128(p, _) = self.data().data_type() {
Ok(*p)
} else {
Err(ArrowError::InvalidArgumentError(format!(
"Decimal128Array datatype is not DataType::Decimal128 but {}",
self.data_type()
)))
}
}
32 => {
if let DataType::Decimal256(p, _) = self.data().data_type() {
Ok(*p)
} else {
Err(ArrowError::InvalidArgumentError(format!(
"Decimal256Array datatype is not DataType::Decimal256 but {}",
self.data_type()
)))
}
}
other => Err(ArrowError::InvalidArgumentError(format!(
"Unsupported byte length for decimal array {}",
other
))),
}
}

pub fn scale(&self) -> Result<u8, ArrowError> {
match T::BYTE_LENGTH {
16 => {
if let DataType::Decimal128(_, s) = self.data().data_type() {
Ok(*s)
} else {
Err(ArrowError::InvalidArgumentError(format!(
"Decimal128Array datatype is not DataType::Decimal128 but {}",
self.data_type()
)))
}
}
32 => {
if let DataType::Decimal256(_, s) = self.data().data_type() {
Ok(*s)
} else {
Err(ArrowError::InvalidArgumentError(format!(
"Decimal256Array datatype is not DataType::Decimal256 but {}",
self.data_type()
)))
}
}
other => Err(ArrowError::InvalidArgumentError(format!(
"Unsupported byte length for decimal array {}",
other
))),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit de451bb

Please sign in to comment.