Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Federated authentication protocol support #625

Merged
merged 1 commit into from
Dec 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
/.idea
/.connstr
.vscode
.terraform
*.tfstate*
*.log
*.swp
*~
29 changes: 4 additions & 25 deletions accesstokenconnector.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,8 @@ import (
"context"
"database/sql/driver"
"errors"
"fmt"
)

var _ driver.Connector = &accessTokenConnector{}

// accessTokenConnector wraps Connector and injects a
// fresh access token when connecting to the database
type accessTokenConnector struct {
Connector

accessTokenProvider func() (string, error)
}

// NewAccessTokenConnector creates a new connector from a DSN and a token provider.
// The token provider func will be called when a new connection is requested and should return a valid access token.
// The returned connector may be used with sql.OpenDB.
Expand All @@ -32,20 +21,10 @@ func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) (
return nil, err
}

c := &accessTokenConnector{
Connector: *conn,
accessTokenProvider: tokenProvider,
}
return c, nil
}

// Connect returns a new database connection
func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) {
var err error
c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider()
if err != nil {
return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err)
conn.params.fedAuthLibrary = fedAuthLibrarySecurityToken
conn.securityTokenProvider = func(ctx context.Context) (string, error) {
return tokenProvider()
}

return c.Connector.Connect(ctx)
return conn, nil
}
12 changes: 6 additions & 6 deletions accesstokenconnector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@ func TestNewAccessTokenConnector(t *testing.T) {
dsn: dsn,
tokenProvider: tp},
want: func(c driver.Connector) error {
tc, ok := c.(*accessTokenConnector)
tc, ok := c.(*Connector)
if !ok {
return fmt.Errorf("Expected driver to be of type *accessTokenConnector, but got %T", c)
return fmt.Errorf("Expected driver to be of type *Connector, but got %T", c)
}
p := tc.Connector.params
p := tc.params
if p.database != "db" {
return fmt.Errorf("expected params.database=db, but got %v", p.database)
}
if p.host != "server.database.windows.net" {
return fmt.Errorf("expected params.host=server.database.windows.net, but got %v", p.host)
}
if tc.accessTokenProvider == nil {
return fmt.Errorf("Expected tokenProvider to not be nil")
if tc.securityTokenProvider == nil {
return fmt.Errorf("Expected federated authentication provider to not be nil")
}
t, err := tc.accessTokenProvider()
t, err := tc.securityTokenProvider(context.TODO())
if t != "token" || err != nil {
return fmt.Errorf("Unexpected results from tokenProvider: %v, %v", t, err)
}
Expand Down
11 changes: 7 additions & 4 deletions conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@ type connectParams struct {
failOverPartner string
failOverPort uint64
packetSize uint16
fedAuthAccessToken string
fedAuthLibrary int
fedAuthADALWorkflow byte
}

// default packet size for TDS buffer
const defaultPacketSize = 4096

func parseConnectParams(dsn string) (connectParams, error) {
var p connectParams
p := connectParams{
fedAuthLibrary: fedAuthLibraryReserved,
}

var params map[string]string
if strings.HasPrefix(dsn, "odbc:") {
Expand Down Expand Up @@ -247,8 +250,8 @@ func (p connectParams) toUrl() *url.URL {
}
res := url.URL{
Scheme: "sqlserver",
Host: p.host,
User: url.UserPassword(p.user, p.password),
Host: p.host,
User: url.UserPassword(p.user, p.password),
}
if p.instance != "" {
res.Path = p.instance
Expand Down
82 changes: 82 additions & 0 deletions fedauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package mssql

import (
"context"
"errors"
)

// Federated authentication library affects the login data structure and message sequence.
const (
wrosenuance marked this conversation as resolved.
Show resolved Hide resolved
// fedAuthLibraryLiveIDCompactToken specifies the Microsoft Live ID Compact Token authentication scheme
fedAuthLibraryLiveIDCompactToken = 0x00

// fedAuthLibrarySecurityToken specifies a token-based authentication where the token is available
// without additional information provided during the login sequence.
fedAuthLibrarySecurityToken = 0x01

// fedAuthLibraryADAL specifies a token-based authentication where a token is obtained during the
// login sequence using the server SPN and STS URL provided by the server during login.
fedAuthLibraryADAL = 0x02

// fedAuthLibraryReserved is used to indicate that no federated authentication scheme applies.
fedAuthLibraryReserved = 0x7F
)

// Federated authentication ADAL workflow affects the mechanism used to authenticate.
const (
// fedAuthADALWorkflowPassword uses a username/password to obtain a token from Active Directory
fedAuthADALWorkflowPassword = 0x01

// fedAuthADALWorkflowPassword uses the Windows identity to obtain a token from Active Directory
fedAuthADALWorkflowIntegrated = 0x02

// fedAuthADALWorkflowMSI uses the managed identity service to obtain a token
fedAuthADALWorkflowMSI = 0x03
)

// newSecurityTokenConnector creates a new connector from a DSN and a token provider.
// When invoked, token provider implementations should contact the security token
// service specified and obtain the appropriate token, or return an error
// to indicate why a token is not available.
// The returned connector may be used with sql.OpenDB.
func newSecurityTokenConnector(dsn string, tokenProvider func(ctx context.Context) (string, error)) (*Connector, error) {
if tokenProvider == nil {
return nil, errors.New("mssql: tokenProvider cannot be nil")
}

conn, err := NewConnector(dsn)
if err != nil {
return nil, err
}

conn.params.fedAuthLibrary = fedAuthLibrarySecurityToken
conn.securityTokenProvider = tokenProvider

return conn, nil
}

// newADALTokenConnector creates a new connector from a DSN and a Active Directory token provider.
// Token provider implementations are called during federated
// authentication login sequences where the server provides a service
// principal name and security token service endpoint that should be used
// to obtain the token. Implementations should contact the security token
// service specified and obtain the appropriate token, or return an error
// to indicate why a token is not available.
//
// The returned connector may be used with sql.OpenDB.
func newActiveDirectoryTokenConnector(dsn string, adalWorkflow byte, tokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)) (*Connector, error) {
if tokenProvider == nil {
return nil, errors.New("mssql: tokenProvider cannot be nil")
}

conn, err := NewConnector(dsn)
if err != nil {
return nil, err
}

conn.params.fedAuthLibrary = fedAuthLibraryADAL
conn.params.fedAuthADALWorkflow = adalWorkflow
conn.adalTokenProvider = tokenProvider

return conn, nil
}
17 changes: 12 additions & 5 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
if err != nil {
return nil, err
}

return &Connector{
params: params,
driver: d,
Expand Down Expand Up @@ -100,6 +101,12 @@ type Connector struct {
params connectParams
driver *Driver

// callback that can provide a security token during login
securityTokenProvider func(ctx context.Context) (string, error)

// callback that can provide a security token during ADAL login
adalTokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)

// SessionInitSQL is executed after marking a given session to be reset.
// When not present, the next query will still reset the session to the
// database defaults.
Expand Down Expand Up @@ -148,7 +155,7 @@ type Conn struct {
processQueryText bool
connectionGood bool

outs map[string]interface{}
outs map[string]interface{}
}

func (c *Conn) checkBadConn(err error) error {
Expand Down Expand Up @@ -653,9 +660,9 @@ func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
}

type Rows struct {
stmt *Stmt
cols []columnStruct
reader *tokenProcessor
stmt *Stmt
cols []columnStruct
reader *tokenProcessor
nextCols []columnStruct

cancel func()
Expand All @@ -669,7 +676,7 @@ func (rc *Rows) Close() error {
for {
tok, err := rc.reader.nextToken()
if err == nil {
if tok == nil {
if tok == nil {
return nil
} else {
// continue consuming tokens
Expand Down