Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read and skip validity buffer of UnionType Array for V4 ipc message #1789

Merged
merged 3 commits into from
Jun 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions arrow-flight/src/utils.rs
Expand Up @@ -71,6 +71,7 @@ pub fn flight_data_to_arrow_batch(
schema,
dictionaries_by_id,
None,
&message.version(),
)
})?
}
Expand Down
31 changes: 28 additions & 3 deletions arrow/src/ipc/reader.rs
Expand Up @@ -52,6 +52,7 @@ fn read_buffer(buf: &ipc::Buffer, a_data: &[u8]) -> Buffer {
/// - check if the bit width of non-64-bit numbers is 64, and
/// - read the buffer as 64-bit (signed integer or float), and
/// - cast the 64-bit array to the appropriate data type
#[allow(clippy::too_many_arguments)]
fn create_array(
nodes: &[ipc::FieldNode],
field: &Field,
Expand All @@ -60,6 +61,7 @@ fn create_array(
dictionaries_by_id: &HashMap<i64, ArrayRef>,
mut node_index: usize,
mut buffer_index: usize,
metadata: &ipc::MetadataVersion,
) -> Result<(ArrayRef, usize, usize)> {
use DataType::*;
let data_type = field.data_type();
Expand Down Expand Up @@ -106,6 +108,7 @@ fn create_array(
dictionaries_by_id,
node_index,
buffer_index,
metadata,
)?;
node_index = triple.1;
buffer_index = triple.2;
Expand All @@ -128,6 +131,7 @@ fn create_array(
dictionaries_by_id,
node_index,
buffer_index,
metadata,
)?;
node_index = triple.1;
buffer_index = triple.2;
Expand All @@ -153,6 +157,7 @@ fn create_array(
dictionaries_by_id,
node_index,
buffer_index,
metadata,
)?;
node_index = triple.1;
buffer_index = triple.2;
Expand Down Expand Up @@ -201,6 +206,13 @@ fn create_array(

let len = union_node.length() as usize;

// In V4, union types has validity bitmap
// In V5 and later, union types have no validity bitmap
if metadata < &ipc::MetadataVersion::V5 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

read_buffer(&buffers[buffer_index], data);
buffer_index += 1;
}

let type_ids: Buffer =
read_buffer(&buffers[buffer_index], data)[..len].into();

Expand All @@ -226,6 +238,7 @@ fn create_array(
dictionaries_by_id,
node_index,
buffer_index,
metadata,
)?;

node_index = triple.1;
Expand Down Expand Up @@ -582,6 +595,7 @@ pub fn read_record_batch(
schema: SchemaRef,
dictionaries_by_id: &HashMap<i64, ArrayRef>,
projection: Option<&[usize]>,
metadata: &ipc::MetadataVersion,
) -> Result<RecordBatch> {
let buffers = batch.buffers().ok_or_else(|| {
ArrowError::IoError("Unable to get buffers from IPC RecordBatch".to_string())
Expand All @@ -607,6 +621,7 @@ pub fn read_record_batch(
dictionaries_by_id,
node_index,
buffer_index,
metadata,
)?;
node_index = triple.1;
buffer_index = triple.2;
Expand Down Expand Up @@ -640,6 +655,7 @@ pub fn read_record_batch(
dictionaries_by_id,
node_index,
buffer_index,
metadata,
)?;
node_index = triple.1;
buffer_index = triple.2;
Expand All @@ -656,6 +672,7 @@ pub fn read_dictionary(
batch: ipc::DictionaryBatch,
schema: &Schema,
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
metadata: &ipc::MetadataVersion,
) -> Result<()> {
if batch.isDelta() {
return Err(ArrowError::IoError(
Expand Down Expand Up @@ -686,6 +703,7 @@ pub fn read_dictionary(
Arc::new(schema),
dictionaries_by_id,
None,
metadata,
)?;
Some(record_batch.column(0).clone())
}
Expand Down Expand Up @@ -816,7 +834,13 @@ impl<R: Read + Seek> FileReader<R> {
))?;
reader.read_exact(&mut buf)?;

read_dictionary(&buf, batch, &schema, &mut dictionaries_by_id)?;
read_dictionary(
&buf,
batch,
&schema,
&mut dictionaries_by_id,
&message.version(),
)?;
}
t => {
return Err(ArrowError::IoError(format!(
Expand Down Expand Up @@ -925,6 +949,7 @@ impl<R: Read + Seek> FileReader<R> {
self.schema(),
&self.dictionaries_by_id,
self.projection.as_ref().map(|x| x.0.as_ref()),
&message.version()

).map(Some)
}
Expand Down Expand Up @@ -1099,7 +1124,7 @@ impl<R: Read> StreamReader<R> {
let mut buf = vec![0; message.bodyLength() as usize];
self.reader.read_exact(&mut buf)?;

read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref())).map(Some)
read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref()), &message.version()).map(Some)
}
ipc::MessageHeader::DictionaryBatch => {
let batch = message.header_as_dictionary_batch().ok_or_else(|| {
Expand All @@ -1112,7 +1137,7 @@ impl<R: Read> StreamReader<R> {
self.reader.read_exact(&mut buf)?;

read_dictionary(
&buf, batch, &self.schema, &mut self.dictionaries_by_id
&buf, batch, &self.schema, &mut self.dictionaries_by_id, &message.version()
)?;

// read the next message until we encounter a RecordBatch
Expand Down
48 changes: 48 additions & 0 deletions arrow/src/ipc/writer.rs
Expand Up @@ -1385,4 +1385,52 @@ mod tests {
// Dictionary with id 2 should have been written to the dict tracker
assert!(dict_tracker.written.contains_key(&2));
}

#[test]
fn read_union_017() {
let testdata = crate::util::test_util::arrow_test_data();
let version = "0.17.1";
let data_file = File::open(format!(
"{}/arrow-ipc-stream/integration/0.17.1/generated_union.stream",
testdata,
))
.unwrap();

let reader = StreamReader::try_new(data_file, None).unwrap();

// read and rewrite the stream to a temp location
{
let file = File::create(format!(
"target/debug/testdata/{}-generated_union.stream",
version
))
.unwrap();
let mut writer = StreamWriter::try_new(file, &reader.schema()).unwrap();
reader.for_each(|batch| {
writer.write(&batch.unwrap()).unwrap();
});
writer.finish().unwrap();
}

// Compare original file and rewrote file
let file = File::open(format!(
"target/debug/testdata/{}-generated_union.stream",
version
))
.unwrap();
let rewrite_reader = StreamReader::try_new(file, None).unwrap();

let data_file = File::open(format!(
"{}/arrow-ipc-stream/integration/0.17.1/generated_union.stream",
testdata,
))
.unwrap();
let reader = StreamReader::try_new(data_file, None).unwrap();

reader.into_iter().zip(rewrite_reader.into_iter()).for_each(
|(batch1, batch2)| {
assert_eq!(batch1.unwrap(), batch2.unwrap());
},
);
}
}
Expand Up @@ -270,6 +270,7 @@ async fn receive_batch_flight_data(
.expect("Error parsing dictionary"),
&schema,
dictionaries_by_id,
&message.version(),
)
.expect("Error reading dictionary");

Expand Down
Expand Up @@ -296,6 +296,7 @@ async fn record_batch_from_message(
schema_ref,
dictionaries_by_id,
None,
&message.version(),
);

arrow_batch_result.map_err(|e| {
Expand All @@ -313,8 +314,13 @@ async fn dictionary_from_message(
Status::internal("Could not parse message header as dictionary batch")
})?;

let dictionary_batch_result =
reader::read_dictionary(data_body, ipc_batch, &schema_ref, dictionaries_by_id);
let dictionary_batch_result = reader::read_dictionary(
data_body,
ipc_batch,
&schema_ref,
dictionaries_by_id,
&message.version(),
);
dictionary_batch_result.map_err(|e| {
Status::internal(format!("Could not convert to Dictionary: {:?}", e))
})
Expand Down