diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 1c28c989524..b2e95ad5e4a 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -332,9 +332,7 @@ where // process data in chunks of 64 elements since we also get 64 bits of validity information at a time - // safety: result is newly created above, always written as a T below - let mut result_chunks = - unsafe { result.typed_data_mut().chunks_exact_mut(64) }; + let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64); let mut left_chunks = left.values().chunks_exact(64); let mut right_chunks = right.values().chunks_exact(64); @@ -380,9 +378,7 @@ where )?; } None => { - // safety: result is newly created above, always written as a T below - let mut result_chunks = - unsafe { result.typed_data_mut().chunks_exact_mut(lanes) }; + let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); let mut left_chunks = left.values().chunks_exact(lanes); let mut right_chunks = right.values().chunks_exact(lanes); @@ -1611,6 +1607,7 @@ mod tests { use crate::array::Int32Array; use crate::datatypes::{Date64Type, Int32Type, Int8Type}; use chrono::NaiveDate; + use half::f16; #[test] fn test_primitive_array_add() { @@ -2898,4 +2895,26 @@ mod tests { let division_by_zero = divide_scalar_opt_dyn::(&a, 0); assert_eq!(&expected, &division_by_zero.unwrap()); } + + #[test] + fn test_sum_f16() { + let a = Float16Array::from_iter_values([ + f16::from_f32(0.1), + f16::from_f32(0.2), + f16::from_f32(1.5), + f16::from_f32(-0.1), + ]); + let b = Float16Array::from_iter_values([ + f16::from_f32(5.1), + f16::from_f32(6.2), + f16::from_f32(-1.), + f16::from_f32(-2.1), + ]); + let expected = Float16Array::from_iter_values( + a.values().iter().zip(b.values()).map(|(a, b)| a + b), + ); + + let c = add(&a, &b).unwrap(); + assert_eq!(c, expected); + } } diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 49aecfb67fa..1ea433150f0 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -1792,7 +1792,6 @@ where .iter() .map(|key| { key.map(|key| unsafe { - // safety lengths were verified above let key = key.as_usize(); dict_comparison.value_unchecked(key) }) @@ -1845,8 +1844,7 @@ where let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE); let mut right_chunks = right.values().chunks_exact(CHUNK_SIZE); - // safety: result is newly created above, always written as a T below - let result_chunks = unsafe { result.typed_data_mut() }; + let result_chunks = result.typed_data_mut(); let result_remainder = left_chunks .borrow_mut() .zip(right_chunks.borrow_mut()) @@ -1937,8 +1935,7 @@ where let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE); let simd_right = T::init(right); - // safety: result is newly created above, always written as a T below - let result_chunks = unsafe { result.typed_data_mut() }; + let result_chunks = result.typed_data_mut(); let result_remainder = left_chunks .borrow_mut() diff --git a/arrow/src/datatypes/numeric.rs b/arrow/src/datatypes/numeric.rs index b8fa87197c3..e74764d4c0e 100644 --- a/arrow/src/datatypes/numeric.rs +++ b/arrow/src/datatypes/numeric.rs @@ -366,6 +366,102 @@ make_numeric_type!(DurationMillisecondType, i64, i64x8, m64x8); make_numeric_type!(DurationMicrosecondType, i64, i64x8, m64x8); make_numeric_type!(DurationNanosecondType, i64, i64x8, m64x8); +#[cfg(not(feature = "simd"))] +impl ArrowNumericType for Float16Type {} + +#[cfg(feature = "simd")] +impl ArrowNumericType for Float16Type { + type Simd = ::Simd; + type SimdMask = ::SimdMask; + + fn lanes() -> usize { + Float32Type::lanes() + } + + fn init(value: Self::Native) -> Self::Simd { + Float32Type::init(value.to_f32()) + } + + fn load(slice: &[Self::Native]) -> Self::Simd { + let mut s = [0_f32; Self::Simd::lanes()]; + s.iter_mut().zip(slice).for_each(|(o, a)| *o = a.to_f32()); + Float32Type::load(&s) + } + + fn mask_init(value: bool) -> Self::SimdMask { + Float32Type::mask_init(value) + } + + fn mask_from_u64(mask: u64) -> Self::SimdMask { + Float32Type::mask_from_u64(mask) + } + + fn mask_to_u64(mask: &Self::SimdMask) -> u64 { + Float32Type::mask_to_u64(mask) + } + + fn mask_get(mask: &Self::SimdMask, idx: usize) -> bool { + Float32Type::mask_get(mask, idx) + } + + fn mask_set(mask: Self::SimdMask, idx: usize, value: bool) -> Self::SimdMask { + Float32Type::mask_set(mask, idx, value) + } + + fn mask_select(mask: Self::SimdMask, a: Self::Simd, b: Self::Simd) -> Self::Simd { + Float32Type::mask_select(mask, a, b) + } + + fn mask_any(mask: Self::SimdMask) -> bool { + Float32Type::mask_any(mask) + } + + fn bin_op Self::Simd>( + left: Self::Simd, + right: Self::Simd, + op: F, + ) -> Self::Simd { + op(left, right) + } + + fn eq(left: Self::Simd, right: Self::Simd) -> Self::SimdMask { + Float32Type::eq(left, right) + } + + fn ne(left: Self::Simd, right: Self::Simd) -> Self::SimdMask { + Float32Type::ne(left, right) + } + + fn lt(left: Self::Simd, right: Self::Simd) -> Self::SimdMask { + Float32Type::lt(left, right) + } + + fn le(left: Self::Simd, right: Self::Simd) -> Self::SimdMask { + Float32Type::le(left, right) + } + + fn gt(left: Self::Simd, right: Self::Simd) -> Self::SimdMask { + Float32Type::gt(left, right) + } + + fn ge(left: Self::Simd, right: Self::Simd) -> Self::SimdMask { + Float32Type::ge(left, right) + } + + fn write(simd_result: Self::Simd, slice: &mut [Self::Native]) { + let mut s = [0_f32; Self::Simd::lanes()]; + Float32Type::write(simd_result, &mut s); + slice + .iter_mut() + .zip(s) + .for_each(|(o, i)| *o = half::f16::from_f32(i)) + } + + fn unary_op Self::Simd>(a: Self::Simd, op: F) -> Self::Simd { + Float32Type::unary_op(a, op) + } +} + #[cfg(feature = "simd")] pub trait ArrowFloatNumericType: ArrowNumericType { fn pow(base: Self::Simd, raise: Self::Simd) -> Self::Simd;