Skip to content

Commit

Permalink
Read/write nested dictionary in ipc stream reader/writer (#1566)
Browse files Browse the repository at this point in the history
* Read dictionary inside dictionary

* Fix clippy
  • Loading branch information
viirya committed Apr 15, 2022
1 parent 2bcc0cf commit bb65358
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 7 deletions.
13 changes: 12 additions & 1 deletion arrow/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,25 @@ impl Field {
/// Returns a (flattened) vector containing all fields contained within this field (including it self)
pub(crate) fn fields(&self) -> Vec<&Field> {
let mut collected_fields = vec![self];
match &self.data_type {
collected_fields.append(&mut self._fields(&self.data_type));

collected_fields
}

fn _fields<'a>(&'a self, dt: &'a DataType) -> Vec<&Field> {
let mut collected_fields = vec![];

match dt {
DataType::Struct(fields) | DataType::Union(fields, _) => {
collected_fields.extend(fields.iter().flat_map(|f| f.fields()))
}
DataType::List(field)
| DataType::LargeList(field)
| DataType::FixedSizeList(field, _)
| DataType::Map(field, _) => collected_fields.push(field),
DataType::Dictionary(_, value_field) => {
collected_fields.append(&mut self._fields(value_field.as_ref()))
}
_ => (),
}

Expand Down
39 changes: 39 additions & 0 deletions arrow/src/ipc/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,7 @@ mod tests {

use flate2::read::GzDecoder;

use crate::datatypes::Int8Type;
use crate::{datatypes, util::integration_util::*};

#[test]
Expand Down Expand Up @@ -1441,4 +1442,42 @@ mod tests {
let output_batch = roundtrip_ipc_stream(&input_batch);
assert_eq!(input_batch, output_batch);
}

#[test]
fn test_roundtrip_stream_nested_dict_dict() {
let values = StringArray::from_iter_values(["a", "b", "c"]);
let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1]);
let dict_array = DictionaryArray::<Int8Type>::try_new(&keys, &values).unwrap();
let dict_data = dict_array.data();

let value_offsets = Buffer::from_slice_ref(&[0, 2, 4, 6]);

let list_data_type = DataType::List(Box::new(Field::new_dict(
"item",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
false,
1,
false,
)));
let list_data = ArrayData::builder(list_data_type)
.len(3)
.add_buffer(value_offsets)
.add_child_data(dict_data.clone())
.build()
.unwrap();
let list_array = ListArray::from(list_data);

let dict_dict_array =
DictionaryArray::<Int8Type>::try_new(&keys, &list_array).unwrap();

let schema = Arc::new(Schema::new(vec![Field::new(
"f1",
dict_dict_array.data_type().clone(),
false,
)]));
let input_batch =
RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
let output_batch = roundtrip_ipc_stream(&input_batch);
assert_eq!(input_batch, output_batch);
}
}
52 changes: 46 additions & 6 deletions arrow/src/ipc/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use std::io::{BufWriter, Write};

use flatbuffers::FlatBufferBuilder;

use crate::array::{as_struct_array, as_union_array, ArrayData, ArrayRef};
use crate::array::{
as_list_array, as_struct_array, as_union_array, make_array, ArrayData, ArrayRef,
};
use crate::buffer::{Buffer, MutableBuffer};
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
Expand Down Expand Up @@ -137,15 +139,14 @@ impl IpcDataGenerator {
}
}

fn encode_dictionaries(
fn _encode_dictionaries(
&self,
field: &Field,
column: &ArrayRef,
encoded_dictionaries: &mut Vec<EncodedData>,
dictionary_tracker: &mut DictionaryTracker,
write_options: &IpcWriteOptions,
) -> Result<()> {
// TODO: Handle other nested types (map, list, etc)
// TODO: Handle other nested types (map, etc)
match column.data_type() {
DataType::Struct(fields) => {
let s = as_struct_array(column);
Expand All @@ -159,6 +160,16 @@ impl IpcDataGenerator {
)?;
}
}
DataType::List(field) => {
let list = as_list_array(column);
self.encode_dictionaries(
field,
&list.values(),
encoded_dictionaries,
dictionary_tracker,
write_options,
)?;
}
DataType::Union(fields, _) => {
let union = as_union_array(column);
for (field, ref column) in fields
Expand All @@ -175,13 +186,37 @@ impl IpcDataGenerator {
)?;
}
}
_ => (),
}

Ok(())
}

fn encode_dictionaries(
&self,
field: &Field,
column: &ArrayRef,
encoded_dictionaries: &mut Vec<EncodedData>,
dictionary_tracker: &mut DictionaryTracker,
write_options: &IpcWriteOptions,
) -> Result<()> {
match column.data_type() {
DataType::Dictionary(_key_type, _value_type) => {
let dict_id = field
.dict_id()
.expect("All Dictionary types have `dict_id`");
let dict_data = column.data();
let dict_values = &dict_data.child_data()[0];

let values = make_array(dict_data.child_data()[0].clone());

self._encode_dictionaries(
&values,
encoded_dictionaries,
dictionary_tracker,
write_options,
)?;

let emit = dictionary_tracker.insert(dict_id, column)?;

if emit {
Expand All @@ -192,7 +227,12 @@ impl IpcDataGenerator {
));
}
}
_ => (),
_ => self._encode_dictionaries(
column,
encoded_dictionaries,
dictionary_tracker,
write_options,
)?,
}

Ok(())
Expand All @@ -205,7 +245,7 @@ impl IpcDataGenerator {
write_options: &IpcWriteOptions,
) -> Result<(Vec<EncodedData>, EncodedData)> {
let schema = batch.schema();
let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len());
let mut encoded_dictionaries = Vec::with_capacity(schema.all_fields().len());

for (i, field) in schema.fields().iter().enumerate() {
let column = batch.column(i);
Expand Down

0 comments on commit bb65358

Please sign in to comment.