Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use downcast_primitive_array in arithmetic kernels #2640

Merged
merged 2 commits into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
249 changes: 249 additions & 0 deletions arrow/src/array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,255 @@ macro_rules! downcast_primitive_array {
$($p => $fallback,)*
}
};

(($values1:ident, $values2:ident) => $e:block $($p:pat => $fallback:expr $(,)*)*) => {
Copy link
Contributor

@tustvold tustvold Sep 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has an implicit assumption that $values1 and $values2 have the same type, not only is this potentially surprising as an API, but I think it changes the behaviour of the kernels which will now panic where previously they would return an error?

Adding $values2.data_type() to the match might work, but this still feels a bit confusing as an API? 🤔

I wonder if we could instead do something like this

downcast_primitive_array!(
      left => {
          let right = as_primitive_array(right);
          multiply(left, right).map(|a| Arc::new(a) as ArrayRef)
      }
      _ => Err(ArrowError::CastError(format!(
          "Unsupported data type {}, {}",
          left.data_type(), right.data_type()
      )))
  )

And rely on the fact the generic kernels constrain them to be the same type. I don't know, perhaps this is hack...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point. But I guess the suggested one will also panic on let right = as_primitive_array(right); if right is not same type?

Let me do a test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, math_op doesn't constrain two sides of op should be same type. So

downcast_primitive_array!(
    left => {
        let right = as_primitive_array(right);
        math_op(left, right, |a, b| a + b).map(|a| Arc::new(a) as ArrayRef)
    }
    _ => Err(ArrowError::CastError(format!(
       "Unsupported data type {}, {}",
        left.data_type(), right.data_type()
   )))
)

will not constrain the right side type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggested one will

Yeah, you will still need to check the type, it is what it is

doesn't constrain the type

That's why i suggested using the generic kernel not math_op directly 😃

Copy link
Member Author

@viirya viirya Sep 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh oh right. 😄

Good news is it works.

But as Float16Type doesn't implement ArrowNumericType, I need to remove Float16Type pattern from downcast_primitive_array! to make it work.

For simd feature, seems f16 related APIs are not available so it appears not easy to let Float16Type implement ArrowNumericType.

Currently I leave single argument downcast_primitive_array! untouched and stick with two argument one and make it constrain the two arguments must be same type.

match ($values1.data_type(), $values2.data_type()) {
($crate::datatypes::DataType::Int8, $crate::datatypes::DataType::Int8) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::Int8Type,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::Int8Type,
>($values2);
$e
}
($crate::datatypes::DataType::Int16, $crate::datatypes::DataType::Int16) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::Int16Type,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::Int16Type,
>($values2);
$e
}
($crate::datatypes::DataType::Int32, $crate::datatypes::DataType::Int32) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::Int32Type,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::Int32Type,
>($values2);
$e
}
($crate::datatypes::DataType::Int64, $crate::datatypes::DataType::Int64) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::Int64Type,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::Int64Type,
>($values2);
$e
}
($crate::datatypes::DataType::UInt8, $crate::datatypes::DataType::UInt8) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::UInt8Type,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::UInt8Type,
>($values2);
$e
}
($crate::datatypes::DataType::UInt16, $crate::datatypes::DataType::UInt16) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::UInt16Type,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::UInt16Type,
>($values2);
$e
}
($crate::datatypes::DataType::UInt32, $crate::datatypes::DataType::UInt32) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::UInt32Type,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::UInt32Type,
>($values2);
$e
}
($crate::datatypes::DataType::UInt64, $crate::datatypes::DataType::UInt64) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::UInt64Type,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::UInt64Type,
>($values2);
$e
}
($crate::datatypes::DataType::Float32, $crate::datatypes::DataType::Float32) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::Float32Type,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::Float32Type,
>($values2);
$e
}
($crate::datatypes::DataType::Float64, $crate::datatypes::DataType::Float64) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::Float64Type,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::Float64Type,
>($values2);
$e
}
($crate::datatypes::DataType::Date32, $crate::datatypes::DataType::Date32) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::Date32Type,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::Date32Type,
>($values2);
$e
}
($crate::datatypes::DataType::Date64, $crate::datatypes::DataType::Date64) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::Date64Type,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::Date64Type,
>($values2);
$e
}
($crate::datatypes::DataType::Time32($crate::datatypes::TimeUnit::Second), $crate::datatypes::DataType::Time32($crate::datatypes::TimeUnit::Second)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::Time32SecondType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::Time32SecondType,
>($values2);
$e
}
($crate::datatypes::DataType::Time32($crate::datatypes::TimeUnit::Millisecond), $crate::datatypes::DataType::Time32($crate::datatypes::TimeUnit::Millisecond)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::Time32MillisecondType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::Time32MillisecondType,
>($values2);
$e
}
($crate::datatypes::DataType::Time64($crate::datatypes::TimeUnit::Microsecond), $crate::datatypes::DataType::Time64($crate::datatypes::TimeUnit::Microsecond)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::Time64MicrosecondType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::Time64MicrosecondType,
>($values2);
$e
}
($crate::datatypes::DataType::Time64($crate::datatypes::TimeUnit::Nanosecond), $crate::datatypes::DataType::Time64($crate::datatypes::TimeUnit::Nanosecond)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::Time64NanosecondType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::Time64NanosecondType,
>($values2);
$e
}
($crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Second, _), $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Second, _)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::TimestampSecondType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::TimestampSecondType,
>($values2);
$e
}
($crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Millisecond, _), $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Millisecond, _)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::TimestampMillisecondType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::TimestampMillisecondType,
>($values2);
$e
}
($crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Microsecond, _), $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Microsecond, _)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::TimestampMicrosecondType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::TimestampMicrosecondType,
>($values2);
$e
}
($crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Nanosecond, _), $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Nanosecond, _)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::TimestampNanosecondType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::TimestampNanosecondType,
>($values2);
$e
}
($crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::YearMonth), $crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::YearMonth)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::IntervalYearMonthType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::IntervalYearMonthType,
>($values2);
$e
}
($crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::DayTime), $crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::DayTime)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::IntervalDayTimeType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::IntervalDayTimeType,
>($values2);
$e
}
($crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::MonthDayNano), $crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::MonthDayNano)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::IntervalMonthDayNanoType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::IntervalMonthDayNanoType,
>($values2);
$e
}
($crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Second), $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Second)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::DurationSecondType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::DurationSecondType,
>($values2);
$e
}
($crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Millisecond), $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Millisecond)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::DurationMillisecondType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::DurationMillisecondType,
>($values2);
$e
}
($crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Microsecond), $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Microsecond)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::DurationMicrosecondType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::DurationMicrosecondType,
>($values2);
$e
}
($crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Nanosecond), $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Nanosecond)) => {
let $values1 = $crate::array::as_primitive_array::<
$crate::datatypes::DurationNanosecondType,
>($values1);
let $values2 = $crate::array::as_primitive_array::<
$crate::datatypes::DurationNanosecondType,
>($values2);
$e
}
$($p => $fallback,)*
}
};
}

/// Force downcast of an [`Array`], such as an [`ArrayRef`], to
Expand Down