diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 34abfeb0a3d..7cf7de72161 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -438,6 +438,57 @@ impl PrimitiveArray { build_primitive_array(len, buffer.finish(), null_count, null_buffer) }) } + + /// Applies a unary and nullable function to all valid values in a primitive array + /// + /// This is unlike [`Self::unary`] which will apply an infallible function to all rows + /// regardless of validity, in many cases this will be significantly faster and should + /// be preferred if `op` is infallible. + /// + /// Note: LLVM is currently unable to effectively vectorize fallible operations + pub fn unary_opt(&self, op: F) -> PrimitiveArray + where + O: ArrowPrimitiveType, + F: Fn(T::Native) -> Option, + { + let data = self.data(); + let len = data.len(); + let offset = data.offset(); + let null_count = data.null_count(); + let nulls = data.null_buffer().map(|x| x.as_slice()); + + let mut null_builder = BooleanBufferBuilder::new(len); + match nulls { + Some(b) => null_builder.append_packed_range(offset..offset + len, b), + None => null_builder.append_n(len, true), + } + + let mut buffer = BufferBuilder::::new(len); + buffer.append_n_zeroed(len); + let slice = buffer.as_slice_mut(); + + let mut out_null_count = null_count; + + let _ = try_for_each_valid_idx(len, offset, null_count, nulls, |idx| { + match op(unsafe { self.value_unchecked(idx) }) { + Some(v) => unsafe { *slice.get_unchecked_mut(idx) = v }, + None => { + out_null_count += 1; + null_builder.set_bit(idx, false); + } + } + Ok::<_, ()>(()) + }); + + unsafe { + build_primitive_array( + len, + buffer.finish(), + out_null_count, + Some(null_builder.finish()), + ) + } + } } #[inline] @@ -1864,6 +1915,21 @@ mod tests { assert!(!array.is_null(2)); } + #[test] + fn test_unary_opt() { + let array = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7]); + let r = array.unary_opt::<_, Int32Type>(|x| (x % 2 != 0).then_some(x)); + + let expected = + Int32Array::from(vec![Some(1), None, Some(3), None, Some(5), None, Some(7)]); + assert_eq!(r, expected); + + let r = expected.unary_opt::<_, Int32Type>(|x| (x % 3 != 0).then_some(x)); + let expected = + Int32Array::from(vec![Some(1), None, None, None, Some(5), None, Some(7)]); + assert_eq!(r, expected); + } + #[test] #[should_panic( expected = "Trying to access an element at index 4 from a PrimitiveArray of length 3" diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 8504a8167b3..d6dbf3061bb 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -337,11 +337,8 @@ where })?; if cast_options.safe { - let iter = array - .iter() - .map(|v| v.and_then(|v| v.as_().mul_checked(mul).ok())); - let casted_array = unsafe { PrimitiveArray::::from_trusted_len_iter(iter) }; - casted_array + array + .unary_opt::<_, D>(|v| v.as_().mul_checked(mul).ok()) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) } else { @@ -364,12 +361,8 @@ where let mul = 10_f64.powi(scale as i32); if cast_options.safe { - let iter = array - .iter() - .map(|v| v.and_then(|v| (mul * v.as_()).round().to_i128())); - let casted_array = - unsafe { PrimitiveArray::::from_trusted_len_iter(iter) }; - casted_array + array + .unary_opt::<_, Decimal128Type>(|v| (mul * v.as_()).round().to_i128()) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) } else { @@ -407,12 +400,8 @@ where let mul = 10_f64.powi(scale as i32); if cast_options.safe { - let iter = array - .iter() - .map(|v| v.and_then(|v| i256::from_f64((v.as_() * mul).round()))); - let casted_array = - unsafe { PrimitiveArray::::from_trusted_len_iter(iter) }; - casted_array + array + .unary_opt::<_, Decimal256Type>(|v| i256::from_f64((v.as_() * mul).round())) .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) } else { @@ -2107,12 +2096,7 @@ where T::Native: NumCast, R::Native: NumCast, { - let iter = from - .iter() - .map(|v| v.and_then(num::cast::cast::)); - // Soundness: - // The iterator is trustedLen because it comes from an `PrimitiveArray`. - unsafe { PrimitiveArray::::from_trusted_len_iter(iter) } + from.unary_opt::<_, R>(num::cast::cast::) } fn as_time_with_string_op<