Skip to content

Commit

Permalink
Infer values for inserts (#4977)
Browse files Browse the repository at this point in the history
* Infer values for updates

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
avantgardnerio and alamb committed Jan 23, 2023
1 parent 624f02d commit 5d4038a
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 2 deletions.
33 changes: 31 additions & 2 deletions datafusion/sql/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use datafusion_expr::{
};
use sqlparser::ast;
use sqlparser::ast::{
Assignment, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query,
Assignment, Expr as SQLExpr, Expr, Ident, ObjectName, ObjectType, Query, SetExpr,
ShowCreateObject, ShowStatementFilter, Statement, TableFactor, TableWithJoins,
UnaryOperator, Value,
};
Expand Down Expand Up @@ -762,8 +762,37 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let arrow_schema = (*provider.schema()).clone();
let table_schema = Arc::new(DFSchema::try_from(arrow_schema)?);

// infer types for Values clause... other types should be resolvable the regular way
let mut prepare_param_data_types = BTreeMap::new();
if let SetExpr::Values(ast::Values { rows, .. }) = (*source.body).clone() {
for row in rows.iter() {
for (idx, val) in row.iter().enumerate() {
if let ast::Expr::Value(Value::Placeholder(name)) = val {
let name =
name.replace('$', "").parse::<usize>().map_err(|_| {
DataFusionError::Plan(format!(
"Can't parse placeholder: {name}"
))
})? - 1;
let col = columns.get(idx).ok_or_else(|| {
DataFusionError::Plan(format!(
"Placeholder ${} refers to a non existent column",
idx + 1
))
})?;
let field =
table_schema.field_with_name(None, col.value.as_str())?;
let dt = field.field().data_type().clone();
let _ = prepare_param_data_types.insert(name, dt);
}
}
}
}
let prepare_param_data_types = prepare_param_data_types.into_values().collect();

// Projection
let mut planner_context = PlannerContext::new();
let mut planner_context =
PlannerContext::new_with_prepare_param_data_types(prepare_param_data_types);
let source = self.query_to_plan(*source, &mut planner_context)?;
if columns.len() != source.schema().fields().len() {
Err(DataFusionError::Plan(
Expand Down
69 changes: 69 additions & 0 deletions datafusion/sql/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3390,6 +3390,75 @@ Dml: op=[Update] table=[person]
prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
}

#[test]
fn test_prepare_statement_insert_infer() {
let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)";

let expected_plan = r#"
Dml: op=[Insert] table=[person]
Projection: column1 AS id, column2 AS first_name, column3 AS last_name
Values: ($1, $2, $3)
"#
.trim();

let expected_dt = "[Int32]";
let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt);

let actual_types = plan.get_parameter_types().unwrap();
let expected_types = HashMap::from([
("$1".to_string(), Some(DataType::UInt32)),
("$2".to_string(), Some(DataType::Utf8)),
("$3".to_string(), Some(DataType::Utf8)),
]);
assert_eq!(actual_types, expected_types);

// replace params with values
let param_values = vec![
ScalarValue::UInt32(Some(1)),
ScalarValue::Utf8(Some("Alan".to_string())),
ScalarValue::Utf8(Some("Turing".to_string())),
];
let expected_plan = r#"
Dml: op=[Insert] table=[person]
Projection: column1 AS id, column2 AS first_name, column3 AS last_name
Values: (UInt32(1), Utf8("Alan"), Utf8("Turing"))
"#
.trim();
let plan = plan.replace_params_with_values(&param_values).unwrap();

prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan);
}

#[test]
#[should_panic(expected = "Placeholder $4 refers to a non existent column")]
fn test_prepare_statement_insert_infer_gt() {
let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3, $4)";

let expected_plan = r#""#.trim();
let expected_dt = "[Int32]";
let _ = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
}

#[test]
#[should_panic(expected = "value: Plan(\"Column count doesn't match insert query!\")")]
fn test_prepare_statement_insert_infer_lt() {
let sql = "insert into person (id, first_name, last_name) values ($1, $2)";

let expected_plan = r#""#.trim();
let expected_dt = "[Int32]";
let _ = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
}

#[test]
#[should_panic(expected = "value: Plan(\"Placeholder type could not be resolved\")")]
fn test_prepare_statement_insert_infer_gap() {
let sql = "insert into person (id, first_name, last_name) values ($2, $4, $6)";

let expected_plan = r#""#.trim();
let expected_dt = "[Int32]";
let _ = prepare_stmt_quick_test(sql, expected_plan, expected_dt);
}

#[test]
fn test_prepare_statement_to_plan_one_param() {
let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1";
Expand Down

0 comments on commit 5d4038a

Please sign in to comment.