Skip to content

Commit

Permalink
Update avro for new API
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jul 21, 2022
1 parent 0099a99 commit 924a40d
Showing 1 changed file with 50 additions and 52 deletions.
102 changes: 50 additions & 52 deletions datafusion/core/src/avro_to_arrow/arrow_array_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,58 +130,52 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr).map(Some))
}

fn build_boolean_array(
&self,
rows: RecordSlice,
col_name: &str,
) -> ArrowResult<ArrayRef> {
fn build_boolean_array(&self, rows: RecordSlice, col_name: &str) -> ArrayRef {
let mut builder = BooleanBuilder::new(rows.len());
for row in rows {
if let Some(value) = self.field_lookup(col_name, row) {
if let Some(boolean) = resolve_boolean(value) {
builder.append_value(boolean)?
builder.append_value(boolean)
} else {
builder.append_null()?;
builder.append_null();
}
} else {
builder.append_null()?;
builder.append_null();
}
}
Ok(Arc::new(builder.finish()))
Arc::new(builder.finish())
}

#[allow(clippy::unnecessary_wraps)]
fn build_primitive_array<T: ArrowPrimitiveType + Resolver>(
&self,
rows: RecordSlice,
col_name: &str,
) -> ArrowResult<ArrayRef>
) -> ArrayRef
where
T: ArrowNumericType,
T::Native: num_traits::cast::NumCast,
{
Ok(Arc::new(
Arc::new(
rows.iter()
.map(|row| {
self.field_lookup(col_name, row)
.and_then(|value| resolve_item::<T>(value))
})
.collect::<PrimitiveArray<T>>(),
))
)
}

#[inline(always)]
#[allow(clippy::unnecessary_wraps)]
fn build_string_dictionary_builder<T>(
&self,
row_len: usize,
) -> ArrowResult<StringDictionaryBuilder<T>>
) -> StringDictionaryBuilder<T>
where
T: ArrowPrimitiveType + ArrowDictionaryKeyType,
{
let key_builder = PrimitiveBuilder::<T>::new(row_len);
let values_builder = StringBuilder::new(row_len * 5);
Ok(StringDictionaryBuilder::new(key_builder, values_builder))
StringDictionaryBuilder::new(key_builder, values_builder)
}

fn build_wrapped_list_array(
Expand Down Expand Up @@ -271,7 +265,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
}
DataType::Dictionary(_, _) => {
let values_builder =
self.build_string_dictionary_builder::<D>(rows.len() * 5)?;
self.build_string_dictionary_builder::<D>(rows.len() * 5);
Box::new(ListBuilder::new(values_builder))
}
e => {
Expand Down Expand Up @@ -316,14 +310,14 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
))?;
for val in vals {
if let Some(v) = val {
builder.values().append_value(&v)?
builder.values().append_value(&v)
} else {
builder.values().append_null()?
builder.values().append_null()
};
}

// Append to the list
builder.append(true)?;
builder.append(true);
}
DataType::Dictionary(_, _) => {
let builder = builder.as_any_mut().downcast_mut::<ListBuilder<StringDictionaryBuilder<D>>>().ok_or_else(||ArrowError::SchemaError(
Expand All @@ -333,12 +327,12 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
if let Some(v) = val {
let _ = builder.values().append(&v)?;
} else {
builder.values().append_null()?
builder.values().append_null()
};
}

// Append to the list
builder.append(true)?;
builder.append(true);
}
e => {
return Err(SchemaError(format!(
Expand All @@ -364,16 +358,16 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
T: ArrowPrimitiveType + ArrowDictionaryKeyType,
{
let mut builder: StringDictionaryBuilder<T> =
self.build_string_dictionary_builder(rows.len())?;
self.build_string_dictionary_builder(rows.len());
for row in rows {
if let Some(value) = self.field_lookup(col_name, row) {
if let Ok(str_v) = resolve_string(value) {
builder.append(str_v).map(drop)?
} else {
builder.append_null()?
builder.append_null()
}
} else {
builder.append_null()?
builder.append_null()
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
Expand Down Expand Up @@ -609,10 +603,8 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
.iter()
.filter(|field| projection.is_empty() || projection.contains(field.name()))
.map(|field| {
match field.data_type() {
DataType::Null => {
Ok(Arc::new(NullArray::new(rows.len())) as ArrayRef)
}
let arr = match field.data_type() {
DataType::Null => Arc::new(NullArray::new(rows.len())) as ArrayRef,
DataType::Boolean => self.build_boolean_array(rows, field.name()),
DataType::Float64 => {
self.build_primitive_array::<Float64Type>(rows, field.name())
Expand Down Expand Up @@ -684,10 +676,12 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
rows,
field.name(),
),
t => Err(ArrowError::SchemaError(format!(
"TimeUnit {:?} not supported with Time64",
t
))),
t => {
return Err(ArrowError::SchemaError(format!(
"TimeUnit {:?} not supported with Time64",
t
)))
}
},
DataType::Time32(unit) => match unit {
TimeUnit::Second => self
Expand All @@ -700,33 +694,35 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
rows,
field.name(),
),
t => Err(ArrowError::SchemaError(format!(
"TimeUnit {:?} not supported with Time32",
t
))),
t => {
return Err(ArrowError::SchemaError(format!(
"TimeUnit {:?} not supported with Time32",
t
)))
}
},
DataType::Utf8 | DataType::LargeUtf8 => Ok(Arc::new(
DataType::Utf8 | DataType::LargeUtf8 => Arc::new(
rows.iter()
.map(|row| {
let maybe_value = self.field_lookup(field.name(), row);
maybe_value.map(resolve_string).transpose()
})
.collect::<ArrowResult<StringArray>>()?,
)
as ArrayRef),
DataType::Binary | DataType::LargeBinary => Ok(Arc::new(
as ArrayRef,
DataType::Binary | DataType::LargeBinary => Arc::new(
rows.iter()
.map(|row| {
let maybe_value = self.field_lookup(field.name(), row);
maybe_value.and_then(resolve_bytes)
})
.collect::<BinaryArray>(),
)
as ArrayRef),
as ArrayRef,
DataType::List(ref list_field) => {
match list_field.data_type() {
DataType::Dictionary(ref key_ty, _) => {
self.build_wrapped_list_array(rows, field.name(), key_ty)
self.build_wrapped_list_array(rows, field.name(), key_ty)?
}
_ => {
// extract rows by name
Expand All @@ -740,7 +736,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
self.build_nested_list_array::<i32>(
extracted_rows.as_slice(),
list_field,
)
)?
}
}
}
Expand All @@ -750,7 +746,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
field.name(),
key_ty,
val_ty,
),
)?,
DataType::Struct(fields) => {
let len = rows.len();
let num_bytes = bit_util::ceil(len, 8);
Expand Down Expand Up @@ -778,15 +774,17 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
.child_data(
arrays.into_iter().map(|a| a.data().clone()).collect(),
)
.build()
.unwrap();
Ok(make_array(data))
.build()?;
make_array(data)
}
_ => Err(ArrowError::SchemaError(format!(
"type {:?} not supported",
field.data_type()
))),
}
_ => {
return Err(ArrowError::SchemaError(format!(
"type {:?} not supported",
field.data_type()
)))
}
};
Ok(arr)
})
.collect();
arrays
Expand Down

0 comments on commit 924a40d

Please sign in to comment.