Skip to content

Commit

Permalink
Add support TVP identity (#771)
Browse files Browse the repository at this point in the history
* add support identity type

Co-authored-by: Gavrilov.Nikita2 <gavrilov.nikita2@wildberries.ru>
  • Loading branch information
NikitaDef and Gavrilov.Nikita2 committed Oct 11, 2022
1 parent c7ddec1 commit 0461d46
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 20 deletions.
36 changes: 28 additions & 8 deletions tvp_go19.go
@@ -1,3 +1,4 @@
//go:build go1.9
// +build go1.9

package mssql
Expand All @@ -16,6 +17,7 @@ import (
const (
jsonTag = "json"
tvpTag = "tvp"
tvpIdentity = "@identity"
skipTagValue = "-"
sqlSeparator = "."
)
Expand All @@ -29,7 +31,7 @@ var (
ErrorWrongTyping = errors.New("the number of elements in columnStr and tvpFieldIndexes do not align")
)

//TVP is driver type, which allows supporting Table Valued Parameters (TVP) in SQL Server
// TVP is driver type, which allows supporting Table Valued Parameters (TVP) in SQL Server
type TVP struct {
//TypeName mustn't be default value
TypeName string
Expand Down Expand Up @@ -76,8 +78,8 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
binary.Write(buf, binary.LittleEndian, uint16(len(columnStr)))

for i, column := range columnStr {
binary.Write(buf, binary.LittleEndian, uint32(column.UserType))
binary.Write(buf, binary.LittleEndian, uint16(column.Flags))
binary.Write(buf, binary.LittleEndian, column.UserType)
binary.Write(buf, binary.LittleEndian, column.Flags)
writeTypeInfo(buf, &columnStr[i].ti)
writeBVarChar(buf, "")
}
Expand All @@ -96,6 +98,9 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
refStr := reflect.ValueOf(val.Index(i).Interface())
buf.WriteByte(_TVP_ROW_TOKEN)
for columnStrIdx, fieldIdx := range tvpFieldIndexes {
if columnStr[columnStrIdx].Flags == fDefault {
continue
}
field := refStr.Field(fieldIdx)
tvpVal := field.Interface()
if tvp.verifyStandardTypeOnNull(buf, tvpVal) {
Expand Down Expand Up @@ -135,6 +140,11 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
}

func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
type fieldDetailStore struct {
defaultValue interface{}
isIdentity bool
}

val := reflect.ValueOf(tvp.Value)
var firstRow interface{}
if val.Len() != 0 {
Expand All @@ -145,7 +155,7 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {

tvpRow := reflect.TypeOf(firstRow)
columnCount := tvpRow.NumField()
defaultValues := make([]interface{}, 0, columnCount)
defaultValues := make([]fieldDetailStore, 0, columnCount)
tvpFieldIndexes := make([]int, 0, columnCount)
for i := 0; i < columnCount; i++ {
field := tvpRow.Field(i)
Expand All @@ -155,12 +165,19 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
continue
}
tvpFieldIndexes = append(tvpFieldIndexes, i)
isIdentity := tvpTagValue == tvpIdentity
if field.Type.Kind() == reflect.Ptr {
v := reflect.New(field.Type.Elem())
defaultValues = append(defaultValues, v.Interface())
defaultValues = append(defaultValues, fieldDetailStore{
defaultValue: v.Interface(),
isIdentity: isIdentity,
})
continue
}
defaultValues = append(defaultValues, tvp.createZeroType(reflect.Zero(field.Type).Interface()))
defaultValues = append(defaultValues, fieldDetailStore{
defaultValue: tvp.createZeroType(reflect.Zero(field.Type).Interface()),
isIdentity: isIdentity,
})
}

if columnCount-len(tvpFieldIndexes) == columnCount {
Expand All @@ -176,9 +193,9 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {

columnConfiguration := make([]columnStruct, 0, columnCount)
for index, val := range defaultValues {
cval, err := convertInputParameter(val)
cval, err := convertInputParameter(val.defaultValue)
if err != nil {
return nil, nil, fmt.Errorf("failed to convert tvp parameter row %d col %d: %s", index, val, err)
return nil, nil, fmt.Errorf("failed to convert tvp parameter row %d col %d: %s", index, val.defaultValue, err)
}
param, err := stmt.makeParam(cval)
if err != nil {
Expand All @@ -187,6 +204,9 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
column := columnStruct{
ti: param.ti,
}
if val.isIdentity {
column.Flags = fDefault
}
switch param.ti.TypeId {
case typeNVarChar, typeBigVarBin:
column.ti.Size = 0
Expand Down
126 changes: 126 additions & 0 deletions tvp_go19_db_test.go
Expand Up @@ -1348,3 +1348,129 @@ func TestTVPUnsigned(t *testing.T) {
t.Errorf("third result set had wrong value expected: %s actual: %s", "test", result3)
}
}

func TestTVPIdentity(t *testing.T) {
type TvpIdentityExample struct {
ID int `tvp:"@identity"`
Message string
}

const (
crateSchema = `create schema TestTVPSchemaIdentity;`

dropSchema = `drop schema TestTVPSchemaIdentity;`

createTVP = `
CREATE TYPE TestTVPSchemaIdentity.exempleTVP AS TABLE
(
id int identity(1,1) not null,
message NVARCHAR(100)
)`

dropTVP = `DROP TYPE TestTVPSchemaIdentity.exempleTVP;`

procedureWithTVP = `
CREATE PROCEDURE ExecIdentityTVP
@param1 TestTVPSchemaIdentity.exempleTVP READONLY
AS
BEGIN
SET NOCOUNT ON;
SELECT * FROM @param1;
END;
`

dropProcedure = `drop PROCEDURE ExecIdentityTVP`

execTvp = `exec ExecIdentityTVP @param1;`
)

checkConnStr(t)
tl := testLogger{t: t}
defer tl.StopLogging()
SetLogger(&tl)

p := makeConnStr(t).String()
conn, err := sql.Open("sqlserver", p)
if err != nil {
log.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()

_, err = conn.Exec(crateSchema)
if err != nil {
t.Fatal(err)
return
}
defer conn.Exec(dropSchema)

_, err = conn.Exec(createTVP)
if err != nil {
t.Fatal(err)
return
}
defer conn.Exec(dropTVP)

_, err = conn.Exec(procedureWithTVP)
if err != nil {
t.Fatal(err)
return
}
defer conn.Exec(dropProcedure)

exempleData := []TvpIdentityExample{
{
Message: "Hello",
},
{
Message: "World",
},
{
Message: "TVP",
},
}

tvpType := TVP{
TypeName: "TestTVPSchemaIdentity.exempleTVP",
Value: exempleData,
}

rows, err := conn.Query(execTvp,
sql.Named("param1", tvpType),
)
if err != nil {
t.Fatal(err)
}
defer rows.Close()

tvpResult := make([]TvpIdentityExample, 0)
for rows.Next() {
tvpExemple := TvpIdentityExample{}
err = rows.Scan(&tvpExemple.ID, &tvpExemple.Message)
if err != nil {
t.Fatal(err)
}
tvpResult = append(tvpResult, tvpExemple)
}

expectData := []TvpIdentityExample{
{
ID: 1,
Message: "Hello",
},
{
ID: 2,
Message: "World",
},
{
ID: 3,
Message: "TVP",
},
}

if len(expectData) != len(tvpResult) {
t.Fatal("TestTVPIdentity have to be len")
}
if !reflect.DeepEqual(expectData, tvpResult) {
t.Fatal("TestTVPIdentity have to be same")
}
}
32 changes: 20 additions & 12 deletions types.go
Expand Up @@ -79,6 +79,12 @@ const _PLP_TERMINATOR = 0x00000000
const _TVP_END_TOKEN = 0x00
const _TVP_ROW_TOKEN = 0x01

// TVP_COLMETADATA definition
// https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/0dfc5367-a388-4c92-9ba4-4d28e775acbc
const (
fDefault = 0x200
)

// TYPE_INFO rule
// http://msdn.microsoft.com/en-us/library/dd358284.aspx
type typeInfo struct {
Expand Down Expand Up @@ -1353,12 +1359,13 @@ func makeGoLangTypeName(ti typeInfo) string {
// not a variable length type ok should return false.
// If length is not limited other than system limits, it should return math.MaxInt64.
// The following are examples of returned values for various types:
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
//
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
func makeGoLangTypeLength(ti typeInfo) (int64, bool) {
switch ti.TypeId {
case typeInt1:
Expand Down Expand Up @@ -1476,12 +1483,13 @@ func makeGoLangTypeLength(ti typeInfo) (int64, bool) {
// not a variable length type ok should return false.
// If length is not limited other than system limits, it should return math.MaxInt64.
// The following are examples of returned values for various types:
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
//
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
func makeGoLangTypePrecisionScale(ti typeInfo) (int64, int64, bool) {
switch ti.TypeId {
case typeInt1:
Expand Down

0 comments on commit 0461d46

Please sign in to comment.