diff --git a/parquet/src/arrow/array_reader/byte_array.rs b/parquet/src/arrow/array_reader/byte_array.rs index 853bc2b1898..96fe02d2dce 100644 --- a/parquet/src/arrow/array_reader/byte_array.rs +++ b/parquet/src/arrow/array_reader/byte_array.rs @@ -406,6 +406,7 @@ impl ByteArrayDecoderPlain { skip += 1; self.offset = self.offset + 4 + len; } + self.max_remaining_values -= skip; Ok(skip) } } diff --git a/parquet/src/arrow/array_reader/byte_array_dictionary.rs b/parquet/src/arrow/array_reader/byte_array_dictionary.rs index bfe55749991..181af09e8cd 100644 --- a/parquet/src/arrow/array_reader/byte_array_dictionary.rs +++ b/parquet/src/arrow/array_reader/byte_array_dictionary.rs @@ -376,8 +376,20 @@ where } } - fn skip_values(&mut self, _num_values: usize) -> Result { - Err(nyi_err!("https://github.com/apache/arrow-rs/issues/1792")) + fn skip_values(&mut self, num_values: usize) -> Result { + match self.decoder.as_mut().expect("decoder set") { + MaybeDictionaryDecoder::Fallback(decoder) => { + decoder.skip::(num_values, None) + } + MaybeDictionaryDecoder::Dict { + decoder, + max_remaining_values, + } => { + let num_values = num_values.min(*max_remaining_values); + *max_remaining_values -= num_values; + decoder.skip(num_values) + } + } } } @@ -507,6 +519,51 @@ mod tests { } } + #[test] + fn test_dictionary_skip_fallback() { + let data_type = utf8_dictionary(); + let data = vec!["hello", "world", "a", "b"]; + + let (pages, encoded_dictionary) = byte_array_all_encodings(data.clone()); + let num_encodings = pages.len(); + + let column_desc = utf8_column(); + let mut decoder = DictionaryDecoder::::new(&column_desc); + + decoder + .set_dict(encoded_dictionary, 4, Encoding::RLE_DICTIONARY, false) + .unwrap(); + + // Read all pages into single buffer + let mut output = DictionaryBuffer::::default(); + + for (encoding, page) in pages { + decoder.set_data(encoding, page, 4, Some(4)).unwrap(); + decoder.skip_values(2).expect("skipping two values"); + assert_eq!(decoder.read(&mut output, 0..1024).unwrap(), 2); + } + let array = output.into_array(None, &data_type).unwrap(); + assert_eq!(array.data_type(), &data_type); + + let array = cast(&array, &ArrowType::Utf8).unwrap(); + let strings = array.as_any().downcast_ref::().unwrap(); + assert_eq!(strings.len(), (data.len() - 2) * num_encodings); + + // Should have a copy of `data` for each encoding + for i in 0..num_encodings { + assert_eq!( + &strings + .iter() + .skip(i * (data.len() - 2)) + .take(data.len() - 2) + .map(|x| x.unwrap()) + .collect::>(), + &data[2..] + ) + } + } + + #[test] fn test_too_large_dictionary() { let data: Vec<_> = (0..128)