Skip to content

Commit

Permalink
stmt: add json.RawMessage for converter and prepared statement (#1059)
Browse files Browse the repository at this point in the history
Following #1058, in order for the driver.Value to get as a json.RawMessage,
the converter should accept it as a valid value, and handle it as bytes in
case where interpolation is disabled
  • Loading branch information
a8m committed Feb 18, 2020
1 parent 5a8a207 commit 3d8a029
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 4 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Expand Up @@ -17,6 +17,7 @@ Alex Snast <alexsn at fb.com>
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
Andrew Reid <andrew.reid at tixtrack.com>
Arne Hormann <arnehormann at gmail.com>
Ariel Mashraki <ariel at mashraki.co.il>
Asta Xie <xiemengjun at gmail.com>
Bulat Gaifullin <gaifullinbf at gmail.com>
Carlos Nieto <jose.carlos at menteslibres.net>
Expand Down
24 changes: 24 additions & 0 deletions driver_test.go
Expand Up @@ -14,6 +14,7 @@ import (
"crypto/tls"
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -559,6 +560,29 @@ func TestRawBytes(t *testing.T) {
})
}

func TestRawMessage(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
v1 := json.RawMessage("{}")
v2 := json.RawMessage("[]")
rows := dbt.mustQuery("SELECT ?, ?", v1, v2)
defer rows.Close()
if rows.Next() {
var o1, o2 json.RawMessage
if err := rows.Scan(&o1, &o2); err != nil {
dbt.Errorf("Got error: %v", err)
}
if !bytes.Equal(v1, o1) {
dbt.Errorf("expected %v, got %v", v1, o1)
}
if !bytes.Equal(v2, o2) {
dbt.Errorf("expected %v, got %v", v2, o2)
}
} else {
dbt.Errorf("no data")
}
})
}

type testValuer struct {
value string
}
Expand Down
4 changes: 4 additions & 0 deletions packets.go
Expand Up @@ -13,6 +13,7 @@ import (
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -1003,6 +1004,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
continue
}

if v, ok := arg.(json.RawMessage); ok {
arg = []byte(v)
}
// cache types and values
switch v := arg.(type) {
case int64:
Expand Down
13 changes: 9 additions & 4 deletions statement.go
Expand Up @@ -10,6 +10,7 @@ package mysql

import (
"database/sql/driver"
"encoding/json"
"fmt"
"io"
"reflect"
Expand Down Expand Up @@ -129,6 +130,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
return rows, err
}

var jsonType = reflect.TypeOf(json.RawMessage{})

type converter struct{}

// ConvertValue mirrors the reference/default converter in database/sql/driver
Expand All @@ -151,7 +154,6 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
}
return sv, nil
}

rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Ptr:
Expand All @@ -170,11 +172,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
case reflect.Bool:
return rv.Bool(), nil
case reflect.Slice:
ek := rv.Type().Elem().Kind()
if ek == reflect.Uint8 {
switch t := rv.Type(); {
case t == jsonType:
return v, nil
case t.Elem().Kind() == reflect.Uint8:
return rv.Bytes(), nil
default:
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
}
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
case reflect.String:
return rv.String(), nil
}
Expand Down
15 changes: 15 additions & 0 deletions statement_test.go
Expand Up @@ -10,6 +10,7 @@ package mysql

import (
"bytes"
"encoding/json"
"testing"
)

Expand Down Expand Up @@ -124,3 +125,17 @@ func TestConvertUnsignedIntegers(t *testing.T) {
t.Fatalf("uint64 high-bit converted, got %#v %T", output, output)
}
}

func TestConvertJSON(t *testing.T) {
raw := json.RawMessage("{}")

out, err := converter{}.ConvertValue(raw)

if err != nil {
t.Fatal("json.RawMessage was failed in convert", err)
}

if _, ok := out.(json.RawMessage); !ok {
t.Fatalf("json.RawMessage converted, got %#v %T", out, out)
}
}

0 comments on commit 3d8a029

Please sign in to comment.