Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reload capability for Vault listener certs #1196

Merged
merged 9 commits into from Mar 15, 2016
18 changes: 18 additions & 0 deletions cli/commands.go
Expand Up @@ -79,6 +79,7 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory {
"ssh": ssh.Factory,
},
ShutdownCh: makeShutdownCh(),
SighupCh: makeSighupCh(),
}, nil
},

Expand Down Expand Up @@ -325,3 +326,20 @@ func makeShutdownCh() <-chan struct{} {
}()
return resultCh
}

// makeSighupCh returns a channel that can be used for SIGHUP
// reloading. This channel will send a message for every
// SIGHUP received.
func makeSighupCh() <-chan struct{} {
resultCh := make(chan struct{})

signalCh := make(chan os.Signal, 4)
signal.Notify(signalCh, os.Interrupt, syscall.SIGHUP)
go func() {
for {
<-signalCh
resultCh <- struct{}{}
}
}()
return resultCh
}
74 changes: 68 additions & 6 deletions command/server.go
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/armon/go-metrics"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/logutils"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/command/server"
Expand All @@ -35,7 +36,11 @@ type ServerCommand struct {
LogicalBackends map[string]logical.Factory

ShutdownCh <-chan struct{}
SighupCh <-chan struct{}

Meta

ReloadFuncs map[string]server.ReloadFunc
}

func (c *ServerCommand) Run(args []string) int {
Expand Down Expand Up @@ -274,7 +279,7 @@ func (c *ServerCommand) Run(args []string) int {
// Initialize the listeners
lns := make([]net.Listener, 0, len(config.Listeners))
for i, lnConfig := range config.Listeners {
ln, props, err := server.NewListener(lnConfig.Type, lnConfig.Config)
ln, props, reloadFactory, err := server.NewListener(lnConfig.Type, lnConfig.Config)
if err != nil {
c.Ui.Error(fmt.Sprintf(
"Error initializing listener of type %s: %s",
Expand All @@ -295,6 +300,11 @@ func (c *ServerCommand) Run(args []string) int {
"%s (%s)", lnConfig.Type, strings.Join(propsList, ", "))

lns = append(lns, ln)

if reloadFactory != nil {
relId, relFunc := reloadFactory()
c.ReloadFuncs[relId] = relFunc
}
}

infoKeys = append(infoKeys, "version")
Expand Down Expand Up @@ -333,11 +343,23 @@ func (c *ServerCommand) Run(args []string) int {
logGate.Flush()

// Wait for shutdown
select {
case <-c.ShutdownCh:
c.Ui.Output("==> Vault shutdown triggered")
if err := core.Shutdown(); err != nil {
c.Ui.Error(fmt.Sprintf("Error with core shutdown: %s", err))
shutdownTriggered := false
for {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May as well do for !shutdownTriggered { and remove the if block

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, I've gotten so used to the infinite-loop-using-for-to-keep-selecting paradigm that it didn't even cross my mind.

select {
case <-c.ShutdownCh:
c.Ui.Output("==> Vault shutdown triggered")
if err := core.Shutdown(); err != nil {
c.Ui.Error(fmt.Sprintf("Error with core shutdown: %s", err))
}
shutdownTriggered = true
case <-c.SighupCh:
c.Ui.Output("==> Vault reload triggered")
if err := c.Reload(configPath); err != nil {
c.Ui.Error(fmt.Sprintf("Error(s) were encountered during reload: %s", err))
}
}
if shutdownTriggered {
break
}
}
return 0
Expand Down Expand Up @@ -530,6 +552,46 @@ func (c *ServerCommand) setupTelementry(config *server.Config) error {
return nil
}

func (c *ServerCommand) Reload(configPath []string) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to expose this method?

Also, is it possible to use this method to load the configuration the first time as well?

// Read the new config
var config *server.Config
for _, path := range configPath {
current, err := server.LoadConfig(path)
if err != nil {
retErr := fmt.Errorf("Error loading configuration from %s: %s", path, err)
c.Ui.Error(retErr.Error())
return retErr
}

if config == nil {
config = current
} else {
config = config.Merge(current)
}
}

// Ensure at least one config was found.
if config == nil {
retErr := fmt.Errorf("No configuration files found")
c.Ui.Error(retErr.Error())
return retErr
}

var reloadErrors *multierror.Error
// Call reload on the listeners. This will call each listener with each
// config block, but they verify the ID.
for _, lnConfig := range config.Listeners {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to range over the listeners? It seems like we have one reloadFunc per anyways, seems like we can just loop over those

for id, relFunc := range c.ReloadFuncs {
if err := relFunc(id, lnConfig.Config); err != nil {
retErr := fmt.Errorf("Error encountered reloading configuration for %s: %s", id, err)
reloadErrors = multierror.Append(retErr)
}
}
}

return reloadErrors.ErrorOrNil()
}

func (c *ServerCommand) Synopsis() string {
return "Start a Vault server"
}
Expand Down
7 changes: 7 additions & 0 deletions command/server/config.go
Expand Up @@ -15,6 +15,13 @@ import (
"github.com/hashicorp/hcl/hcl/ast"
)

// ReloadFunc are functions that are called when a reload is requested.
type ReloadFunc func(string, map[string]string) error

// ReloadFactory can be called to return the desired ID and the associated
// reload function.
type ReloadFactory func() (string, ReloadFunc)

// Config is the configuration for the vault server.
type Config struct {
Listeners []*Listener `hcl:"-"`
Expand Down
75 changes: 59 additions & 16 deletions command/server/listener.go
Expand Up @@ -8,10 +8,11 @@ import (
"fmt"
"net"
"strconv"
"sync"
)

// ListenerFactory is the factory function to create a listener.
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, error)
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, ReloadFactory, error)

// BuiltinListeners is the list of built-in listener types.
var BuiltinListeners = map[string]ListenerFactory{
Expand All @@ -27,10 +28,10 @@ var tlsLookup = map[string]uint16{

// NewListener creates a new listener of the given type with the given
// configuration. The type is looked up in the BuiltinListeners map.
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, error) {
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, ReloadFactory, error) {
f, ok := BuiltinListeners[t]
if !ok {
return nil, nil, fmt.Errorf("unknown listener type: %s", t)
return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t)
}

return f(config)
Expand All @@ -39,32 +40,35 @@ func NewListener(t string, config map[string]string) (net.Listener, map[string]s
func listenerWrapTLS(
ln net.Listener,
props map[string]string,
config map[string]string) (net.Listener, map[string]string, error) {
config map[string]string) (net.Listener, map[string]string, ReloadFactory, error) {
props["tls"] = "disabled"

if v, ok := config["tls_disable"]; ok {
disabled, err := strconv.ParseBool(v)
if err != nil {
return nil, nil, fmt.Errorf("invalid value for 'tls_disable': %v", err)
return nil, nil, nil, fmt.Errorf("invalid value for 'tls_disable': %v", err)
}
if disabled {
return ln, props, nil
return ln, props, nil, nil
}
}

certFile, ok := config["tls_cert_file"]
_, ok := config["tls_cert_file"]
if !ok {
return nil, nil, fmt.Errorf("'tls_cert_file' must be set")
return nil, nil, nil, fmt.Errorf("'tls_cert_file' must be set")
}

keyFile, ok := config["tls_key_file"]
_, ok = config["tls_key_file"]
if !ok {
return nil, nil, fmt.Errorf("'tls_key_file' must be set")
return nil, nil, nil, fmt.Errorf("'tls_key_file' must be set")
}

cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, nil, fmt.Errorf("error loading TLS cert: %s", err)
cg := &certificateGetter{
id: "listen|" + ln.Addr().String(),
}

if err := cg.reload(cg.id, config); err != nil {
return nil, nil, nil, fmt.Errorf("error loading TLS cert: %s", err)
}

tlsvers, ok := config["tls_min_version"]
Expand All @@ -73,15 +77,54 @@ func listenerWrapTLS(
}

tlsConf := &tls.Config{}
tlsConf.Certificates = []tls.Certificate{cert}
tlsConf.GetCertificate = cg.getCertificate
tlsConf.NextProtos = []string{"http/1.1"}
tlsConf.MinVersion, ok = tlsLookup[tlsvers]
if !ok {
return nil, nil, fmt.Errorf("'tls_min_version' value %s not supported, please specify one of [tls10,tls11,tls12]", tlsvers)
return nil, nil, nil, fmt.Errorf("'tls_min_version' value %s not supported, please specify one of [tls10,tls11,tls12]", tlsvers)
}
tlsConf.ClientAuth = tls.RequestClientCert

ln = tls.NewListener(ln, tlsConf)
props["tls"] = "enabled"
return ln, props, nil
return ln, props, func() (string, ReloadFunc) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't inline the function like this, its kind of weird to read

return cg.id, cg.reload
}, nil
}

type certificateGetter struct {
sync.RWMutex

cert *tls.Certificate

id string
}

func (cg *certificateGetter) reload(id string, config map[string]string) error {
if id != cg.id {
return nil
}

cert, err := tls.LoadX509KeyPair(config["tls_cert_file"], config["tls_key_file"])
if err != nil {
return err
}

cg.Lock()
defer cg.Unlock()

cg.cert = &cert

return nil
}

func (cg *certificateGetter) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cg.RLock()
defer cg.RUnlock()

if cg.cert == nil {
return nil, fmt.Errorf("nil certificate")
}

return cg.cert, nil
}
4 changes: 2 additions & 2 deletions command/server/listener_tcp.go
Expand Up @@ -5,15 +5,15 @@ import (
"time"
)

func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, error) {
func tcpListenerFactory(config map[string]string) (net.Listener, map[string]string, ReloadFactory, error) {
addr, ok := config["address"]
if !ok {
addr = "127.0.0.1:8200"
}

ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}

ln = tcpKeepAliveListener{ln.(*net.TCPListener)}
Expand Down