diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go index f540fa04800..3f766d8740c 100644 --- a/spanner/spansql/parser.go +++ b/spanner/spansql/parser.go @@ -2913,6 +2913,13 @@ func (p *parser) parseLit() (Expr, *parseError) { // TODO: Check IsKeyWord(tok.value), and return a good error? } + // Handle conditional expressions. + switch { + case tok.caseEqual("CASE"): + p.back() + return p.parseCaseExpr() + } + // Handle typed literals. switch { case tok.caseEqual("ARRAY") || tok.value == "[": @@ -2950,6 +2957,72 @@ func (p *parser) parseLit() (Expr, *parseError) { return pe, nil } +func (p *parser) parseCaseExpr() (Case, *parseError) { + if err := p.expect("CASE"); err != nil { + return Case{}, err + } + + var expr Expr + if !p.sniff("WHEN") { + var err *parseError + expr, err = p.parseExpr() + if err != nil { + return Case{}, err + } + } + + when, err := p.parseWhenClause() + if err != nil { + return Case{}, err + } + whens := []WhenClause{when} + for p.sniff("WHEN") { + when, err := p.parseWhenClause() + if err != nil { + return Case{}, err + } + whens = append(whens, when) + } + + var elseResult Expr + if p.sniff("ELSE") { + p.eat("ELSE") + var err *parseError + elseResult, err = p.parseExpr() + if err != nil { + return Case{}, err + } + } + + if err := p.expect("END"); err != nil { + return Case{}, err + } + + return Case{ + Expr: expr, + WhenClauses: whens, + ElseResult: elseResult, + }, nil +} + +func (p *parser) parseWhenClause() (WhenClause, *parseError) { + if err := p.expect("WHEN"); err != nil { + return WhenClause{}, err + } + cond, err := p.parseExpr() + if err != nil { + return WhenClause{}, err + } + if err := p.expect("THEN"); err != nil { + return WhenClause{}, err + } + result, err := p.parseExpr() + if err != nil { + return WhenClause{}, err + } + return WhenClause{Cond: cond, Result: result}, nil +} + func (p *parser) parseArrayLit() (Array, *parseError) { // ARRAY keyword is optional. // TODO: If it is present, consume any after it. diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go index c9289c4b207..0c5a5644cf6 100644 --- a/spanner/spansql/parser_test.go +++ b/spanner/spansql/parser_test.go @@ -343,6 +343,26 @@ func TestParseExpr(t *testing.T) { {`EXTRACT(DATE FROM TIMESTAMP AT TIME ZONE "America/Los_Angeles")`, Func{Name: "EXTRACT", Args: []Expr{ExtractExpr{Part: "DATE", Type: Type{Base: Date}, Expr: AtTimeZoneExpr{Expr: ID("TIMESTAMP"), Zone: "America/Los_Angeles", Type: Type{Base: Timestamp}}}}}}, {`EXTRACT(DAY FROM DATE)`, Func{Name: "EXTRACT", Args: []Expr{ExtractExpr{Part: "DAY", Expr: ID("DATE"), Type: Type{Base: Int64}}}}}, + // Conditional expressions + {`CASE X WHEN 1 THEN "X" WHEN 2 THEN "Y" ELSE NULL END`, + Case{ + Expr: ID("X"), + WhenClauses: []WhenClause{ + {Cond: IntegerLiteral(1), Result: StringLiteral("X")}, + {Cond: IntegerLiteral(2), Result: StringLiteral("Y")}, + }, + ElseResult: Null, + }, + }, + {`CASE WHEN TRUE THEN "X" WHEN FALSE THEN "Y" END`, + Case{ + WhenClauses: []WhenClause{ + {Cond: True, Result: StringLiteral("X")}, + {Cond: False, Result: StringLiteral("Y")}, + }, + }, + }, + // String literal: // Accept double quote and single quote. {`"hello"`, StringLiteral("hello")}, diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go index ef877369e6f..6d64c35e33e 100644 --- a/spanner/spansql/sql.go +++ b/spanner/spansql/sql.go @@ -671,6 +671,21 @@ func (p Param) addSQL(sb *strings.Builder) { sb.WriteString(string(p)) } +func (c Case) SQL() string { return buildSQL(c) } +func (c Case) addSQL(sb *strings.Builder) { + sb.WriteString("CASE ") + if c.Expr != nil { + fmt.Fprintf(sb, "%s ", c.Expr.SQL()) + } + for _, w := range c.WhenClauses { + fmt.Fprintf(sb, "WHEN %s THEN %s ", w.Cond.SQL(), w.Result.SQL()) + } + if c.ElseResult != nil { + fmt.Fprintf(sb, "ELSE %s ", c.ElseResult.SQL()) + } + sb.WriteString("END") +} + func (b BoolLiteral) SQL() string { return buildSQL(b) } func (b BoolLiteral) addSQL(sb *strings.Builder) { if b { diff --git a/spanner/spansql/sql_test.go b/spanner/spansql/sql_test.go index b442506c719..8791cfec9bc 100644 --- a/spanner/spansql/sql_test.go +++ b/spanner/spansql/sql_test.go @@ -566,6 +566,38 @@ func TestSQL(t *testing.T) { "SELECT A, B FROM Table1 INNER JOIN Table2 ON Table1.A = Table2.A INNER JOIN Table3 USING (X)", reparseQuery, }, + { + Query{ + Select: Select{ + List: []Expr{ + Case{ + Expr: ID("X"), + WhenClauses: []WhenClause{ + {Cond: IntegerLiteral(1), Result: StringLiteral("X")}, + {Cond: IntegerLiteral(2), Result: StringLiteral("Y")}, + }, + ElseResult: Null, + }}, + }, + }, + `SELECT CASE X WHEN 1 THEN "X" WHEN 2 THEN "Y" ELSE NULL END`, + reparseQuery, + }, + { + Query{ + Select: Select{ + List: []Expr{ + Case{ + WhenClauses: []WhenClause{ + {Cond: True, Result: StringLiteral("X")}, + {Cond: False, Result: StringLiteral("Y")}, + }, + }}, + }, + }, + `SELECT CASE WHEN TRUE THEN "X" WHEN FALSE THEN "Y" END`, + reparseQuery, + }, } for _, test := range tests { sql := test.data.SQL() diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go index 9d099cc3505..4a9e8f7e431 100644 --- a/spanner/spansql/types.go +++ b/spanner/spansql/types.go @@ -692,6 +692,20 @@ func (Param) isBoolExpr() {} // possibly bool func (Param) isExpr() {} func (Param) isLiteralOrParam() {} +type Case struct { + Expr Expr + WhenClauses []WhenClause + ElseResult Expr +} + +func (Case) isBoolExpr() {} // possibly bool +func (Case) isExpr() {} + +type WhenClause struct { + Cond Expr + Result Expr +} + type BoolLiteral bool const (