Skip to content

Commit

Permalink
Clean the create_array in IPC reader. (#2525)
Browse files Browse the repository at this point in the history
* clean create_array

Signed-off-by: remzi <13716567376yh@gmail.com>

* remove useless code

Signed-off-by: remzi <13716567376yh@gmail.com>

* reintro checked build

Signed-off-by: remzi <13716567376yh@gmail.com>

Signed-off-by: remzi <13716567376yh@gmail.com>
  • Loading branch information
HaoYang670 committed Aug 20, 2022
1 parent 98de2e3 commit 7cfe3cf
Showing 1 changed file with 57 additions and 127 deletions.
184 changes: 57 additions & 127 deletions arrow/src/ipc/reader.rs
Expand Up @@ -82,17 +82,16 @@ fn create_array(
compression_codec: &Option<CompressionCodec>,
metadata: &ipc::MetadataVersion,
) -> Result<(ArrayRef, usize, usize)> {
use DataType::*;
let data_type = field.data_type();
let array = match data_type {
Utf8 | Binary | LargeBinary | LargeUtf8 => {
let array = create_primitive_array(
&nodes[node_index],
data_type,
buffers[buffer_index..buffer_index + 3]
&buffers[buffer_index..buffer_index + 3]
.iter()
.map(|buf| read_buffer(buf, data, compression_codec))
.collect::<Result<_>>()?,
.collect::<Result<Vec<Buffer>>>()?,
);
node_index += 1;
buffer_index += 3;
Expand All @@ -102,10 +101,10 @@ fn create_array(
let array = create_primitive_array(
&nodes[node_index],
data_type,
buffers[buffer_index..buffer_index + 2]
&buffers[buffer_index..buffer_index + 2]
.iter()
.map(|buf| read_buffer(buf, data, compression_codec))
.collect::<Result<_>>()?,
.collect::<Result<Vec<Buffer>>>()?,
);
node_index += 1;
buffer_index += 2;
Expand Down Expand Up @@ -301,10 +300,10 @@ fn create_array(
let array = create_primitive_array(
&nodes[node_index],
data_type,
buffers[buffer_index..buffer_index + 2]
&buffers[buffer_index..buffer_index + 2]
.iter()
.map(|buf| read_buffer(buf, data, compression_codec))
.collect::<Result<_>>()?,
.collect::<Result<Vec<Buffer>>>()?,
);
node_index += 1;
buffer_index += 2;
Expand All @@ -318,16 +317,10 @@ fn create_array(
/// This function should be called when doing projection in fn `read_record_batch`.
/// The advancement logic references fn `create_array`.
fn skip_field(
nodes: &[ipc::FieldNode],
field: &Field,
data: &[u8],
buffers: &[ipc::Buffer],
dictionaries_by_id: &HashMap<i64, ArrayRef>,
data_type: &DataType,
mut node_index: usize,
mut buffer_index: usize,
) -> Result<(usize, usize)> {
use DataType::*;
let data_type = field.data_type();
match data_type {
Utf8 | Binary | LargeBinary | LargeUtf8 => {
node_index += 1;
Expand All @@ -340,30 +333,14 @@ fn skip_field(
List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => {
node_index += 1;
buffer_index += 2;
let tuple = skip_field(
nodes,
list_field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
let tuple = skip_field(list_field.data_type(), node_index, buffer_index)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
FixedSizeList(ref list_field, _) => {
node_index += 1;
buffer_index += 1;
let tuple = skip_field(
nodes,
list_field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
let tuple = skip_field(list_field.data_type(), node_index, buffer_index)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
Expand All @@ -373,15 +350,8 @@ fn skip_field(

// skip for each field
for struct_field in struct_fields {
let tuple = skip_field(
nodes,
struct_field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
let tuple =
skip_field(struct_field.data_type(), node_index, buffer_index)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
Expand All @@ -402,15 +372,7 @@ fn skip_field(
};

for field in fields {
let tuple = skip_field(
nodes,
field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
let tuple = skip_field(field.data_type(), node_index, buffer_index)?;

node_index = tuple.0;
buffer_index = tuple.1;
Expand All @@ -433,28 +395,26 @@ fn skip_field(
fn create_primitive_array(
field_node: &ipc::FieldNode,
data_type: &DataType,
buffers: Vec<Buffer>,
buffers: &[Buffer],
) -> ArrayRef {
let length = field_node.length() as usize;
let null_count = field_node.null_count() as usize;
let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
let array_data = match data_type {
Utf8 | Binary | LargeBinary | LargeUtf8 => {
// read 3 buffers
// read 3 buffers: null buffer (optional), offsets buffer and data buffer
ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..3].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()))
.null_bit_buffer(null_buffer)
.build()
.unwrap()
}
FixedSizeBinary(_) => {
// read 3 buffers
// read 2 buffers: null buffer (optional) and data buffer
let builder = ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..2].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

unsafe { builder.build_unchecked() }
}
Expand All @@ -471,9 +431,8 @@ fn create_primitive_array(
// interpret as a signed i64, and cast appropriately
let builder = ArrayData::builder(DataType::Int64)
.len(length)
.buffers(buffers[1..].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

let data = unsafe { builder.build_unchecked() };
let values = Arc::new(Int64Array::from(data)) as ArrayRef;
Expand All @@ -483,9 +442,8 @@ fn create_primitive_array(
} else {
let builder = ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

unsafe { builder.build_unchecked() }
}
Expand All @@ -495,9 +453,8 @@ fn create_primitive_array(
// interpret as a f64, and cast appropriately
let builder = ArrayData::builder(DataType::Float64)
.len(length)
.buffers(buffers[1..].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

let data = unsafe { builder.build_unchecked() };
let values = Arc::new(Float64Array::from(data)) as ArrayRef;
Expand All @@ -507,9 +464,8 @@ fn create_primitive_array(
} else {
let builder = ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

unsafe { builder.build_unchecked() }
}
Expand All @@ -526,23 +482,21 @@ fn create_primitive_array(
| Interval(IntervalUnit::MonthDayNano) => {
let builder = ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

unsafe { builder.build_unchecked() }
}
Decimal128(_, _) | Decimal256(_, _) => {
// read 3 buffers
// read 2 buffers: null buffer (optional) and data buffer
let builder = ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..2].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

unsafe { builder.build_unchecked() }
}
t => panic!("Data type {:?} either unsupported or not primitive", t),
t => unreachable!("Data type {:?} either unsupported or not primitive", t),
};

make_array(array_data)
Expand All @@ -556,39 +510,24 @@ fn create_list_array(
buffers: &[Buffer],
child_array: ArrayRef,
) -> ArrayRef {
if let DataType::List(_) | DataType::LargeList(_) = *data_type {
let null_count = field_node.null_count() as usize;
let builder = ArrayData::builder(data_type.clone())
.len(field_node.length() as usize)
.buffers(buffers[1..2].to_vec())
.offset(0)
.child_data(vec![child_array.into_data()])
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));

make_array(unsafe { builder.build_unchecked() })
} else if let DataType::FixedSizeList(_, _) = *data_type {
let null_count = field_node.null_count() as usize;
let builder = ArrayData::builder(data_type.clone())
.len(field_node.length() as usize)
.buffers(buffers[1..1].to_vec())
.offset(0)
.child_data(vec![child_array.into_data()])
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));

make_array(unsafe { builder.build_unchecked() })
} else if let DataType::Map(_, _) = *data_type {
let null_count = field_node.null_count() as usize;
let builder = ArrayData::builder(data_type.clone())
.len(field_node.length() as usize)
.buffers(buffers[1..2].to_vec())
.offset(0)
.child_data(vec![child_array.into_data()])
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));

make_array(unsafe { builder.build_unchecked() })
} else {
panic!("Cannot create list or map array from {:?}", data_type)
}
let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
let length = field_node.length() as usize;
let child_data = child_array.into_data();
let builder = match data_type {
List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone())
.len(length)
.add_buffer(buffers[1].clone())
.add_child_data(child_data)
.null_bit_buffer(null_buffer),

FixedSizeList(_, _) => ArrayData::builder(data_type.clone())
.len(length)
.add_child_data(child_data)
.null_bit_buffer(null_buffer),

_ => unreachable!("Cannot create list or map array from {:?}", data_type),
};
make_array(unsafe { builder.build_unchecked() })
}

/// Reads the correct number of buffers based on list type and null_count, and creates a
Expand All @@ -599,14 +538,13 @@ fn create_dictionary_array(
buffers: &[Buffer],
value_array: ArrayRef,
) -> ArrayRef {
if let DataType::Dictionary(_, _) = *data_type {
let null_count = field_node.null_count() as usize;
if let Dictionary(_, _) = *data_type {
let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
let builder = ArrayData::builder(data_type.clone())
.len(field_node.length() as usize)
.buffers(buffers[1..2].to_vec())
.offset(0)
.child_data(vec![value_array.into_data()])
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.add_child_data(value_array.into_data())
.null_bit_buffer(null_buffer);

make_array(unsafe { builder.build_unchecked() })
} else {
Expand Down Expand Up @@ -666,15 +604,7 @@ pub fn read_record_batch(
} else {
// Skip field.
// This must be called to advance `node_index` and `buffer_index`.
let tuple = skip_field(
field_nodes,
field,
buf,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
let tuple = skip_field(field.data_type(), node_index, buffer_index)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
Expand Down

0 comments on commit 7cfe3cf

Please sign in to comment.