Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seal the decimal type. #2439

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 32 additions & 20 deletions arrow/src/array/array_decimal.rs
Expand Up @@ -34,7 +34,7 @@ use crate::datatypes::{
validate_decimal_precision, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE,
};
use crate::error::{ArrowError, Result};
use crate::util::decimal::{BasicDecimal, Decimal256};
use crate::util::decimal::{BasicDecimal, Decimal256, ValidDecimal};

/// `Decimal128Array` stores fixed width decimal numbers,
/// with a fixed precision and scale.
Expand Down Expand Up @@ -81,9 +81,11 @@ pub struct BasicDecimalArray<const BYTE_WIDTH: usize> {
scale: usize,
}

impl<const BYTE_WIDTH: usize> BasicDecimalArray<BYTE_WIDTH> {
impl<const BYTE_WIDTH: usize> BasicDecimalArray<BYTE_WIDTH>
where
BasicDecimal<BYTE_WIDTH>: ValidDecimal,
{
pub const VALUE_LENGTH: i32 = BYTE_WIDTH as i32;
const DEFAULT_TYPE: DataType = BasicDecimal::<BYTE_WIDTH>::DEFAULT_TYPE;
pub const MAX_PRECISION: usize = BasicDecimal::<BYTE_WIDTH>::MAX_PRECISION;
pub const MAX_SCALE: usize = BasicDecimal::<BYTE_WIDTH>::MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(usize, usize) -> DataType =
Expand Down Expand Up @@ -174,11 +176,7 @@ impl<const BYTE_WIDTH: usize> BasicDecimalArray<BYTE_WIDTH> {
v.value_length(),
Self::VALUE_LENGTH,
);
let data_type = if Self::VALUE_LENGTH == 16 {
DataType::Decimal128(precision, scale)
} else {
DataType::Decimal256(precision, scale)
};
let data_type = Self::TYPE_CONSTRUCTOR(precision, scale);
let builder = v.into_data().into_builder().data_type(data_type);

let array_data = unsafe { builder.build_unchecked() };
Expand Down Expand Up @@ -228,11 +226,7 @@ impl<const BYTE_WIDTH: usize> BasicDecimalArray<BYTE_WIDTH> {

let list_offset = v.offset();
let child_offset = child_data.offset();
let data_type = if Self::VALUE_LENGTH == 16 {
DataType::Decimal128(precision, scale)
} else {
DataType::Decimal256(precision, scale)
};
let data_type = Self::TYPE_CONSTRUCTOR(precision, scale);
let builder = ArrayData::builder(data_type)
.len(v.len())
.add_buffer(child_data.buffers()[0].slice(child_offset))
Expand All @@ -245,7 +239,7 @@ impl<const BYTE_WIDTH: usize> BasicDecimalArray<BYTE_WIDTH> {

/// The default precision and scale used when not specified.
pub const fn default_type() -> DataType {
Self::DEFAULT_TYPE
BasicDecimal::<BYTE_WIDTH>::DEFAULT_TYPE
}

fn raw_value_data_ptr(&self) -> *const u8 {
Expand Down Expand Up @@ -445,7 +439,10 @@ impl From<BigInt> for Decimal256 {
fn build_decimal_array_from<const BYTE_WIDTH: usize>(
null_buf: BooleanBufferBuilder,
buffer: Buffer,
) -> BasicDecimalArray<BYTE_WIDTH> {
) -> BasicDecimalArray<BYTE_WIDTH>
where
BasicDecimal<BYTE_WIDTH>: ValidDecimal,
{
let data = unsafe {
ArrayData::new_unchecked(
BasicDecimalArray::<BYTE_WIDTH>::default_type(),
Expand Down Expand Up @@ -509,7 +506,10 @@ impl<Ptr: Borrow<Option<i128>>> FromIterator<Ptr> for Decimal128Array {
}
}

impl<const BYTE_WIDTH: usize> Array for BasicDecimalArray<BYTE_WIDTH> {
impl<const BYTE_WIDTH: usize> Array for BasicDecimalArray<BYTE_WIDTH>
where
BasicDecimal<BYTE_WIDTH>: ValidDecimal,
{
fn as_any(&self) -> &dyn Any {
self
}
Expand All @@ -529,7 +529,10 @@ impl<const BYTE_WIDTH: usize> From<BasicDecimalArray<BYTE_WIDTH>> for ArrayData
}
}

impl<const BYTE_WIDTH: usize> fmt::Debug for BasicDecimalArray<BYTE_WIDTH> {
impl<const BYTE_WIDTH: usize> fmt::Debug for BasicDecimalArray<BYTE_WIDTH>
where
BasicDecimal<BYTE_WIDTH>: ValidDecimal,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
Expand All @@ -547,7 +550,10 @@ impl<const BYTE_WIDTH: usize> fmt::Debug for BasicDecimalArray<BYTE_WIDTH> {
}
}

impl<'a, const BYTE_WIDTH: usize> ArrayAccessor for &'a BasicDecimalArray<BYTE_WIDTH> {
impl<'a, const BYTE_WIDTH: usize> ArrayAccessor for &'a BasicDecimalArray<BYTE_WIDTH>
where
BasicDecimal<BYTE_WIDTH>: ValidDecimal,
{
type Item = BasicDecimal<BYTE_WIDTH>;

fn value(&self, index: usize) -> Self::Item {
Expand All @@ -559,7 +565,10 @@ impl<'a, const BYTE_WIDTH: usize> ArrayAccessor for &'a BasicDecimalArray<BYTE_W
}
}

impl<'a, const BYTE_WIDTH: usize> IntoIterator for &'a BasicDecimalArray<BYTE_WIDTH> {
impl<'a, const BYTE_WIDTH: usize> IntoIterator for &'a BasicDecimalArray<BYTE_WIDTH>
where
BasicDecimal<BYTE_WIDTH>: ValidDecimal,
{
type Item = Option<BasicDecimal<BYTE_WIDTH>>;
type IntoIter = BasicDecimalIter<'a, BYTE_WIDTH>;

Expand All @@ -568,7 +577,10 @@ impl<'a, const BYTE_WIDTH: usize> IntoIterator for &'a BasicDecimalArray<BYTE_WI
}
}

impl<'a, const BYTE_WIDTH: usize> BasicDecimalArray<BYTE_WIDTH> {
impl<'a, const BYTE_WIDTH: usize> BasicDecimalArray<BYTE_WIDTH>
where
BasicDecimal<BYTE_WIDTH>: ValidDecimal,
{
/// constructs a new iterator
pub fn iter(&'a self) -> BasicDecimalIter<'a, BYTE_WIDTH> {
BasicDecimalIter::<'a, BYTE_WIDTH>::new(self)
Expand Down
88 changes: 38 additions & 50 deletions arrow/src/util/decimal.rs
Expand Up @@ -26,43 +26,42 @@ use num::bigint::BigInt;
use num::Signed;
use std::cmp::{min, Ordering};

/// Indicate which [`BasicDecimal`]s are valid.
/// Currently we only support [`Decimal128`] and [`Decimal256`].
pub trait ValidDecimal {
const MAX_PRECISION: usize;
const MAX_SCALE: usize;
const TYPE_CONSTRUCTOR: fn(usize, usize) -> DataType;
const DEFAULT_TYPE: DataType;
}

impl ValidDecimal for Decimal128 {
const MAX_PRECISION: usize = DECIMAL128_MAX_PRECISION;
const MAX_SCALE: usize = DECIMAL128_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(usize, usize) -> DataType = DataType::Decimal128;
const DEFAULT_TYPE: DataType =
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
}

impl ValidDecimal for Decimal256 {
const MAX_PRECISION: usize = DECIMAL256_MAX_PRECISION;
const MAX_SCALE: usize = DECIMAL256_MAX_SCALE;
const TYPE_CONSTRUCTOR: fn(usize, usize) -> DataType = DataType::Decimal256;
const DEFAULT_TYPE: DataType =
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
}

#[derive(Debug)]
pub struct BasicDecimal<const BYTE_WIDTH: usize> {
precision: usize,
scale: usize,
value: [u8; BYTE_WIDTH],
}

impl<const BYTE_WIDTH: usize> BasicDecimal<BYTE_WIDTH> {
#[allow(clippy::type_complexity)]
const MAX_PRECISION_SCALE_CONSTRUCTOR_DEFAULT_TYPE: (
usize,
usize,
fn(usize, usize) -> DataType,
DataType,
) = match BYTE_WIDTH {
16 => (
DECIMAL128_MAX_PRECISION,
DECIMAL128_MAX_SCALE,
DataType::Decimal128,
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
),
32 => (
DECIMAL256_MAX_PRECISION,
DECIMAL256_MAX_SCALE,
DataType::Decimal256,
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
),
_ => panic!("invalid byte width"),
};

pub const MAX_PRECISION: usize = Self::MAX_PRECISION_SCALE_CONSTRUCTOR_DEFAULT_TYPE.0;
pub const MAX_SCALE: usize = Self::MAX_PRECISION_SCALE_CONSTRUCTOR_DEFAULT_TYPE.1;
pub const TYPE_CONSTRUCTOR: fn(usize, usize) -> DataType =
Self::MAX_PRECISION_SCALE_CONSTRUCTOR_DEFAULT_TYPE.2;
pub const DEFAULT_TYPE: DataType =
Self::MAX_PRECISION_SCALE_CONSTRUCTOR_DEFAULT_TYPE.3;

impl<const BYTE_WIDTH: usize> BasicDecimal<BYTE_WIDTH>
where
Self: ValidDecimal,
{
/// Tries to create a decimal value from precision, scale and bytes.
/// The bytes should be stored in little-endian order.
///
Expand All @@ -78,35 +77,24 @@ impl<const BYTE_WIDTH: usize> BasicDecimal<BYTE_WIDTH> {
Self: Sized,
{
if precision > Self::MAX_PRECISION {
return Err(ArrowError::InvalidArgumentError(format!(
Err(ArrowError::InvalidArgumentError(format!(
"precision {} is greater than max {}",
precision,
Self::MAX_PRECISION
)));
}
if scale > Self::MAX_SCALE {
return Err(ArrowError::InvalidArgumentError(format!(
)))
} else if scale > Self::MAX_SCALE {
Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than max {}",
scale,
Self::MAX_SCALE
)));
}

if precision < scale {
return Err(ArrowError::InvalidArgumentError(format!(
)))
} else if precision < scale {
Err(ArrowError::InvalidArgumentError(format!(
"Precision {} is less than scale {}",
precision, scale
)));
}

if bytes.len() == BYTE_WIDTH {
Ok(Self::new(precision, scale, bytes))
} else {
Err(ArrowError::InvalidArgumentError(format!(
"Input to Decimal{} must be {} bytes",
BYTE_WIDTH * 8,
BYTE_WIDTH
)))
} else {
Ok(Self::new(precision, scale, bytes))
}
}

Expand Down