Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a mutable visitor #782

Merged
merged 3 commits into from Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Expand Up @@ -33,7 +33,7 @@ serde = { version = "1.0", features = ["derive"], optional = true }
# of dev-dependencies because of
# https://github.com/rust-lang/cargo/issues/1596
serde_json = { version = "1.0", optional = true }
sqlparser_derive = { version = "0.1", path = "derive", optional = true }
sqlparser_derive = { version = "0.1.1", path = "derive", optional = true }

[dev-dependencies]
simple_logger = "4.0"
Expand Down
2 changes: 1 addition & 1 deletion derive/Cargo.toml
@@ -1,7 +1,7 @@
[package]
name = "sqlparser_derive"
description = "proc macro for sqlparser"
version = "0.1.0"
version = "0.1.1"
authors = ["sqlparser-rs authors"]
homepage = "https://github.com/sqlparser-rs/sqlparser-rs"
documentation = "https://docs.rs/sqlparser_derive/"
Expand Down
6 changes: 3 additions & 3 deletions derive/README.md
Expand Up @@ -6,13 +6,13 @@ This crate contains a procedural macro that can automatically derive
implementations of the `Visit` trait in the [sqlparser](https://crates.io/crates/sqlparser) crate

```rust
#[derive(Visit)]
#[derive(Visit, VisitMut)]
struct Foo {
boolean: bool,
bar: Bar,
}

#[derive(Visit)]
#[derive(Visit, VisitMut)]
enum Bar {
A(),
B(String, bool),
Expand Down Expand Up @@ -51,7 +51,7 @@ impl Visit for Bar {
Additionally certain types may wish to call a corresponding method on visitor before recursing

```rust
#[derive(Visit)]
#[derive(Visit, VisitMut)]
#[visit(with = "visit_expr")]
enum Expr {
A(),
Expand Down
65 changes: 49 additions & 16 deletions derive/src/lib.rs
Expand Up @@ -6,25 +6,58 @@ use syn::{
Ident, Index, Lit, Meta, MetaNameValue, NestedMeta,
};


/// Implementation of `[#derive(Visit)]`
#[proc_macro_derive(VisitMut, attributes(visit))]
pub fn derive_visit_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
derive_visit(input, &VisitType {
visit_trait: quote!(VisitMut),
visitor_trait: quote!(VisitorMut),
modifier: Some(quote!(mut)),
})
}

/// Implementation of `[#derive(Visit)]`
#[proc_macro_derive(Visit, attributes(visit))]
pub fn derive_visit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
derive_visit(input, &VisitType {
visit_trait: quote!(Visit),
visitor_trait: quote!(Visitor),
modifier: None,
})
}

struct VisitType {
visit_trait: TokenStream,
visitor_trait: TokenStream,
modifier: Option<TokenStream>,
}

fn derive_visit(
input: proc_macro::TokenStream,
visit_type: &VisitType,
) -> proc_macro::TokenStream {
// Parse the input tokens into a syntax tree.
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident;

let VisitType { visit_trait, visitor_trait, modifier } = visit_type;

let attributes = Attributes::parse(&input.attrs);
// Add a bound `T: HeapSize` to every type parameter T.
let generics = add_trait_bounds(input.generics);
// Add a bound `T: Visit` to every type parameter T.
let generics = add_trait_bounds(input.generics, visit_type);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let (pre_visit, post_visit) = attributes.visit(quote!(self));
let children = visit_children(&input.data);
let children = visit_children(&input.data, visit_type);

let expanded = quote! {
// The generated impl.
impl #impl_generics sqlparser::ast::Visit for #name #ty_generics #where_clause {
fn visit<V: sqlparser::ast::Visitor>(&self, visitor: &mut V) -> ::std::ops::ControlFlow<V::Break> {
impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
fn visit<V: sqlparser::ast::#visitor_trait>(
&#modifier self,
visitor: &mut V
) -> ::std::ops::ControlFlow<V::Break> {
#pre_visit
#children
#post_visit
Expand Down Expand Up @@ -92,25 +125,25 @@ impl Attributes {
}

// Add a bound `T: Visit` to every type parameter T.
fn add_trait_bounds(mut generics: Generics) -> Generics {
fn add_trait_bounds(mut generics: Generics, VisitType{visit_trait, ..}: &VisitType) -> Generics {
for param in &mut generics.params {
if let GenericParam::Type(ref mut type_param) = *param {
type_param.bounds.push(parse_quote!(sqlparser::ast::Visit));
type_param.bounds.push(parse_quote!(sqlparser::ast::#visit_trait));
}
}
generics
}

// Generate the body of the visit implementation for the given type
fn visit_children(data: &Data) -> TokenStream {
fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType) -> TokenStream {
match data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => {
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#name, visitor)?; #post_visit)
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
Expand All @@ -121,7 +154,7 @@ fn visit_children(data: &Data) -> TokenStream {
let index = Index::from(i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#index, visitor)?; #post_visit)
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit)
});
quote! {
#(#recurse)*
Expand All @@ -140,8 +173,8 @@ fn visit_children(data: &Data) -> TokenStream {
let visit = fields.named.iter().map(|f| {
let name = &f.ident;
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
});

quote!(
Expand All @@ -155,8 +188,8 @@ fn visit_children(data: &Data) -> TokenStream {
let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
let name = format_ident!("_{}", i);
let attributes = Attributes::parse(&f.attrs);
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
});

quote! {
Expand Down
12 changes: 6 additions & 6 deletions src/ast/data_type.rs
Expand Up @@ -18,7 +18,7 @@ use core::fmt;
use serde::{Deserialize, Serialize};

#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};

use crate::ast::ObjectName;

Expand All @@ -27,7 +27,7 @@ use super::value::escape_single_quote_string;
/// SQL data types
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum DataType {
/// Fixed-length character type e.g. CHARACTER(10)
Character(Option<CharacterLength>),
Expand Down Expand Up @@ -341,7 +341,7 @@ fn format_datetime_precision_and_tz(
/// guarantee compatibility with the input query we must maintain its exact information.
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TimezoneInfo {
/// No information about time zone. E.g., TIMESTAMP
None,
Expand Down Expand Up @@ -389,7 +389,7 @@ impl fmt::Display for TimezoneInfo {
/// [standard]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#exact-numeric-type
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ExactNumberInfo {
/// No additional information e.g. `DECIMAL`
None,
Expand Down Expand Up @@ -420,7 +420,7 @@ impl fmt::Display for ExactNumberInfo {
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#character-length
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CharacterLength {
/// Default (if VARYING) or maximum (if not VARYING) length
pub length: u64,
Expand All @@ -443,7 +443,7 @@ impl fmt::Display for CharacterLength {
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#char-length-units
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum CharLengthUnits {
/// CHARACTERS unit
Characters,
Expand Down
22 changes: 11 additions & 11 deletions src/ast/ddl.rs
Expand Up @@ -21,7 +21,7 @@ use core::fmt;
use serde::{Deserialize, Serialize};

#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};

use crate::ast::value::escape_single_quote_string;
use crate::ast::{display_comma_separated, display_separated, DataType, Expr, Ident, ObjectName};
Expand All @@ -30,7 +30,7 @@ use crate::tokenizer::Token;
/// An `ALTER TABLE` (`Statement::AlterTable`) operation
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum AlterTableOperation {
/// `ADD <table_constraint>`
AddConstraint(TableConstraint),
Expand Down Expand Up @@ -100,7 +100,7 @@ pub enum AlterTableOperation {

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum AlterIndexOperation {
RenameIndex { index_name: ObjectName },
}
Expand Down Expand Up @@ -224,7 +224,7 @@ impl fmt::Display for AlterIndexOperation {
/// An `ALTER COLUMN` (`Statement::AlterTable`) operation
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum AlterColumnOperation {
/// `SET NOT NULL`
SetNotNull,
Expand Down Expand Up @@ -268,7 +268,7 @@ impl fmt::Display for AlterColumnOperation {
/// `ALTER TABLE ADD <constraint>` statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TableConstraint {
/// `[ CONSTRAINT <name> ] { PRIMARY KEY | UNIQUE } (<columns>)`
Unique {
Expand Down Expand Up @@ -433,7 +433,7 @@ impl fmt::Display for TableConstraint {
/// [1]: https://dev.mysql.com/doc/refman/8.0/en/create-table.html
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum KeyOrIndexDisplay {
/// Nothing to display
None,
Expand Down Expand Up @@ -469,7 +469,7 @@ impl fmt::Display for KeyOrIndexDisplay {
/// [3]: https://www.postgresql.org/docs/14/sql-createindex.html
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum IndexType {
BTree,
Hash,
Expand All @@ -488,7 +488,7 @@ impl fmt::Display for IndexType {
/// SQL column definition
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ColumnDef {
pub name: Ident,
pub data_type: DataType,
Expand Down Expand Up @@ -524,7 +524,7 @@ impl fmt::Display for ColumnDef {
/// "column options," and we allow any column option to be named.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ColumnOptionDef {
pub name: Option<Ident>,
pub option: ColumnOption,
Expand All @@ -540,7 +540,7 @@ impl fmt::Display for ColumnOptionDef {
/// TABLE` statement.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ColumnOption {
/// `NULL`
Null,
Expand Down Expand Up @@ -630,7 +630,7 @@ fn display_constraint_name(name: &'_ Option<Ident>) -> impl fmt::Display + '_ {
/// Used in foreign key constraints in `ON UPDATE` and `ON DELETE` options.
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ReferentialAction {
Restrict,
Cascade,
Expand Down
4 changes: 2 additions & 2 deletions src/ast/helpers/stmt_create_table.rs
Expand Up @@ -5,7 +5,7 @@ use alloc::{boxed::Box, format, string::String, vec, vec::Vec};
use serde::{Deserialize, Serialize};

#[cfg(feature = "visitor")]
use sqlparser_derive::Visit;
use sqlparser_derive::{Visit, VisitMut};

use crate::ast::{
ColumnDef, FileFormat, HiveDistributionStyle, HiveFormat, ObjectName, OnCommit, Query,
Expand Down Expand Up @@ -43,7 +43,7 @@ use crate::parser::ParserError;
/// [1]: crate::ast::Statement::CreateTable
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CreateTableBuilder {
pub or_replace: bool,
pub temporary: bool,
Expand Down