diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index ef57be27f338..7b6704eb5e6b 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -59,8 +59,8 @@ use sqlparser::ast::{ BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator, ObjectName, Offset as SQLOffset, Query, Select, SelectItem, SetExpr, SetOperator, - ShowCreateObject, ShowStatementFilter, TableFactor, TableWithJoins, TimezoneInfo, - TrimWhereField, UnaryOperator, Value, Values as SQLValues, + ShowCreateObject, ShowStatementFilter, TableAlias, TableFactor, TableWithJoins, + TimezoneInfo, TrimWhereField, UnaryOperator, Value, Values as SQLValues, }; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{ObjectType, OrderByExpr, Statement}; @@ -370,6 +370,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &mut ctes.clone(), outer_query_schema, )?; + + // Each `WITH` block can change the column names in the last + // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). + let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?; + ctes.insert(cte_name, logical_plan); } } @@ -779,33 +784,40 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; if let Some(alias) = alias { - let columns_alias = alias.clone().columns; - if columns_alias.is_empty() { - // sqlparser-rs encodes AS t as an empty list of column alias - Ok(plan) - } else if columns_alias.len() != plan.schema().fields().len() { - Err(DataFusionError::Plan(format!( - "Source table contains {} columns but only {} names given as column alias", - plan.schema().fields().len(), - columns_alias.len(), - ))) - } else { - Ok(LogicalPlanBuilder::from(plan.clone()) - .project_with_alias( - plan.schema().fields().iter().zip(columns_alias.iter()).map( - |(field, ident)| { - col(field.name()).alias(&normalize_ident(ident)) - }, - ), - Some(normalize_ident(&alias.name)), - )? - .build()?) - } + self.apply_table_alias(plan, alias) } else { Ok(plan) } } + /// Apply the given TableAlias to the top-level projection. + fn apply_table_alias( + &self, + plan: LogicalPlan, + alias: TableAlias, + ) -> Result { + let columns_alias = alias.clone().columns; + if columns_alias.is_empty() { + // sqlparser-rs encodes AS t as an empty list of column alias + Ok(plan) + } else if columns_alias.len() != plan.schema().fields().len() { + Err(DataFusionError::Plan(format!( + "Source table contains {} columns but only {} names given as column alias", + plan.schema().fields().len(), + columns_alias.len(), + ))) + } else { + Ok(LogicalPlanBuilder::from(plan.clone()) + .project_with_alias( + plan.schema().fields().iter().zip(columns_alias.iter()).map( + |(field, ident)| col(field.name()).alias(&normalize_ident(ident)), + ), + Some(normalize_ident(&alias.name)), + )? + .build()?) + } + } + /// Generate a logic plan from selection clause, the function contain optimization for cross join to inner join /// Related PR: fn plan_selection( @@ -5013,6 +5025,67 @@ mod tests { quick_test(sql, expected) } + #[test] + fn cte_with_no_column_names() { + let sql = "WITH \ + numbers AS ( \ + SELECT 1 as a, 2 as b, 3 as c \ + ) \ + SELECT * FROM numbers;"; + + let expected = "Projection: #numbers.a, #numbers.b, #numbers.c\ + \n Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c, alias=numbers\ + \n EmptyRelation"; + + quick_test(sql, expected) + } + + #[test] + fn cte_with_column_names() { + let sql = "WITH \ + numbers(a, b, c) AS ( \ + SELECT 1, 2, 3 \ + ) \ + SELECT * FROM numbers;"; + + let expected = "Projection: #numbers.a, #numbers.b, #numbers.c\ + \n Projection: #numbers.Int64(1) AS a, #numbers.Int64(2) AS b, #numbers.Int64(3) AS c, alias=numbers\ + \n Projection: Int64(1), Int64(2), Int64(3), alias=numbers\ + \n EmptyRelation"; + + quick_test(sql, expected) + } + + #[test] + fn cte_with_column_aliases_precedence() { + // The end result should always be what CTE specification says + let sql = "WITH \ + numbers(a, b, c) AS ( \ + SELECT 1 as x, 2 as y, 3 as z \ + ) \ + SELECT * FROM numbers;"; + + let expected = "Projection: #numbers.a, #numbers.b, #numbers.c\ + \n Projection: #numbers.x AS a, #numbers.y AS b, #numbers.z AS c, alias=numbers\ + \n Projection: Int64(1) AS x, Int64(2) AS y, Int64(3) AS z, alias=numbers\ + \n EmptyRelation"; + + quick_test(sql, expected) + } + + #[test] + fn cte_unbalanced_number_of_columns() { + let sql = "WITH \ + numbers(a) AS ( \ + SELECT 1, 2, 3 \ + ) \ + SELECT * FROM numbers;"; + + let expected = "Error during planning: Source table contains 3 columns but only 1 names given as column alias"; + let result = logical_plan(sql).err().unwrap(); + assert_eq!(expected, format!("{}", result)); + } + #[test] fn aggregate_with_rollup() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)";