diff --git a/src/ast/mod.rs b/src/ast/mod.rs index cf936d3c9..6003cb1d4 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1268,6 +1268,14 @@ pub enum Statement { /// deleted along with the dropped table purge: bool, }, + /// DROP Function + DropFunction { + if_exists: bool, + /// One or more function to drop + func_desc: Vec, + /// `CASCADE` or `RESTRICT` + option: Option, + }, /// DECLARE - Declaring Cursor Variables /// /// Note: this is a PostgreSQL-specific statement, @@ -1432,7 +1440,7 @@ pub enum Statement { or_replace: bool, temporary: bool, name: ObjectName, - args: Option>, + args: Option>, return_type: Option, /// Optional parameters. params: CreateFunctionBody, @@ -2284,6 +2292,22 @@ impl fmt::Display for Statement { if *restrict { " RESTRICT" } else { "" }, if *purge { " PURGE" } else { "" } ), + Statement::DropFunction { + if_exists, + func_desc, + option, + } => { + write!( + f, + "DROP FUNCTION{} {}", + if *if_exists { " IF EXISTS" } else { "" }, + display_comma_separated(func_desc), + )?; + if let Some(op) = option { + write!(f, " {}", op)?; + } + Ok(()) + } Statement::Discard { object_type } => { write!(f, "DISCARD {object_type}", object_type = object_type)?; Ok(()) @@ -3726,17 +3750,52 @@ impl fmt::Display for ContextModifier { } } -/// Function argument in CREATE FUNCTION. +/// Function describe in DROP FUNCTION. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum DropFunctionOption { + Restrict, + Cascade, +} + +impl fmt::Display for DropFunctionOption { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + DropFunctionOption::Restrict => write!(f, "RESTRICT "), + DropFunctionOption::Cascade => write!(f, "CASCADE "), + } + } +} + +/// Function describe in DROP FUNCTION. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct DropFunctionDesc { + pub name: ObjectName, + pub args: Option>, +} + +impl fmt::Display for DropFunctionDesc { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.name)?; + if let Some(args) = &self.args { + write!(f, "({})", display_comma_separated(args))?; + } + Ok(()) + } +} + +/// Function argument in CREATE OR DROP FUNCTION. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct CreateFunctionArg { +pub struct OperateFunctionArg { pub mode: Option, pub name: Option, pub data_type: DataType, pub default_expr: Option, } -impl CreateFunctionArg { +impl OperateFunctionArg { /// Returns an unnamed argument. pub fn unnamed(data_type: DataType) -> Self { Self { @@ -3758,7 +3817,7 @@ impl CreateFunctionArg { } } -impl fmt::Display for CreateFunctionArg { +impl fmt::Display for OperateFunctionArg { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if let Some(mode) = &self.mode { write!(f, "{} ", mode)?; diff --git a/src/parser.rs b/src/parser.rs index 528b8c6fe..21ab6d2c6 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -2343,7 +2343,13 @@ impl<'a> Parser<'a> { } else if dialect_of!(self is PostgreSqlDialect) { let name = self.parse_object_name()?; self.expect_token(&Token::LParen)?; - let args = self.parse_comma_separated(Parser::parse_create_function_arg)?; + let args = if self.consume_token(&Token::RParen) { + self.prev_token(); + None + } else { + Some(self.parse_comma_separated(Parser::parse_function_arg)?) + }; + self.expect_token(&Token::RParen)?; let return_type = if self.parse_keyword(Keyword::RETURNS) { @@ -2358,7 +2364,7 @@ impl<'a> Parser<'a> { or_replace, temporary, name, - args: Some(args), + args, return_type, params, }) @@ -2368,7 +2374,7 @@ impl<'a> Parser<'a> { } } - fn parse_create_function_arg(&mut self) -> Result { + fn parse_function_arg(&mut self) -> Result { let mode = if self.parse_keyword(Keyword::IN) { Some(ArgMode::In) } else if self.parse_keyword(Keyword::OUT) { @@ -2394,7 +2400,7 @@ impl<'a> Parser<'a> { } else { None }; - Ok(CreateFunctionArg { + Ok(OperateFunctionArg { mode, name, data_type, @@ -2767,9 +2773,11 @@ impl<'a> Parser<'a> { ObjectType::Schema } else if self.parse_keyword(Keyword::SEQUENCE) { ObjectType::Sequence + } else if self.parse_keyword(Keyword::FUNCTION) { + return self.parse_drop_function(); } else { return self.expected( - "TABLE, VIEW, INDEX, ROLE, SCHEMA, or SEQUENCE after DROP", + "TABLE, VIEW, INDEX, ROLE, SCHEMA, FUNCTION or SEQUENCE after DROP", self.peek_token(), ); }; @@ -2796,6 +2804,41 @@ impl<'a> Parser<'a> { }) } + /// DROP FUNCTION [ IF EXISTS ] name [ ( [ [ argmode ] [ argname ] argtype [, ...] ] ) ] [, ...] + /// [ CASCADE | RESTRICT ] + fn parse_drop_function(&mut self) -> Result { + let if_exists = self.parse_keywords(&[Keyword::IF, Keyword::EXISTS]); + let func_desc = self.parse_comma_separated(Parser::parse_drop_function_desc)?; + let option = match self.parse_one_of_keywords(&[Keyword::CASCADE, Keyword::RESTRICT]) { + Some(Keyword::CASCADE) => Some(ReferentialAction::Cascade), + Some(Keyword::RESTRICT) => Some(ReferentialAction::Restrict), + _ => None, + }; + Ok(Statement::DropFunction { + if_exists, + func_desc, + option, + }) + } + + fn parse_drop_function_desc(&mut self) -> Result { + let name = self.parse_object_name()?; + + let args = if self.consume_token(&Token::LParen) { + if self.consume_token(&Token::RParen) { + None + } else { + let args = self.parse_comma_separated(Parser::parse_function_arg)?; + self.expect_token(&Token::RParen)?; + Some(args) + } + } else { + None + }; + + Ok(DropFunctionDesc { name, args }) + } + /// DECLARE name [ BINARY ] [ ASENSITIVE | INSENSITIVE ] [ [ NO ] SCROLL ] // CURSOR [ { WITH | WITHOUT } HOLD ] FOR query pub fn parse_declare(&mut self) -> Result { diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 6b26fe48b..6e190a01b 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -2348,8 +2348,8 @@ fn parse_create_function() { temporary: false, name: ObjectName(vec![Ident::new("add")]), args: Some(vec![ - CreateFunctionArg::unnamed(DataType::Integer(None)), - CreateFunctionArg::unnamed(DataType::Integer(None)), + OperateFunctionArg::unnamed(DataType::Integer(None)), + OperateFunctionArg::unnamed(DataType::Integer(None)), ]), return_type: Some(DataType::Integer(None)), params: CreateFunctionBody { @@ -2371,8 +2371,8 @@ fn parse_create_function() { temporary: false, name: ObjectName(vec![Ident::new("add")]), args: Some(vec![ - CreateFunctionArg::with_name("a", DataType::Integer(None)), - CreateFunctionArg { + OperateFunctionArg::with_name("a", DataType::Integer(None)), + OperateFunctionArg { mode: Some(ArgMode::In), name: Some("b".into()), data_type: DataType::Integer(None), @@ -2400,7 +2400,7 @@ fn parse_create_function() { or_replace: true, temporary: false, name: ObjectName(vec![Ident::new("increment")]), - args: Some(vec![CreateFunctionArg::with_name( + args: Some(vec![OperateFunctionArg::with_name( "i", DataType::Integer(None) )]), @@ -2417,3 +2417,93 @@ fn parse_create_function() { } ); } + +#[test] +fn parse_drop_function() { + let sql = "DROP FUNCTION IF EXISTS test_func"; + assert_eq!( + pg().verified_stmt(sql), + Statement::DropFunction { + if_exists: true, + func_desc: vec![DropFunctionDesc { + name: ObjectName(vec![Ident { + value: "test_func".to_string(), + quote_style: None + }]), + args: None + }], + option: None + } + ); + + let sql = "DROP FUNCTION IF EXISTS test_func(a INTEGER, IN b INTEGER = 1)"; + assert_eq!( + pg().verified_stmt(sql), + Statement::DropFunction { + if_exists: true, + func_desc: vec![DropFunctionDesc { + name: ObjectName(vec![Ident { + value: "test_func".to_string(), + quote_style: None + }]), + args: Some(vec![ + OperateFunctionArg::with_name("a", DataType::Integer(None)), + OperateFunctionArg { + mode: Some(ArgMode::In), + name: Some("b".into()), + data_type: DataType::Integer(None), + default_expr: Some(Expr::Value(Value::Number("1".parse().unwrap(), false))), + } + ]), + }], + option: None + } + ); + + let sql = "DROP FUNCTION IF EXISTS test_func1(a INTEGER, IN b INTEGER = 1), test_func2(a VARCHAR, IN b INTEGER = 1)"; + assert_eq!( + pg().verified_stmt(sql), + Statement::DropFunction { + if_exists: true, + func_desc: vec![ + DropFunctionDesc { + name: ObjectName(vec![Ident { + value: "test_func1".to_string(), + quote_style: None + }]), + args: Some(vec![ + OperateFunctionArg::with_name("a", DataType::Integer(None)), + OperateFunctionArg { + mode: Some(ArgMode::In), + name: Some("b".into()), + data_type: DataType::Integer(None), + default_expr: Some(Expr::Value(Value::Number( + "1".parse().unwrap(), + false + ))), + } + ]), + }, + DropFunctionDesc { + name: ObjectName(vec![Ident { + value: "test_func2".to_string(), + quote_style: None + }]), + args: Some(vec![ + OperateFunctionArg::with_name("a", DataType::Varchar(None)), + OperateFunctionArg { + mode: Some(ArgMode::In), + name: Some("b".into()), + data_type: DataType::Integer(None), + default_expr: Some(Expr::Value(Value::Number( + "1".parse().unwrap(), + false + ))), + } + ]), + } + ], + option: None + } + ); +}