Skip to content

Commit

Permalink
Expose ArrowNativeTypeOp (#2840)
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Oct 7, 2022
1 parent 37c8679 commit 8dd94a9
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 92 deletions.
3 changes: 1 addition & 2 deletions arrow/src/compute/kernels/aggregate.rs
Expand Up @@ -27,8 +27,7 @@ use crate::array::{
as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray,
GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
};
use crate::datatypes::native_op::ArrowNativeTypeOp;
use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType};
use crate::datatypes::{ArrowNativeType, ArrowNativeTypeOp, ArrowNumericType, DataType};
use crate::error::Result;
use crate::util::bit_iterator::BitIndexIterator;

Expand Down
2 changes: 1 addition & 1 deletion arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -34,7 +34,7 @@ use crate::compute::{
binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn,
};
use crate::datatypes::{
native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
};
#[cfg(feature = "dyn_arith_dict")]
Expand Down
14 changes: 7 additions & 7 deletions arrow/src/compute/kernels/comparison.rs
Expand Up @@ -27,13 +27,13 @@ use crate::array::*;
use crate::buffer::{buffer_unary_not, Buffer, MutableBuffer};
use crate::compute::util::combine_option_bitmap;
use crate::datatypes::{
native_op::ArrowNativeTypeOp, ArrowNativeType, ArrowNumericType, DataType,
Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType, Time32MillisecondType, Time32SecondType,
Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
ArrowNativeType, ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type,
Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type,
UInt8Type,
};
#[allow(unused_imports)]
use crate::downcast_dictionary_array;
Expand Down
164 changes: 82 additions & 82 deletions arrow/src/datatypes/native.rs
Expand Up @@ -16,114 +16,113 @@
// under the License.

use crate::error::{ArrowError, Result};
pub use arrow_array::ArrowPrimitiveType;
pub use arrow_buffer::{ArrowNativeType, ToByteSlice};
use half::f16;
use num::Zero;
use std::ops::{Add, Div, Mul, Rem, Sub};

pub use arrow_array::ArrowPrimitiveType;
mod private {
pub trait Sealed {}
}

pub(crate) mod native_op {
use super::ArrowNativeType;
use crate::error::{ArrowError, Result};
use num::Zero;
use std::ops::{Add, Div, Mul, Rem, Sub};

/// Trait for ArrowNativeType to provide overflow-checking and non-overflow-checking
/// variants for arithmetic operations. For floating point types, this provides some
/// default implementations. Integer types that need to deal with overflow can implement
/// this trait.
///
/// The APIs with `_wrapping` suffix are the variant of non-overflow-checking. If overflow
/// occurred, they will supposedly wrap around the boundary of the type.
///
/// The APIs with `_checked` suffix are the variant of overflow-checking which return `None`
/// if overflow occurred.
pub trait ArrowNativeTypeOp:
ArrowNativeType
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
+ Rem<Output = Self>
+ Zero
{
fn add_checked(self, rhs: Self) -> Result<Self> {
Ok(self + rhs)
}
/// Trait for ArrowNativeType to provide overflow-checking and non-overflow-checking
/// variants for arithmetic operations. For floating point types, this provides some
/// default implementations. Integer types that need to deal with overflow can implement
/// this trait.
///
/// The APIs with `_wrapping` suffix are the variant of non-overflow-checking. If overflow
/// occurred, they will supposedly wrap around the boundary of the type.
///
/// The APIs with `_checked` suffix are the variant of overflow-checking which return `None`
/// if overflow occurred.
pub trait ArrowNativeTypeOp:
ArrowNativeType
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
+ Rem<Output = Self>
+ Zero
+ private::Sealed
{
fn add_checked(self, rhs: Self) -> Result<Self> {
Ok(self + rhs)
}

fn add_wrapping(self, rhs: Self) -> Self {
self + rhs
}
fn add_wrapping(self, rhs: Self) -> Self {
self + rhs
}

fn sub_checked(self, rhs: Self) -> Result<Self> {
Ok(self - rhs)
}
fn sub_checked(self, rhs: Self) -> Result<Self> {
Ok(self - rhs)
}

fn sub_wrapping(self, rhs: Self) -> Self {
self - rhs
}
fn sub_wrapping(self, rhs: Self) -> Self {
self - rhs
}

fn mul_checked(self, rhs: Self) -> Result<Self> {
Ok(self * rhs)
}
fn mul_checked(self, rhs: Self) -> Result<Self> {
Ok(self * rhs)
}

fn mul_wrapping(self, rhs: Self) -> Self {
self * rhs
}
fn mul_wrapping(self, rhs: Self) -> Self {
self * rhs
}

fn div_checked(self, rhs: Self) -> Result<Self> {
if rhs.is_zero() {
Err(ArrowError::DivideByZero)
} else {
Ok(self / rhs)
}
fn div_checked(self, rhs: Self) -> Result<Self> {
if rhs.is_zero() {
Err(ArrowError::DivideByZero)
} else {
Ok(self / rhs)
}
}

fn div_wrapping(self, rhs: Self) -> Self {
self / rhs
}
fn div_wrapping(self, rhs: Self) -> Self {
self / rhs
}

fn mod_checked(self, rhs: Self) -> Result<Self> {
if rhs.is_zero() {
Err(ArrowError::DivideByZero)
} else {
Ok(self % rhs)
}
fn mod_checked(self, rhs: Self) -> Result<Self> {
if rhs.is_zero() {
Err(ArrowError::DivideByZero)
} else {
Ok(self % rhs)
}
}

fn mod_wrapping(self, rhs: Self) -> Self {
self % rhs
}
fn mod_wrapping(self, rhs: Self) -> Self {
self % rhs
}

fn is_eq(self, rhs: Self) -> bool {
self == rhs
}
fn is_eq(self, rhs: Self) -> bool {
self == rhs
}

fn is_ne(self, rhs: Self) -> bool {
self != rhs
}
fn is_ne(self, rhs: Self) -> bool {
self != rhs
}

fn is_lt(self, rhs: Self) -> bool {
self < rhs
}
fn is_lt(self, rhs: Self) -> bool {
self < rhs
}

fn is_le(self, rhs: Self) -> bool {
self <= rhs
}
fn is_le(self, rhs: Self) -> bool {
self <= rhs
}

fn is_gt(self, rhs: Self) -> bool {
self > rhs
}
fn is_gt(self, rhs: Self) -> bool {
self > rhs
}

fn is_ge(self, rhs: Self) -> bool {
self >= rhs
}
fn is_ge(self, rhs: Self) -> bool {
self >= rhs
}
}

macro_rules! native_type_op {
($t:tt) => {
impl native_op::ArrowNativeTypeOp for $t {
impl private::Sealed for $t {}
impl ArrowNativeTypeOp for $t {
fn add_checked(self, rhs: Self) -> Result<Self> {
self.checked_add(rhs).ok_or_else(|| {
ArrowError::ComputeError(format!(
Expand Down Expand Up @@ -212,7 +211,8 @@ native_type_op!(u64);

macro_rules! native_type_float_op {
($t:tt) => {
impl native_op::ArrowNativeTypeOp for $t {
impl private::Sealed for $t {}
impl ArrowNativeTypeOp for $t {
fn is_eq(self, rhs: Self) -> bool {
self.total_cmp(&rhs).is_eq()
}
Expand Down

0 comments on commit 8dd94a9

Please sign in to comment.