From 54e91811bd774b3c2dc8fdcf15c07b751256a5ed Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 16 Jul 2022 11:22:58 -0700 Subject: [PATCH] type_id and value_offset are incorrect for sliced UnionArray (#2087) * Fix slice UnionArray * For review --- arrow/src/array/array_union.rs | 80 +++++++++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 7 deletions(-) diff --git a/arrow/src/array/array_union.rs b/arrow/src/array/array_union.rs index 639b82ae980..e7ddec01fc7 100644 --- a/arrow/src/array/array_union.rs +++ b/arrow/src/array/array_union.rs @@ -243,8 +243,8 @@ impl UnionArray { /// /// Panics if `index` is greater than the length of the array. pub fn type_id(&self, index: usize) -> i8 { - assert!(index - self.offset() < self.len()); - self.data().buffers()[0].as_slice()[index] as i8 + assert!(index < self.len()); + self.data().buffers()[0].as_slice()[self.offset() + index] as i8 } /// Returns the offset into the underlying values array for the array slot at `index`. @@ -253,11 +253,11 @@ impl UnionArray { /// /// Panics if `index` is greater than the length of the array. pub fn value_offset(&self, index: usize) -> i32 { - assert!(index - self.offset() < self.len()); + assert!(index < self.len()); if self.is_dense() { - self.data().buffers()[1].typed_data::()[index] + self.data().buffers()[1].typed_data::()[self.offset() + index] } else { - index as i32 + (self.offset() + index) as i32 } } @@ -267,8 +267,8 @@ impl UnionArray { /// /// Panics if `index` is greater than the length of the array. pub fn value(&self, index: usize) -> ArrayRef { - let type_id = self.type_id(self.offset() + index); - let value_offset = self.value_offset(self.offset() + index) as usize; + let type_id = self.type_id(index); + let value_offset = self.value_offset(index) as usize; let child_data = self.boxed_fields[type_id as usize].clone(); child_data.slice(value_offset, 1) } @@ -383,6 +383,7 @@ mod tests { use crate::array::*; use crate::buffer::Buffer; use crate::datatypes::{DataType, Field}; + use crate::record_batch::RecordBatch; #[test] fn test_dense_i32() { @@ -956,4 +957,69 @@ mod tests { let err = builder.append::("a", 1).unwrap_err().to_string(); assert!(err.contains("Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"), "{}", err); } + + #[test] + fn slice_union_array() { + // [1, null, 3.0, null, 4] + fn create_union(mut builder: UnionBuilder) -> UnionArray { + builder.append::("a", 1).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("c", 3.0).unwrap(); + builder.append_null::("c").unwrap(); + builder.append::("a", 4).unwrap(); + builder.build().unwrap() + } + + fn create_batch(union: UnionArray) -> RecordBatch { + let schema = Schema::new(vec![Field::new( + "struct_array", + union.data_type().clone(), + true, + )]); + + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap() + } + + fn test_slice_union(record_batch_slice: RecordBatch) { + let union_slice = record_batch_slice + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(union_slice.type_id(0), 0); + assert_eq!(union_slice.type_id(1), 1); + assert_eq!(union_slice.type_id(2), 1); + + let slot = union_slice.value(0); + let array = slot.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(array.is_null(0)); + + let slot = union_slice.value(1); + let array = slot.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(array.is_valid(0)); + assert_eq!(array.value(0), 3.0); + + let slot = union_slice.value(2); + let array = slot.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(array.is_null(0)); + } + + // Sparse Union + let builder = UnionBuilder::new_sparse(5); + let record_batch = create_batch(create_union(builder)); + // [null, 3.0, null] + let record_batch_slice = record_batch.slice(1, 3); + test_slice_union(record_batch_slice); + + // Dense Union + let builder = UnionBuilder::new_dense(5); + let record_batch = create_batch(create_union(builder)); + // [null, 3.0, null] + let record_batch_slice = record_batch.slice(1, 3); + test_slice_union(record_batch_slice); + } }