Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support postgres CREATE FUNCTION #722

Merged
merged 7 commits into from Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
150 changes: 143 additions & 7 deletions src/ast/mod.rs
Expand Up @@ -1405,11 +1405,15 @@ pub enum Statement {
/// CREATE FUNCTION
///
/// Hive: https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction
/// Postgres: https://www.postgresql.org/docs/15/sql-createfunction.html
CreateFunction {
or_replace: bool,
temporary: bool,
name: ObjectName,
class_name: String,
using: Option<CreateFunctionUsing>,
args: Option<Vec<CreateFunctionArg>>,
return_type: Option<DataType>,
/// Optional parameters.
params: CreateFunctionBody,
},
/// `ASSERT <condition> [AS <message>]`
Assert {
Expand Down Expand Up @@ -1866,19 +1870,26 @@ impl fmt::Display for Statement {
Ok(())
}
Statement::CreateFunction {
or_replace,
temporary,
name,
class_name,
using,
args,
return_type,
params,
} => {
write!(
f,
"CREATE {temp}FUNCTION {name} AS '{class_name}'",
"CREATE {or_replace}{temp}FUNCTION {name}",
temp = if *temporary { "TEMPORARY " } else { "" },
or_replace = if *or_replace { "OR REPLACE " } else { "" },
)?;
if let Some(u) = using {
write!(f, " {}", u)?;
if let Some(args) = args {
write!(f, "({})", display_comma_separated(args))?;
}
if let Some(return_type) = return_type {
write!(f, " RETURNS {}", return_type)?;
}
write!(f, "{params}")?;
Ok(())
}
Statement::CreateView {
Expand Down Expand Up @@ -3679,6 +3690,131 @@ impl fmt::Display for ContextModifier {
}
}

/// Function argument in CREATE FUNCTION.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct CreateFunctionArg {
pub mode: Option<ArgMode>,
pub name: Option<Ident>,
pub data_type: DataType,
pub default_expr: Option<Expr>,
}

impl CreateFunctionArg {
/// Returns an unnamed argument.
pub fn unnamed(data_type: DataType) -> Self {
Self {
mode: None,
name: None,
data_type,
default_expr: None,
}
}

/// Returns an argument with name.
pub fn with_name(name: &str, data_type: DataType) -> Self {
Self {
mode: None,
name: Some(name.into()),
data_type,
default_expr: None,
}
}
}

impl fmt::Display for CreateFunctionArg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(mode) = &self.mode {
write!(f, "{} ", mode)?;
}
if let Some(name) = &self.name {
write!(f, "{} ", name)?;
}
write!(f, "{}", self.data_type)?;
if let Some(default_expr) = &self.default_expr {
write!(f, " = {}", default_expr)?;
}
Ok(())
}
}

/// The mode of an argument in CREATE FUNCTION.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ArgMode {
In,
Out,
InOut,
}

impl fmt::Display for ArgMode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ArgMode::In => write!(f, "IN"),
ArgMode::Out => write!(f, "OUT"),
ArgMode::InOut => write!(f, "INOUT"),
}
}
}

/// These attributes inform the query optimizer about the behavior of the function.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum FunctionBehavior {
Immutable,
Stable,
Volatile,
}

impl fmt::Display for FunctionBehavior {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FunctionBehavior::Immutable => write!(f, "IMMUTABLE"),
FunctionBehavior::Stable => write!(f, "STABLE"),
FunctionBehavior::Volatile => write!(f, "VOLATILE"),
}
}
}

/// Postgres: https://www.postgresql.org/docs/15/sql-createfunction.html
#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct CreateFunctionBody {
/// LANGUAGE lang_name
pub language: Option<Ident>,
/// IMMUTABLE | STABLE | VOLATILE
pub behavior: Option<FunctionBehavior>,
/// AS 'definition'
///
/// Note that Hive's `AS class_name` is also parsed here.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

pub as_: Option<String>,
/// RETURN expression
pub return_: Option<Expr>,
/// USING ... (Hive only)
pub using: Option<CreateFunctionUsing>,
}

impl fmt::Display for CreateFunctionBody {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(language) = &self.language {
write!(f, " LANGUAGE {language}")?;
}
if let Some(behavior) = &self.behavior {
write!(f, " {behavior}")?;
}
if let Some(definition) = &self.as_ {
write!(f, " AS '{definition}'")?;
}
if let Some(expr) = &self.return_ {
write!(f, " RETURN {expr}")?;
}
if let Some(using) = &self.using {
write!(f, " {using}")?;
}
Ok(())
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum CreateFunctionUsing {
Expand Down
3 changes: 3 additions & 0 deletions src/keywords.rs
Expand Up @@ -284,6 +284,7 @@ define_keywords!(
IF,
IGNORE,
ILIKE,
IMMUTABLE,
IN,
INCREMENT,
INDEX,
Expand Down Expand Up @@ -518,6 +519,7 @@ define_keywords!(
SQLSTATE,
SQLWARNING,
SQRT,
STABLE,
START,
STATIC,
STATISTICS,
Expand Down Expand Up @@ -604,6 +606,7 @@ define_keywords!(
VERSIONING,
VIEW,
VIRTUAL,
VOLATILE,
WEEK,
WHEN,
WHENEVER,
Expand Down
130 changes: 118 additions & 12 deletions src/parser.rs
Expand Up @@ -2026,9 +2026,11 @@ impl<'a> Parser<'a> {
self.parse_create_view(or_replace)
} else if self.parse_keyword(Keyword::EXTERNAL) {
self.parse_create_external_table(or_replace)
} else if self.parse_keyword(Keyword::FUNCTION) {
self.parse_create_function(or_replace, temporary)
} else if or_replace {
self.expected(
"[EXTERNAL] TABLE or [MATERIALIZED] VIEW after CREATE OR REPLACE",
"[EXTERNAL] TABLE or [MATERIALIZED] VIEW or FUNCTION after CREATE OR REPLACE",
self.peek_token(),
)
} else if self.parse_keyword(Keyword::INDEX) {
Expand All @@ -2041,8 +2043,6 @@ impl<'a> Parser<'a> {
self.parse_create_schema()
} else if self.parse_keyword(Keyword::DATABASE) {
self.parse_create_database()
} else if dialect_of!(self is HiveDialect) && self.parse_keyword(Keyword::FUNCTION) {
self.parse_create_function(temporary)
} else if self.parse_keyword(Keyword::ROLE) {
self.parse_create_role()
} else if self.parse_keyword(Keyword::SEQUENCE) {
Expand Down Expand Up @@ -2253,20 +2253,126 @@ impl<'a> Parser<'a> {
}
}

pub fn parse_create_function(&mut self, temporary: bool) -> Result<Statement, ParserError> {
let name = self.parse_object_name()?;
self.expect_keyword(Keyword::AS)?;
let class_name = self.parse_literal_string()?;
let using = self.parse_optional_create_function_using()?;
pub fn parse_create_function(
&mut self,
or_replace: bool,
temporary: bool,
) -> Result<Statement, ParserError> {
if dialect_of!(self is HiveDialect) {
let name = self.parse_object_name()?;
self.expect_keyword(Keyword::AS)?;
let class_name = self.parse_literal_string()?;
let params = CreateFunctionBody {
as_: Some(class_name),
using: self.parse_optional_create_function_using()?,
..Default::default()
};

Ok(Statement::CreateFunction {
temporary,
Ok(Statement::CreateFunction {
or_replace,
temporary,
name,
args: None,
return_type: None,
params,
})
} else if dialect_of!(self is PostgreSqlDialect) {
let name = self.parse_object_name()?;
self.expect_token(&Token::LParen)?;
let args = self.parse_comma_separated(Parser::parse_create_function_arg)?;
self.expect_token(&Token::RParen)?;

let return_type = if self.parse_keyword(Keyword::RETURNS) {
Some(self.parse_data_type()?)
} else {
None
};

let params = self.parse_create_function_body()?;

Ok(Statement::CreateFunction {
or_replace,
temporary,
name,
args: Some(args),
return_type,
params,
})
} else {
self.prev_token();
self.expected("an object type after CREATE", self.peek_token())
}
}

fn parse_create_function_arg(&mut self) -> Result<CreateFunctionArg, ParserError> {
let mode = if self.parse_keyword(Keyword::IN) {
Some(ArgMode::In)
} else if self.parse_keyword(Keyword::OUT) {
Some(ArgMode::Out)
} else if self.parse_keyword(Keyword::INOUT) {
Some(ArgMode::InOut)
} else {
None
};

// parse: [ argname ] argtype
let mut name = None;
let mut data_type = self.parse_data_type()?;
if let DataType::Custom(n, _) = &data_type {
// the first token is actually a name
name = Some(n.0[0].clone());
data_type = self.parse_data_type()?;
}

let default_expr = if self.parse_keyword(Keyword::DEFAULT) || self.consume_token(&Token::Eq)
{
Some(self.parse_expr()?)
} else {
None
};
Ok(CreateFunctionArg {
mode,
name,
class_name,
using,
data_type,
default_expr,
})
}

fn parse_create_function_body(&mut self) -> Result<CreateFunctionBody, ParserError> {
let mut body = CreateFunctionBody::default();
loop {
fn ensure_not_set<T>(field: &Option<T>, name: &str) -> Result<(), ParserError> {
if field.is_some() {
return Err(ParserError::ParserError(format!(
"{name} specified more than once",
)));
}
Ok(())
}
if self.parse_keyword(Keyword::AS) {
ensure_not_set(&body.as_, "AS")?;
body.as_ = Some(self.parse_literal_string()?);
} else if self.parse_keyword(Keyword::LANGUAGE) {
ensure_not_set(&body.language, "LANGUAGE")?;
body.language = Some(self.parse_identifier()?);
} else if self.parse_keyword(Keyword::IMMUTABLE) {
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
body.behavior = Some(FunctionBehavior::Immutable);
} else if self.parse_keyword(Keyword::STABLE) {
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
body.behavior = Some(FunctionBehavior::Stable);
} else if self.parse_keyword(Keyword::VOLATILE) {
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
body.behavior = Some(FunctionBehavior::Volatile);
} else if self.parse_keyword(Keyword::RETURN) {
ensure_not_set(&body.return_, "RETURN")?;
body.return_ = Some(self.parse_expr()?);
} else {
return Ok(body);
}
}
}

pub fn parse_create_external_table(
&mut self,
or_replace: bool,
Expand Down