Skip to content

Commit

Permalink
feat(mysql): Implement cast function parser (#2473)
Browse files Browse the repository at this point in the history
What is this

As the title said, this PR wants to add support for CAST function in MySQL.

This PR is based from PR by @ryanpbrewster here (which unfortunately he didn't send here, and only exist in his repository).
Why is this PR created

Currently sqlc unable to infer the correct type from SQL function like MAX, MIN, SUM, etc. For those function, sqlc will return its value as interface{}. This behavior can be seen in this playground.

As workaround, it advised to use CAST function to explicitly tell what is the type for that column, as mentioned in #1574.

Unfortunately, currently sqlc only support CAST function in PostgreSQL and not in MySQL. Thanks to this, right now MySQL users have to parse the interface{} manually, which is not really desirable.
What does this PR do?

    Implement convertFuncCast function for MySQL.
    Add better nil pointer check in some functions that related to convertFuncCast.

I haven't write any test because I'm not sure how and where to put it. However, as far as I know the code that handle ast.TypeCast for PostgreSQL also don't have any test, so I guess it's fine 🤷‍♂️
Related issues

Support CAST ... AS #687, which currently is the oldest MySQL issue that still opened.
Using MYSQL functions ( CONVERT and CAST) result in removing column from struct #1622
Unable to Type Alias #1866
sum in select result in model field type interface{} #1901
MIN() returns an interface{} #1965
  • Loading branch information
RadhiFadlillah committed Jul 30, 2023
1 parent d354c0e commit 39f16cc
Show file tree
Hide file tree
Showing 14 changed files with 197 additions and 9 deletions.
25 changes: 19 additions & 6 deletions internal/compiler/compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ type Relation struct {

func parseRelation(node ast.Node) (*Relation, error) {
switch n := node.(type) {

case *ast.Boolean:
return &Relation{
Name: "bool",
}, nil
if n == nil {
return nil, fmt.Errorf("unexpected nil in %T node", n)
}
return &Relation{Name: "bool"}, nil

case *ast.List:
if n == nil {
return nil, fmt.Errorf("unexpected nil in %T node", n)
}
parts := stringSlice(n)
switch len(parts) {
case 1:
Expand All @@ -61,6 +64,9 @@ func parseRelation(node ast.Node) (*Relation, error) {
}

case *ast.RangeVar:
if n == nil {
return nil, fmt.Errorf("unexpected nil in %T node", n)
}
name := Relation{}
if n.Catalogname != nil {
name.Catalog = *n.Catalogname
Expand All @@ -74,10 +80,17 @@ func parseRelation(node ast.Node) (*Relation, error) {
return &name, nil

case *ast.TypeName:
return parseRelation(n.Names)
if n == nil {
return nil, fmt.Errorf("unexpected nil in %T node", n)
}
if n.Names != nil {
return parseRelation(n.Names)
} else {
return &Relation{Name: n.Name}, nil
}

default:
return nil, fmt.Errorf("unexpected node type: %T", n)
return nil, fmt.Errorf("unexpected node type: %T", node)
}
}

Expand Down
2 changes: 1 addition & 1 deletion internal/compiler/to_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

func isArray(n *ast.TypeName) bool {
if n == nil {
if n == nil || n.ArrayBounds == nil {
return false
}
return len(n.ArrayBounds.Items) > 0
Expand Down
31 changes: 31 additions & 0 deletions internal/endtoend/testdata/func_call_cast/mysql/go/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions internal/endtoend/testdata/func_call_cast/mysql/go/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions internal/endtoend/testdata/func_call_cast/mysql/go/query.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions internal/endtoend/testdata/func_call_cast/mysql/query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- name: Demo :one
SELECT CAST(GREATEST(1,2,3,4,5) AS UNSIGNED) as col1
12 changes: 12 additions & 0 deletions internal/endtoend/testdata/func_call_cast/mysql/sqlc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"version": "1",
"packages": [
{
"path": "go",
"engine": "mysql",
"name": "querytest",
"schema": "query.sql",
"queries": "query.sql"
}
]
}
31 changes: 31 additions & 0 deletions internal/endtoend/testdata/select_column_cast/mysql/go/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions internal/endtoend/testdata/select_column_cast/mysql/go/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions internal/endtoend/testdata/select_column_cast/mysql/query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CREATE TABLE foo (bar BOOLEAN NOT NULL);

-- name: SelectColumnCast :many
SELECT CAST(bar AS UNSIGNED) FROM foo;
12 changes: 12 additions & 0 deletions internal/endtoend/testdata/select_column_cast/mysql/sqlc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"version": "1",
"packages": [
{
"path": "go",
"engine": "mysql",
"name": "querytest",
"schema": "query.sql",
"queries": "query.sql"
}
]
}
5 changes: 4 additions & 1 deletion internal/engine/dolphin/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,10 @@ func (c *cc) convertFrameClause(n *pcast.FrameClause) ast.Node {
}

func (c *cc) convertFuncCastExpr(n *pcast.FuncCastExpr) ast.Node {
return todo(n)
return &ast.TypeCast{
Arg: c.convert(n.Expr),
TypeName: &ast.TypeName{Name: types.TypeStr(n.Tp.GetType())},
}
}

func (c *cc) convertGetFormatSelectorExpr(n *pcast.GetFormatSelectorExpr) ast.Node {
Expand Down
6 changes: 5 additions & 1 deletion internal/sql/astutils/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ import (
)

func Join(list *ast.List, sep string) string {
items := []string{}
if list == nil {
return ""
}

var items []string
for _, item := range list.Items {
if n, ok := item.(*ast.String); ok {
items = append(items, n.Str)
Expand Down

0 comments on commit 39f16cc

Please sign in to comment.