Skip to content

Commit

Permalink
Add PrimitiveArray::unary_opt (#3110)
Browse files Browse the repository at this point in the history
* Add PrimitiveArray::unary_opt

* Format

* Clippy
  • Loading branch information
tustvold committed Nov 15, 2022
1 parent 81ce601 commit b0b5d8b
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 23 deletions.
66 changes: 66 additions & 0 deletions arrow-array/src/array/primitive_array.rs
Expand Up @@ -438,6 +438,57 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
build_primitive_array(len, buffer.finish(), null_count, null_buffer)
})
}

/// Applies a unary and nullable function to all valid values in a primitive array
///
/// This is unlike [`Self::unary`] which will apply an infallible function to all rows
/// regardless of validity, in many cases this will be significantly faster and should
/// be preferred if `op` is infallible.
///
/// Note: LLVM is currently unable to effectively vectorize fallible operations
pub fn unary_opt<F, O>(&self, op: F) -> PrimitiveArray<O>
where
O: ArrowPrimitiveType,
F: Fn(T::Native) -> Option<O::Native>,
{
let data = self.data();
let len = data.len();
let offset = data.offset();
let null_count = data.null_count();
let nulls = data.null_buffer().map(|x| x.as_slice());

let mut null_builder = BooleanBufferBuilder::new(len);
match nulls {
Some(b) => null_builder.append_packed_range(offset..offset + len, b),
None => null_builder.append_n(len, true),
}

let mut buffer = BufferBuilder::<O::Native>::new(len);
buffer.append_n_zeroed(len);
let slice = buffer.as_slice_mut();

let mut out_null_count = null_count;

let _ = try_for_each_valid_idx(len, offset, null_count, nulls, |idx| {
match op(unsafe { self.value_unchecked(idx) }) {
Some(v) => unsafe { *slice.get_unchecked_mut(idx) = v },
None => {
out_null_count += 1;
null_builder.set_bit(idx, false);
}
}
Ok::<_, ()>(())
});

unsafe {
build_primitive_array(
len,
buffer.finish(),
out_null_count,
Some(null_builder.finish()),
)
}
}
}

#[inline]
Expand Down Expand Up @@ -1864,6 +1915,21 @@ mod tests {
assert!(!array.is_null(2));
}

#[test]
fn test_unary_opt() {
let array = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7]);
let r = array.unary_opt::<_, Int32Type>(|x| (x % 2 != 0).then_some(x));

let expected =
Int32Array::from(vec![Some(1), None, Some(3), None, Some(5), None, Some(7)]);
assert_eq!(r, expected);

let r = expected.unary_opt::<_, Int32Type>(|x| (x % 3 != 0).then_some(x));
let expected =
Int32Array::from(vec![Some(1), None, None, None, Some(5), None, Some(7)]);
assert_eq!(r, expected);
}

#[test]
#[should_panic(
expected = "Trying to access an element at index 4 from a PrimitiveArray of length 3"
Expand Down
30 changes: 7 additions & 23 deletions arrow-cast/src/cast.rs
Expand Up @@ -337,11 +337,8 @@ where
})?;

if cast_options.safe {
let iter = array
.iter()
.map(|v| v.and_then(|v| v.as_().mul_checked(mul).ok()));
let casted_array = unsafe { PrimitiveArray::<D>::from_trusted_len_iter(iter) };
casted_array
array
.unary_opt::<_, D>(|v| v.as_().mul_checked(mul).ok())
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
Expand All @@ -364,12 +361,8 @@ where
let mul = 10_f64.powi(scale as i32);

if cast_options.safe {
let iter = array
.iter()
.map(|v| v.and_then(|v| (mul * v.as_()).round().to_i128()));
let casted_array =
unsafe { PrimitiveArray::<Decimal128Type>::from_trusted_len_iter(iter) };
casted_array
array
.unary_opt::<_, Decimal128Type>(|v| (mul * v.as_()).round().to_i128())
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
Expand Down Expand Up @@ -407,12 +400,8 @@ where
let mul = 10_f64.powi(scale as i32);

if cast_options.safe {
let iter = array
.iter()
.map(|v| v.and_then(|v| i256::from_f64((v.as_() * mul).round())));
let casted_array =
unsafe { PrimitiveArray::<Decimal256Type>::from_trusted_len_iter(iter) };
casted_array
array
.unary_opt::<_, Decimal256Type>(|v| i256::from_f64((v.as_() * mul).round()))
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
Expand Down Expand Up @@ -2107,12 +2096,7 @@ where
T::Native: NumCast,
R::Native: NumCast,
{
let iter = from
.iter()
.map(|v| v.and_then(num::cast::cast::<T::Native, R::Native>));
// Soundness:
// The iterator is trustedLen because it comes from an `PrimitiveArray`.
unsafe { PrimitiveArray::<R>::from_trusted_len_iter(iter) }
from.unary_opt::<_, R>(num::cast::cast::<T::Native, R::Native>)
}

fn as_time_with_string_op<
Expand Down

0 comments on commit b0b5d8b

Please sign in to comment.