Skip to content

Commit

Permalink
Merge pull request #1196 from hashicorp/reload-listener-tls
Browse files Browse the repository at this point in the history
Add reload capability for Vault listener certs
  • Loading branch information
jefferai committed Mar 15, 2016
2 parents b953955 + 3a878c3 commit 6e3b771
Show file tree
Hide file tree
Showing 14 changed files with 477 additions and 58 deletions.
24 changes: 4 additions & 20 deletions cli/commands.go
Expand Up @@ -2,11 +2,10 @@ package cli

import (
"os"
"os/signal"
"syscall"

auditFile "github.com/hashicorp/vault/builtin/audit/file"
auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog"
"github.com/hashicorp/vault/command/server"
"github.com/hashicorp/vault/version"

credAppId "github.com/hashicorp/vault/builtin/credential/app-id"
Expand Down Expand Up @@ -78,7 +77,9 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory {
"mysql": mysql.Factory,
"ssh": ssh.Factory,
},
ShutdownCh: makeShutdownCh(),
ShutdownCh: command.MakeShutdownCh(),
SighupCh: command.MakeSighupCh(),
ReloadFuncs: map[string][]server.ReloadFunc{},
}, nil
},

Expand Down Expand Up @@ -308,20 +309,3 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory {
},
}
}

// makeShutdownCh returns a channel that can be used for shutdown
// notifications for commands. This channel will send a message for every
// interrupt or SIGTERM received.
func makeShutdownCh() <-chan struct{} {
resultCh := make(chan struct{})

signalCh := make(chan os.Signal, 4)
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)
go func() {
for {
<-signalCh
resultCh <- struct{}{}
}
}()
return resultCh
}
110 changes: 103 additions & 7 deletions command/server.go
Expand Up @@ -8,13 +8,16 @@ import (
"net/http"
"net/url"
"os"
"os/signal"
"runtime"
"sort"
"strconv"
"strings"
"syscall"
"time"

"github.com/armon/go-metrics"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/logutils"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/command/server"
Expand All @@ -34,8 +37,12 @@ type ServerCommand struct {
CredentialBackends map[string]logical.Factory
LogicalBackends map[string]logical.Factory

ShutdownCh <-chan struct{}
ShutdownCh chan struct{}
SighupCh chan struct{}

Meta

ReloadFuncs map[string][]server.ReloadFunc
}

func (c *ServerCommand) Run(args []string) int {
Expand Down Expand Up @@ -274,7 +281,7 @@ func (c *ServerCommand) Run(args []string) int {
// Initialize the listeners
lns := make([]net.Listener, 0, len(config.Listeners))
for i, lnConfig := range config.Listeners {
ln, props, err := server.NewListener(lnConfig.Type, lnConfig.Config)
ln, props, reloadFunc, err := server.NewListener(lnConfig.Type, lnConfig.Config)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing listener of type %s: %s",
Expand All @@ -295,6 +302,12 @@ func (c *ServerCommand) Run(args []string) int {
"%s (%s)", lnConfig.Type, strings.Join(propsList, ", "))

lns = append(lns, ln)

if reloadFunc != nil {
relSlice := c.ReloadFuncs["listener|"+lnConfig.Type]
relSlice = append(relSlice, reloadFunc)
c.ReloadFuncs["listener|"+lnConfig.Type] = relSlice
}
}

infoKeys = append(infoKeys, "version")
Expand Down Expand Up @@ -333,11 +346,20 @@ func (c *ServerCommand) Run(args []string) int {
logGate.Flush()

// Wait for shutdown
select {
case <-c.ShutdownCh:
c.Ui.Output("==> Vault shutdown triggered")
if err := core.Shutdown(); err != nil {
c.Ui.Error(fmt.Sprintf("Error with core shutdown: %s", err))
shutdownTriggered := false
for !shutdownTriggered {
select {
case <-c.ShutdownCh:
c.Ui.Output("==> Vault shutdown triggered")
if err := core.Shutdown(); err != nil {
c.Ui.Error(fmt.Sprintf("Error with core shutdown: %s", err))
}
shutdownTriggered = true
case <-c.SighupCh:
c.Ui.Output("==> Vault reload triggered")
if err := c.Reload(configPath); err != nil {
c.Ui.Error(fmt.Sprintf("Error(s) were encountered during reload: %s", err))
}
}
}
return 0
Expand Down Expand Up @@ -530,6 +552,46 @@ func (c *ServerCommand) setupTelementry(config *server.Config) error {
return nil
}

func (c *ServerCommand) Reload(configPath []string) error {
// Read the new config
var config *server.Config
for _, path := range configPath {
current, err := server.LoadConfig(path)
if err != nil {
retErr := fmt.Errorf("Error loading configuration from %s: %s", path, err)
c.Ui.Error(retErr.Error())
return retErr
}

if config == nil {
config = current
} else {
config = config.Merge(current)
}
}

// Ensure at least one config was found.
if config == nil {
retErr := fmt.Errorf("No configuration files found")
c.Ui.Error(retErr.Error())
return retErr
}

var reloadErrors *multierror.Error
// Call reload on the listeners. This will call each listener with each
// config block, but they verify the address.
for _, lnConfig := range config.Listeners {
for _, relFunc := range c.ReloadFuncs["listener|"+lnConfig.Type] {
if err := relFunc(lnConfig.Config); err != nil {
retErr := fmt.Errorf("Error encountered reloading configuration: %s", err)
reloadErrors = multierror.Append(retErr)
}
}
}

return reloadErrors.ErrorOrNil()
}

func (c *ServerCommand) Synopsis() string {
return "Start a Vault server"
}
Expand Down Expand Up @@ -577,3 +639,37 @@ General Options:
`
return strings.TrimSpace(helpText)
}

// MakeShutdownCh returns a channel that can be used for shutdown
// notifications for commands. This channel will send a message for every
// interrupt or SIGTERM received.
func MakeShutdownCh() chan struct{} {
resultCh := make(chan struct{})

signalCh := make(chan os.Signal, 4)
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)
go func() {
for {
<-signalCh
resultCh <- struct{}{}
}
}()
return resultCh
}

// MakeSighupCh returns a channel that can be used for SIGHUP
// reloading. This channel will send a message for every
// SIGHUP received.
func MakeSighupCh() chan struct{} {
resultCh := make(chan struct{})

signalCh := make(chan os.Signal, 4)
signal.Notify(signalCh, os.Interrupt, syscall.SIGHUP)
go func() {
for {
<-signalCh
resultCh <- struct{}{}
}
}()
return resultCh
}
3 changes: 3 additions & 0 deletions command/server/config.go
Expand Up @@ -15,6 +15,9 @@ import (
"github.com/hashicorp/hcl/hcl/ast"
)

// ReloadFunc are functions that are called when a reload is requested.
type ReloadFunc func(map[string]string) error

// Config is the configuration for the vault server.
type Config struct {
Listeners []*Listener `hcl:"-"`
Expand Down
73 changes: 57 additions & 16 deletions command/server/listener.go
Expand Up @@ -8,10 +8,11 @@ import (
"fmt"
"net"
"strconv"
"sync"
)

// ListenerFactory is the factory function to create a listener.
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, error)
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, ReloadFunc, error)

// BuiltinListeners is the list of built-in listener types.
var BuiltinListeners = map[string]ListenerFactory{
Expand All @@ -27,10 +28,10 @@ var tlsLookup = map[string]uint16{

// NewListener creates a new listener of the given type with the given
// configuration. The type is looked up in the BuiltinListeners map.
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, error) {
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) {
f, ok := BuiltinListeners[t]
if !ok {
return nil, nil, fmt.Errorf("unknown listener type: %s", t)
return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t)
}

return f(config)
Expand All @@ -39,32 +40,35 @@ func NewListener(t string, config map[string]string) (net.Listener, map[string]s
func listenerWrapTLS(
ln net.Listener,
props map[string]string,
config map[string]string) (net.Listener, map[string]string, error) {
config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) {
props["tls"] = "disabled"

if v, ok := config["tls_disable"]; ok {
disabled, err := strconv.ParseBool(v)
if err != nil {
return nil, nil, fmt.Errorf("invalid value for 'tls_disable': %v", err)
return nil, nil, nil, fmt.Errorf("invalid value for 'tls_disable': %v", err)
}
if disabled {
return ln, props, nil
return ln, props, nil, nil
}
}

certFile, ok := config["tls_cert_file"]
_, ok := config["tls_cert_file"]
if !ok {
return nil, nil, fmt.Errorf("'tls_cert_file' must be set")
return nil, nil, nil, fmt.Errorf("'tls_cert_file' must be set")
}

keyFile, ok := config["tls_key_file"]
_, ok = config["tls_key_file"]
if !ok {
return nil, nil, fmt.Errorf("'tls_key_file' must be set")
return nil, nil, nil, fmt.Errorf("'tls_key_file' must be set")
}

cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, nil, fmt.Errorf("error loading TLS cert: %s", err)
cg := &certificateGetter{
id: config["address"],
}

if err := cg.reload(config); err != nil {
return nil, nil, nil, fmt.Errorf("error loading TLS cert: %s", err)
}

tlsvers, ok := config["tls_min_version"]
Expand All @@ -73,15 +77,52 @@ func listenerWrapTLS(
}

tlsConf := &tls.Config{}
tlsConf.Certificates = []tls.Certificate{cert}
tlsConf.GetCertificate = cg.getCertificate
tlsConf.NextProtos = []string{"http/1.1"}
tlsConf.MinVersion, ok = tlsLookup[tlsvers]
if !ok {
return nil, nil, fmt.Errorf("'tls_min_version' value %s not supported, please specify one of [tls10,tls11,tls12]", tlsvers)
return nil, nil, nil, fmt.Errorf("'tls_min_version' value %s not supported, please specify one of [tls10,tls11,tls12]", tlsvers)
}
tlsConf.ClientAuth = tls.RequestClientCert

ln = tls.NewListener(ln, tlsConf)
props["tls"] = "enabled"
return ln, props, nil
return ln, props, cg.reload, nil
}

type certificateGetter struct {
sync.RWMutex

cert *tls.Certificate

id string
}

func (cg *certificateGetter) reload(config map[string]string) error {
if config["address"] != cg.id {
return nil
}

cert, err := tls.LoadX509KeyPair(config["tls_cert_file"], config["tls_key_file"])
if err != nil {
return err
}

cg.Lock()
defer cg.Unlock()

cg.cert = &cert

return nil
}

func (cg *certificateGetter) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cg.RLock()
defer cg.RUnlock()

if cg.cert == nil {
return nil, fmt.Errorf("nil certificate")
}

return cg.cert, nil
}
4 changes: 2 additions & 2 deletions command/server/listener_tcp.go
Expand Up @@ -5,15 +5,15 @@ import (
"time"
)

func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, error) {
func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, ReloadFunc, error) {
addr, ok := config["address"]
if !ok {
addr = "127.0.0.1:8200"
}

ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}

ln = tcpKeepAliveListener{ln.(*net.TCPListener)}
Expand Down

0 comments on commit 6e3b771

Please sign in to comment.