From 7cfe3cf6ab0d84b4f9d21c7ed43a04e7fdd6689b Mon Sep 17 00:00:00 2001 From: Remzi Yang <59198230+HaoYang670@users.noreply.github.com> Date: Sat, 20 Aug 2022 22:43:30 +0800 Subject: [PATCH] Clean the `create_array` in IPC reader. (#2525) * clean create_array Signed-off-by: remzi <13716567376yh@gmail.com> * remove useless code Signed-off-by: remzi <13716567376yh@gmail.com> * reintro checked build Signed-off-by: remzi <13716567376yh@gmail.com> Signed-off-by: remzi <13716567376yh@gmail.com> --- arrow/src/ipc/reader.rs | 184 +++++++++++++--------------------------- 1 file changed, 57 insertions(+), 127 deletions(-) diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs index 086a5b15957..adfe4060084 100644 --- a/arrow/src/ipc/reader.rs +++ b/arrow/src/ipc/reader.rs @@ -82,17 +82,16 @@ fn create_array( compression_codec: &Option, metadata: &ipc::MetadataVersion, ) -> Result<(ArrayRef, usize, usize)> { - use DataType::*; let data_type = field.data_type(); let array = match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => { let array = create_primitive_array( &nodes[node_index], data_type, - buffers[buffer_index..buffer_index + 3] + &buffers[buffer_index..buffer_index + 3] .iter() .map(|buf| read_buffer(buf, data, compression_codec)) - .collect::>()?, + .collect::>>()?, ); node_index += 1; buffer_index += 3; @@ -102,10 +101,10 @@ fn create_array( let array = create_primitive_array( &nodes[node_index], data_type, - buffers[buffer_index..buffer_index + 2] + &buffers[buffer_index..buffer_index + 2] .iter() .map(|buf| read_buffer(buf, data, compression_codec)) - .collect::>()?, + .collect::>>()?, ); node_index += 1; buffer_index += 2; @@ -301,10 +300,10 @@ fn create_array( let array = create_primitive_array( &nodes[node_index], data_type, - buffers[buffer_index..buffer_index + 2] + &buffers[buffer_index..buffer_index + 2] .iter() .map(|buf| read_buffer(buf, data, compression_codec)) - .collect::>()?, + .collect::>>()?, ); node_index += 1; buffer_index += 2; @@ -318,16 +317,10 @@ fn create_array( /// This function should be called when doing projection in fn `read_record_batch`. /// The advancement logic references fn `create_array`. fn skip_field( - nodes: &[ipc::FieldNode], - field: &Field, - data: &[u8], - buffers: &[ipc::Buffer], - dictionaries_by_id: &HashMap, + data_type: &DataType, mut node_index: usize, mut buffer_index: usize, ) -> Result<(usize, usize)> { - use DataType::*; - let data_type = field.data_type(); match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => { node_index += 1; @@ -340,30 +333,14 @@ fn skip_field( List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => { node_index += 1; buffer_index += 2; - let tuple = skip_field( - nodes, - list_field, - data, - buffers, - dictionaries_by_id, - node_index, - buffer_index, - )?; + let tuple = skip_field(list_field.data_type(), node_index, buffer_index)?; node_index = tuple.0; buffer_index = tuple.1; } FixedSizeList(ref list_field, _) => { node_index += 1; buffer_index += 1; - let tuple = skip_field( - nodes, - list_field, - data, - buffers, - dictionaries_by_id, - node_index, - buffer_index, - )?; + let tuple = skip_field(list_field.data_type(), node_index, buffer_index)?; node_index = tuple.0; buffer_index = tuple.1; } @@ -373,15 +350,8 @@ fn skip_field( // skip for each field for struct_field in struct_fields { - let tuple = skip_field( - nodes, - struct_field, - data, - buffers, - dictionaries_by_id, - node_index, - buffer_index, - )?; + let tuple = + skip_field(struct_field.data_type(), node_index, buffer_index)?; node_index = tuple.0; buffer_index = tuple.1; } @@ -402,15 +372,7 @@ fn skip_field( }; for field in fields { - let tuple = skip_field( - nodes, - field, - data, - buffers, - dictionaries_by_id, - node_index, - buffer_index, - )?; + let tuple = skip_field(field.data_type(), node_index, buffer_index)?; node_index = tuple.0; buffer_index = tuple.1; @@ -433,28 +395,26 @@ fn skip_field( fn create_primitive_array( field_node: &ipc::FieldNode, data_type: &DataType, - buffers: Vec, + buffers: &[Buffer], ) -> ArrayRef { let length = field_node.length() as usize; - let null_count = field_node.null_count() as usize; + let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); let array_data = match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => { - // read 3 buffers + // read 3 buffers: null buffer (optional), offsets buffer and data buffer ArrayData::builder(data_type.clone()) .len(length) .buffers(buffers[1..3].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())) + .null_bit_buffer(null_buffer) .build() .unwrap() } FixedSizeBinary(_) => { - // read 3 buffers + // read 2 buffers: null buffer (optional) and data buffer let builder = ArrayData::builder(data_type.clone()) .len(length) - .buffers(buffers[1..2].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer); unsafe { builder.build_unchecked() } } @@ -471,9 +431,8 @@ fn create_primitive_array( // interpret as a signed i64, and cast appropriately let builder = ArrayData::builder(DataType::Int64) .len(length) - .buffers(buffers[1..].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer); let data = unsafe { builder.build_unchecked() }; let values = Arc::new(Int64Array::from(data)) as ArrayRef; @@ -483,9 +442,8 @@ fn create_primitive_array( } else { let builder = ArrayData::builder(data_type.clone()) .len(length) - .buffers(buffers[1..].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer); unsafe { builder.build_unchecked() } } @@ -495,9 +453,8 @@ fn create_primitive_array( // interpret as a f64, and cast appropriately let builder = ArrayData::builder(DataType::Float64) .len(length) - .buffers(buffers[1..].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer); let data = unsafe { builder.build_unchecked() }; let values = Arc::new(Float64Array::from(data)) as ArrayRef; @@ -507,9 +464,8 @@ fn create_primitive_array( } else { let builder = ArrayData::builder(data_type.clone()) .len(length) - .buffers(buffers[1..].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer); unsafe { builder.build_unchecked() } } @@ -526,23 +482,21 @@ fn create_primitive_array( | Interval(IntervalUnit::MonthDayNano) => { let builder = ArrayData::builder(data_type.clone()) .len(length) - .buffers(buffers[1..].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer); unsafe { builder.build_unchecked() } } Decimal128(_, _) | Decimal256(_, _) => { - // read 3 buffers + // read 2 buffers: null buffer (optional) and data buffer let builder = ArrayData::builder(data_type.clone()) .len(length) - .buffers(buffers[1..2].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer); unsafe { builder.build_unchecked() } } - t => panic!("Data type {:?} either unsupported or not primitive", t), + t => unreachable!("Data type {:?} either unsupported or not primitive", t), }; make_array(array_data) @@ -556,39 +510,24 @@ fn create_list_array( buffers: &[Buffer], child_array: ArrayRef, ) -> ArrayRef { - if let DataType::List(_) | DataType::LargeList(_) = *data_type { - let null_count = field_node.null_count() as usize; - let builder = ArrayData::builder(data_type.clone()) - .len(field_node.length() as usize) - .buffers(buffers[1..2].to_vec()) - .offset(0) - .child_data(vec![child_array.into_data()]) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); - - make_array(unsafe { builder.build_unchecked() }) - } else if let DataType::FixedSizeList(_, _) = *data_type { - let null_count = field_node.null_count() as usize; - let builder = ArrayData::builder(data_type.clone()) - .len(field_node.length() as usize) - .buffers(buffers[1..1].to_vec()) - .offset(0) - .child_data(vec![child_array.into_data()]) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); - - make_array(unsafe { builder.build_unchecked() }) - } else if let DataType::Map(_, _) = *data_type { - let null_count = field_node.null_count() as usize; - let builder = ArrayData::builder(data_type.clone()) - .len(field_node.length() as usize) - .buffers(buffers[1..2].to_vec()) - .offset(0) - .child_data(vec![child_array.into_data()]) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); - - make_array(unsafe { builder.build_unchecked() }) - } else { - panic!("Cannot create list or map array from {:?}", data_type) - } + let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); + let length = field_node.length() as usize; + let child_data = child_array.into_data(); + let builder = match data_type { + List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone()) + .len(length) + .add_buffer(buffers[1].clone()) + .add_child_data(child_data) + .null_bit_buffer(null_buffer), + + FixedSizeList(_, _) => ArrayData::builder(data_type.clone()) + .len(length) + .add_child_data(child_data) + .null_bit_buffer(null_buffer), + + _ => unreachable!("Cannot create list or map array from {:?}", data_type), + }; + make_array(unsafe { builder.build_unchecked() }) } /// Reads the correct number of buffers based on list type and null_count, and creates a @@ -599,14 +538,13 @@ fn create_dictionary_array( buffers: &[Buffer], value_array: ArrayRef, ) -> ArrayRef { - if let DataType::Dictionary(_, _) = *data_type { - let null_count = field_node.null_count() as usize; + if let Dictionary(_, _) = *data_type { + let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); let builder = ArrayData::builder(data_type.clone()) .len(field_node.length() as usize) - .buffers(buffers[1..2].to_vec()) - .offset(0) - .child_data(vec![value_array.into_data()]) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + .add_buffer(buffers[1].clone()) + .add_child_data(value_array.into_data()) + .null_bit_buffer(null_buffer); make_array(unsafe { builder.build_unchecked() }) } else { @@ -666,15 +604,7 @@ pub fn read_record_batch( } else { // Skip field. // This must be called to advance `node_index` and `buffer_index`. - let tuple = skip_field( - field_nodes, - field, - buf, - buffers, - dictionaries_by_id, - node_index, - buffer_index, - )?; + let tuple = skip_field(field.data_type(), node_index, buffer_index)?; node_index = tuple.0; buffer_index = tuple.1; }