diff --git a/arrow/src/compute/kernels/concat_elements.rs b/arrow/src/compute/kernels/concat_elements.rs index bc341df889c..7d460b21cb0 100644 --- a/arrow/src/compute/kernels/concat_elements.rs +++ b/arrow/src/compute/kernels/concat_elements.rs @@ -85,6 +85,86 @@ pub fn concat_elements_utf8( Ok(unsafe { builder.build_unchecked() }.into()) } +/// Returns the elementwise concatenation of [`StringArray`]. +/// ```text +/// e.g: +/// ["a", "b"] + [None, "c"] + [None, "d"] = [None, "bcd"] +/// ``` +/// +/// An error will be returned if the [`StringArray`] are of different lengths +pub fn concat_elements_utf8_many( + arrays: &[&GenericStringArray], +) -> Result> { + if arrays.is_empty() { + return Err(ArrowError::ComputeError( + "concat requires input of at least one array".to_string(), + )); + } + + let size = arrays[0].len(); + if !arrays.iter().all(|array| array.len() == size) { + return Err(ArrowError::ComputeError(format!( + "Arrays must have the same length of {}", + size, + ))); + } + + let output_bitmap = combine_option_bitmap( + arrays + .iter() + .map(|a| a.data()) + .collect::>() + .as_slice(), + size, + )?; + + let data_buffers = arrays + .iter() + .map(|array| array.value_data()) + .collect::>(); + + let data_values = data_buffers + .iter() + .map(|buffer| buffer.as_slice()) + .collect::>(); + + let mut offsets = arrays + .iter() + .map(|a| a.value_offsets().iter().peekable()) + .collect::>(); + + let mut output_values = BufferBuilder::::new( + data_values + .iter() + .zip(offsets.iter_mut()) + .map(|(data, offset)| data.len() - offset.peek().unwrap().to_usize().unwrap()) + .sum(), + ); + + let mut output_offsets = BufferBuilder::::new(size + 1); + output_offsets.append(Offset::zero()); + for _ in 0..size { + data_values + .iter() + .zip(offsets.iter_mut()) + .for_each(|(values, offset)| { + let index_start = offset.next().unwrap().to_usize().unwrap(); + let index_end = offset.peek().unwrap().to_usize().unwrap(); + output_values.append_slice(&values[index_start..index_end]); + }); + output_offsets.append(Offset::from_usize(output_values.len()).unwrap()); + } + + let builder = ArrayDataBuilder::new(GenericStringArray::::get_data_type()) + .len(size) + .add_buffer(output_offsets.finish()) + .add_buffer(output_values.finish()) + .null_bit_buffer(output_bitmap); + + // SAFETY - offsets valid by construction + Ok(unsafe { builder.build_unchecked() }.into()) +} + #[cfg(test)] mod tests { use super::*; @@ -143,7 +223,10 @@ mod tests { let output = concat_elements_utf8(&left, &right); - assert!(output.is_err()); + assert_eq!( + output.unwrap_err().to_string(), + "Compute error: Arrays must have the same length: 2 != 1".to_string() + ); } #[test] @@ -190,4 +273,40 @@ mod tests { assert_eq!(output, expected); } + + #[test] + fn test_string_concat_error_empty() { + assert_eq!( + concat_elements_utf8_many::(&[]) + .unwrap_err() + .to_string(), + "Compute error: concat requires input of at least one array".to_string() + ); + } + + #[test] + fn test_string_concat_one() { + let expected = [None, Some("baryyy"), None] + .into_iter() + .collect::(); + + let output = concat_elements_utf8_many(&[&expected]).unwrap(); + + assert_eq!(output, expected); + } + + #[test] + fn test_string_concat_many() { + let foo = StringArray::from(vec![Some("f"), Some("o"), Some("o"), None]); + let bar = StringArray::from(vec![None, Some("b"), Some("a"), Some("r")]); + let baz = StringArray::from(vec![Some("b"), None, Some("a"), Some("z")]); + + let output = concat_elements_utf8_many(&[&foo, &bar, &baz]).unwrap(); + + let expected = [None, None, Some("oaa"), None] + .into_iter() + .collect::(); + + assert_eq!(output, expected); + } }