diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 276189f95..94f44edf7 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1405,11 +1405,15 @@ pub enum Statement { /// CREATE FUNCTION /// /// Hive: https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction + /// Postgres: https://www.postgresql.org/docs/15/sql-createfunction.html CreateFunction { + or_replace: bool, temporary: bool, name: ObjectName, - class_name: String, - using: Option, + args: Option>, + return_type: Option, + /// Optional parameters. + params: CreateFunctionBody, }, /// `ASSERT [AS ]` Assert { @@ -1866,19 +1870,26 @@ impl fmt::Display for Statement { Ok(()) } Statement::CreateFunction { + or_replace, temporary, name, - class_name, - using, + args, + return_type, + params, } => { write!( f, - "CREATE {temp}FUNCTION {name} AS '{class_name}'", + "CREATE {or_replace}{temp}FUNCTION {name}", temp = if *temporary { "TEMPORARY " } else { "" }, + or_replace = if *or_replace { "OR REPLACE " } else { "" }, )?; - if let Some(u) = using { - write!(f, " {}", u)?; + if let Some(args) = args { + write!(f, "({})", display_comma_separated(args))?; } + if let Some(return_type) = return_type { + write!(f, " RETURNS {}", return_type)?; + } + write!(f, "{params}")?; Ok(()) } Statement::CreateView { @@ -3679,6 +3690,131 @@ impl fmt::Display for ContextModifier { } } +/// Function argument in CREATE FUNCTION. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct CreateFunctionArg { + pub mode: Option, + pub name: Option, + pub data_type: DataType, + pub default_expr: Option, +} + +impl CreateFunctionArg { + /// Returns an unnamed argument. + pub fn unnamed(data_type: DataType) -> Self { + Self { + mode: None, + name: None, + data_type, + default_expr: None, + } + } + + /// Returns an argument with name. + pub fn with_name(name: &str, data_type: DataType) -> Self { + Self { + mode: None, + name: Some(name.into()), + data_type, + default_expr: None, + } + } +} + +impl fmt::Display for CreateFunctionArg { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(mode) = &self.mode { + write!(f, "{} ", mode)?; + } + if let Some(name) = &self.name { + write!(f, "{} ", name)?; + } + write!(f, "{}", self.data_type)?; + if let Some(default_expr) = &self.default_expr { + write!(f, " = {}", default_expr)?; + } + Ok(()) + } +} + +/// The mode of an argument in CREATE FUNCTION. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ArgMode { + In, + Out, + InOut, +} + +impl fmt::Display for ArgMode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ArgMode::In => write!(f, "IN"), + ArgMode::Out => write!(f, "OUT"), + ArgMode::InOut => write!(f, "INOUT"), + } + } +} + +/// These attributes inform the query optimizer about the behavior of the function. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum FunctionBehavior { + Immutable, + Stable, + Volatile, +} + +impl fmt::Display for FunctionBehavior { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + FunctionBehavior::Immutable => write!(f, "IMMUTABLE"), + FunctionBehavior::Stable => write!(f, "STABLE"), + FunctionBehavior::Volatile => write!(f, "VOLATILE"), + } + } +} + +/// Postgres: https://www.postgresql.org/docs/15/sql-createfunction.html +#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct CreateFunctionBody { + /// LANGUAGE lang_name + pub language: Option, + /// IMMUTABLE | STABLE | VOLATILE + pub behavior: Option, + /// AS 'definition' + /// + /// Note that Hive's `AS class_name` is also parsed here. + pub as_: Option, + /// RETURN expression + pub return_: Option, + /// USING ... (Hive only) + pub using: Option, +} + +impl fmt::Display for CreateFunctionBody { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(language) = &self.language { + write!(f, " LANGUAGE {language}")?; + } + if let Some(behavior) = &self.behavior { + write!(f, " {behavior}")?; + } + if let Some(definition) = &self.as_ { + write!(f, " AS '{definition}'")?; + } + if let Some(expr) = &self.return_ { + write!(f, " RETURN {expr}")?; + } + if let Some(using) = &self.using { + write!(f, " {using}")?; + } + Ok(()) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum CreateFunctionUsing { diff --git a/src/keywords.rs b/src/keywords.rs index 9246411aa..1fc15fda0 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -284,6 +284,7 @@ define_keywords!( IF, IGNORE, ILIKE, + IMMUTABLE, IN, INCREMENT, INDEX, @@ -518,6 +519,7 @@ define_keywords!( SQLSTATE, SQLWARNING, SQRT, + STABLE, START, STATIC, STATISTICS, @@ -604,6 +606,7 @@ define_keywords!( VERSIONING, VIEW, VIRTUAL, + VOLATILE, WEEK, WHEN, WHENEVER, diff --git a/src/parser.rs b/src/parser.rs index 260dbfbeb..f37f3d1e6 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -2026,9 +2026,11 @@ impl<'a> Parser<'a> { self.parse_create_view(or_replace) } else if self.parse_keyword(Keyword::EXTERNAL) { self.parse_create_external_table(or_replace) + } else if self.parse_keyword(Keyword::FUNCTION) { + self.parse_create_function(or_replace, temporary) } else if or_replace { self.expected( - "[EXTERNAL] TABLE or [MATERIALIZED] VIEW after CREATE OR REPLACE", + "[EXTERNAL] TABLE or [MATERIALIZED] VIEW or FUNCTION after CREATE OR REPLACE", self.peek_token(), ) } else if self.parse_keyword(Keyword::INDEX) { @@ -2041,8 +2043,6 @@ impl<'a> Parser<'a> { self.parse_create_schema() } else if self.parse_keyword(Keyword::DATABASE) { self.parse_create_database() - } else if dialect_of!(self is HiveDialect) && self.parse_keyword(Keyword::FUNCTION) { - self.parse_create_function(temporary) } else if self.parse_keyword(Keyword::ROLE) { self.parse_create_role() } else if self.parse_keyword(Keyword::SEQUENCE) { @@ -2253,20 +2253,126 @@ impl<'a> Parser<'a> { } } - pub fn parse_create_function(&mut self, temporary: bool) -> Result { - let name = self.parse_object_name()?; - self.expect_keyword(Keyword::AS)?; - let class_name = self.parse_literal_string()?; - let using = self.parse_optional_create_function_using()?; + pub fn parse_create_function( + &mut self, + or_replace: bool, + temporary: bool, + ) -> Result { + if dialect_of!(self is HiveDialect) { + let name = self.parse_object_name()?; + self.expect_keyword(Keyword::AS)?; + let class_name = self.parse_literal_string()?; + let params = CreateFunctionBody { + as_: Some(class_name), + using: self.parse_optional_create_function_using()?, + ..Default::default() + }; - Ok(Statement::CreateFunction { - temporary, + Ok(Statement::CreateFunction { + or_replace, + temporary, + name, + args: None, + return_type: None, + params, + }) + } else if dialect_of!(self is PostgreSqlDialect) { + let name = self.parse_object_name()?; + self.expect_token(&Token::LParen)?; + let args = self.parse_comma_separated(Parser::parse_create_function_arg)?; + self.expect_token(&Token::RParen)?; + + let return_type = if self.parse_keyword(Keyword::RETURNS) { + Some(self.parse_data_type()?) + } else { + None + }; + + let params = self.parse_create_function_body()?; + + Ok(Statement::CreateFunction { + or_replace, + temporary, + name, + args: Some(args), + return_type, + params, + }) + } else { + self.prev_token(); + self.expected("an object type after CREATE", self.peek_token()) + } + } + + fn parse_create_function_arg(&mut self) -> Result { + let mode = if self.parse_keyword(Keyword::IN) { + Some(ArgMode::In) + } else if self.parse_keyword(Keyword::OUT) { + Some(ArgMode::Out) + } else if self.parse_keyword(Keyword::INOUT) { + Some(ArgMode::InOut) + } else { + None + }; + + // parse: [ argname ] argtype + let mut name = None; + let mut data_type = self.parse_data_type()?; + if let DataType::Custom(n, _) = &data_type { + // the first token is actually a name + name = Some(n.0[0].clone()); + data_type = self.parse_data_type()?; + } + + let default_expr = if self.parse_keyword(Keyword::DEFAULT) || self.consume_token(&Token::Eq) + { + Some(self.parse_expr()?) + } else { + None + }; + Ok(CreateFunctionArg { + mode, name, - class_name, - using, + data_type, + default_expr, }) } + fn parse_create_function_body(&mut self) -> Result { + let mut body = CreateFunctionBody::default(); + loop { + fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { + if field.is_some() { + return Err(ParserError::ParserError(format!( + "{name} specified more than once", + ))); + } + Ok(()) + } + if self.parse_keyword(Keyword::AS) { + ensure_not_set(&body.as_, "AS")?; + body.as_ = Some(self.parse_literal_string()?); + } else if self.parse_keyword(Keyword::LANGUAGE) { + ensure_not_set(&body.language, "LANGUAGE")?; + body.language = Some(self.parse_identifier()?); + } else if self.parse_keyword(Keyword::IMMUTABLE) { + ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?; + body.behavior = Some(FunctionBehavior::Immutable); + } else if self.parse_keyword(Keyword::STABLE) { + ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?; + body.behavior = Some(FunctionBehavior::Stable); + } else if self.parse_keyword(Keyword::VOLATILE) { + ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?; + body.behavior = Some(FunctionBehavior::Volatile); + } else if self.parse_keyword(Keyword::RETURN) { + ensure_not_set(&body.return_, "RETURN")?; + body.return_ = Some(self.parse_expr()?); + } else { + return Ok(body); + } + } + } + pub fn parse_create_external_table( &mut self, or_replace: bool, diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index 070f55089..99a81eff2 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -16,8 +16,8 @@ //! is also tested (on the inputs it can handle). use sqlparser::ast::{ - CreateFunctionUsing, Expr, Function, Ident, ObjectName, SelectItem, Statement, TableFactor, - UnaryOperator, Value, + CreateFunctionBody, CreateFunctionUsing, Expr, Function, Ident, ObjectName, SelectItem, + Statement, TableFactor, UnaryOperator, Value, }; use sqlparser::dialect::{GenericDialect, HiveDialect}; use sqlparser::parser::ParserError; @@ -244,17 +244,20 @@ fn parse_create_function() { Statement::CreateFunction { temporary, name, - class_name, - using, + params, + .. } => { assert!(temporary); - assert_eq!("mydb.myfunc", name.to_string()); - assert_eq!("org.random.class.Name", class_name); + assert_eq!(name.to_string(), "mydb.myfunc"); assert_eq!( - using, - Some(CreateFunctionUsing::Jar( - "hdfs://somewhere.com:8020/very/far".to_string() - )) + params, + CreateFunctionBody { + as_: Some("org.random.class.Name".to_string()), + using: Some(CreateFunctionUsing::Jar( + "hdfs://somewhere.com:8020/very/far".to_string() + )), + ..Default::default() + } ) } _ => unreachable!(), diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index e32a7b027..980530530 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -2232,3 +2232,57 @@ fn parse_similar_to() { chk(false); chk(true); } + +#[test] +fn parse_create_function() { + let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE AS 'select $1 + $2;'"; + assert_eq!( + pg().verified_stmt(sql), + Statement::CreateFunction { + or_replace: false, + temporary: false, + name: ObjectName(vec![Ident::new("add")]), + args: Some(vec![ + CreateFunctionArg::unnamed(DataType::Integer(None)), + CreateFunctionArg::unnamed(DataType::Integer(None)), + ]), + return_type: Some(DataType::Integer(None)), + params: CreateFunctionBody { + language: Some("SQL".into()), + behavior: Some(FunctionBehavior::Immutable), + as_: Some("select $1 + $2;".into()), + ..Default::default() + }, + } + ); + + let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL IMMUTABLE RETURN a + b"; + assert_eq!( + pg().verified_stmt(sql), + Statement::CreateFunction { + or_replace: true, + temporary: false, + name: ObjectName(vec![Ident::new("add")]), + args: Some(vec![ + CreateFunctionArg::with_name("a", DataType::Integer(None)), + CreateFunctionArg { + mode: Some(ArgMode::In), + name: Some("b".into()), + data_type: DataType::Integer(None), + default_expr: Some(Expr::Value(Value::Number("1".parse().unwrap(), false))), + } + ]), + return_type: Some(DataType::Integer(None)), + params: CreateFunctionBody { + language: Some("SQL".into()), + behavior: Some(FunctionBehavior::Immutable), + return_: Some(Expr::BinaryOp { + left: Box::new(Expr::Identifier("a".into())), + op: BinaryOperator::Plus, + right: Box::new(Expr::Identifier("b".into())), + }), + ..Default::default() + }, + } + ); +}