Skip to content

Commit

Permalink
x/sqbuilder: Make the Marker upgradable to QuerySegment
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisMontagne committed Sep 20, 2023
1 parent eb42e45 commit 9770b35
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 9 deletions.
6 changes: 5 additions & 1 deletion x/sqlbuilder/predicate_clause.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,12 @@ func (svpcw *staticValuePredicateClauseWrapper) WriteTo(w QueryWriter, _ map[str
return svpcw.svpc.WriteTo(w)
}

type PredicateClause interface {
type QuerySegment interface {
WriteTo(QueryWriter, map[string]interface{}) error
}

type PredicateClause interface {
QuerySegment
Clone() PredicateClause
}

Expand Down
13 changes: 12 additions & 1 deletion x/sqlbuilder/select_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ func (ss SelectStatement) Clone() SelectStatement {
}
}

func writeSelectClause(c Marker, qw *queryWriter, vs map[string]interface{}) error {
if qs, ok := c.(QuerySegment); ok {
return qs.WriteTo(qw, vs)
}

qw.WriteString(c.ToSQL())
return nil
}

func (ss SelectStatement) buildQuery(vs map[string]interface{}) (string, []interface{}, []string, error) {
var (
qw queryWriter
Expand All @@ -55,7 +64,9 @@ func (ss SelectStatement) buildQuery(vs map[string]interface{}) (string, []inter
qw.WriteString("SELECT ")

for i, c := range ss.SelectClauses {
qw.WriteString(c.ToSQL())
if err := writeSelectClause(c, &qw, vs); err != nil {
return "", nil, nil, err
}

if i < len(ss.SelectClauses)-1 {
qw.WriteString(", ")
Expand Down
30 changes: 23 additions & 7 deletions x/sqlbuilder/update_statement.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package sqlbuilder

import "fmt"
import (
"fmt"
"io"
)

type UpdateStatement struct {
Table string
Expand All @@ -17,6 +20,22 @@ func (us UpdateStatement) Clone() UpdateStatement {
}
}

func writeUpdateClause(f Marker, qw *queryWriter, vs map[string]interface{}) error {
if qs, ok := f.(QuerySegment); ok {
return qs.WriteTo(qw, vs)
}

k := f.Binding()
v, ok := vs[k]

if !ok {
return ErrMissingKey{Key: k}
}

_, err := io.WriteString(qw, qw.RedeemVariable(v))
return err
}

func (us UpdateStatement) buildQuery(vs map[string]interface{}) (string, []interface{}, error) {
var qw queryWriter

Expand All @@ -27,15 +46,12 @@ func (us UpdateStatement) buildQuery(vs map[string]interface{}) (string, []inter
fmt.Fprintf(&qw, "UPDATE %s SET ", us.Table)

for i, f := range us.Fields {
k := f.Binding()
v, ok := vs[k]
fmt.Fprintf(&qw, "%s = ", columnName(f))

if !ok {
return "", nil, ErrMissingKey{Key: k}
if err := writeUpdateClause(f, &qw, vs); err != nil {
return "", nil, err
}

fmt.Fprintf(&qw, "%s = %s", columnName(f), qw.RedeemVariable(v))

if i < len(us.Fields)-1 {
qw.WriteString(", ")
}
Expand Down

0 comments on commit 9770b35

Please sign in to comment.