Skip to content

Commit

Permalink
Move unary kernels to arrow-array (#2787) (#2789)
Browse files Browse the repository at this point in the history
* Move primitive arity kernels (#2787)

* Fix doc
  • Loading branch information
tustvold committed Sep 27, 2022
1 parent 333633e commit 2ba1307
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 87 deletions.
119 changes: 104 additions & 15 deletions arrow-array/src/array/primitive_array.rs
Expand Up @@ -15,14 +15,15 @@
// specific language governing permissions and limitations
// under the License.

use crate::builder::{BooleanBufferBuilder, PrimitiveBuilder};
use crate::builder::{BooleanBufferBuilder, BufferBuilder, PrimitiveBuilder};
use crate::iterator::PrimitiveIter;
use crate::raw_pointer::RawPtrBox;
use crate::temporal_conversions::{as_date, as_datetime, as_duration, as_time};
use crate::trusted_len::trusted_len_unzip;
use crate::types::*;
use crate::{print_long_array, Array, ArrayAccessor};
use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer};
use arrow_data::bit_iterator::try_for_each_valid_idx;
use arrow_data::ArrayData;
use arrow_schema::DataType;
use chrono::{Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
Expand Down Expand Up @@ -298,20 +299,10 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {

/// Creates a PrimitiveArray based on a constant value with `count` elements
pub fn from_value(value: T::Native, count: usize) -> Self {
// # Safety: iterator (0..count) correctly reports its length
let val_buf = unsafe { Buffer::from_trusted_len_iter((0..count).map(|_| value)) };
let data = unsafe {
ArrayData::new_unchecked(
T::DATA_TYPE,
val_buf.len() / std::mem::size_of::<<T as ArrowPrimitiveType>::Native>(),
None,
None,
0,
vec![val_buf],
vec![],
)
};
PrimitiveArray::from(data)
unsafe {
let val_buf = Buffer::from_trusted_len_iter((0..count).map(|_| value));
build_primitive_array(count, val_buf, 0, None)
}
}

/// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i`
Expand All @@ -332,6 +323,104 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
) -> impl Iterator<Item = Option<T::Native>> + 'a {
indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index)))
}

/// Applies an unary and infallible function to a primitive array.
/// This is the fastest way to perform an operation on a primitive array when
/// the benefits of a vectorized operation outweigh the cost of branching nulls and non-nulls.
///
/// # Implementation
///
/// This will apply the function for all values, including those on null slots.
/// This implies that the operation must be infallible for any value of the corresponding type
/// or this function may panic.
/// # Example
/// ```rust
/// # use arrow_array::{Int32Array, types::Int32Type};
/// # fn main() {
/// let array = Int32Array::from(vec![Some(5), Some(7), None]);
/// let c = array.unary(|x| x * 2 + 1);
/// assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None]));
/// # }
/// ```
pub fn unary<F, O>(&self, op: F) -> PrimitiveArray<O>
where
O: ArrowPrimitiveType,
F: Fn(T::Native) -> O::Native,
{
let data = self.data();
let len = self.len();
let null_count = self.null_count();

let null_buffer = data.null_buffer().map(|b| b.bit_slice(data.offset(), len));
let values = self.values().iter().map(|v| op(*v));
// JUSTIFICATION
// Benefit
// ~60% speedup
// Soundness
// `values` is an iterator with a known size because arrays are sized.
let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
unsafe { build_primitive_array(len, buffer, null_count, null_buffer) }
}

/// Applies a unary and fallible function to all valid values in a primitive array
///
/// This is unlike [`Self::unary`] which will apply an infallible function to all rows
/// regardless of validity, in many cases this will be significantly faster and should
/// be preferred if `op` is infallible.
///
/// Note: LLVM is currently unable to effectively vectorize fallible operations
pub fn try_unary<F, O, E>(&self, op: F) -> Result<PrimitiveArray<O>, E>
where
O: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<O::Native, E>,
{
let data = self.data();
let len = self.len();
let null_count = self.null_count();

if null_count == 0 {
let values = self.values().iter().map(|v| op(*v));
// JUSTIFICATION
// Benefit
// ~60% speedup
// Soundness
// `values` is an iterator with a known size because arrays are sized.
let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
return Ok(unsafe { build_primitive_array(len, buffer, 0, None) });
}

let null_buffer = data.null_buffer().map(|b| b.bit_slice(data.offset(), len));
let mut buffer = BufferBuilder::<O::Native>::new(len);
buffer.append_n_zeroed(len);
let slice = buffer.as_slice_mut();

try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| {
unsafe { *slice.get_unchecked_mut(idx) = op(self.value_unchecked(idx))? };
Ok::<_, E>(())
})?;

Ok(unsafe {
build_primitive_array(len, buffer.finish(), null_count, null_buffer)
})
}
}

#[inline]
unsafe fn build_primitive_array<O: ArrowPrimitiveType>(
len: usize,
buffer: Buffer,
null_count: usize,
null_buffer: Option<Buffer>,
) -> PrimitiveArray<O> {
PrimitiveArray::from(ArrayData::new_unchecked(
O::DATA_TYPE,
len,
Some(null_count),
null_buffer,
0,
vec![buffer],
vec![],
))
}

impl<T: ArrowPrimitiveType> From<PrimitiveArray<T>> for ArrayData {
Expand Down
76 changes: 4 additions & 72 deletions arrow/src/compute/kernels/arity.rs
Expand Up @@ -48,92 +48,24 @@ unsafe fn build_primitive_array<O: ArrowPrimitiveType>(
))
}

/// Applies an unary and infallible function to a primitive array.
/// This is the fastest way to perform an operation on a primitive array when
/// the benefits of a vectorized operation outweigh the cost of branching nulls and non-nulls.
///
/// # Implementation
///
/// This will apply the function for all values, including those on null slots.
/// This implies that the operation must be infallible for any value of the corresponding type
/// or this function may panic.
/// # Example
/// ```rust
/// # use arrow::array::Int32Array;
/// # use arrow::datatypes::Int32Type;
/// # use arrow::compute::kernels::arity::unary;
/// # fn main() {
/// let array = Int32Array::from(vec![Some(5), Some(7), None]);
/// let c = unary::<_, _, Int32Type>(&array, |x| x * 2 + 1);
/// assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None]));
/// # }
/// ```
/// See [`PrimitiveArray::unary`]
pub fn unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
where
I: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(I::Native) -> O::Native,
{
let data = array.data();
let len = data.len();
let null_count = data.null_count();

let null_buffer = data
.null_buffer()
.map(|b| b.bit_slice(data.offset(), data.len()));

let values = array.values().iter().map(|v| op(*v));
// JUSTIFICATION
// Benefit
// ~60% speedup
// Soundness
// `values` is an iterator with a known size because arrays are sized.
let buffer = unsafe { Buffer::from_trusted_len_iter(values) };
unsafe { build_primitive_array(len, buffer, null_count, null_buffer) }
array.unary(op)
}

/// Applies a unary and fallible function to all valid values in a primitive array
///
/// This is unlike [`unary`] which will apply an infallible function to all rows regardless
/// of validity, in many cases this will be significantly faster and should be preferred
/// if `op` is infallible.
///
/// Note: LLVM is currently unable to effectively vectorize fallible operations
/// See [`PrimitiveArray::try_unary`]
pub fn try_unary<I, F, O>(array: &PrimitiveArray<I>, op: F) -> Result<PrimitiveArray<O>>
where
I: ArrowPrimitiveType,
O: ArrowPrimitiveType,
F: Fn(I::Native) -> Result<O::Native>,
{
let len = array.len();
let null_count = array.null_count();

if null_count == 0 {
let values = array.values().iter().map(|v| op(*v));
// JUSTIFICATION
// Benefit
// ~60% speedup
// Soundness
// `values` is an iterator with a known size because arrays are sized.
let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
return Ok(unsafe { build_primitive_array(len, buffer, 0, None) });
}

let null_buffer = array
.data_ref()
.null_buffer()
.map(|b| b.bit_slice(array.offset(), array.len()));

let mut buffer = BufferBuilder::<O::Native>::new(len);
buffer.append_n_zeroed(array.len());
let slice = buffer.as_slice_mut();

try_for_each_valid_idx(array.len(), 0, null_count, null_buffer.as_deref(), |idx| {
unsafe { *slice.get_unchecked_mut(idx) = op(array.value_unchecked(idx))? };
Ok::<_, ArrowError>(())
})?;

Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) })
array.try_unary(op)
}

/// A helper function that applies an infallible unary function to a dictionary array with primitive value type.
Expand Down

0 comments on commit 2ba1307

Please sign in to comment.