From 0c7cfb5f61069c9d0d983b25e0e76d8765510687 Mon Sep 17 00:00:00 2001 From: yangjiang Date: Tue, 19 Jul 2022 12:06:33 +0800 Subject: [PATCH] add test for skip_values in DictionaryDecoder and fix it --- .../array_reader/byte_array_dictionary.rs | 81 ++++++++++++++++++- 1 file changed, 79 insertions(+), 2 deletions(-) diff --git a/parquet/src/arrow/array_reader/byte_array_dictionary.rs b/parquet/src/arrow/array_reader/byte_array_dictionary.rs index 181af09e8cd..39d920ef16d 100644 --- a/parquet/src/arrow/array_reader/byte_array_dictionary.rs +++ b/parquet/src/arrow/array_reader/byte_array_dictionary.rs @@ -346,6 +346,7 @@ where // Keys will be validated on conversion to arrow let keys_slice = keys.spare_capacity_mut(range.start + len); let len = decoder.get_batch(&mut keys_slice[range.start..])?; + *max_remaining_values -= len; Ok(len) } None => { @@ -368,7 +369,7 @@ where dict_offsets, dict_values, )?; - + *max_remaining_values -= len; Ok(len) } } @@ -476,6 +477,68 @@ mod tests { ) } + #[test] + fn test_dictionary_preservation_skip() { + let data_type = utf8_dictionary(); + + let data: Vec<_> = vec!["0", "1", "0", "1", "2", "1", "2"] + .into_iter() + .map(ByteArray::from) + .collect(); + let (dict, encoded) = encode_dictionary(&data); + + let column_desc = utf8_column(); + let mut decoder = DictionaryDecoder::::new(&column_desc); + + decoder + .set_dict(dict, 3, Encoding::RLE_DICTIONARY, false) + .unwrap(); + + decoder + .set_data(Encoding::RLE_DICTIONARY, encoded, 7, Some(data.len())) + .unwrap(); + + let mut output = DictionaryBuffer::::default(); + + // read two skip one + assert_eq!(decoder.read(&mut output, 0..2).unwrap(), 2); + assert_eq!(decoder.skip_values(1).unwrap(), 1); + + assert!(matches!(output, DictionaryBuffer::Dict { .. })); + + // read two skip one + assert_eq!(decoder.read(&mut output, 2..4).unwrap(), 2); + assert_eq!(decoder.skip_values(1).unwrap(), 1); + + // read one and test on skip at the end + assert_eq!(decoder.read(&mut output, 4..5).unwrap(), 1); + assert_eq!(decoder.skip_values(4).unwrap(), 0); + + let valid = vec![true, true, true, true, true]; + let valid_buffer = Buffer::from_iter(valid.iter().cloned()); + output.pad_nulls(0, 5, 5, valid_buffer.as_slice()); + + assert!(matches!(output, DictionaryBuffer::Dict { .. })); + + let array = output.into_array(Some(valid_buffer), &data_type).unwrap(); + assert_eq!(array.data_type(), &data_type); + + let array = cast(&array, &ArrowType::Utf8).unwrap(); + let strings = array.as_any().downcast_ref::().unwrap(); + assert_eq!(strings.len(), 5); + + assert_eq!( + strings.iter().collect::>(), + vec![ + Some("0"), + Some("1"), + Some("1"), + Some("2"), + Some("2"), + ] + ) + } + #[test] fn test_dictionary_fallback() { let data_type = utf8_dictionary(); @@ -599,7 +662,7 @@ mod tests { .set_dict(encoded_dictionary, 4, Encoding::PLAIN_DICTIONARY, false) .unwrap(); - for (encoding, page) in pages { + for (encoding, page) in pages.clone() { let mut output = DictionaryBuffer::::default(); decoder.set_data(encoding, page, 8, None).unwrap(); assert_eq!(decoder.read(&mut output, 0..1024).unwrap(), 0); @@ -612,5 +675,19 @@ mod tests { assert_eq!(array.len(), 8); assert_eq!(array.null_count(), 8); } + + for (encoding, page) in pages { + let mut output = DictionaryBuffer::::default(); + decoder.set_data(encoding, page, 8, None).unwrap(); + assert_eq!(decoder.skip_values(1024).unwrap(), 0); + + output.pad_nulls(0, 0, 8, &[0]); + let array = output + .into_array(Some(Buffer::from(&[0])), &data_type) + .unwrap(); + + assert_eq!(array.len(), 8); + assert_eq!(array.null_count(), 8); + } } }