Skip to content

Commit

Permalink
Allow concat_batches to work with different metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Sep 8, 2023
1 parent 1d0093c commit 45da217
Showing 1 changed file with 90 additions and 3 deletions.
93 changes: 90 additions & 3 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::ArrowNativeType;
use arrow_data::transform::{Capacities, MutableArrayData};
use arrow_schema::{ArrowError, DataType, SchemaRef};
use arrow_schema::{ArrowError, DataType, Schema, SchemaRef};

fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
let mut item_capacity = 0;
Expand Down Expand Up @@ -112,7 +112,7 @@ pub fn concat_batches<'a>(
if let Some((i, _)) = batches
.iter()
.enumerate()
.find(|&(_, batch)| batch.schema() != *schema)
.find(|&(_, batch)| !concatable_schema(schema.as_ref(), batch.schema().as_ref()))
{
return Err(ArrowError::InvalidArgumentError(format!(
"batches[{i}] schema is different with argument schema.
Expand All @@ -137,12 +137,31 @@ pub fn concat_batches<'a>(
RecordBatch::try_new(schema.clone(), arrays)
}

/// Returns true if data with the `source` Schema can be placed in a
/// record batch with `target` Schema
fn concatable_schema(target: &Schema, source: &Schema) -> bool {
// ignore metadata
// https://github.com/apache/arrow-rs/issues/4799
if source.fields().len() != target.fields().len() {
return false;
}

source.fields().iter().zip(target.fields().iter()).all(
|(source_field, target_field)| {
// also ignore nullabulity as `RecordBatch::try_new()`
// will validate that
source_field.name() == target_field.name()
&& source_field.data_type() == target_field.data_type()
},
)
}

#[cfg(test)]
mod tests {
use super::*;
use arrow_array::cast::AsArray;
use arrow_schema::{Field, Schema};
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

#[test]
fn test_concat_empty_vec() {
Expand Down Expand Up @@ -680,6 +699,74 @@ mod tests {
);
}

#[test]
fn concat_record_batches_of_different_metadata() {
let metadata = HashMap::from([("foo".to_string(), "bar".to_string())]);
let field = Field::new("a", DataType::Int32, false);

let schema1 = Arc::new(Schema::new(vec![field.clone()]));

let batch1 =
RecordBatch::try_new(schema1, vec![Arc::new(Int32Array::from(vec![1]))])
.unwrap();

let schema2 = Arc::new(Schema::new(vec![field.with_metadata(metadata)]));

let batch2 =
RecordBatch::try_new(schema2, vec![Arc::new(Int32Array::from(vec![3]))])
.unwrap();

// should be able to concat batches with different metadata
let new_batch = concat_batches(&batch1.schema(), [&batch1, &batch2]).unwrap();
assert_eq!(new_batch.schema(), batch1.schema());
assert_eq!(2, new_batch.num_rows());

// using batch2 schema should also work
let new_batch = concat_batches(&batch2.schema(), [&batch1, &batch2]).unwrap();
assert_eq!(new_batch.schema(), batch2.schema());
assert_eq!(2, new_batch.num_rows());
}

#[test]
fn concat_record_batches_of_different_nullability() {
// is nullable
let field = Field::new("a", DataType::Int32, true);
let nullable_schema = Arc::new(Schema::new(vec![field.clone()]));

let batch_with_nulls = RecordBatch::try_new(
nullable_schema,
vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
)
.unwrap();

let non_nullable_schema = Arc::new(Schema::new(vec![field.with_nullable(false)]));

let batch_without_nulls = RecordBatch::try_new(
non_nullable_schema,
vec![Arc::new(Int32Array::from(vec![3]))],
)
.unwrap();

// should be able to concat batches if the schema says it is
// nullable
let new_batch = concat_batches(
&batch_with_nulls.schema(),
[&batch_with_nulls, &batch_without_nulls],
)
.unwrap();
assert_eq!(new_batch.schema(), batch_with_nulls.schema());
assert_eq!(3, new_batch.num_rows());

// should not be able to concat batches with nulls together if
// the schema says it is not nullable
let err = concat_batches(
&batch_without_nulls.schema(),
[&batch_with_nulls, &batch_without_nulls],
)
.unwrap_err();
assert_eq!(err.to_string(), "Invalid argument error: Column 'a' is declared as non-nullable but contains null values");
}

#[test]
fn concat_capacity() {
let a = Int32Array::from_iter_values(0..100);
Expand Down

0 comments on commit 45da217

Please sign in to comment.