Skip to content

Commit

Permalink
Conn.LoadType supports range and multirange types (jackc#1393)
Browse files Browse the repository at this point in the history
Closes jackc#1393
  • Loading branch information
mcdoker18 committed Nov 28, 2022
1 parent 8eb062f commit a063f79
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
46 changes: 46 additions & 0 deletions conn.go
Expand Up @@ -1196,6 +1196,30 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil
case "e": // enum
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
case "r": // range
elementOID, err := c.getRangeElementOID(ctx, oid)
if err != nil {
return nil, err
}

dt, ok := c.TypeMap().TypeForOID(elementOID)
if !ok {
return nil, errors.New("range element OID not registered")
}

return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}}, nil
case "m": // multirange
elementOID, err := c.getMultiRangeElementOID(ctx, oid)
if err != nil {
return nil, err
}

dt, ok := c.TypeMap().TypeForOID(elementOID)
if !ok {
return nil, errors.New("multirange element OID not registered")
}

return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}, nil
default:
return &pgtype.Type{}, errors.New("unknown typtype")
}
Expand All @@ -1212,6 +1236,28 @@ func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, erro
return typelem, nil
}

func (c *Conn) getRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
var typelem uint32

err := c.QueryRow(ctx, "select rngsubtype from pg_range where rngtypid=$1", oid).Scan(&typelem)
if err != nil {
return 0, err
}

return typelem, nil
}

func (c *Conn) getMultiRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
var typelem uint32

err := c.QueryRow(ctx, "select rngtypid from pg_range where rngmultitypid=$1", oid).Scan(&typelem)
if err != nil {
return 0, err
}

return typelem, nil
}

func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) {
var typrelid uint32

Expand Down
59 changes: 59 additions & 0 deletions conn_test.go
Expand Up @@ -903,6 +903,65 @@ create type pgx_b.point as (c text);
})
}

func TestLoadRangeType(t *testing.T) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support range types")

tx, err := conn.Begin(ctx)
require.NoError(t, err)
defer tx.Rollback(ctx)

_, err = tx.Exec(ctx, "create type examplefloatrange as range (subtype=float8, subtype_diff=float8mi, multirange_type_name=examplefloatmultirange)")
require.NoError(t, err)

// Register types
newRangeType, err := conn.LoadType(ctx, "examplefloatrange")
require.NoError(t, err)
conn.TypeMap().RegisterType(newRangeType)
conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange")

newMultiRangeType, err := conn.LoadType(ctx, "examplefloatmultirange")
require.NoError(t, err)
conn.TypeMap().RegisterType(newMultiRangeType)
conn.TypeMap().RegisterDefaultPgType(pgtype.Multirange[pgtype.Range[float64]]{}, "examplefloatmultirange")

// Test range type
var inputRangeType = pgtype.Range[float64]{
Lower: 1.0,
Upper: 2.0,
LowerType: pgtype.Inclusive,
UpperType: pgtype.Inclusive,
Valid: true,
}
var outputRangeType pgtype.Range[float64]
err = tx.QueryRow(ctx, "SELECT $1::examplefloatrange", inputRangeType).Scan(&outputRangeType)
require.NoError(t, err)
require.Equal(t, inputRangeType, outputRangeType)

// Test multi range type
var inputMultiRangeType = pgtype.Multirange[pgtype.Range[float64]]{
{
Lower: 1.0,
Upper: 2.0,
LowerType: pgtype.Inclusive,
UpperType: pgtype.Inclusive,
Valid: true,
},
{
Lower: 3.0,
Upper: 4.0,
LowerType: pgtype.Exclusive,
UpperType: pgtype.Exclusive,
Valid: true,
},
}
var outputMultiRangeType pgtype.Multirange[pgtype.Range[float64]]
err = tx.QueryRow(ctx, "SELECT $1::examplefloatmultirange", inputMultiRangeType).Scan(&outputMultiRangeType)
require.NoError(t, err)
require.Equal(t, inputMultiRangeType, outputMultiRangeType)
})
}

func TestStmtCacheInvalidationConn(t *testing.T) {
ctx := context.Background()

Expand Down

0 comments on commit a063f79

Please sign in to comment.