From 490a653193fdc11437a4ed38aad0358ee4a0bf89 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Mon, 10 Feb 2020 17:44:22 +0200 Subject: [PATCH] stmt: add json.RawMessage for converter and prepared statement Following #1058, in order for the driver.Value to get as a json.RawMessage, the converter should accept it as a valid value, and handle it as bytes in case where interpolation is disabled --- packets.go | 20 ++++++++++++++------ statement.go | 13 +++++++++---- statement_test.go | 15 +++++++++++++++ 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/packets.go b/packets.go index 82ad7a200..d543fdb33 100644 --- a/packets.go +++ b/packets.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "database/sql/driver" "encoding/binary" + "encoding/json" "errors" "fmt" "io" @@ -1063,19 +1064,26 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramValues = append(paramValues, 0x00) } - case []byte: + case []byte, json.RawMessage: + var ( + ok bool + b []byte + ) + if b, ok = v.(json.RawMessage); !ok { + b = v.([]byte) + } // Common case (non-nil value) first - if v != nil { + if b != nil { paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < longDataSize { + if len(b) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, - uint64(len(v)), + uint64(len(b)), ) - paramValues = append(paramValues, v...) + paramValues = append(paramValues, b...) } else { - if err := stmt.writeCommandLongData(i, v); err != nil { + if err := stmt.writeCommandLongData(i, b); err != nil { return err } } diff --git a/statement.go b/statement.go index f7e370939..7c6dc1367 100644 --- a/statement.go +++ b/statement.go @@ -10,6 +10,7 @@ package mysql import ( "database/sql/driver" + "encoding/json" "fmt" "io" "reflect" @@ -129,6 +130,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { return rows, err } +var jsonType = reflect.TypeOf(json.RawMessage{}) + type converter struct{} // ConvertValue mirrors the reference/default converter in database/sql/driver @@ -151,7 +154,6 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { } return sv, nil } - rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: @@ -170,11 +172,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { case reflect.Bool: return rv.Bool(), nil case reflect.Slice: - ek := rv.Type().Elem().Kind() - if ek == reflect.Uint8 { + switch t := rv.Type(); { + case t == jsonType: + return v, nil + case t.Elem().Kind() == reflect.Uint8: return rv.Bytes(), nil + default: + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind()) } - return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) case reflect.String: return rv.String(), nil } diff --git a/statement_test.go b/statement_test.go index 4b9914f8e..913b2abd5 100644 --- a/statement_test.go +++ b/statement_test.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "encoding/json" "testing" ) @@ -124,3 +125,17 @@ func TestConvertUnsignedIntegers(t *testing.T) { t.Fatalf("uint64 high-bit converted, got %#v %T", output, output) } } + +func TestConvertJSON(t *testing.T) { + raw := json.RawMessage("{}") + + out, err := converter{}.ConvertValue(&raw) + + if err != nil { + t.Fatal("json.RawMessage was failed in covert", err) + } + + if _, ok := out.(json.RawMessage); !ok { + t.Fatalf("json.RawMessage converted, got %#v %T", out, out) + } +}