diff --git a/src/ast/data_type.rs b/src/ast/data_type.rs index e906383fc..c0b2f977b 100644 --- a/src/ast/data_type.rs +++ b/src/ast/data_type.rs @@ -141,7 +141,7 @@ pub enum DataType { /// Custom type such as enums Custom(ObjectName, Vec), /// Arrays - Array(Box), + Array(Option>), /// Enums Enum(Vec), /// Set @@ -232,7 +232,13 @@ impl fmt::Display for DataType { DataType::Text => write!(f, "TEXT"), DataType::String => write!(f, "STRING"), DataType::Bytea => write!(f, "BYTEA"), - DataType::Array(ty) => write!(f, "{}[]", ty), + DataType::Array(ty) => { + if let Some(t) = &ty { + write!(f, "{}[]", t) + } else { + write!(f, "ARRAY") + } + } DataType::Custom(ty, modifiers) => { if modifiers.is_empty() { write!(f, "{}", ty) diff --git a/src/parser.rs b/src/parser.rs index 161062458..f7cadbef1 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -3672,13 +3672,17 @@ impl<'a> Parser<'a> { Keyword::ENUM => Ok(DataType::Enum(self.parse_string_values()?)), Keyword::SET => Ok(DataType::Set(self.parse_string_values()?)), Keyword::ARRAY => { - // Hive array syntax. Note that nesting arrays - or other Hive syntax - // that ends with > will fail due to "C++" problem - >> is parsed as - // Token::ShiftRight - self.expect_token(&Token::Lt)?; - let inside_type = self.parse_data_type()?; - self.expect_token(&Token::Gt)?; - Ok(DataType::Array(Box::new(inside_type))) + if dialect_of!(self is SnowflakeDialect) { + Ok(DataType::Array(None)) + } else { + // Hive array syntax. Note that nesting arrays - or other Hive syntax + // that ends with > will fail due to "C++" problem - >> is parsed as + // Token::ShiftRight + self.expect_token(&Token::Lt)?; + let inside_type = self.parse_data_type()?; + self.expect_token(&Token::Gt)?; + Ok(DataType::Array(Some(Box::new(inside_type)))) + } } _ => { self.prev_token(); @@ -3697,7 +3701,7 @@ impl<'a> Parser<'a> { // Keyword::ARRAY syntax from above while self.consume_token(&Token::LBracket) { self.expect_token(&Token::RBracket)?; - data = DataType::Array(Box::new(data)) + data = DataType::Array(Some(Box::new(data))) } Ok(data) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index f6dfe1590..a3d6ea766 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -24,7 +24,7 @@ use sqlparser::ast::SelectItem::UnnamedExpr; use sqlparser::ast::*; use sqlparser::dialect::{ AnsiDialect, BigQueryDialect, ClickHouseDialect, GenericDialect, HiveDialect, MsSqlDialect, - PostgreSqlDialect, SQLiteDialect, SnowflakeDialect, + MySqlDialect, PostgreSqlDialect, SQLiteDialect, SnowflakeDialect, }; use sqlparser::keywords::ALL_KEYWORDS; use sqlparser::parser::{Parser, ParserError}; @@ -2099,7 +2099,7 @@ fn parse_create_table_hive_array() { }, ColumnDef { name: Ident::new("val"), - data_type: DataType::Array(Box::new(DataType::Int(None))), + data_type: DataType::Array(Some(Box::new(DataType::Int(None)))), collation: None, options: vec![], }, @@ -2109,12 +2109,20 @@ fn parse_create_table_hive_array() { _ => unreachable!(), } - let res = - parse_sql_statements("CREATE TABLE IF NOT EXISTS something (name int, val array, found: )")); + // SnowflakeDialect using array diffrent + let dialects = TestedDialects { + dialects: vec![ + Box::new(PostgreSqlDialect {}), + Box::new(HiveDialect {}), + Box::new(MySqlDialect {}), + ], + }; + let sql = "CREATE TABLE IF NOT EXISTS something (name int, val array, found: )".to_string()) + ); } #[test] diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index c5d4bd0fc..869106220 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1281,9 +1281,9 @@ fn parse_array_index_expr() { })], named: true, })), - data_type: DataType::Array(Box::new(DataType::Array(Box::new(DataType::Int( - None - ))))) + data_type: DataType::Array(Some(Box::new(DataType::Array(Some(Box::new( + DataType::Int(None) + )))))) }))), indexes: vec![num[1].clone(), num[2].clone()], }, diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index 64fff62f9..b182becb0 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -143,6 +143,19 @@ fn test_single_table_in_parenthesis_with_alias() { ); } +#[test] +fn parse_array() { + let sql = "SELECT CAST(a AS ARRAY) FROM customer"; + let select = snowflake().verified_only_select(sql); + assert_eq!( + &Expr::Cast { + expr: Box::new(Expr::Identifier(Ident::new("a"))), + data_type: DataType::Array(None), + }, + expr_from_projection(only(&select.projection)) + ); +} + #[test] fn parse_json_using_colon() { let sql = "SELECT a:b FROM t";