Skip to content

Commit

Permalink
fix panic on json.RawMessage parse
Browse files Browse the repository at this point in the history
  • Loading branch information
xiam committed Jul 4, 2022
1 parent bb6a386 commit d144469
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 152 deletions.
1 change: 1 addition & 0 deletions adapter/postgresql/database_pgx.go
@@ -1,3 +1,4 @@
//go:build !pq
// +build !pq

package postgresql
Expand Down
1 change: 1 addition & 0 deletions adapter/postgresql/helper_test.go
Expand Up @@ -194,6 +194,7 @@ func (h *Helper) TearUp() error {
, integer_array integer[]
, string_array text[]
, jsonb_map jsonb
, raw_jsonb_map jsonb
, integer_array_ptr integer[]
, string_array_ptr text[]
Expand Down
6 changes: 6 additions & 0 deletions adapter/postgresql/postgresql_test.go
Expand Up @@ -25,6 +25,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"math/rand"
"strings"
Expand Down Expand Up @@ -229,6 +230,8 @@ func testPostgreSQLTypes(t *testing.T, sess db.Session) {
StringArray StringArray `db:"string_array,stringarray"`
JSONBMap JSONBMap `db:"jsonb_map"`

RawJSONBMap json.RawMessage `db:"raw_jsonb_map"`

PGTypeInline `db:",inline"`

PGTypeAutoInline `db:",inline"`
Expand Down Expand Up @@ -329,6 +332,9 @@ func testPostgreSQLTypes(t *testing.T, sess db.Session) {
AutoJSONBMapInteger: map[string]interface{}{"a": 12.0, "b": 13.0},
},
},
PGType{
RawJSONBMap: json.RawMessage(`{"foo": "bar"}`),
},
PGType{
IntegerValue: integerValue,
StringValue: stringValue,
Expand Down
68 changes: 35 additions & 33 deletions internal/sqladapter/collection.go
Expand Up @@ -67,42 +67,43 @@ type condsFilter interface {

// collection is the implementation of Collection.
type collection struct {
name string
sess Session

name string
adapter CollectionAdapter
}

// NewCollection initializes a Collection by wrapping a CollectionAdapter.
func NewCollection(sess Session, name string, adapter CollectionAdapter) Collection {
type collectionWithSession struct {
*collection

session Session
}

func newCollection(name string, adapter CollectionAdapter) *collection {
if adapter == nil {
panic("upper: received nil adapter")
panic("upper: nil adapter")
}
c := &collection{
sess: sess,
return &collection{
name: name,
adapter: adapter,
}
return c
}

func (c *collection) SQL() db.SQL {
return c.sess.SQL()
func (c *collectionWithSession) SQL() db.SQL {
return c.session.SQL()
}

func (c *collection) Session() db.Session {
return c.sess
func (c *collectionWithSession) Session() db.Session {
return c.session
}

func (c *collection) Name() string {
func (c *collectionWithSession) Name() string {
return c.name
}

func (c *collection) Count() (uint64, error) {
func (c *collectionWithSession) Count() (uint64, error) {
return c.Find().Count()
}

func (c *collection) Insert(item interface{}) (db.InsertResult, error) {
func (c *collectionWithSession) Insert(item interface{}) (db.InsertResult, error) {
id, err := c.adapter.Insert(c, item)
if err != nil {
return nil, err
Expand All @@ -111,11 +112,11 @@ func (c *collection) Insert(item interface{}) (db.InsertResult, error) {
return db.NewInsertResult(id), nil
}

func (c *collection) PrimaryKeys() ([]string, error) {
return c.sess.PrimaryKeys(c.Name())
func (c *collectionWithSession) PrimaryKeys() ([]string, error) {
return c.session.PrimaryKeys(c.Name())
}

func (c *collection) filterConds(conds ...interface{}) ([]interface{}, error) {
func (c *collectionWithSession) filterConds(conds ...interface{}) ([]interface{}, error) {
pk, err := c.PrimaryKeys()
if err != nil {
return nil, err
Expand All @@ -131,15 +132,16 @@ func (c *collection) filterConds(conds ...interface{}) ([]interface{}, error) {
return conds, nil
}

func (c *collection) Find(conds ...interface{}) db.Result {
func (c *collectionWithSession) Find(conds ...interface{}) db.Result {
filteredConds, err := c.filterConds(conds...)
if err != nil {
res := &Result{}
res.setErr(err)
return res
}

res := NewResult(
c.sess.SQL(),
c.session.SQL(),
c.Name(),
filteredConds,
)
Expand All @@ -149,14 +151,14 @@ func (c *collection) Find(conds ...interface{}) db.Result {
return res
}

func (c *collection) Exists() (bool, error) {
if err := c.sess.TableExists(c.Name()); err != nil {
func (c *collectionWithSession) Exists() (bool, error) {
if err := c.session.TableExists(c.Name()); err != nil {
return false, err
}
return true, nil
}

func (c *collection) InsertReturning(item interface{}) error {
func (c *collectionWithSession) InsertReturning(item interface{}) error {
if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr {
return fmt.Errorf("Expecting a pointer but got %T", item)
}
Expand All @@ -175,12 +177,12 @@ func (c *collection) InsertReturning(item interface{}) error {
}

var tx Session
isTransaction := c.sess.IsTransaction()
isTransaction := c.session.IsTransaction()
if isTransaction {
tx = c.sess
tx = c.session
} else {
var err error
tx, err = c.sess.NewTransaction(c.sess.Context(), nil)
tx, err = c.session.NewTransaction(c.session.Context(), nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -261,7 +263,7 @@ cancel:
return err
}

func (c *collection) UpdateReturning(item interface{}) error {
func (c *collectionWithSession) UpdateReturning(item interface{}) error {
if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr {
return fmt.Errorf("Expecting a pointer but got %T", item)
}
Expand All @@ -280,14 +282,14 @@ func (c *collection) UpdateReturning(item interface{}) error {
}

var tx Session
isTransaction := c.sess.IsTransaction()
isTransaction := c.session.IsTransaction()

if isTransaction {
tx = c.sess
tx = c.session
} else {
// Not within a transaction, let's create one.
var err error
tx, err = c.sess.NewTransaction(c.sess.Context(), nil)
tx, err = c.session.NewTransaction(c.session.Context(), nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -355,12 +357,12 @@ cancel:
return err
}

func (c *collection) Truncate() error {
func (c *collectionWithSession) Truncate() error {
stmt := exql.Statement{
Type: exql.Truncate,
Table: exql.TableWithName(c.Name()),
}
if _, err := c.sess.SQL().Exec(&stmt); err != nil {
if _, err := c.session.SQL().Exec(&stmt); err != nil {
return err
}
return nil
Expand Down
15 changes: 8 additions & 7 deletions internal/sqladapter/result.go
Expand Up @@ -213,7 +213,7 @@ func (r *Result) Select(fields ...interface{}) db.Result {

// String satisfies fmt.Stringer
func (r *Result) String() string {
query, err := r.buildPaginator()
query, err := r.Paginator()
if err != nil {
panic(err.Error())
}
Expand All @@ -222,7 +222,7 @@ func (r *Result) String() string {

// All dumps all Results into a pointer to an slice of structs or maps.
func (r *Result) All(dst interface{}) error {
query, err := r.buildPaginator()
query, err := r.Paginator()
if err != nil {
r.setErr(err)
return err
Expand All @@ -235,11 +235,12 @@ func (r *Result) All(dst interface{}) error {
// One fetches only one Result from the set.
func (r *Result) One(dst interface{}) error {
one := r.Limit(1).(*Result)
query, err := one.buildPaginator()
query, err := one.Paginator()
if err != nil {
r.setErr(err)
return err
}

err = query.Iterator().One(dst)
r.setErr(err)
return err
Expand All @@ -251,7 +252,7 @@ func (r *Result) Next(dst interface{}) bool {
defer r.iterMu.Unlock()

if r.iter == nil {
query, err := r.buildPaginator()
query, err := r.Paginator()
if err != nil {
r.setErr(err)
return false
Expand Down Expand Up @@ -309,7 +310,7 @@ func (r *Result) Update(values interface{}) error {
}

func (r *Result) TotalPages() (uint, error) {
query, err := r.buildPaginator()
query, err := r.Paginator()
if err != nil {
r.setErr(err)
return 0, err
Expand All @@ -325,7 +326,7 @@ func (r *Result) TotalPages() (uint, error) {
}

func (r *Result) TotalEntries() (uint64, error) {
query, err := r.buildPaginator()
query, err := r.Paginator()
if err != nil {
r.setErr(err)
return 0, err
Expand Down Expand Up @@ -391,7 +392,7 @@ func (r *Result) Count() (uint64, error) {
return counter.Count, nil
}

func (r *Result) buildPaginator() (db.Paginator, error) {
func (r *Result) Paginator() (db.Paginator, error) {
if err := r.Err(); err != nil {
return nil, err
}
Expand Down

0 comments on commit d144469

Please sign in to comment.