From a9d628050af4752dcf38a15c5625fc0d4db2211b Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Mon, 10 Feb 2020 17:44:22 +0200 Subject: [PATCH] stmt: add json.RawMessage for converter and prepared statement Following #1058, in order for the driver.Value to get as a json.RawMessage, the converter should accept it as a valid value, and handle it as bytes in case where interpolation is disabled --- AUTHORS | 1 + packets.go | 4 ++++ statement.go | 13 +++++++++---- statement_test.go | 15 +++++++++++++++ 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/AUTHORS b/AUTHORS index 0896ba1bc..7dafbe2ec 100644 --- a/AUTHORS +++ b/AUTHORS @@ -17,6 +17,7 @@ Alex Snast Alexey Palazhchenko Andrew Reid Arne Hormann +Ariel Mashraki Asta Xie Bulat Gaifullin Carlos Nieto diff --git a/packets.go b/packets.go index 82ad7a200..575202ea3 100644 --- a/packets.go +++ b/packets.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "database/sql/driver" "encoding/binary" + "encoding/json" "errors" "fmt" "io" @@ -1003,6 +1004,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { continue } + if v, ok := arg.(json.RawMessage); ok { + arg = []byte(v) + } // cache types and values switch v := arg.(type) { case int64: diff --git a/statement.go b/statement.go index f7e370939..7c6dc1367 100644 --- a/statement.go +++ b/statement.go @@ -10,6 +10,7 @@ package mysql import ( "database/sql/driver" + "encoding/json" "fmt" "io" "reflect" @@ -129,6 +130,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { return rows, err } +var jsonType = reflect.TypeOf(json.RawMessage{}) + type converter struct{} // ConvertValue mirrors the reference/default converter in database/sql/driver @@ -151,7 +154,6 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { } return sv, nil } - rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: @@ -170,11 +172,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { case reflect.Bool: return rv.Bool(), nil case reflect.Slice: - ek := rv.Type().Elem().Kind() - if ek == reflect.Uint8 { + switch t := rv.Type(); { + case t == jsonType: + return v, nil + case t.Elem().Kind() == reflect.Uint8: return rv.Bytes(), nil + default: + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind()) } - return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) case reflect.String: return rv.String(), nil } diff --git a/statement_test.go b/statement_test.go index 4b9914f8e..2f2eb348a 100644 --- a/statement_test.go +++ b/statement_test.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "encoding/json" "testing" ) @@ -124,3 +125,17 @@ func TestConvertUnsignedIntegers(t *testing.T) { t.Fatalf("uint64 high-bit converted, got %#v %T", output, output) } } + +func TestConvertJSON(t *testing.T) { + raw := json.RawMessage("{}") + + out, err := converter{}.ConvertValue(&raw) + + if err != nil { + t.Fatal("json.RawMessage was failed in convert", err) + } + + if _, ok := out.(json.RawMessage); !ok { + t.Fatalf("json.RawMessage converted, got %#v %T", out, out) + } +}