Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ArrowNumericType for Float16Type #2810

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 25 additions & 6 deletions arrow/src/compute/kernels/arithmetic.rs
Expand Up @@ -332,9 +332,7 @@ where

// process data in chunks of 64 elements since we also get 64 bits of validity information at a time

// safety: result is newly created above, always written as a T below
let mut result_chunks =
unsafe { result.typed_data_mut().chunks_exact_mut(64) };
let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are drive-by fixes that date from #1866

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 -- as below we probably can remove the old //safety comments too

let mut left_chunks = left.values().chunks_exact(64);
let mut right_chunks = right.values().chunks_exact(64);

Expand Down Expand Up @@ -380,9 +378,7 @@ where
)?;
}
None => {
// safety: result is newly created above, always written as a T below
let mut result_chunks =
unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
let mut left_chunks = left.values().chunks_exact(lanes);
let mut right_chunks = right.values().chunks_exact(lanes);

Expand Down Expand Up @@ -1611,6 +1607,7 @@ mod tests {
use crate::array::Int32Array;
use crate::datatypes::{Date64Type, Int32Type, Int8Type};
use chrono::NaiveDate;
use half::f16;

#[test]
fn test_primitive_array_add() {
Expand Down Expand Up @@ -2898,4 +2895,26 @@ mod tests {
let division_by_zero = divide_scalar_opt_dyn::<Int32Type>(&a, 0);
assert_eq!(&expected, &division_by_zero.unwrap());
}

#[test]
fn test_sum_f16() {
let a = Float16Array::from_iter_values([
f16::from_f32(0.1),
f16::from_f32(0.2),
f16::from_f32(1.5),
f16::from_f32(-0.1),
]);
let b = Float16Array::from_iter_values([
f16::from_f32(5.1),
f16::from_f32(6.2),
f16::from_f32(-1.),
f16::from_f32(-2.1),
]);
let expected = Float16Array::from_iter_values(
a.values().iter().zip(b.values()).map(|(a, b)| a + b),
);

let c = add(&a, &b).unwrap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

assert_eq!(c, expected);
}
}
7 changes: 2 additions & 5 deletions arrow/src/compute/kernels/comparison.rs
Expand Up @@ -1792,7 +1792,6 @@ where
.iter()
.map(|key| {
key.map(|key| unsafe {
// safety lengths were verified above
let key = key.as_usize();
dict_comparison.value_unchecked(key)
})
Expand Down Expand Up @@ -1845,8 +1844,7 @@ where
let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE);
let mut right_chunks = right.values().chunks_exact(CHUNK_SIZE);

// safety: result is newly created above, always written as a T below
let result_chunks = unsafe { result.typed_data_mut() };
let result_chunks = result.typed_data_mut();
let result_remainder = left_chunks
.borrow_mut()
.zip(right_chunks.borrow_mut())
Expand Down Expand Up @@ -1937,8 +1935,7 @@ where
let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE);
let simd_right = T::init(right);

// safety: result is newly created above, always written as a T below
let result_chunks = unsafe { result.typed_data_mut() };
let result_chunks = result.typed_data_mut();
let result_remainder =
left_chunks
.borrow_mut()
Expand Down
96 changes: 96 additions & 0 deletions arrow/src/datatypes/numeric.rs
Expand Up @@ -366,6 +366,102 @@ make_numeric_type!(DurationMillisecondType, i64, i64x8, m64x8);
make_numeric_type!(DurationMicrosecondType, i64, i64x8, m64x8);
make_numeric_type!(DurationNanosecondType, i64, i64x8, m64x8);

#[cfg(not(feature = "simd"))]
impl ArrowNumericType for Float16Type {}

#[cfg(feature = "simd")]
impl ArrowNumericType for Float16Type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps an expert in SIMD / GPU could give this a look as well (no need to hold merging up) -- I though f16 was primarily useful when doing SIMD/GPU so perhaps there some special support here we could use

In any event, this seems like a good step forward to me 👍

Perhaps @HaoYang670 or @jimexist (who originally contributed f16 support in #890) might have some ideas

type Simd = <Float32Type as ArrowNumericType>::Simd;
type SimdMask = <Float32Type as ArrowNumericType>::SimdMask;

fn lanes() -> usize {
Float32Type::lanes()
}

fn init(value: Self::Native) -> Self::Simd {
Float32Type::init(value.to_f32())
}

fn load(slice: &[Self::Native]) -> Self::Simd {
let mut s = [0_f32; Self::Simd::lanes()];
s.iter_mut().zip(slice).for_each(|(o, a)| *o = a.to_f32());
Float32Type::load(&s)
}

fn mask_init(value: bool) -> Self::SimdMask {
Float32Type::mask_init(value)
}

fn mask_from_u64(mask: u64) -> Self::SimdMask {
Float32Type::mask_from_u64(mask)
}

fn mask_to_u64(mask: &Self::SimdMask) -> u64 {
Float32Type::mask_to_u64(mask)
}

fn mask_get(mask: &Self::SimdMask, idx: usize) -> bool {
Float32Type::mask_get(mask, idx)
}

fn mask_set(mask: Self::SimdMask, idx: usize, value: bool) -> Self::SimdMask {
Float32Type::mask_set(mask, idx, value)
}

fn mask_select(mask: Self::SimdMask, a: Self::Simd, b: Self::Simd) -> Self::Simd {
Float32Type::mask_select(mask, a, b)
}

fn mask_any(mask: Self::SimdMask) -> bool {
Float32Type::mask_any(mask)
}

fn bin_op<F: Fn(Self::Simd, Self::Simd) -> Self::Simd>(
left: Self::Simd,
right: Self::Simd,
op: F,
) -> Self::Simd {
op(left, right)
}

fn eq(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
Float32Type::eq(left, right)
}

fn ne(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
Float32Type::ne(left, right)
}

fn lt(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
Float32Type::lt(left, right)
}

fn le(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
Float32Type::le(left, right)
}

fn gt(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
Float32Type::gt(left, right)
}

fn ge(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
Float32Type::ge(left, right)
}

fn write(simd_result: Self::Simd, slice: &mut [Self::Native]) {
let mut s = [0_f32; Self::Simd::lanes()];
Float32Type::write(simd_result, &mut s);
slice
.iter_mut()
.zip(s)
.for_each(|(o, i)| *o = half::f16::from_f32(i))
}

fn unary_op<F: Fn(Self::Simd) -> Self::Simd>(a: Self::Simd, op: F) -> Self::Simd {
Float32Type::unary_op(a, op)
}
}

#[cfg(feature = "simd")]
pub trait ArrowFloatNumericType: ArrowNumericType {
fn pow(base: Self::Simd, raise: Self::Simd) -> Self::Simd;
Expand Down