Skip to content

Commit

Permalink
Added explicit Athena dialect, support for empty array literals
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Allen committed Jul 5, 2022
1 parent 6876853 commit e16a085
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 3 deletions.
28 changes: 28 additions & 0 deletions src/dialect/athena.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use crate::dialect::Dialect;

#[derive(Debug)]
pub struct AthenaSqlDialect {}

// Basically a knock-off of the Hive SQL Dialect
impl Dialect for AthenaSqlDialect {
fn is_delimited_identifier_start(&self, ch: char) -> bool {
(ch == '"') || (ch == '`')
}

fn is_identifier_start(&self, ch: char) -> bool {
('a'..='z').contains(&ch)
|| ('A'..='Z').contains(&ch)
|| ('0'..='9').contains(&ch)
|| ch == '$'
}

fn is_identifier_part(&self, ch: char) -> bool {
('a'..='z').contains(&ch)
|| ('A'..='Z').contains(&ch)
|| ('0'..='9').contains(&ch)
|| ch == '_'
|| ch == '$'
|| ch == '{'
|| ch == '}'
}
}
2 changes: 2 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// limitations under the License.

mod ansi;
mod athena;
mod bigquery;
mod clickhouse;
mod generic;
Expand All @@ -28,6 +29,7 @@ use core::iter::Peekable;
use core::str::Chars;

pub use self::ansi::AnsiDialect;
pub use self::athena::AthenaSqlDialect;
pub use self::bigquery::BigQueryDialect;
pub use self::clickhouse::ClickHouseDialect;
pub use self::generic::GenericDialect;
Expand Down
14 changes: 11 additions & 3 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -883,9 +883,17 @@ impl<'a> Parser<'a> {
/// Parses an array expression `[ex1, ex2, ..]`
/// if `named` is `true`, came from an expression like `ARRAY[ex1, ex2]`
pub fn parse_array_expr(&mut self, named: bool) -> Result<Expr, ParserError> {
let exprs = self.parse_comma_separated(Parser::parse_expr)?;
self.expect_token(&Token::RBracket)?;
Ok(Expr::Array(Array { elem: exprs, named }))
if self.peek_token() == Token::RBracket {
let _ = self.next_token();
Ok(Expr::Array( Array {
elem: vec![],
named,
}))
} else {
let exprs = self.parse_comma_separated(Parser::parse_expr)?;
self.expect_token(&Token::RBracket)?;
Ok(Expr::Array(Array { elem: exprs, named }))
}
}

/// Parse a SQL LISTAGG expression, e.g. `LISTAGG(...) WITHIN GROUP (ORDER BY ...)`.
Expand Down
243 changes: 243 additions & 0 deletions tests/sqlparser_athena.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#![warn(clippy::all)]

//! Test SQL syntax specific to Athena. The parser based on the generic dialect
//! is also tested (on the inputs it can handle).

#[macro_use]
mod test_utils;
use test_utils::*;

use sqlparser::ast::{Expr, Value};
use sqlparser::dialect::{AthenaSqlDialect};

fn athena() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(AthenaSqlDialect {})],
}
}

#[test]
fn parse_array_value_expr() {
#[cfg(feature = "bigdecimal")]
let num: Vec<Expr> = (0..=10)
.into_iter()
.map(|s| Expr::Value(Value::Number(bigdecimal::BigDecimal::from(s), false)))
.collect();
#[cfg(not(feature = "bigdecimal"))]
let num: Vec<Expr> = (0..=10)
.into_iter()
.map(|s| Expr::Value(Value::Number(s.to_string(), false)))
.collect();
let sql = "SELECT ARRAY[1, 2, 3]";
let select = athena().verified_only_select(sql);
assert_eq!(
// Why is this named: true?
&Expr::Array(sqlparser::ast::Array { elem: vec![num[1].clone(), num[2].clone(), num[3].clone()], named: true}),
expr_from_projection(only(&select.projection)),
);
let sql = "SELECT ARRAY[]";
let select = athena().verified_only_select(sql);
assert_eq!(
&Expr::Array(sqlparser::ast::Array { elem: vec![], named: true}),
expr_from_projection(only(&select.projection)),
);
}

#[test]
fn parse_table_create() {
let sql = r#"CREATE TABLE IF NOT EXISTS db.table (a BIGINT, b STRING, c TIMESTAMP) PARTITIONED BY (d STRING, e TIMESTAMP) STORED AS ORC LOCATION 's3://...' TBLPROPERTIES ("prop" = "2", "asdf" = '1234', 'asdf' = "1234", "asdf" = 2)"#;
let iof = r#"CREATE TABLE IF NOT EXISTS db.table (a BIGINT, b STRING, c TIMESTAMP) PARTITIONED BY (d STRING, e TIMESTAMP) STORED AS INPUTFORMAT 'org.apache.hadoop.hive.ql.io.orc.OrcInputFormat' OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat' LOCATION 's3://...'"#;

athena().verified_stmt(sql);
athena().verified_stmt(iof);
}

#[test]
fn parse_insert_overwrite() {
let insert_partitions = r#"INSERT OVERWRITE TABLE db.new_table PARTITION (a = '1', b) SELECT a, b, c FROM db.table"#;
athena().verified_stmt(insert_partitions);
}

#[test]
fn test_truncate() {
let truncate = r#"TRUNCATE TABLE db.table"#;
athena().verified_stmt(truncate);
}

#[test]
fn parse_analyze() {
let analyze = r#"ANALYZE TABLE db.table_name PARTITION (a = '1234', b) COMPUTE STATISTICS NOSCAN CACHE METADATA"#;
athena().verified_stmt(analyze);
}

#[test]
fn parse_analyze_for_columns() {
let analyze =
r#"ANALYZE TABLE db.table_name PARTITION (a = '1234', b) COMPUTE STATISTICS FOR COLUMNS"#;
athena().verified_stmt(analyze);
}

#[test]
fn parse_msck() {
let msck = r#"MSCK REPAIR TABLE db.table_name ADD PARTITIONS"#;
let msck2 = r#"MSCK REPAIR TABLE db.table_name"#;
athena().verified_stmt(msck);
athena().verified_stmt(msck2);
}

#[test]
fn parse_set() {
let set = "SET HIVEVAR:name = a, b, c_d";
athena().verified_stmt(set);
}

#[test]
fn test_spaceship() {
let spaceship = "SELECT * FROM db.table WHERE a <=> b";
athena().verified_stmt(spaceship);
}

#[test]
fn parse_with_cte() {
let with = "WITH a AS (SELECT * FROM b) INSERT INTO TABLE db.table_table PARTITION (a) SELECT * FROM b";
athena().verified_stmt(with);
}

#[test]
fn drop_table_purge() {
let purge = "DROP TABLE db.table_name PURGE";
athena().verified_stmt(purge);
}

#[test]
fn create_table_like() {
let like = "CREATE TABLE db.table_name LIKE db.other_table";
athena().verified_stmt(like);
}

// Turning off this test until we can parse identifiers starting with numbers :(
#[test]
fn test_identifier() {
let between = "SELECT a AS 3_barrr_asdf FROM db.table_name";
athena().verified_stmt(between);
}

#[test]
fn test_alter_partition() {
let alter = "ALTER TABLE db.table PARTITION (a = 2) RENAME TO PARTITION (a = 1)";
athena().verified_stmt(alter);
}

#[test]
fn test_add_partition() {
let add = "ALTER TABLE db.table ADD IF NOT EXISTS PARTITION (a = 'asdf', b = 2)";
athena().verified_stmt(add);
}

#[test]
fn test_drop_partition() {
let drop = "ALTER TABLE db.table DROP PARTITION (a = 1)";
athena().verified_stmt(drop);
}

#[test]
fn test_drop_if_exists() {
let drop = "ALTER TABLE db.table DROP IF EXISTS PARTITION (a = 'b', c = 'd')";
athena().verified_stmt(drop);
}

#[test]
fn test_cluster_by() {
let cluster = "SELECT a FROM db.table CLUSTER BY a, b";
athena().verified_stmt(cluster);
}

#[test]
fn test_distribute_by() {
let cluster = "SELECT a FROM db.table DISTRIBUTE BY a, b";
athena().verified_stmt(cluster);
}

#[test]
fn no_join_condition() {
let join = "SELECT a, b FROM db.table_name JOIN a";
athena().verified_stmt(join);
}

#[test]
fn columns_after_partition() {
let query = "INSERT INTO db.table_name PARTITION (a, b) (c, d) SELECT a, b, c, d FROM db.table";
athena().verified_stmt(query);
}

#[test]
fn long_numerics() {
let query = r#"SELECT MIN(MIN(10, 5), 1L) AS a"#;
athena().verified_stmt(query);
}

#[test]
fn decimal_precision() {
let query = "SELECT CAST(a AS DECIMAL(18,2)) FROM db.table";
let expected = "SELECT CAST(a AS NUMERIC(18,2)) FROM db.table";
athena().one_statement_parses_to(query, expected);
}

#[test]
fn create_temp_table() {
let query = "CREATE TEMPORARY TABLE db.table (a INT NOT NULL)";
let query2 = "CREATE TEMP TABLE db.table (a INT NOT NULL)";

athena().verified_stmt(query);
athena().one_statement_parses_to(query2, query);
}

#[test]
fn create_local_directory() {
let query =
"INSERT OVERWRITE LOCAL DIRECTORY '/home/blah' STORED AS TEXTFILE SELECT * FROM db.table";
athena().verified_stmt(query);
}

#[test]
fn lateral_view() {
let view = "SELECT a FROM db.table LATERAL VIEW explode(a) t AS j, P LATERAL VIEW OUTER explode(a) t AS a, b WHERE a = 1";
athena().verified_stmt(view);
}

#[test]
fn sort_by() {
let sort_by = "SELECT * FROM db.table SORT BY a";
athena().verified_stmt(sort_by);
}

#[test]
fn rename_table() {
let rename = "ALTER TABLE db.table_name RENAME TO db.table_2";
athena().verified_stmt(rename);
}

#[test]
fn map_access() {
let rename = r#"SELECT a.b["asdf"] FROM db.table WHERE a = 2"#;
athena().verified_stmt(rename);
}

#[test]
fn from_cte() {
let rename =
"WITH cte AS (SELECT * FROM a.b) FROM cte INSERT INTO TABLE a.b PARTITION (a) SELECT *";
println!("{}", athena().verified_stmt(rename));
}

0 comments on commit e16a085

Please sign in to comment.