diff --git a/datafusion/core/src/physical_plan/join_utils.rs b/datafusion/core/src/physical_plan/join_utils.rs index b2cc5654a69..1c0440db80d 100644 --- a/datafusion/core/src/physical_plan/join_utils.rs +++ b/datafusion/core/src/physical_plan/join_utils.rs @@ -156,6 +156,27 @@ impl JoinFilter { } } +/// Returns the output field given the input field. Outer joins may +/// insert nulls even if the input was not null +/// +fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> Field { + let force_nullable = match join_type { + JoinType::Inner => false, + JoinType::Left => !is_left, // right input is padded with nulls + JoinType::Right => is_left, // left input is padded with nulls + JoinType::Full => true, // both inputs can be padded with nulls + JoinType::Semi => false, // doesn't introduce nulls + JoinType::Anti => false, // doesn't introduce nulls (or can it??) + }; + + if force_nullable { + // Could cleanup after https://github.com/apache/arrow-rs/issues/1934 + Field::new(old_field.name(), old_field.data_type().clone(), true) + } else { + old_field.clone() + } +} + /// Creates a schema for a join operation. /// The fields from the left side are first pub fn build_join_schema( @@ -165,8 +186,12 @@ pub fn build_join_schema( ) -> (Schema, Vec) { let (fields, column_indices): (Vec, Vec) = match join_type { JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - let left_fields = - left.fields().iter().cloned().enumerate().map(|(index, f)| { + let left_fields = left + .fields() + .iter() + .map(|f| output_join_field(f, join_type, true)) + .enumerate() + .map(|(index, f)| { ( f, ColumnIndex { @@ -175,21 +200,20 @@ pub fn build_join_schema( }, ) }); - let right_fields = - right - .fields() - .iter() - .cloned() - .enumerate() - .map(|(index, f)| { - ( - f, - ColumnIndex { - index, - side: JoinSide::Right, - }, - ) - }); + let right_fields = right + .fields() + .iter() + .map(|f| output_join_field(f, join_type, false)) + .enumerate() + .map(|(index, f)| { + ( + f, + ColumnIndex { + index, + side: JoinSide::Right, + }, + ) + }); // left then right left_fields.chain(right_fields).unzip() @@ -323,6 +347,7 @@ impl OnceFut { #[cfg(test)] mod tests { use super::*; + use arrow::datatypes::DataType; fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { let left = left @@ -382,4 +407,59 @@ mod tests { assert!(check(&left, &right, on).is_ok()); } + + #[test] + fn test_join_schema() -> Result<()> { + let a = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a_nulls = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let b = Schema::new(vec![Field::new("b", DataType::Int32, false)]); + let b_nulls = Schema::new(vec![Field::new("b", DataType::Int32, true)]); + + let cases = vec![ + (&a, &b, JoinType::Inner, &a, &b), + (&a, &b_nulls, JoinType::Inner, &a, &b_nulls), + (&a_nulls, &b, JoinType::Inner, &a_nulls, &b), + (&a_nulls, &b_nulls, JoinType::Inner, &a_nulls, &b_nulls), + // right input of a `LEFT` join can be null, regardless of input nullness + (&a, &b, JoinType::Left, &a, &b_nulls), + (&a, &b_nulls, JoinType::Left, &a, &b_nulls), + (&a_nulls, &b, JoinType::Left, &a_nulls, &b_nulls), + (&a_nulls, &b_nulls, JoinType::Left, &a_nulls, &b_nulls), + // left input of a `RIGHT` join can be null, regardless of input nullness + (&a, &b, JoinType::Right, &a_nulls, &b), + (&a, &b_nulls, JoinType::Right, &a_nulls, &b_nulls), + (&a_nulls, &b, JoinType::Right, &a_nulls, &b), + (&a_nulls, &b_nulls, JoinType::Right, &a_nulls, &b_nulls), + // Either input of a `FULL` join can be null + (&a, &b, JoinType::Full, &a_nulls, &b_nulls), + (&a, &b_nulls, JoinType::Full, &a_nulls, &b_nulls), + (&a_nulls, &b, JoinType::Full, &a_nulls, &b_nulls), + (&a_nulls, &b_nulls, JoinType::Full, &a_nulls, &b_nulls), + ]; + + for (left_in, right_in, join_type, left_out, right_out) in cases { + let (schema, _) = build_join_schema(left_in, right_in, &join_type); + + let expected_fields = left_out + .fields() + .iter() + .cloned() + .chain(right_out.fields().iter().cloned()) + .collect(); + + let expected_schema = Schema::new(expected_fields); + assert_eq!( + schema, + expected_schema, + "Mismatch with left_in={}:{}, right_in={}:{}, join_type={:?}", + left_in.fields()[0].name(), + left_in.fields()[0].is_nullable(), + right_in.fields()[0].name(), + right_in.fields()[0].is_nullable(), + join_type + ); + } + + Ok(()) + } }