Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with unsighted types #713

Merged
merged 4 commits into from Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion tvp_go19.go
Expand Up @@ -105,7 +105,8 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
elemKind := field.Kind()
if elemKind == reflect.Ptr && valOf.IsNil() {
switch tvpVal.(type) {
case *bool, *time.Time, *int8, *int16, *int32, *int64, *float32, *float64, *int:
case *bool, *time.Time, *int8, *int16, *int32, *int64, *float32, *float64, *int,
*uint8, *uint16, *uint32, *uint64, *uint:
binary.Write(buf, binary.LittleEndian, uint8(0))
continue
default:
Expand Down
187 changes: 187 additions & 0 deletions tvp_go19_db_test.go
@@ -1,3 +1,4 @@
//go:build go1.9
// +build go1.9

package mssql
Expand Down Expand Up @@ -1161,3 +1162,189 @@ func TestTVPObject(t *testing.T) {
})
}
}

// fix pointer uint in tvp https://github.com/denisenkom/go-mssqldb/issues/703
func TestTVPUnsigned(t *testing.T) {
checkConnStr(t)
tl := testLogger{t: t}
defer tl.StopLogging()
SetLogger(&tl)

c := makeConnStr(t).String()
db, err := sql.Open("sqlserver", c)
if err != nil {
t.Fatalf("failed to open driver sqlserver")
}
defer db.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sqltextcreatetable := `
CREATE TYPE unsignedTvpTableTypes AS TABLE
(
p_tinyint TINYINT,
p_tinyintNull TINYINT,
p_smallint SMALLINT,
p_smallintNull SMALLINT,
p_int INT,
p_intNull INT,
p_bigint BIGINT,
p_bigintNull BIGINT,
pInt INT,
pIntNull INT
); `

sqltextdroptable := `DROP TYPE unsignedTvpTableTypes;`

sqltextcreatesp := `
CREATE PROCEDURE spwithtvpUnsigned
@param1 unsignedTvpTableTypes READONLY,
@param2 unsignedTvpTableTypes READONLY,
@param3 NVARCHAR(10)
AS
BEGIN
SET NOCOUNT ON;
SELECT * FROM @param1;
SELECT * FROM @param2;
SELECT @param3;
END;`

type TvptableRow struct {
PTinyint uint8 `db:"p_tinyint"`
PTinyintNull *uint8 `db:"p_tinyintNull"`
PSmallint uint16 `db:"p_smallint"`
PSmallintNull *uint16 `db:"p_smallintNull"`
PInt uint32 `db:"p_int"`
PIntNull *uint32 `db:"p_intNull"`
PBigint uint64 `db:"p_bigint"`
PBigintNull *uint64 `db:"p_bigintNull"`
Pint uint `db:"pInt"`
PintNull *uint `db:"pIntNull"`
}

sqltextdropsp := `DROP PROCEDURE spwithtvpUnsigned;`

_, err = db.ExecContext(ctx, sqltextcreatetable)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdroptable)

_, err = db.ExecContext(ctx, sqltextcreatesp)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdropsp)
i8 := uint8(1)
i16 := uint16(2)
i32 := uint32(3)
i64 := uint64(4)
i := uint(5)
param1 := []TvptableRow{
{
PTinyint: i8,
PSmallint: i16,
PInt: i32,
PBigint: i64,
Pint: 355,
},
{
PTinyint: 5,
PSmallint: 16000,
PInt: 20000000,
PBigint: 2000000020000000,
Pint: 455,
},
{
PTinyintNull: &i8,
PSmallintNull: &i16,
PIntNull: &i32,
PBigintNull: &i64,
PintNull: &i,
},
{
PTinyint: 5,
PSmallint: 16000,
PInt: 20000000,
PBigint: 2000000020000000,
PTinyintNull: &i8,
PSmallintNull: &i16,
PIntNull: &i32,
PBigintNull: &i64,
PintNull: &i,
},
}

tvpType := TVP{
TypeName: "unsignedTvpTableTypes",
Value: param1,
}
tvpTypeEmpty := TVP{
TypeName: "unsignedTvpTableTypes",
Value: []TvptableRow{},
}

rows, err := db.QueryContext(ctx,
"exec spwithtvpUnsigned @param1, @param2, @param3",
sql.Named("param1", tvpType),
sql.Named("param2", tvpTypeEmpty),
sql.Named("param3", "test"),
)
if err != nil {
t.Fatal(err)
}
defer rows.Close()

var result1 []TvptableRow
for rows.Next() {
var val TvptableRow
err := rows.Scan(
&val.PTinyint,
&val.PTinyintNull,
&val.PSmallint,
&val.PSmallintNull,
&val.PInt,
&val.PIntNull,
&val.PBigint,
&val.PBigintNull,
&val.Pint,
&val.PintNull,
)
if err != nil {
t.Fatalf("scan failed with error: %s", err)
}

result1 = append(result1, val)
}

if !reflect.DeepEqual(param1, result1) {
t.Logf("expected: %+v", param1)
t.Logf("actual: %+v", result1)
t.Errorf("first resultset did not match param1")
}

if !rows.NextResultSet() {
t.Errorf("second resultset did not exist")
}

if rows.Next() {
t.Errorf("second resultset was not empty")
}

if !rows.NextResultSet() {
t.Errorf("third resultset did not exist")
}

if !rows.Next() {
t.Errorf("third resultset was empty")
}

var result3 string
if err := rows.Scan(&result3); err != nil {
t.Errorf("error scanning third result set: %s", err)
}
if result3 != "test" {
t.Errorf("third result set had wrong value expected: %s actual: %s", "test", result3)
}
}