Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support TVP identity #771

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 28 additions & 8 deletions tvp_go19.go
@@ -1,3 +1,4 @@
//go:build go1.9
// +build go1.9

package mssql
Expand All @@ -16,6 +17,7 @@ import (
const (
jsonTag = "json"
tvpTag = "tvp"
tvpIdentity = "@identity"
skipTagValue = "-"
sqlSeparator = "."
)
Expand All @@ -29,7 +31,7 @@ var (
ErrorWrongTyping = errors.New("the number of elements in columnStr and tvpFieldIndexes do not align")
)

//TVP is driver type, which allows supporting Table Valued Parameters (TVP) in SQL Server
// TVP is driver type, which allows supporting Table Valued Parameters (TVP) in SQL Server
type TVP struct {
//TypeName mustn't be default value
TypeName string
Expand Down Expand Up @@ -76,8 +78,8 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
binary.Write(buf, binary.LittleEndian, uint16(len(columnStr)))

for i, column := range columnStr {
binary.Write(buf, binary.LittleEndian, uint32(column.UserType))
binary.Write(buf, binary.LittleEndian, uint16(column.Flags))
binary.Write(buf, binary.LittleEndian, column.UserType)
binary.Write(buf, binary.LittleEndian, column.Flags)
writeTypeInfo(buf, &columnStr[i].ti)
writeBVarChar(buf, "")
}
Expand All @@ -96,6 +98,9 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
refStr := reflect.ValueOf(val.Index(i).Interface())
buf.WriteByte(_TVP_ROW_TOKEN)
for columnStrIdx, fieldIdx := range tvpFieldIndexes {
if columnStr[columnStrIdx].Flags == fDefault {
continue
}
field := refStr.Field(fieldIdx)
tvpVal := field.Interface()
if tvp.verifyStandardTypeOnNull(buf, tvpVal) {
Expand Down Expand Up @@ -135,6 +140,11 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
}

func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
type fieldDetailStore struct {
defaultValue interface{}
isIdentity bool
}

val := reflect.ValueOf(tvp.Value)
var firstRow interface{}
if val.Len() != 0 {
Expand All @@ -145,7 +155,7 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {

tvpRow := reflect.TypeOf(firstRow)
columnCount := tvpRow.NumField()
defaultValues := make([]interface{}, 0, columnCount)
defaultValues := make([]fieldDetailStore, 0, columnCount)
tvpFieldIndexes := make([]int, 0, columnCount)
for i := 0; i < columnCount; i++ {
field := tvpRow.Field(i)
Expand All @@ -155,12 +165,19 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
continue
}
tvpFieldIndexes = append(tvpFieldIndexes, i)
isIdentity := tvpTagValue == tvpIdentity
if field.Type.Kind() == reflect.Ptr {
v := reflect.New(field.Type.Elem())
defaultValues = append(defaultValues, v.Interface())
defaultValues = append(defaultValues, fieldDetailStore{
defaultValue: v.Interface(),
isIdentity: isIdentity,
})
continue
}
defaultValues = append(defaultValues, tvp.createZeroType(reflect.Zero(field.Type).Interface()))
defaultValues = append(defaultValues, fieldDetailStore{
defaultValue: tvp.createZeroType(reflect.Zero(field.Type).Interface()),
isIdentity: isIdentity,
})
}

if columnCount-len(tvpFieldIndexes) == columnCount {
Expand All @@ -176,9 +193,9 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {

columnConfiguration := make([]columnStruct, 0, columnCount)
for index, val := range defaultValues {
cval, err := convertInputParameter(val)
cval, err := convertInputParameter(val.defaultValue)
if err != nil {
return nil, nil, fmt.Errorf("failed to convert tvp parameter row %d col %d: %s", index, val, err)
return nil, nil, fmt.Errorf("failed to convert tvp parameter row %d col %d: %s", index, val.defaultValue, err)
}
param, err := stmt.makeParam(cval)
if err != nil {
Expand All @@ -187,6 +204,9 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
column := columnStruct{
ti: param.ti,
}
if val.isIdentity {
column.Flags = fDefault
}
switch param.ti.TypeId {
case typeNVarChar, typeBigVarBin:
column.ti.Size = 0
Expand Down
126 changes: 126 additions & 0 deletions tvp_go19_db_test.go
Expand Up @@ -1348,3 +1348,129 @@ func TestTVPUnsigned(t *testing.T) {
t.Errorf("third result set had wrong value expected: %s actual: %s", "test", result3)
}
}

func TestTVPIdentity(t *testing.T) {
type TvpIdentityExample struct {
ID int `tvp:"@identity"`
Message string
}

const (
crateSchema = `create schema TestTVPSchemaIdentity;`

dropSchema = `drop schema TestTVPSchemaIdentity;`

createTVP = `
CREATE TYPE TestTVPSchemaIdentity.exempleTVP AS TABLE
(
id int identity(1,1) not null,
message NVARCHAR(100)
)`

dropTVP = `DROP TYPE TestTVPSchemaIdentity.exempleTVP;`

procedureWithTVP = `
CREATE PROCEDURE ExecIdentityTVP
@param1 TestTVPSchemaIdentity.exempleTVP READONLY
AS
BEGIN
SET NOCOUNT ON;
SELECT * FROM @param1;
END;
`

dropProcedure = `drop PROCEDURE ExecIdentityTVP`

execTvp = `exec ExecIdentityTVP @param1;`
)

checkConnStr(t)
tl := testLogger{t: t}
defer tl.StopLogging()
SetLogger(&tl)

p := makeConnStr(t).String()
conn, err := sql.Open("sqlserver", p)
if err != nil {
log.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()

_, err = conn.Exec(crateSchema)
if err != nil {
t.Fatal(err)
return
}
defer conn.Exec(dropSchema)

_, err = conn.Exec(createTVP)
if err != nil {
t.Fatal(err)
return
}
defer conn.Exec(dropTVP)

_, err = conn.Exec(procedureWithTVP)
if err != nil {
t.Fatal(err)
return
}
defer conn.Exec(dropProcedure)

exempleData := []TvpIdentityExample{
{
Message: "Hello",
},
{
Message: "World",
},
{
Message: "TVP",
},
}

tvpType := TVP{
TypeName: "TestTVPSchemaIdentity.exempleTVP",
Value: exempleData,
}

rows, err := conn.Query(execTvp,
sql.Named("param1", tvpType),
)
if err != nil {
t.Fatal(err)
}
defer rows.Close()

tvpResult := make([]TvpIdentityExample, 0)
for rows.Next() {
tvpExemple := TvpIdentityExample{}
err = rows.Scan(&tvpExemple.ID, &tvpExemple.Message)
if err != nil {
t.Fatal(err)
}
tvpResult = append(tvpResult, tvpExemple)
}

expectData := []TvpIdentityExample{
{
ID: 1,
Message: "Hello",
},
{
ID: 2,
Message: "World",
},
{
ID: 3,
Message: "TVP",
},
}

if len(expectData) != len(tvpResult) {
t.Fatal("TestTVPIdentity have to be len")
}
if !reflect.DeepEqual(expectData, tvpResult) {
t.Fatal("TestTVPIdentity have to be same")
}
}
32 changes: 20 additions & 12 deletions types.go
Expand Up @@ -79,6 +79,12 @@ const _PLP_TERMINATOR = 0x00000000
const _TVP_END_TOKEN = 0x00
const _TVP_ROW_TOKEN = 0x01

// TVP_COLMETADATA definition
// https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/0dfc5367-a388-4c92-9ba4-4d28e775acbc
const (
fDefault = 0x200
)

// TYPE_INFO rule
// http://msdn.microsoft.com/en-us/library/dd358284.aspx
type typeInfo struct {
Expand Down Expand Up @@ -1353,12 +1359,13 @@ func makeGoLangTypeName(ti typeInfo) string {
// not a variable length type ok should return false.
// If length is not limited other than system limits, it should return math.MaxInt64.
// The following are examples of returned values for various types:
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
//
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
func makeGoLangTypeLength(ti typeInfo) (int64, bool) {
switch ti.TypeId {
case typeInt1:
Expand Down Expand Up @@ -1476,12 +1483,13 @@ func makeGoLangTypeLength(ti typeInfo) (int64, bool) {
// not a variable length type ok should return false.
// If length is not limited other than system limits, it should return math.MaxInt64.
// The following are examples of returned values for various types:
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
//
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
func makeGoLangTypePrecisionScale(ti typeInfo) (int64, int64, bool) {
switch ti.TypeId {
case typeInt1:
Expand Down