From 3d8a0293423afe714a98d549f0a8015b2d0930b7 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki <7413593+a8m@users.noreply.github.com> Date: Tue, 18 Feb 2020 17:16:20 +0200 Subject: [PATCH] stmt: add json.RawMessage for converter and prepared statement (#1059) Following #1058, in order for the driver.Value to get as a json.RawMessage, the converter should accept it as a valid value, and handle it as bytes in case where interpolation is disabled --- AUTHORS | 1 + driver_test.go | 24 ++++++++++++++++++++++++ packets.go | 4 ++++ statement.go | 13 +++++++++---- statement_test.go | 15 +++++++++++++++ 5 files changed, 53 insertions(+), 4 deletions(-) diff --git a/AUTHORS b/AUTHORS index 0896ba1bc..7dafbe2ec 100644 --- a/AUTHORS +++ b/AUTHORS @@ -17,6 +17,7 @@ Alex Snast Alexey Palazhchenko Andrew Reid Arne Hormann +Ariel Mashraki Asta Xie Bulat Gaifullin Carlos Nieto diff --git a/driver_test.go b/driver_test.go index ace083dfc..8edd17c47 100644 --- a/driver_test.go +++ b/driver_test.go @@ -14,6 +14,7 @@ import ( "crypto/tls" "database/sql" "database/sql/driver" + "encoding/json" "fmt" "io" "io/ioutil" @@ -559,6 +560,29 @@ func TestRawBytes(t *testing.T) { }) } +func TestRawMessage(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + v1 := json.RawMessage("{}") + v2 := json.RawMessage("[]") + rows := dbt.mustQuery("SELECT ?, ?", v1, v2) + defer rows.Close() + if rows.Next() { + var o1, o2 json.RawMessage + if err := rows.Scan(&o1, &o2); err != nil { + dbt.Errorf("Got error: %v", err) + } + if !bytes.Equal(v1, o1) { + dbt.Errorf("expected %v, got %v", v1, o1) + } + if !bytes.Equal(v2, o2) { + dbt.Errorf("expected %v, got %v", v2, o2) + } + } else { + dbt.Errorf("no data") + } + }) +} + type testValuer struct { value string } diff --git a/packets.go b/packets.go index 82ad7a200..575202ea3 100644 --- a/packets.go +++ b/packets.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "database/sql/driver" "encoding/binary" + "encoding/json" "errors" "fmt" "io" @@ -1003,6 +1004,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { continue } + if v, ok := arg.(json.RawMessage); ok { + arg = []byte(v) + } // cache types and values switch v := arg.(type) { case int64: diff --git a/statement.go b/statement.go index f7e370939..7c6dc1367 100644 --- a/statement.go +++ b/statement.go @@ -10,6 +10,7 @@ package mysql import ( "database/sql/driver" + "encoding/json" "fmt" "io" "reflect" @@ -129,6 +130,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { return rows, err } +var jsonType = reflect.TypeOf(json.RawMessage{}) + type converter struct{} // ConvertValue mirrors the reference/default converter in database/sql/driver @@ -151,7 +154,6 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { } return sv, nil } - rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: @@ -170,11 +172,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { case reflect.Bool: return rv.Bool(), nil case reflect.Slice: - ek := rv.Type().Elem().Kind() - if ek == reflect.Uint8 { + switch t := rv.Type(); { + case t == jsonType: + return v, nil + case t.Elem().Kind() == reflect.Uint8: return rv.Bytes(), nil + default: + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind()) } - return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) case reflect.String: return rv.String(), nil } diff --git a/statement_test.go b/statement_test.go index 4b9914f8e..2cc022bf5 100644 --- a/statement_test.go +++ b/statement_test.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "encoding/json" "testing" ) @@ -124,3 +125,17 @@ func TestConvertUnsignedIntegers(t *testing.T) { t.Fatalf("uint64 high-bit converted, got %#v %T", output, output) } } + +func TestConvertJSON(t *testing.T) { + raw := json.RawMessage("{}") + + out, err := converter{}.ConvertValue(raw) + + if err != nil { + t.Fatal("json.RawMessage was failed in convert", err) + } + + if _, ok := out.(json.RawMessage); !ok { + t.Fatalf("json.RawMessage converted, got %#v %T", out, out) + } +}