Skip to content

Commit

Permalink
Adding mssql secret backend
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishoffman committed Mar 3, 2016
1 parent f88c6c1 commit ed5ca17
Show file tree
Hide file tree
Showing 10 changed files with 1,381 additions and 0 deletions.
138 changes: 138 additions & 0 deletions builtin/logical/mssql/backend.go
@@ -0,0 +1,138 @@
package mssql

import (
"database/sql"
"fmt"
"strings"
"sync"

_ "github.com/denisenkom/go-mssqldb"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)

func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
return Backend().Setup(conf)
}

func Backend() *framework.Backend {
var b backend
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),

PathsSpecial: &logical.Paths{
Root: []string{
"config/*",
},
},

Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),
pathListRoles(&b),
pathRoles(&b),
pathCredsCreate(&b),
},

Secrets: []*framework.Secret{
secretCreds(&b),
},
}

return b.Backend
}

type backend struct {
*framework.Backend

db *sql.DB
defaultDb string
lock sync.Mutex
}

// DB returns the default database connection.
func (b *backend) DB(s logical.Storage) (*sql.DB, error) {
b.lock.Lock()
defer b.lock.Unlock()

// If we already have a DB, we got it!
if b.db != nil {
return b.db, nil
}

// Otherwise, attempt to make connection
entry, err := s.Get("config/connection")
if err != nil {
return nil, err
}
if entry == nil {
return nil, fmt.Errorf("configure the DB connection with config/connection first")
}

var connConfig connectionConfig
if err := entry.DecodeJSON(&connConfig); err != nil {
return nil, err
}
conn := connConfig.ConnectionParams

b.db, err = sql.Open("mssql", BuildDsn(conn))
if err != nil {
return nil, err
}

// Set some connection pool settings. We don't need much of this,
// since the request rate shouldn't be high.
b.db.SetMaxOpenConns(connConfig.MaxOpenConnections)

stmt, err := b.db.Prepare("SELECT db_name();")
if err != nil {
return nil, err
}
defer stmt.Close()

err = stmt.QueryRow().Scan(&b.defaultDb)
if err != nil {
return nil, err
}

return b.db, nil
}

// ResetDB forces a connection next time DB() is called.
func (b *backend) ResetDB() {
b.lock.Lock()
defer b.lock.Unlock()

if b.db != nil {
b.db.Close()
}

b.db = nil
}

// Lease returns the lease information
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease")
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}

var result configLease
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}

return &result, nil
}

const backendHelp = `
The MSSQL backend dynamically generates database users.
After mounting this backend, configure it using the endpoints within
the "config/" path.
This backend does not support Azure SQL Databases
`
169 changes: 169 additions & 0 deletions builtin/logical/mssql/backend_test.go
@@ -0,0 +1,169 @@
package mssql

import (
"fmt"
"log"
"os"
"testing"

"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
"github.com/mitchellh/mapstructure"
)

func TestBackend_basic(t *testing.T) {
b, _ := Factory(logical.TestBackendConfig())

logicaltest.Test(t, logicaltest.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepRole(t),
testAccStepReadCreds(t, "web"),
},
})
}

func TestBackend_roleCrud(t *testing.T) {
b := Backend()

logicaltest.Test(t, logicaltest.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepRole(t),
testAccStepReadRole(t, "web", testRoleSQL),
testAccStepDeleteRole(t, "web"),
testAccStepReadRole(t, "web", ""),
},
})
}

func TestBackend_leaseWriteRead(t *testing.T) {
b := Backend()

logicaltest.Test(t, logicaltest.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t),
testAccStepWriteLease(t),
testAccStepReadLease(t),
},
})

}

func testAccPreCheck(t *testing.T) {
if v := os.Getenv("MSSQL_PARAMS"); v == "" {
t.Fatal("MSSQL_PARAMS must be set for acceptance tests")
}
}

func testAccStepConfig(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/connection",
Data: map[string]interface{}{
"connection_params": os.Getenv("MSSQL_PARAMS"),
},
}
}

func testAccStepRole(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "roles/web",
Data: map[string]interface{}{
"sql": testRoleSQL,
},
}
}

func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "roles/" + n,
}
}

func testAccStepReadCreds(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "creds/" + name,
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[WARN] Generated credentials: %v", d)

return nil
},
}
}

func testAccStepReadRole(t *testing.T, name, sql string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "roles/" + name,
Check: func(resp *logical.Response) error {
if resp == nil {
if sql == "" {
return nil
}

return fmt.Errorf("bad: %#v", resp)
}

var d struct {
SQL string `mapstructure:"sql"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}

if d.SQL != sql {
return fmt.Errorf("bad: %#v", resp)
}

return nil
},
}
}

func testAccStepWriteLease(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/lease",
Data: map[string]interface{}{
"lease": "1h5m",
"lease_max": "24h",
},
}
}

func testAccStepReadLease(t *testing.T) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "config/lease",
Check: func(resp *logical.Response) error {
if resp.Data["lease"] != "1h5m0s" || resp.Data["lease_max"] != "24h0m0s" {
return fmt.Errorf("bad: %#v", resp)
}

return nil
},
}
}

const testRoleSQL = `
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
CREATE USER [{{name}}] FOR LOGIN [{{name}}];
GRANT SELECT ON SCHEMA::dbo TO [{{name}}]
`
88 changes: 88 additions & 0 deletions builtin/logical/mssql/path_config_connection.go
@@ -0,0 +1,88 @@
package mssql

import (
"database/sql"
"fmt"

"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)

func pathConfigConnection(b *backend) *framework.Path {
return &framework.Path{
Pattern: "config/connection",
Fields: map[string]*framework.FieldSchema{
"connection_params": &framework.FieldSchema{
Type: framework.TypeString,
Description: "DB connection parameters",
},
"max_open_connections": &framework.FieldSchema{
Type: framework.TypeInt,
Description: "Maximum number of open connections to database",
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathConnectionWrite,
},

HelpSynopsis: pathConfigConnectionHelpSyn,
HelpDescription: pathConfigConnectionHelpDesc,
}
}

func (b *backend) pathConnectionWrite(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connParams := data.Get("connection_params").(string)

maxOpenConns := data.Get("max_open_connections").(int)
if maxOpenConns == 0 {
maxOpenConns = 2
}

// Verify the string
db, err := sql.Open("mssql", BuildDsn(connParams))

if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
}
defer db.Close()
if err := db.Ping(); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
}

// Store it
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
ConnectionParams: connParams,
MaxOpenConnections: maxOpenConns,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}

// Reset the DB connection
b.ResetDB()
return nil, nil
}

type connectionConfig struct {
ConnectionParams string `json:"connection_params"`
MaxOpenConnections int `json:"max_open_connections"`
}

const pathConfigConnectionHelpSyn = `
Configure the connection string to talk to Microsoft Sql Server.
`

const pathConfigConnectionHelpDesc = `
This path configures the connection string used to connect to Sql Server.
The value of the string is a Data Source Name (DSN). An example is
using "server=<hostname>;port=<port>;user id=<username>;password=<password>;database=<database>;"
When configuring the connection string, the backend will verify its validity.
`

0 comments on commit ed5ca17

Please sign in to comment.