Skip to content

Commit

Permalink
fix: correctly calculate join output schema nullability
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jun 27, 2022
1 parent 64d0d02 commit 63673aa
Showing 1 changed file with 97 additions and 17 deletions.
114 changes: 97 additions & 17 deletions datafusion/core/src/physical_plan/join_utils.rs
Expand Up @@ -156,6 +156,27 @@ impl JoinFilter {
}
}

/// Returns the output field given the input field. Outer joins may
/// insert nulls even if the input was not null
///
fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> Field {
let force_nullable = match join_type {
JoinType::Inner => false,
JoinType::Left => !is_left, // right input is padded with nulls
JoinType::Right => is_left, // left input is padded with nulls
JoinType::Full => true, // both inputs can be padded with nulls
JoinType::Semi => false, // doesn't introduce nulls
JoinType::Anti => false, // doesn't introduce nulls (or can it??)
};

if force_nullable {
// Could cleanup after https://github.com/apache/arrow-rs/issues/1934
Field::new(old_field.name(), old_field.data_type().clone(), true)
} else {
old_field.clone()
}
}

/// Creates a schema for a join operation.
/// The fields from the left side are first
pub fn build_join_schema(
Expand All @@ -165,8 +186,12 @@ pub fn build_join_schema(
) -> (Schema, Vec<ColumnIndex>) {
let (fields, column_indices): (Vec<Field>, Vec<ColumnIndex>) = match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
let left_fields =
left.fields().iter().cloned().enumerate().map(|(index, f)| {
let left_fields = left
.fields()
.iter()
.map(|f| output_join_field(f, join_type, true))
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
Expand All @@ -175,21 +200,20 @@ pub fn build_join_schema(
},
)
});
let right_fields =
right
.fields()
.iter()
.cloned()
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Right,
},
)
});
let right_fields = right
.fields()
.iter()
.map(|f| output_join_field(f, join_type, false))
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Right,
},
)
});

// left then right
left_fields.chain(right_fields).unzip()
Expand Down Expand Up @@ -323,6 +347,7 @@ impl<T: 'static> OnceFut<T> {
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;

fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> {
let left = left
Expand Down Expand Up @@ -382,4 +407,59 @@ mod tests {

assert!(check(&left, &right, on).is_ok());
}

#[test]
fn test_join_schema() -> Result<()> {
let a = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a_nulls = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let b = Schema::new(vec![Field::new("b", DataType::Int32, false)]);
let b_nulls = Schema::new(vec![Field::new("b", DataType::Int32, true)]);

let cases = vec![
(&a, &b, JoinType::Inner, &a, &b),
(&a, &b_nulls, JoinType::Inner, &a, &b_nulls),
(&a_nulls, &b, JoinType::Inner, &a_nulls, &b),
(&a_nulls, &b_nulls, JoinType::Inner, &a_nulls, &b_nulls),
// right input of a `LEFT` join can be null, regardless of input nullness
(&a, &b, JoinType::Left, &a, &b_nulls),
(&a, &b_nulls, JoinType::Left, &a, &b_nulls),
(&a_nulls, &b, JoinType::Left, &a_nulls, &b_nulls),
(&a_nulls, &b_nulls, JoinType::Left, &a_nulls, &b_nulls),
// left input of a `RIGHT` join can be null, regardless of input nullness
(&a, &b, JoinType::Right, &a_nulls, &b),
(&a, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
(&a_nulls, &b, JoinType::Right, &a_nulls, &b),
(&a_nulls, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
// Either input of a `FULL` join can be null
(&a, &b, JoinType::Full, &a_nulls, &b_nulls),
(&a, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
(&a_nulls, &b, JoinType::Full, &a_nulls, &b_nulls),
(&a_nulls, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
];

for (left_in, right_in, join_type, left_out, right_out) in cases {
let (schema, _) = build_join_schema(left_in, right_in, &join_type);

let expected_fields = left_out
.fields()
.iter()
.cloned()
.chain(right_out.fields().iter().cloned())
.collect();

let expected_schema = Schema::new(expected_fields);
assert_eq!(
schema,
expected_schema,
"Mismatch with left_in={}:{}, right_in={}:{}, join_type={:?}",
left_in.fields()[0].name(),
left_in.fields()[0].is_nullable(),
right_in.fields()[0].name(),
right_in.fields()[0].is_nullable(),
join_type
);
}

Ok(())
}
}

0 comments on commit 63673aa

Please sign in to comment.