Skip to content

Commit

Permalink
- rewrite parsing macros to do everything manually
Browse files Browse the repository at this point in the history
  • Loading branch information
siemiatj committed Oct 18, 2022
1 parent d531473 commit 519a1ba
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 49 deletions.
105 changes: 60 additions & 45 deletions macros.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,49 +135,71 @@ func trimAll(s []string) []string {

var pair = map[rune]rune{')': '('}

// parseInput find the closing parenthesis of the macro string and adds an empty space,
// so that our macro regex test has an unambiguous result (because regex has no memory)
func parseInput(input string, name string) string {
cache := make([]rune, 0)

macroName := fmt.Sprintf("$__%s(", name)
var start = strings.Index(input, macroName)
var stringEnd = 0

// if can't find macro or macro doesn't contain parentheses, return
if start > -1 {
input = input[start:]
} else {
return input
// getMacroMatches extracts macro strings with their respective arguments from the sql input given
// It manually parses the string to find the closing parenthesis of the macro (because regex has no memory)
func getMacroMatches(input string, name string) ([][]string, error) {
macroName := fmt.Sprintf("\\$__%s\\b", name)
matchedMacros := [][]string{}
rgx, err := regexp.Compile(macroName)

if err != nil {
return nil, err
}

for idx, r := range input {
switch r {
case '(':
cache = append(cache, r)
case ')':
l := len(cache)
if l == 0 || pair[r] != cache[l-1] {
// get all matching macro instances
matched := rgx.FindAllStringIndex(input, -1)

if matched == nil {
return nil, nil
}

for matchedIndex := 0; matchedIndex < len(matched); matchedIndex++ {
var macroEnd = 0
var argStart = 0
macroStart := matched[matchedIndex][0]
inputCopy := input[macroStart:]
cache := make([]rune, 0)

// find the opening and closing arguments brackets
for idx, r := range inputCopy {
if len(cache) == 0 && macroEnd > 0 {
break
}
cache = cache[:l-1]
stringEnd = idx
default:
continue
switch r {
case '(':
cache = append(cache, r)
if argStart == 0 {
argStart = idx + 1
}
case ')':
l := len(cache)
if l == 0 {
break
}
cache = cache[:l-1]
macroEnd = idx + 1
default:
continue
}
}
}

stringTail := ""
if stringEnd+2 <= len(input) {
stringTail = input[stringEnd+2:]
}

return input[0:stringEnd+1] + " " + stringTail
}
// macroEnd equals to 0 means there are no parentheses, so just set it
// to the end of the regex match
if macroEnd == 0 {
macroEnd = matched[matchedIndex][1] - macroStart
}
macroString := inputCopy[0:macroEnd]
macroMatch := []string{macroString}

// getMacroRegex returns a regex for finding macro name and arguments passed to it
func getMacroRegex(name string) string {
return fmt.Sprintf("\\$__%s\\b(?:\\((.*?\\)?)\\))?", name)
args := ""
// if opening parenthesis was found, extract contents as arguments
if argStart > 0 {
args = inputCopy[argStart : macroEnd-1]
}
macroMatch = append(macroMatch, args)
matchedMacros = append(matchedMacros, macroMatch)
}
return matchedMacros, nil
}

// Interpolate returns an interpolated query string given a backend.DataQuery
Expand All @@ -191,12 +213,7 @@ func Interpolate(driver Driver, query *Query) (string, error) {
}
rawSQL := query.RawSQL

// cleanup SQL stripping any empty spaces between the closing brackets
re := regexp.MustCompile("(?:\\s+\\))")
rawSQL = re.ReplaceAllLiteralString(rawSQL, ")")

for key, macro := range macros {

matches, err := getMatches(key, rawSQL)

if err != nil {
Expand Down Expand Up @@ -228,13 +245,11 @@ func Interpolate(driver Driver, query *Query) (string, error) {

func getMatches(macroName, rawSQL string) ([][]string, error) {
sqlCopy := rawSQL
parsedInput := parseInput(sqlCopy, macroName)
macroRegex := getMacroRegex(macroName)
rgx, err := regexp.Compile(macroRegex)
parsedInput, err := getMacroMatches(sqlCopy, macroName)

if err != nil {
return nil, err
}

return rgx.FindAllStringSubmatch(parsedInput, -1), nil
return parsedInput, err
}
5 changes: 1 addition & 4 deletions macros_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ func TestInterpolate(t *testing.T) {
{input: "select * from $__foo", output: "select * from bar", name: "macro without paranthesis"},
{input: "select * from $__params()", output: "select * from bar", name: "macro without params"},
{input: "select * from $__params(hello)", output: "select * from bar_hello", name: "with param"},
{input: "select * from $__params(h)", output: "select * from bar_h", name: "with short param"},
{input: "select * from $__params(hello) AND $__params(hello)", output: "select * from bar_hello AND bar_hello", name: "same macro multiple times with same param"},
{input: "(select * from $__params(hello) AND $__params(hello))", output: "(select * from bar_hello AND bar_hello)", name: "same macro multiple times with same param and additional parentheses"},
{input: "select * from $__params(hello) AND $__params(world)", output: "select * from bar_hello AND bar_world", name: "same macro multiple times with different param"},
Expand Down Expand Up @@ -84,10 +85,6 @@ func TestInterpolate(t *testing.T) {
}
}

func TestGetMacroRegex_returns_composed_regular_expression(t *testing.T) {
assert.Equal(t, `\$__some_string\b(?:\((.*?\)?)\))?`, getMacroRegex("some_string"))
}

func TestGetMatches(t *testing.T) {
t.Run("FindAllStringSubmatch returns DefaultMacros", func(t *testing.T) {
for macroName := range DefaultMacros {
Expand Down

0 comments on commit 519a1ba

Please sign in to comment.