Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arbitrary size concat elements utf8 #1787

Merged
merged 15 commits into from Jun 29, 2022
131 changes: 91 additions & 40 deletions arrow/src/compute/kernels/concat_elements.rs
Expand Up @@ -29,54 +29,76 @@ use crate::error::{ArrowError, Result};
///
/// ["Hello"] + ["World"] = ["HelloWorld"]
///
/// ["a", "b"] + [None, "c"] = [None, "bc"]
/// ["a", "b"] + [None, "c"] + [None, "d"] = [None, "bcd"]
/// ```
///
/// An error will be returned if `left` and `right` have different lengths
/// An error will be returned if the [`StringArray`] are of different lengths.
pub fn concat_elements_utf8<Offset: OffsetSizeTrait>(
left: &GenericStringArray<Offset>,
right: &GenericStringArray<Offset>,
arrays: &[&GenericStringArray<Offset>],
) -> Result<GenericStringArray<Offset>> {
if left.len() != right.len() {
if arrays.len() < 2 {
ismailmaj marked this conversation as resolved.
Show resolved Hide resolved
return Err(ArrowError::ComputeError(format!(
"Arrays must have the same length: {} != {}",
left.len(),
right.len()
"Arrays must have at least two elements length: {}",
arrays.len(),
)));
}

let output_bitmap = combine_option_bitmap(&[left.data(), right.data()], left.len())?;

let left_offsets = left.value_offsets();
let right_offsets = right.value_offsets();
let size = arrays[0].len();
if !arrays.iter().all(|array| array.len() == size) {
return Err(ArrowError::ComputeError(format!(
"Arrays must have the same length of {}",
size,
)));
}

let left_buffer = left.value_data();
let right_buffer = right.value_data();
let left_values = left_buffer.as_slice();
let right_values = right_buffer.as_slice();
let output_bitmap = combine_option_bitmap(
arrays
.iter()
.map(|a| a.data())
.collect::<Vec<_>>()
.as_slice(),
size,
)?;

let data_buffers = arrays
.iter()
.map(|array| array.value_data())
.collect::<Vec<_>>();

let data_values = data_buffers
.iter()
.map(|buffer| buffer.as_slice())
.collect::<Vec<_>>();

let mut offsets = arrays
.iter()
.map(|a| a.value_offsets().iter().peekable())
.collect::<Vec<_>>();

let mut output_values = BufferBuilder::<u8>::new(
left_values.len() + right_values.len()
- left_offsets[0].to_usize().unwrap()
- right_offsets[0].to_usize().unwrap(),
data_values
.iter()
.zip(offsets.iter_mut())
.map(|(data, offset)| data.len() - offset.peek().unwrap().to_usize().unwrap())
ismailmaj marked this conversation as resolved.
Show resolved Hide resolved
.sum(),
);

let mut output_offsets = BufferBuilder::<Offset>::new(left_offsets.len());
let mut output_offsets = BufferBuilder::<Offset>::new(size + 1);
output_offsets.append(Offset::zero());
for (left_idx, right_idx) in left_offsets.windows(2).zip(right_offsets.windows(2)) {
output_values.append_slice(
&left_values
[left_idx[0].to_usize().unwrap()..left_idx[1].to_usize().unwrap()],
);
output_values.append_slice(
&right_values
[right_idx[0].to_usize().unwrap()..right_idx[1].to_usize().unwrap()],
);
for _ in 0..size {
data_values
.iter()
.zip(offsets.iter_mut())
.for_each(|(values, offset)| {
let index_start = offset.next().unwrap().to_usize().unwrap();
let index_end = offset.peek().unwrap().to_usize().unwrap();
output_values.append_slice(&values[index_start..index_end]);
});
output_offsets.append(Offset::from_usize(output_values.len()).unwrap());
}

let builder = ArrayDataBuilder::new(GenericStringArray::<Offset>::get_data_type())
.len(left.len())
.len(size)
.add_buffer(output_offsets.finish())
.add_buffer(output_values.finish())
.null_bit_buffer(output_bitmap);
Expand All @@ -97,7 +119,7 @@ mod tests {
.into_iter()
.collect::<StringArray>();

let output = concat_elements_utf8(&left, &right).unwrap();
let output = concat_elements_utf8(&[&left, &right]).unwrap();

let expected = [None, Some("baryyy"), None]
.into_iter()
Expand All @@ -115,7 +137,7 @@ mod tests {
.into_iter()
.collect::<StringArray>();

let output = concat_elements_utf8(&left, &right).unwrap();
let output = concat_elements_utf8(&[&left, &right]).unwrap();

let expected = [Some("foobaz"), Some(""), Some("bar")]
.into_iter()
Expand All @@ -129,21 +151,50 @@ mod tests {
let left = StringArray::from(vec!["foo", "bar"]);
let right = StringArray::from(vec!["bar", "baz"]);

let output = concat_elements_utf8(&left, &right).unwrap();
let output = concat_elements_utf8(&[&left, &right]).unwrap();

let expected = StringArray::from(vec!["foobar", "barbaz"]);

assert_eq!(output, expected);
}

#[test]
fn test_string_concat_error() {
fn test_string_concat_error_array_length() {
let left = StringArray::from(vec!["foo", "bar"]);
let right = StringArray::from(vec!["baz"]);

let output = concat_elements_utf8(&left, &right);
let output = concat_elements_utf8(&[&left, &right]);

assert!(output.is_err());
assert_eq!(
output.unwrap_err().to_string(),
"Compute error: Arrays must have the same length of 2".to_string()
);
}

#[test]
fn test_string_concat_error_slice_length() {
let left = [Some("foo"), Some("bar"), None]
.into_iter()
.collect::<StringArray>();
assert_eq!(
concat_elements_utf8(&[&left]).unwrap_err().to_string(),
"Compute error: Arrays must have at least two elements length: 1".to_string()
);
}

#[test]
fn test_string_concat_multiple() {
let foo = StringArray::from(vec![Some("f"), Some("o"), Some("o"), None]);
let bar = StringArray::from(vec![None, Some("b"), Some("a"), Some("r")]);
let baz = StringArray::from(vec![Some("b"), None, Some("a"), Some("z")]);

let output = concat_elements_utf8(&[&foo, &bar, &baz]).unwrap();

let expected = [None, None, Some("oaa"), None]
.into_iter()
.collect::<StringArray>();

assert_eq!(output, expected);
}

#[test]
Expand All @@ -153,7 +204,7 @@ mod tests {

let left_slice = left.slice(0, 3);
let right_slice = right.slice(1, 3);
let output = concat_elements_utf8(
let output = concat_elements_utf8(&[
left_slice
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
Expand All @@ -162,7 +213,7 @@ mod tests {
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.unwrap(),
)
])
.unwrap();

let expected = [None, Some("foofar"), Some("barfaz")]
Expand All @@ -174,7 +225,7 @@ mod tests {
let left_slice = left.slice(2, 2);
let right_slice = right.slice(1, 2);

let output = concat_elements_utf8(
let output = concat_elements_utf8(&[
left_slice
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
Expand All @@ -183,7 +234,7 @@ mod tests {
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.unwrap(),
)
])
.unwrap();

let expected = [None, Some("bazfar")].into_iter().collect::<StringArray>();
Expand Down