diff --git a/conn_go19.go b/conn_go19.go new file mode 100644 index 00000000..00360156 --- /dev/null +++ b/conn_go19.go @@ -0,0 +1,35 @@ +//go:build go1.9 +// +build go1.9 + +package pq + +import ( + "database/sql/driver" + "reflect" +) + +var _ driver.NamedValueChecker = (*conn)(nil) + +func (c *conn) CheckNamedValue(nv *driver.NamedValue) error { + if _, ok := nv.Value.(driver.Valuer); ok { + // Ignore Valuer, for backward compatibility with pq.Array(). + return driver.ErrSkip + } + + // Ignoring []byte / []uint8. + if _, ok := nv.Value.([]uint8); ok { + return driver.ErrSkip + } + + v := reflect.ValueOf(nv.Value) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() == reflect.Slice { + var err error + nv.Value, err = Array(v.Interface()).Value() + return err + } + + return driver.ErrSkip +} diff --git a/conn_go19_test.go b/conn_go19_test.go new file mode 100644 index 00000000..43c98263 --- /dev/null +++ b/conn_go19_test.go @@ -0,0 +1,83 @@ +//go:build go1.9 +// +build go1.9 + +package pq + +import ( + "fmt" + "reflect" + "testing" +) + +func TestArrayArg(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + for _, tc := range []struct { + pgType string + in, out interface{} + }{ + { + pgType: "int[]", + in: []int{245, 231}, + out: []int64{245, 231}, + }, + { + pgType: "int[]", + in: &[]int{245, 231}, + out: []int64{245, 231}, + }, + { + pgType: "int[]", + in: []int64{245, 231}, + }, + { + pgType: "int[]", + in: &[]int64{245, 231}, + out: []int64{245, 231}, + }, + { + pgType: "varchar[]", + in: []string{"hello", "world"}, + }, + { + pgType: "varchar[]", + in: &[]string{"hello", "world"}, + out: []string{"hello", "world"}, + }, + } { + if tc.out == nil { + tc.out = tc.in + } + t.Run(fmt.Sprintf("%#v", tc.in), func(t *testing.T) { + r, err := db.Query(fmt.Sprintf("SELECT $1::%s", tc.pgType), tc.in) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(r.Err()) + } + t.Fatal("expected row") + } + + defer func() { + if r.Next() { + t.Fatal("unexpected row") + } + }() + + got := reflect.New(reflect.TypeOf(tc.out)) + if err := r.Scan(Array(got.Interface())); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(tc.out, got.Elem().Interface()) { + t.Errorf("got %v, want %v", got, tc.out) + } + }) + } + +}