Skip to content

Commit

Permalink
Do not check schema for equality in concat_batches (#4815)
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Sep 16, 2023
1 parent d2be733 commit d960379
Showing 1 changed file with 36 additions and 36 deletions.
72 changes: 36 additions & 36 deletions arrow-select/src/concat.rs
Expand Up @@ -159,7 +159,12 @@ fn concat_fallback(
Ok(make_array(mutable.freeze()))
}

/// Concatenates `batches` together into a single record batch.
/// Concatenates `batches` together into a single [`RecordBatch`].
///
/// The output batch has the specified `schemas`; The schema of the
/// input are ignored.
///
/// Returns an error if the types of underlying arrays are different.
pub fn concat_batches<'a>(
schema: &SchemaRef,
input_batches: impl IntoIterator<Item = &'a RecordBatch>,
Expand All @@ -176,20 +181,6 @@ pub fn concat_batches<'a>(
if batches.is_empty() {
return Ok(RecordBatch::new_empty(schema.clone()));
}
if let Some((i, _)) = batches
.iter()
.enumerate()
.find(|&(_, batch)| batch.schema() != *schema)
{
return Err(ArrowError::InvalidArgumentError(format!(
"batches[{i}] schema is different with argument schema.
batches[{i}] schema: {:?},
argument schema: {:?}
",
batches[i].schema(),
*schema
)));
}
let field_num = schema.fields().len();
let mut arrays = Vec::with_capacity(field_num);
for i in 0..field_num {
Expand Down Expand Up @@ -727,36 +718,45 @@ mod tests {
}

#[test]
fn concat_record_batches_of_different_schemas() {
let schema1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));
let schema2 = Arc::new(Schema::new(vec![
Field::new("c", DataType::Int32, false),
Field::new("d", DataType::Utf8, false),
]));
fn concat_record_batches_of_different_schemas_but_compatible_data() {
let schema1 =
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
// column names differ
let schema2 =
Arc::new(Schema::new(vec![Field::new("c", DataType::Int32, false)]));
let batch1 = RecordBatch::try_new(
schema1.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["a", "b"])),
],
vec![Arc::new(Int32Array::from(vec![1, 2]))],
)
.unwrap();
let batch2 =
RecordBatch::try_new(schema2, vec![Arc::new(Int32Array::from(vec![3, 4]))])
.unwrap();
// concat_batches simply uses the schema provided
let batch = concat_batches(&schema1, [&batch1, &batch2]).unwrap();
assert_eq!(batch.schema().as_ref(), schema1.as_ref());
assert_eq!(4, batch.num_rows());
}

#[test]
fn concat_record_batches_of_different_schemas_incompatible_data() {
let schema1 =
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
// column names differ
let schema2 = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, false)]));
let batch1 = RecordBatch::try_new(
schema1.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2]))],
)
.unwrap();
let batch2 = RecordBatch::try_new(
schema2,
vec![
Arc::new(Int32Array::from(vec![3, 4])),
Arc::new(StringArray::from(vec!["c", "d"])),
],
vec![Arc::new(StringArray::from(vec!["foo", "bar"]))],
)
.unwrap();

let error = concat_batches(&schema1, [&batch1, &batch2]).unwrap_err();
assert_eq!(
error.to_string(),
"Invalid argument error: batches[1] schema is different with argument schema.\n batches[1] schema: Schema { fields: [Field { name: \"c\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"d\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} },\n argument schema: Schema { fields: [Field { name: \"a\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"b\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }\n "
);
assert_eq!(error.to_string(), "Invalid argument error: It is not possible to concatenate arrays of different data types.");
}

#[test]
Expand Down

0 comments on commit d960379

Please sign in to comment.