Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correctly calculate join output schema nullability #2803

Merged
merged 1 commit into from Jun 28, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
114 changes: 97 additions & 17 deletions datafusion/core/src/physical_plan/join_utils.rs
Expand Up @@ -156,6 +156,27 @@ impl JoinFilter {
}
}

/// Returns the output field given the input field. Outer joins may
/// insert nulls even if the input was not null
///
fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> Field {
let force_nullable = match join_type {
JoinType::Inner => false,
JoinType::Left => !is_left, // right input is padded with nulls
JoinType::Right => is_left, // left input is padded with nulls
JoinType::Full => true, // both inputs can be padded with nulls
JoinType::Semi => false, // doesn't introduce nulls
JoinType::Anti => false, // doesn't introduce nulls (or can it??)
};

if force_nullable {
// Could cleanup after https://github.com/apache/arrow-rs/issues/1934
Field::new(old_field.name(), old_field.data_type().clone(), true)
} else {
old_field.clone()
}
}

/// Creates a schema for a join operation.
/// The fields from the left side are first
pub fn build_join_schema(
Expand All @@ -165,8 +186,12 @@ pub fn build_join_schema(
) -> (Schema, Vec<ColumnIndex>) {
let (fields, column_indices): (Vec<Field>, Vec<ColumnIndex>) = match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
let left_fields =
left.fields().iter().cloned().enumerate().map(|(index, f)| {
let left_fields = left
.fields()
.iter()
.map(|f| output_join_field(f, join_type, true))
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
Expand All @@ -175,21 +200,20 @@ pub fn build_join_schema(
},
)
});
let right_fields =
right
.fields()
.iter()
.cloned()
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Right,
},
)
});
let right_fields = right
.fields()
.iter()
.map(|f| output_join_field(f, join_type, false))
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Right,
},
)
});

// left then right
left_fields.chain(right_fields).unzip()
Expand Down Expand Up @@ -323,6 +347,7 @@ impl<T: 'static> OnceFut<T> {
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;

fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> {
let left = left
Expand Down Expand Up @@ -382,4 +407,59 @@ mod tests {

assert!(check(&left, &right, on).is_ok());
}

#[test]
fn test_join_schema() -> Result<()> {
let a = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a_nulls = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let b = Schema::new(vec![Field::new("b", DataType::Int32, false)]);
let b_nulls = Schema::new(vec![Field::new("b", DataType::Int32, true)]);

let cases = vec![
(&a, &b, JoinType::Inner, &a, &b),
(&a, &b_nulls, JoinType::Inner, &a, &b_nulls),
(&a_nulls, &b, JoinType::Inner, &a_nulls, &b),
(&a_nulls, &b_nulls, JoinType::Inner, &a_nulls, &b_nulls),
// right input of a `LEFT` join can be null, regardless of input nullness
(&a, &b, JoinType::Left, &a, &b_nulls),
(&a, &b_nulls, JoinType::Left, &a, &b_nulls),
(&a_nulls, &b, JoinType::Left, &a_nulls, &b_nulls),
(&a_nulls, &b_nulls, JoinType::Left, &a_nulls, &b_nulls),
// left input of a `RIGHT` join can be null, regardless of input nullness
(&a, &b, JoinType::Right, &a_nulls, &b),
(&a, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
(&a_nulls, &b, JoinType::Right, &a_nulls, &b),
(&a_nulls, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
// Either input of a `FULL` join can be null
(&a, &b, JoinType::Full, &a_nulls, &b_nulls),
(&a, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
(&a_nulls, &b, JoinType::Full, &a_nulls, &b_nulls),
(&a_nulls, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
];

for (left_in, right_in, join_type, left_out, right_out) in cases {
let (schema, _) = build_join_schema(left_in, right_in, &join_type);

let expected_fields = left_out
.fields()
.iter()
.cloned()
.chain(right_out.fields().iter().cloned())
.collect();

let expected_schema = Schema::new(expected_fields);
assert_eq!(
schema,
expected_schema,
"Mismatch with left_in={}:{}, right_in={}:{}, join_type={:?}",
left_in.fields()[0].name(),
left_in.fields()[0].is_nullable(),
right_in.fields()[0].name(),
right_in.fields()[0].is_nullable(),
join_type
);
}

Ok(())
}
}