Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

- fix parsing macros in queries #74

Merged
merged 1 commit into from Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 85 additions & 11 deletions macros.go
Expand Up @@ -24,7 +24,8 @@ type Macros map[string]MacroFunc
// Default time filter for SQL based on the query time range.
// It requires one argument, the time column to filter.
// Example:
// $__timeFilter(time) => "time BETWEEN '2006-01-02T15:04:05Z07:00' AND '2006-01-02T15:04:05Z07:00'"
//
// $__timeFilter(time) => "time BETWEEN '2006-01-02T15:04:05Z07:00' AND '2006-01-02T15:04:05Z07:00'"
func macroTimeFilter(query *Query, args []string) (string, error) {
if len(args) != 1 {
return "", fmt.Errorf("%w: expected 1 argument, received %d", ErrorBadArgumentCount, len(args))
Expand All @@ -42,7 +43,8 @@ func macroTimeFilter(query *Query, args []string) (string, error) {
// Default time filter for SQL based on the starting query time range.
// It requires one argument, the time column to filter.
// Example:
// $__timeFrom(time) => "time > '2006-01-02T15:04:05Z07:00'"
//
// $__timeFrom(time) => "time > '2006-01-02T15:04:05Z07:00'"
func macroTimeFrom(query *Query, args []string) (string, error) {
if len(args) != 1 {
return "", fmt.Errorf("%w: expected 1 argument, received %d", ErrorBadArgumentCount, len(args))
Expand All @@ -55,7 +57,8 @@ func macroTimeFrom(query *Query, args []string) (string, error) {
// Default time filter for SQL based on the ending query time range.
// It requires one argument, the time column to filter.
// Example:
// $__timeTo(time) => "time < '2006-01-02T15:04:05Z07:00'"
//
// $__timeTo(time) => "time < '2006-01-02T15:04:05Z07:00'"
func macroTimeTo(query *Query, args []string) (string, error) {
if len(args) != 1 {
return "", fmt.Errorf("%w: expected 1 argument, received %d", ErrorBadArgumentCount, len(args))
Expand All @@ -68,7 +71,8 @@ func macroTimeTo(query *Query, args []string) (string, error) {
// This basic example is meant to be customized with more complex periods.
// It requires two arguments, the column to filter and the period.
// Example:
// $__timeTo(time, month) => "datepart(year, time), datepart(month, time)'"
//
// $__timeTo(time, month) => "datepart(year, time), datepart(month, time)'"
func macroTimeGroup(query *Query, args []string) (string, error) {
if len(args) != 2 {
return "", fmt.Errorf("%w: expected 1 argument, received %d", ErrorBadArgumentCount, len(args))
Expand Down Expand Up @@ -97,14 +101,16 @@ func macroTimeGroup(query *Query, args []string) (string, error) {

// Default macro to return the query table name.
// Example:
// $__table => "my_table"
//
// $__table => "my_table"
func macroTable(query *Query, args []string) (string, error) {
return query.Table, nil
}

// Default macro to return the query column name.
// Example:
// $__column => "my_col"
//
// $__column => "my_col"
func macroColumn(query *Query, args []string) (string, error) {
return query.Column, nil
}
Expand All @@ -127,8 +133,73 @@ func trimAll(s []string) []string {
return r
}

func getMacroRegex(name string) string {
return fmt.Sprintf("\\$__%s\\b(?:\\((.*?\\)?)\\))?", name)
var pair = map[rune]rune{')': '('}

// getMacroMatches extracts macro strings with their respective arguments from the sql input given
// It manually parses the string to find the closing parenthesis of the macro (because regex has no memory)
func getMacroMatches(input string, name string) ([][]string, error) {
macroName := fmt.Sprintf("\\$__%s\\b", name)
matchedMacros := [][]string{}
rgx, err := regexp.Compile(macroName)

if err != nil {
return nil, err
}

// get all matching macro instances
matched := rgx.FindAllStringIndex(input, -1)

if matched == nil {
return nil, nil
}

for matchedIndex := 0; matchedIndex < len(matched); matchedIndex++ {
var macroEnd = 0
var argStart = 0
macroStart := matched[matchedIndex][0]
inputCopy := input[macroStart:]
cache := make([]rune, 0)

// find the opening and closing arguments brackets
for idx, r := range inputCopy {
if len(cache) == 0 && macroEnd > 0 {
break
}
switch r {
case '(':
cache = append(cache, r)
if argStart == 0 {
argStart = idx + 1
}
case ')':
l := len(cache)
if l == 0 {
break
}
cache = cache[:l-1]
macroEnd = idx + 1
default:
continue
}
}

// macroEnd equals to 0 means there are no parentheses, so just set it
// to the end of the regex match
if macroEnd == 0 {
macroEnd = matched[matchedIndex][1] - macroStart
}
macroString := inputCopy[0:macroEnd]
macroMatch := []string{macroString}

args := ""
// if opening parenthesis was found, extract contents as arguments
if argStart > 0 {
args = inputCopy[argStart : macroEnd-1]
}
macroMatch = append(macroMatch, args)
matchedMacros = append(matchedMacros, macroMatch)
}
return matchedMacros, nil
}

// Interpolate returns an interpolated query string given a backend.DataQuery
Expand All @@ -141,8 +212,10 @@ func Interpolate(driver Driver, query *Query) (string, error) {
}
}
rawSQL := query.RawSQL

for key, macro := range macros {
matches, err := getMatches(key, rawSQL)

if err != nil {
return rawSQL, err
}
Expand All @@ -165,16 +238,17 @@ func Interpolate(driver Driver, query *Query) (string, error) {

rawSQL = strings.Replace(rawSQL, match[0], res, -1)
}

}

return rawSQL, nil
}

func getMatches(macroName, rawSQL string) ([][]string, error) {
rgx, err := regexp.Compile(getMacroRegex(macroName))
parsedInput, err := getMacroMatches(rawSQL, macroName)

if err != nil {
return nil, err
}
return rgx.FindAllStringSubmatch(rawSQL, -1), nil

return parsedInput, err
}
7 changes: 3 additions & 4 deletions macros_test.go
Expand Up @@ -56,11 +56,14 @@ func TestInterpolate(t *testing.T) {
{input: "select * from $__foo", output: "select * from bar", name: "macro without paranthesis"},
{input: "select * from $__params()", output: "select * from bar", name: "macro without params"},
{input: "select * from $__params(hello)", output: "select * from bar_hello", name: "with param"},
{input: "select * from $__params(h)", output: "select * from bar_h", name: "with short param"},
{input: "select * from $__params(hello) AND $__params(hello)", output: "select * from bar_hello AND bar_hello", name: "same macro multiple times with same param"},
{input: "(select * from $__params(hello) AND $__params(hello))", output: "(select * from bar_hello AND bar_hello)", name: "same macro multiple times with same param and additional parentheses"},
{input: "select * from $__params(hello) AND $__params(world)", output: "select * from bar_hello AND bar_world", name: "same macro multiple times with different param"},
{input: "select * from $__params(world) AND $__foo() AND $__params(hello)", output: "select * from bar_world AND bar AND bar_hello", name: "different macros with different params"},
{input: "select * from foo where $__timeFilter(time)", output: "select * from foo where time >= '0001-01-01T00:00:00Z' AND time <= '0001-01-01T00:00:00Z'", name: "default timeFilter"},
{input: "select * from foo where $__timeFilter(cast(sth as timestamp))", output: "select * from foo where cast(sth as timestamp) >= '0001-01-01T00:00:00Z' AND cast(sth as timestamp) <= '0001-01-01T00:00:00Z'", name: "default timeFilter"},
{input: "select * from foo where $__timeFilter(cast(sth as timestamp) )", output: "select * from foo where cast(sth as timestamp) >= '0001-01-01T00:00:00Z' AND cast(sth as timestamp) <= '0001-01-01T00:00:00Z'", name: "default timeFilter with empty spaces"},
{input: "select * from foo where $__timeTo(time)", output: "select * from foo where time <= '0001-01-01T00:00:00Z'", name: "default timeTo macro"},
{input: "select * from foo where $__timeFrom(time)", output: "select * from foo where time >= '0001-01-01T00:00:00Z'", name: "default timeFrom macro"},
{input: "select * from foo where $__timeFrom(cast(sth as timestamp))", output: "select * from foo where cast(sth as timestamp) >= '0001-01-01T00:00:00Z'", name: "default timeFrom macro"},
Expand All @@ -82,10 +85,6 @@ func TestInterpolate(t *testing.T) {
}
}

func TestGetMacroRegex_returns_composed_regular_expression(t *testing.T) {
assert.Equal(t, `\$__some_string\b(?:\((.*?\)?)\))?`, getMacroRegex("some_string"))
}

func TestGetMatches(t *testing.T) {
t.Run("FindAllStringSubmatch returns DefaultMacros", func(t *testing.T) {
for macroName := range DefaultMacros {
Expand Down