Skip to content

Commit

Permalink
Add unary_mut as an example
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Nov 15, 2022
1 parent 4a028e6 commit 8311372
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 6 deletions.
62 changes: 60 additions & 2 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,54 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
unsafe { build_primitive_array(len, buffer, null_count, null_buffer) }
}

/// Applies an unary and infallible function to a mutable primitive array.
/// Mutable primitive array means that the buffer is not shared with other arrays.
/// As a result, this mutates the buffer directly without allocating new buffer.
///
/// # Implementation
///
/// This will apply the function for all values, including those on null slots.
/// This implies that the operation must be infallible for any value of the corresponding type
/// or this function may panic.
/// # Example
/// ```rust
/// # use arrow_array::{Int32Array, types::Int32Type};
/// # fn main() {
/// let array = Int32Array::from(vec![Some(5), Some(7), None]);
/// let c = array.unary_mut(|x| x * 2 + 1).unwrap();
/// assert_eq!(c, Int32Array::from(vec![Some(11), Some(15), None]));
/// # }
/// ```
pub fn unary_mut<F>(self, op: F) -> Result<PrimitiveArray<T>, ArrowError>
where
F: Fn(T::Native) -> T::Native,
{
let data = self.data();
let len = self.len();
let null_count = self.null_count();
let null_buffer = data.null_buffer().map(|b| b.bit_slice(data.offset(), len));

let mut buffers = self.data.get_buffers();
let buffer = buffers.remove(0);
let buffer_len = buffer.len();

let mutable_buffer = buffer.into_mutable(buffer_len);

let buffer = match mutable_buffer {
Ok(mut mutable_buffer) => {
mutable_buffer
.typed_data_mut()
.iter_mut()
.for_each(|l| *l = op(*l));
Ok(mutable_buffer.into())
}
Err(_) => Err(ArrowError::InvalidArgumentError(
"Not a mutable array because its buffer is shared.".to_string(),
)),
}?;
Ok(unsafe { build_primitive_array(len, buffer, null_count, null_buffer) })
}

/// Applies a unary and fallible function to all valid values in a primitive array
///
/// This is unlike [`Self::unary`] which will apply an infallible function to all rows
Expand Down Expand Up @@ -497,7 +545,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
.data
.null_buffer()
.cloned()
.and_then(|b| b.into_mutable().ok());
.and_then(|b| b.into_mutable(0).ok());

let len = self.len();
let null_bit_buffer = self.data.null_buffer().cloned();
Expand All @@ -506,7 +554,7 @@ impl<T: ArrowPrimitiveType> PrimitiveArray<T> {
let buffer = buffers.remove(0);

let builder = buffer
.into_mutable()
.into_mutable(0)
.map(|buffer| PrimitiveBuilder::<T>::new_from_buffer(buffer, null_buffer));

match builder {
Expand Down Expand Up @@ -2014,4 +2062,14 @@ mod tests {
}
}
}

#[test]
fn test_unary_mut() {
let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect();

let c = array.unary_mut(|x| x * 2 + 1).unwrap();
let expected: Int32Array = vec![3, 5, 7].into_iter().map(Some).collect();

assert_eq!(expected, c);
}
}
4 changes: 2 additions & 2 deletions arrow-buffer/src/buffer/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ impl Buffer {
}

/// Returns `MutableBuffer` for mutating the buffer if this buffer is not shared.
pub fn into_mutable(self) -> Result<MutableBuffer, Self> {
pub fn into_mutable(self, len: usize) -> Result<MutableBuffer, Self> {
let offset_ptr = self.as_ptr();
let offset = self.offset;
let length = self.length;
Expand All @@ -239,7 +239,7 @@ impl Buffer {
assert_eq!(offset_ptr, bytes.ptr().as_ptr());

let mutable_buffer =
MutableBuffer::from_ptr(bytes.ptr(), bytes.capacity());
MutableBuffer::from_ptr(bytes.ptr(), len, bytes.capacity());
mem::forget(bytes);
mutable_buffer
})
Expand Down
4 changes: 2 additions & 2 deletions arrow-buffer/src/buffer/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ impl MutableBuffer {
}

/// Allocates a new [MutableBuffer] from given pointer `ptr`, `capacity`.
pub(crate) fn from_ptr(ptr: NonNull<u8>, capacity: usize) -> Self {
pub(crate) fn from_ptr(ptr: NonNull<u8>, len: usize, capacity: usize) -> Self {
Self {
data: ptr,
len: 0,
len,
capacity,
}
}
Expand Down

0 comments on commit 8311372

Please sign in to comment.