Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean the create_array in IPC reader. #2525

Merged
merged 5 commits into from Aug 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
184 changes: 57 additions & 127 deletions arrow/src/ipc/reader.rs
Expand Up @@ -82,17 +82,16 @@ fn create_array(
compression_codec: &Option<CompressionCodec>,
metadata: &ipc::MetadataVersion,
) -> Result<(ArrayRef, usize, usize)> {
use DataType::*;
let data_type = field.data_type();
let array = match data_type {
Utf8 | Binary | LargeBinary | LargeUtf8 => {
let array = create_primitive_array(
&nodes[node_index],
data_type,
buffers[buffer_index..buffer_index + 3]
&buffers[buffer_index..buffer_index + 3]
.iter()
.map(|buf| read_buffer(buf, data, compression_codec))
.collect::<Result<_>>()?,
.collect::<Result<Vec<Buffer>>>()?,
);
node_index += 1;
buffer_index += 3;
Expand All @@ -102,10 +101,10 @@ fn create_array(
let array = create_primitive_array(
&nodes[node_index],
data_type,
buffers[buffer_index..buffer_index + 2]
&buffers[buffer_index..buffer_index + 2]
.iter()
.map(|buf| read_buffer(buf, data, compression_codec))
.collect::<Result<_>>()?,
.collect::<Result<Vec<Buffer>>>()?,
);
node_index += 1;
buffer_index += 2;
Expand Down Expand Up @@ -301,10 +300,10 @@ fn create_array(
let array = create_primitive_array(
&nodes[node_index],
data_type,
buffers[buffer_index..buffer_index + 2]
&buffers[buffer_index..buffer_index + 2]
.iter()
.map(|buf| read_buffer(buf, data, compression_codec))
.collect::<Result<_>>()?,
.collect::<Result<Vec<Buffer>>>()?,
);
node_index += 1;
buffer_index += 2;
Expand All @@ -318,16 +317,10 @@ fn create_array(
/// This function should be called when doing projection in fn `read_record_batch`.
/// The advancement logic references fn `create_array`.
fn skip_field(
nodes: &[ipc::FieldNode],
field: &Field,
data: &[u8],
buffers: &[ipc::Buffer],
dictionaries_by_id: &HashMap<i64, ArrayRef>,
data_type: &DataType,
mut node_index: usize,
mut buffer_index: usize,
) -> Result<(usize, usize)> {
use DataType::*;
let data_type = field.data_type();
match data_type {
Utf8 | Binary | LargeBinary | LargeUtf8 => {
node_index += 1;
Expand All @@ -340,30 +333,14 @@ fn skip_field(
List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => {
node_index += 1;
buffer_index += 2;
let tuple = skip_field(
nodes,
list_field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
let tuple = skip_field(list_field.data_type(), node_index, buffer_index)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
FixedSizeList(ref list_field, _) => {
node_index += 1;
buffer_index += 1;
let tuple = skip_field(
nodes,
list_field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
let tuple = skip_field(list_field.data_type(), node_index, buffer_index)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
Expand All @@ -373,15 +350,8 @@ fn skip_field(

// skip for each field
for struct_field in struct_fields {
let tuple = skip_field(
nodes,
struct_field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
let tuple =
skip_field(struct_field.data_type(), node_index, buffer_index)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
Expand All @@ -402,15 +372,7 @@ fn skip_field(
};

for field in fields {
let tuple = skip_field(
nodes,
field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
let tuple = skip_field(field.data_type(), node_index, buffer_index)?;

node_index = tuple.0;
buffer_index = tuple.1;
Expand All @@ -433,28 +395,26 @@ fn skip_field(
fn create_primitive_array(
field_node: &ipc::FieldNode,
data_type: &DataType,
buffers: Vec<Buffer>,
buffers: &[Buffer],
) -> ArrayRef {
let length = field_node.length() as usize;
let null_count = field_node.null_count() as usize;
let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
let array_data = match data_type {
Utf8 | Binary | LargeBinary | LargeUtf8 => {
// read 3 buffers
// read 3 buffers: null buffer (optional), offsets buffer and data buffer
ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..3].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()))
.null_bit_buffer(null_buffer)
.build()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we only validate the data when building these 4 types array, but for other types, we use build_uncheked ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should probably be validating all of them, I'm not sure why we aren't validating some of them... IPC is inherently not trusted

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will file an issue to track it.

.unwrap()
}
FixedSizeBinary(_) => {
// read 3 buffers
// read 2 buffers: null buffer (optional) and data buffer
let builder = ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..2].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

unsafe { builder.build_unchecked() }
}
Expand All @@ -471,9 +431,8 @@ fn create_primitive_array(
// interpret as a signed i64, and cast appropriately
let builder = ArrayData::builder(DataType::Int64)
.len(length)
.buffers(buffers[1..].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

let data = unsafe { builder.build_unchecked() };
let values = Arc::new(Int64Array::from(data)) as ArrayRef;
Expand All @@ -483,9 +442,8 @@ fn create_primitive_array(
} else {
let builder = ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

unsafe { builder.build_unchecked() }
}
Expand All @@ -495,9 +453,8 @@ fn create_primitive_array(
// interpret as a f64, and cast appropriately
let builder = ArrayData::builder(DataType::Float64)
.len(length)
.buffers(buffers[1..].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

let data = unsafe { builder.build_unchecked() };
let values = Arc::new(Float64Array::from(data)) as ArrayRef;
Expand All @@ -507,9 +464,8 @@ fn create_primitive_array(
} else {
let builder = ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

unsafe { builder.build_unchecked() }
}
Expand All @@ -526,23 +482,21 @@ fn create_primitive_array(
| Interval(IntervalUnit::MonthDayNano) => {
let builder = ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

unsafe { builder.build_unchecked() }
}
Decimal128(_, _) | Decimal256(_, _) => {
// read 3 buffers
// read 2 buffers: null buffer (optional) and data buffer
let builder = ArrayData::builder(data_type.clone())
.len(length)
.buffers(buffers[1..2].to_vec())
.offset(0)
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.null_bit_buffer(null_buffer);

unsafe { builder.build_unchecked() }
}
t => panic!("Data type {:?} either unsupported or not primitive", t),
t => unreachable!("Data type {:?} either unsupported or not primitive", t),
};

make_array(array_data)
Expand All @@ -556,39 +510,24 @@ fn create_list_array(
buffers: &[Buffer],
child_array: ArrayRef,
) -> ArrayRef {
if let DataType::List(_) | DataType::LargeList(_) = *data_type {
let null_count = field_node.null_count() as usize;
let builder = ArrayData::builder(data_type.clone())
.len(field_node.length() as usize)
.buffers(buffers[1..2].to_vec())
.offset(0)
.child_data(vec![child_array.into_data()])
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));

make_array(unsafe { builder.build_unchecked() })
} else if let DataType::FixedSizeList(_, _) = *data_type {
let null_count = field_node.null_count() as usize;
let builder = ArrayData::builder(data_type.clone())
.len(field_node.length() as usize)
.buffers(buffers[1..1].to_vec())
.offset(0)
.child_data(vec![child_array.into_data()])
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));

make_array(unsafe { builder.build_unchecked() })
} else if let DataType::Map(_, _) = *data_type {
let null_count = field_node.null_count() as usize;
let builder = ArrayData::builder(data_type.clone())
.len(field_node.length() as usize)
.buffers(buffers[1..2].to_vec())
.offset(0)
.child_data(vec![child_array.into_data()])
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));

make_array(unsafe { builder.build_unchecked() })
} else {
panic!("Cannot create list or map array from {:?}", data_type)
}
let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
let length = field_node.length() as usize;
let child_data = child_array.into_data();
let builder = match data_type {
List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone())
.len(length)
.add_buffer(buffers[1].clone())
.add_child_data(child_data)
.null_bit_buffer(null_buffer),

FixedSizeList(_, _) => ArrayData::builder(data_type.clone())
.len(length)
.add_child_data(child_data)
.null_bit_buffer(null_buffer),

_ => unreachable!("Cannot create list or map array from {:?}", data_type),
};
make_array(unsafe { builder.build_unchecked() })
}

/// Reads the correct number of buffers based on list type and null_count, and creates a
Expand All @@ -599,14 +538,13 @@ fn create_dictionary_array(
buffers: &[Buffer],
value_array: ArrayRef,
) -> ArrayRef {
if let DataType::Dictionary(_, _) = *data_type {
let null_count = field_node.null_count() as usize;
if let Dictionary(_, _) = *data_type {
let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone());
let builder = ArrayData::builder(data_type.clone())
.len(field_node.length() as usize)
.buffers(buffers[1..2].to_vec())
.offset(0)
.child_data(vec![value_array.into_data()])
.null_bit_buffer((null_count > 0).then(|| buffers[0].clone()));
.add_buffer(buffers[1].clone())
.add_child_data(value_array.into_data())
.null_bit_buffer(null_buffer);

make_array(unsafe { builder.build_unchecked() })
} else {
Expand Down Expand Up @@ -666,15 +604,7 @@ pub fn read_record_batch(
} else {
// Skip field.
// This must be called to advance `node_index` and `buffer_index`.
let tuple = skip_field(
field_nodes,
field,
buf,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
let tuple = skip_field(field.data_type(), node_index, buffer_index)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
Expand Down