Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #998 from chrishoffman/mssql
Sql Server (mssql) secret backend
- Loading branch information
Showing
50 changed files
with
70,367 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
package mssql | ||
|
||
import ( | ||
"database/sql" | ||
"fmt" | ||
"strings" | ||
"sync" | ||
|
||
_ "github.com/denisenkom/go-mssqldb" | ||
"github.com/hashicorp/vault/logical" | ||
"github.com/hashicorp/vault/logical/framework" | ||
) | ||
|
||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) { | ||
return Backend().Setup(conf) | ||
} | ||
|
||
func Backend() *framework.Backend { | ||
var b backend | ||
b.Backend = &framework.Backend{ | ||
Help: strings.TrimSpace(backendHelp), | ||
|
||
Paths: []*framework.Path{ | ||
pathConfigConnection(&b), | ||
pathConfigLease(&b), | ||
pathListRoles(&b), | ||
pathRoles(&b), | ||
pathCredsCreate(&b), | ||
}, | ||
|
||
Secrets: []*framework.Secret{ | ||
secretCreds(&b), | ||
}, | ||
} | ||
|
||
return b.Backend | ||
} | ||
|
||
type backend struct { | ||
*framework.Backend | ||
|
||
db *sql.DB | ||
defaultDb string | ||
lock sync.Mutex | ||
} | ||
|
||
// DB returns the default database connection. | ||
func (b *backend) DB(s logical.Storage) (*sql.DB, error) { | ||
b.lock.Lock() | ||
defer b.lock.Unlock() | ||
|
||
// If we already have a DB, we got it! | ||
if b.db != nil { | ||
return b.db, nil | ||
} | ||
|
||
// Otherwise, attempt to make connection | ||
entry, err := s.Get("config/connection") | ||
if err != nil { | ||
return nil, err | ||
} | ||
if entry == nil { | ||
return nil, fmt.Errorf("configure the DB connection with config/connection first") | ||
} | ||
|
||
var connConfig connectionConfig | ||
if err := entry.DecodeJSON(&connConfig); err != nil { | ||
return nil, err | ||
} | ||
connString := connConfig.ConnectionString | ||
|
||
db, err := sql.Open("mssql", connString) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// Set some connection pool settings. We don't need much of this, | ||
// since the request rate shouldn't be high. | ||
db.SetMaxOpenConns(connConfig.MaxOpenConnections) | ||
|
||
stmt, err := db.Prepare("SELECT db_name();") | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer stmt.Close() | ||
|
||
err = stmt.QueryRow().Scan(&b.defaultDb) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
b.db = db | ||
return b.db, nil | ||
} | ||
|
||
// ResetDB forces a connection next time DB() is called. | ||
func (b *backend) ResetDB() { | ||
b.lock.Lock() | ||
defer b.lock.Unlock() | ||
|
||
if b.db != nil { | ||
b.db.Close() | ||
} | ||
|
||
b.db = nil | ||
} | ||
|
||
// LeaseConfig returns the lease configuration | ||
func (b *backend) LeaseConfig(s logical.Storage) (*configLease, error) { | ||
entry, err := s.Get("config/lease") | ||
if err != nil { | ||
return nil, err | ||
} | ||
if entry == nil { | ||
return nil, nil | ||
} | ||
|
||
var result configLease | ||
if err := entry.DecodeJSON(&result); err != nil { | ||
return nil, err | ||
} | ||
|
||
return &result, nil | ||
} | ||
|
||
const backendHelp = ` | ||
The MSSQL backend dynamically generates database users. | ||
After mounting this backend, configure it using the endpoints within | ||
the "config/" path. | ||
This backend does not support Azure SQL Databases. | ||
` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
package mssql | ||
|
||
import ( | ||
"fmt" | ||
"log" | ||
"os" | ||
"testing" | ||
|
||
"github.com/hashicorp/vault/logical" | ||
logicaltest "github.com/hashicorp/vault/logical/testing" | ||
"github.com/mitchellh/mapstructure" | ||
) | ||
|
||
func TestBackend_basic(t *testing.T) { | ||
b, _ := Factory(logical.TestBackendConfig()) | ||
|
||
logicaltest.Test(t, logicaltest.TestCase{ | ||
PreCheck: func() { testAccPreCheck(t) }, | ||
Backend: b, | ||
Steps: []logicaltest.TestStep{ | ||
testAccStepConfig(t), | ||
testAccStepRole(t), | ||
testAccStepReadCreds(t, "web"), | ||
}, | ||
}) | ||
} | ||
|
||
func TestBackend_roleCrud(t *testing.T) { | ||
b := Backend() | ||
|
||
logicaltest.Test(t, logicaltest.TestCase{ | ||
PreCheck: func() { testAccPreCheck(t) }, | ||
Backend: b, | ||
Steps: []logicaltest.TestStep{ | ||
testAccStepConfig(t), | ||
testAccStepRole(t), | ||
testAccStepReadRole(t, "web", testRoleSQL), | ||
testAccStepDeleteRole(t, "web"), | ||
testAccStepReadRole(t, "web", ""), | ||
}, | ||
}) | ||
} | ||
|
||
func TestBackend_leaseWriteRead(t *testing.T) { | ||
b := Backend() | ||
|
||
logicaltest.Test(t, logicaltest.TestCase{ | ||
PreCheck: func() { testAccPreCheck(t) }, | ||
Backend: b, | ||
Steps: []logicaltest.TestStep{ | ||
testAccStepConfig(t), | ||
testAccStepWriteLease(t), | ||
testAccStepReadLease(t), | ||
}, | ||
}) | ||
|
||
} | ||
|
||
func testAccPreCheck(t *testing.T) { | ||
if v := os.Getenv("MSSQL_DSN"); v == "" { | ||
t.Fatal("MSSQL_DSN must be set for acceptance tests") | ||
} | ||
} | ||
|
||
func testAccStepConfig(t *testing.T) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.UpdateOperation, | ||
Path: "config/connection", | ||
Data: map[string]interface{}{ | ||
"connection_string": os.Getenv("MSSQL_DSN"), | ||
}, | ||
} | ||
} | ||
|
||
func testAccStepRole(t *testing.T) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.UpdateOperation, | ||
Path: "roles/web", | ||
Data: map[string]interface{}{ | ||
"sql": testRoleSQL, | ||
}, | ||
} | ||
} | ||
|
||
func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.DeleteOperation, | ||
Path: "roles/" + n, | ||
} | ||
} | ||
|
||
func testAccStepReadCreds(t *testing.T, name string) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.ReadOperation, | ||
Path: "creds/" + name, | ||
Check: func(resp *logical.Response) error { | ||
var d struct { | ||
Username string `mapstructure:"username"` | ||
Password string `mapstructure:"password"` | ||
} | ||
if err := mapstructure.Decode(resp.Data, &d); err != nil { | ||
return err | ||
} | ||
log.Printf("[WARN] Generated credentials: %v", d) | ||
|
||
return nil | ||
}, | ||
} | ||
} | ||
|
||
func testAccStepReadRole(t *testing.T, name, sql string) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.ReadOperation, | ||
Path: "roles/" + name, | ||
Check: func(resp *logical.Response) error { | ||
if resp == nil { | ||
if sql == "" { | ||
return nil | ||
} | ||
|
||
return fmt.Errorf("bad: %#v", resp) | ||
} | ||
|
||
var d struct { | ||
SQL string `mapstructure:"sql"` | ||
} | ||
if err := mapstructure.Decode(resp.Data, &d); err != nil { | ||
return err | ||
} | ||
|
||
if d.SQL != sql { | ||
return fmt.Errorf("bad: %#v", resp) | ||
} | ||
|
||
return nil | ||
}, | ||
} | ||
} | ||
|
||
func testAccStepWriteLease(t *testing.T) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.UpdateOperation, | ||
Path: "config/lease", | ||
Data: map[string]interface{}{ | ||
"ttl": "1h5m", | ||
"ttl_max": "24h", | ||
}, | ||
} | ||
} | ||
|
||
func testAccStepReadLease(t *testing.T) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.ReadOperation, | ||
Path: "config/lease", | ||
Check: func(resp *logical.Response) error { | ||
if resp.Data["ttl"] != "1h5m0s" || resp.Data["ttl_max"] != "24h0m0s" { | ||
return fmt.Errorf("bad: %#v", resp) | ||
} | ||
|
||
return nil | ||
}, | ||
} | ||
} | ||
|
||
const testRoleSQL = ` | ||
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; | ||
CREATE USER [{{name}}] FOR LOGIN [{{name}}]; | ||
GRANT SELECT ON SCHEMA::dbo TO [{{name}}] | ||
` |
Oops, something went wrong.