Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f88c6c1
commit ed5ca17
Showing
10 changed files
with
1,381 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
package mssql | ||
|
||
import ( | ||
"database/sql" | ||
"fmt" | ||
"strings" | ||
"sync" | ||
|
||
_ "github.com/denisenkom/go-mssqldb" | ||
"github.com/hashicorp/vault/logical" | ||
"github.com/hashicorp/vault/logical/framework" | ||
) | ||
|
||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) { | ||
return Backend().Setup(conf) | ||
} | ||
|
||
func Backend() *framework.Backend { | ||
var b backend | ||
b.Backend = &framework.Backend{ | ||
Help: strings.TrimSpace(backendHelp), | ||
|
||
PathsSpecial: &logical.Paths{ | ||
Root: []string{ | ||
"config/*", | ||
}, | ||
}, | ||
|
||
Paths: []*framework.Path{ | ||
pathConfigConnection(&b), | ||
pathConfigLease(&b), | ||
pathListRoles(&b), | ||
pathRoles(&b), | ||
pathCredsCreate(&b), | ||
}, | ||
|
||
Secrets: []*framework.Secret{ | ||
secretCreds(&b), | ||
}, | ||
} | ||
|
||
return b.Backend | ||
} | ||
|
||
type backend struct { | ||
*framework.Backend | ||
|
||
db *sql.DB | ||
defaultDb string | ||
lock sync.Mutex | ||
} | ||
|
||
// DB returns the default database connection. | ||
func (b *backend) DB(s logical.Storage) (*sql.DB, error) { | ||
b.lock.Lock() | ||
defer b.lock.Unlock() | ||
|
||
// If we already have a DB, we got it! | ||
if b.db != nil { | ||
return b.db, nil | ||
} | ||
|
||
// Otherwise, attempt to make connection | ||
entry, err := s.Get("config/connection") | ||
if err != nil { | ||
return nil, err | ||
} | ||
if entry == nil { | ||
return nil, fmt.Errorf("configure the DB connection with config/connection first") | ||
} | ||
|
||
var connConfig connectionConfig | ||
if err := entry.DecodeJSON(&connConfig); err != nil { | ||
return nil, err | ||
} | ||
conn := connConfig.ConnectionParams | ||
|
||
b.db, err = sql.Open("mssql", BuildDsn(conn)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// Set some connection pool settings. We don't need much of this, | ||
// since the request rate shouldn't be high. | ||
b.db.SetMaxOpenConns(connConfig.MaxOpenConnections) | ||
|
||
stmt, err := b.db.Prepare("SELECT db_name();") | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer stmt.Close() | ||
|
||
err = stmt.QueryRow().Scan(&b.defaultDb) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return b.db, nil | ||
} | ||
|
||
// ResetDB forces a connection next time DB() is called. | ||
func (b *backend) ResetDB() { | ||
b.lock.Lock() | ||
defer b.lock.Unlock() | ||
|
||
if b.db != nil { | ||
b.db.Close() | ||
} | ||
|
||
b.db = nil | ||
} | ||
|
||
// Lease returns the lease information | ||
func (b *backend) Lease(s logical.Storage) (*configLease, error) { | ||
entry, err := s.Get("config/lease") | ||
if err != nil { | ||
return nil, err | ||
} | ||
if entry == nil { | ||
return nil, nil | ||
} | ||
|
||
var result configLease | ||
if err := entry.DecodeJSON(&result); err != nil { | ||
return nil, err | ||
} | ||
|
||
return &result, nil | ||
} | ||
|
||
const backendHelp = ` | ||
The MSSQL backend dynamically generates database users. | ||
After mounting this backend, configure it using the endpoints within | ||
the "config/" path. | ||
This backend does not support Azure SQL Databases | ||
` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
package mssql | ||
|
||
import ( | ||
"fmt" | ||
"log" | ||
"os" | ||
"testing" | ||
|
||
"github.com/hashicorp/vault/logical" | ||
logicaltest "github.com/hashicorp/vault/logical/testing" | ||
"github.com/mitchellh/mapstructure" | ||
) | ||
|
||
func TestBackend_basic(t *testing.T) { | ||
b, _ := Factory(logical.TestBackendConfig()) | ||
|
||
logicaltest.Test(t, logicaltest.TestCase{ | ||
PreCheck: func() { testAccPreCheck(t) }, | ||
Backend: b, | ||
Steps: []logicaltest.TestStep{ | ||
testAccStepConfig(t), | ||
testAccStepRole(t), | ||
testAccStepReadCreds(t, "web"), | ||
}, | ||
}) | ||
} | ||
|
||
func TestBackend_roleCrud(t *testing.T) { | ||
b := Backend() | ||
|
||
logicaltest.Test(t, logicaltest.TestCase{ | ||
PreCheck: func() { testAccPreCheck(t) }, | ||
Backend: b, | ||
Steps: []logicaltest.TestStep{ | ||
testAccStepConfig(t), | ||
testAccStepRole(t), | ||
testAccStepReadRole(t, "web", testRoleSQL), | ||
testAccStepDeleteRole(t, "web"), | ||
testAccStepReadRole(t, "web", ""), | ||
}, | ||
}) | ||
} | ||
|
||
func TestBackend_leaseWriteRead(t *testing.T) { | ||
b := Backend() | ||
|
||
logicaltest.Test(t, logicaltest.TestCase{ | ||
PreCheck: func() { testAccPreCheck(t) }, | ||
Backend: b, | ||
Steps: []logicaltest.TestStep{ | ||
testAccStepConfig(t), | ||
testAccStepWriteLease(t), | ||
testAccStepReadLease(t), | ||
}, | ||
}) | ||
|
||
} | ||
|
||
func testAccPreCheck(t *testing.T) { | ||
if v := os.Getenv("MSSQL_PARAMS"); v == "" { | ||
t.Fatal("MSSQL_PARAMS must be set for acceptance tests") | ||
} | ||
} | ||
|
||
func testAccStepConfig(t *testing.T) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.UpdateOperation, | ||
Path: "config/connection", | ||
Data: map[string]interface{}{ | ||
"connection_params": os.Getenv("MSSQL_PARAMS"), | ||
}, | ||
} | ||
} | ||
|
||
func testAccStepRole(t *testing.T) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.UpdateOperation, | ||
Path: "roles/web", | ||
Data: map[string]interface{}{ | ||
"sql": testRoleSQL, | ||
}, | ||
} | ||
} | ||
|
||
func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.DeleteOperation, | ||
Path: "roles/" + n, | ||
} | ||
} | ||
|
||
func testAccStepReadCreds(t *testing.T, name string) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.ReadOperation, | ||
Path: "creds/" + name, | ||
Check: func(resp *logical.Response) error { | ||
var d struct { | ||
Username string `mapstructure:"username"` | ||
Password string `mapstructure:"password"` | ||
} | ||
if err := mapstructure.Decode(resp.Data, &d); err != nil { | ||
return err | ||
} | ||
log.Printf("[WARN] Generated credentials: %v", d) | ||
|
||
return nil | ||
}, | ||
} | ||
} | ||
|
||
func testAccStepReadRole(t *testing.T, name, sql string) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.ReadOperation, | ||
Path: "roles/" + name, | ||
Check: func(resp *logical.Response) error { | ||
if resp == nil { | ||
if sql == "" { | ||
return nil | ||
} | ||
|
||
return fmt.Errorf("bad: %#v", resp) | ||
} | ||
|
||
var d struct { | ||
SQL string `mapstructure:"sql"` | ||
} | ||
if err := mapstructure.Decode(resp.Data, &d); err != nil { | ||
return err | ||
} | ||
|
||
if d.SQL != sql { | ||
return fmt.Errorf("bad: %#v", resp) | ||
} | ||
|
||
return nil | ||
}, | ||
} | ||
} | ||
|
||
func testAccStepWriteLease(t *testing.T) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.UpdateOperation, | ||
Path: "config/lease", | ||
Data: map[string]interface{}{ | ||
"lease": "1h5m", | ||
"lease_max": "24h", | ||
}, | ||
} | ||
} | ||
|
||
func testAccStepReadLease(t *testing.T) logicaltest.TestStep { | ||
return logicaltest.TestStep{ | ||
Operation: logical.ReadOperation, | ||
Path: "config/lease", | ||
Check: func(resp *logical.Response) error { | ||
if resp.Data["lease"] != "1h5m0s" || resp.Data["lease_max"] != "24h0m0s" { | ||
return fmt.Errorf("bad: %#v", resp) | ||
} | ||
|
||
return nil | ||
}, | ||
} | ||
} | ||
|
||
const testRoleSQL = ` | ||
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; | ||
CREATE USER [{{name}}] FOR LOGIN [{{name}}]; | ||
GRANT SELECT ON SCHEMA::dbo TO [{{name}}] | ||
` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package mssql | ||
|
||
import ( | ||
"database/sql" | ||
"fmt" | ||
|
||
"github.com/hashicorp/vault/logical" | ||
"github.com/hashicorp/vault/logical/framework" | ||
) | ||
|
||
func pathConfigConnection(b *backend) *framework.Path { | ||
return &framework.Path{ | ||
Pattern: "config/connection", | ||
Fields: map[string]*framework.FieldSchema{ | ||
"connection_params": &framework.FieldSchema{ | ||
Type: framework.TypeString, | ||
Description: "DB connection parameters", | ||
}, | ||
"max_open_connections": &framework.FieldSchema{ | ||
Type: framework.TypeInt, | ||
Description: "Maximum number of open connections to database", | ||
}, | ||
}, | ||
|
||
Callbacks: map[logical.Operation]framework.OperationFunc{ | ||
logical.UpdateOperation: b.pathConnectionWrite, | ||
}, | ||
|
||
HelpSynopsis: pathConfigConnectionHelpSyn, | ||
HelpDescription: pathConfigConnectionHelpDesc, | ||
} | ||
} | ||
|
||
func (b *backend) pathConnectionWrite( | ||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||
connParams := data.Get("connection_params").(string) | ||
|
||
maxOpenConns := data.Get("max_open_connections").(int) | ||
if maxOpenConns == 0 { | ||
maxOpenConns = 2 | ||
} | ||
|
||
// Verify the string | ||
db, err := sql.Open("mssql", BuildDsn(connParams)) | ||
|
||
if err != nil { | ||
return logical.ErrorResponse(fmt.Sprintf( | ||
"Error validating connection info: %s", err)), nil | ||
} | ||
defer db.Close() | ||
if err := db.Ping(); err != nil { | ||
return logical.ErrorResponse(fmt.Sprintf( | ||
"Error validating connection info: %s", err)), nil | ||
} | ||
|
||
// Store it | ||
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{ | ||
ConnectionParams: connParams, | ||
MaxOpenConnections: maxOpenConns, | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if err := req.Storage.Put(entry); err != nil { | ||
return nil, err | ||
} | ||
|
||
// Reset the DB connection | ||
b.ResetDB() | ||
return nil, nil | ||
} | ||
|
||
type connectionConfig struct { | ||
ConnectionParams string `json:"connection_params"` | ||
MaxOpenConnections int `json:"max_open_connections"` | ||
} | ||
|
||
const pathConfigConnectionHelpSyn = ` | ||
Configure the connection string to talk to Microsoft Sql Server. | ||
` | ||
|
||
const pathConfigConnectionHelpDesc = ` | ||
This path configures the connection string used to connect to Sql Server. | ||
The value of the string is a Data Source Name (DSN). An example is | ||
using "server=<hostname>;port=<port>;user id=<username>;password=<password>;database=<database>;" | ||
When configuring the connection string, the backend will verify its validity. | ||
` |
Oops, something went wrong.