New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add reload capability for Vault listener certs #1196
Changes from 4 commits
7e52796
6430cd9
9e49463
9f2f5b1
92088f0
ca40e06
14f5385
0c56385
3a878c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ import ( | |
"time" | ||
|
||
"github.com/armon/go-metrics" | ||
"github.com/hashicorp/go-multierror" | ||
"github.com/hashicorp/logutils" | ||
"github.com/hashicorp/vault/audit" | ||
"github.com/hashicorp/vault/command/server" | ||
|
@@ -35,7 +36,11 @@ type ServerCommand struct { | |
LogicalBackends map[string]logical.Factory | ||
|
||
ShutdownCh <-chan struct{} | ||
SighupCh <-chan struct{} | ||
|
||
Meta | ||
|
||
ReloadFuncs map[string]server.ReloadFunc | ||
} | ||
|
||
func (c *ServerCommand) Run(args []string) int { | ||
|
@@ -274,7 +279,7 @@ func (c *ServerCommand) Run(args []string) int { | |
// Initialize the listeners | ||
lns := make([]net.Listener, 0, len(config.Listeners)) | ||
for i, lnConfig := range config.Listeners { | ||
ln, props, err := server.NewListener(lnConfig.Type, lnConfig.Config) | ||
ln, props, reloadFactory, err := server.NewListener(lnConfig.Type, lnConfig.Config) | ||
if err != nil { | ||
c.Ui.Error(fmt.Sprintf( | ||
"Error initializing listener of type %s: %s", | ||
|
@@ -295,6 +300,11 @@ func (c *ServerCommand) Run(args []string) int { | |
"%s (%s)", lnConfig.Type, strings.Join(propsList, ", ")) | ||
|
||
lns = append(lns, ln) | ||
|
||
if reloadFactory != nil { | ||
relId, relFunc := reloadFactory() | ||
c.ReloadFuncs[relId] = relFunc | ||
} | ||
} | ||
|
||
infoKeys = append(infoKeys, "version") | ||
|
@@ -333,11 +343,23 @@ func (c *ServerCommand) Run(args []string) int { | |
logGate.Flush() | ||
|
||
// Wait for shutdown | ||
select { | ||
case <-c.ShutdownCh: | ||
c.Ui.Output("==> Vault shutdown triggered") | ||
if err := core.Shutdown(); err != nil { | ||
c.Ui.Error(fmt.Sprintf("Error with core shutdown: %s", err)) | ||
shutdownTriggered := false | ||
for { | ||
select { | ||
case <-c.ShutdownCh: | ||
c.Ui.Output("==> Vault shutdown triggered") | ||
if err := core.Shutdown(); err != nil { | ||
c.Ui.Error(fmt.Sprintf("Error with core shutdown: %s", err)) | ||
} | ||
shutdownTriggered = true | ||
case <-c.SighupCh: | ||
c.Ui.Output("==> Vault reload triggered") | ||
if err := c.Reload(configPath); err != nil { | ||
c.Ui.Error(fmt.Sprintf("Error(s) were encountered during reload: %s", err)) | ||
} | ||
} | ||
if shutdownTriggered { | ||
break | ||
} | ||
} | ||
return 0 | ||
|
@@ -530,6 +552,46 @@ func (c *ServerCommand) setupTelementry(config *server.Config) error { | |
return nil | ||
} | ||
|
||
func (c *ServerCommand) Reload(configPath []string) error { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to expose this method? Also, is it possible to use this method to load the configuration the first time as well? |
||
// Read the new config | ||
var config *server.Config | ||
for _, path := range configPath { | ||
current, err := server.LoadConfig(path) | ||
if err != nil { | ||
retErr := fmt.Errorf("Error loading configuration from %s: %s", path, err) | ||
c.Ui.Error(retErr.Error()) | ||
return retErr | ||
} | ||
|
||
if config == nil { | ||
config = current | ||
} else { | ||
config = config.Merge(current) | ||
} | ||
} | ||
|
||
// Ensure at least one config was found. | ||
if config == nil { | ||
retErr := fmt.Errorf("No configuration files found") | ||
c.Ui.Error(retErr.Error()) | ||
return retErr | ||
} | ||
|
||
var reloadErrors *multierror.Error | ||
// Call reload on the listeners. This will call each listener with each | ||
// config block, but they verify the ID. | ||
for _, lnConfig := range config.Listeners { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to range over the listeners? It seems like we have one reloadFunc per anyways, seems like we can just loop over those |
||
for id, relFunc := range c.ReloadFuncs { | ||
if err := relFunc(id, lnConfig.Config); err != nil { | ||
retErr := fmt.Errorf("Error encountered reloading configuration for %s: %s", id, err) | ||
reloadErrors = multierror.Append(retErr) | ||
} | ||
} | ||
} | ||
|
||
return reloadErrors.ErrorOrNil() | ||
} | ||
|
||
func (c *ServerCommand) Synopsis() string { | ||
return "Start a Vault server" | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,10 +8,11 @@ import ( | |
"fmt" | ||
"net" | ||
"strconv" | ||
"sync" | ||
) | ||
|
||
// ListenerFactory is the factory function to create a listener. | ||
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, error) | ||
type ListenerFactory func(map[string]string) (net.Listener, map[string]string, ReloadFactory, error) | ||
|
||
// BuiltinListeners is the list of built-in listener types. | ||
var BuiltinListeners = map[string]ListenerFactory{ | ||
|
@@ -27,10 +28,10 @@ var tlsLookup = map[string]uint16{ | |
|
||
// NewListener creates a new listener of the given type with the given | ||
// configuration. The type is looked up in the BuiltinListeners map. | ||
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, error) { | ||
func NewListener(t string, config map[string]string) (net.Listener, map[string]string, ReloadFactory, error) { | ||
f, ok := BuiltinListeners[t] | ||
if !ok { | ||
return nil, nil, fmt.Errorf("unknown listener type: %s", t) | ||
return nil, nil, nil, fmt.Errorf("unknown listener type: %s", t) | ||
} | ||
|
||
return f(config) | ||
|
@@ -39,32 +40,35 @@ func NewListener(t string, config map[string]string) (net.Listener, map[string]s | |
func listenerWrapTLS( | ||
ln net.Listener, | ||
props map[string]string, | ||
config map[string]string) (net.Listener, map[string]string, error) { | ||
config map[string]string) (net.Listener, map[string]string, ReloadFactory, error) { | ||
props["tls"] = "disabled" | ||
|
||
if v, ok := config["tls_disable"]; ok { | ||
disabled, err := strconv.ParseBool(v) | ||
if err != nil { | ||
return nil, nil, fmt.Errorf("invalid value for 'tls_disable': %v", err) | ||
return nil, nil, nil, fmt.Errorf("invalid value for 'tls_disable': %v", err) | ||
} | ||
if disabled { | ||
return ln, props, nil | ||
return ln, props, nil, nil | ||
} | ||
} | ||
|
||
certFile, ok := config["tls_cert_file"] | ||
_, ok := config["tls_cert_file"] | ||
if !ok { | ||
return nil, nil, fmt.Errorf("'tls_cert_file' must be set") | ||
return nil, nil, nil, fmt.Errorf("'tls_cert_file' must be set") | ||
} | ||
|
||
keyFile, ok := config["tls_key_file"] | ||
_, ok = config["tls_key_file"] | ||
if !ok { | ||
return nil, nil, fmt.Errorf("'tls_key_file' must be set") | ||
return nil, nil, nil, fmt.Errorf("'tls_key_file' must be set") | ||
} | ||
|
||
cert, err := tls.LoadX509KeyPair(certFile, keyFile) | ||
if err != nil { | ||
return nil, nil, fmt.Errorf("error loading TLS cert: %s", err) | ||
cg := &certificateGetter{ | ||
id: "listen|" + ln.Addr().String(), | ||
} | ||
|
||
if err := cg.reload(cg.id, config); err != nil { | ||
return nil, nil, nil, fmt.Errorf("error loading TLS cert: %s", err) | ||
} | ||
|
||
tlsvers, ok := config["tls_min_version"] | ||
|
@@ -73,15 +77,54 @@ func listenerWrapTLS( | |
} | ||
|
||
tlsConf := &tls.Config{} | ||
tlsConf.Certificates = []tls.Certificate{cert} | ||
tlsConf.GetCertificate = cg.getCertificate | ||
tlsConf.NextProtos = []string{"http/1.1"} | ||
tlsConf.MinVersion, ok = tlsLookup[tlsvers] | ||
if !ok { | ||
return nil, nil, fmt.Errorf("'tls_min_version' value %s not supported, please specify one of [tls10,tls11,tls12]", tlsvers) | ||
return nil, nil, nil, fmt.Errorf("'tls_min_version' value %s not supported, please specify one of [tls10,tls11,tls12]", tlsvers) | ||
} | ||
tlsConf.ClientAuth = tls.RequestClientCert | ||
|
||
ln = tls.NewListener(ln, tlsConf) | ||
props["tls"] = "enabled" | ||
return ln, props, nil | ||
return ln, props, func() (string, ReloadFunc) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wouldn't inline the function like this, its kind of weird to read |
||
return cg.id, cg.reload | ||
}, nil | ||
} | ||
|
||
type certificateGetter struct { | ||
sync.RWMutex | ||
|
||
cert *tls.Certificate | ||
|
||
id string | ||
} | ||
|
||
func (cg *certificateGetter) reload(id string, config map[string]string) error { | ||
if id != cg.id { | ||
return nil | ||
} | ||
|
||
cert, err := tls.LoadX509KeyPair(config["tls_cert_file"], config["tls_key_file"]) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
cg.Lock() | ||
defer cg.Unlock() | ||
|
||
cg.cert = &cert | ||
|
||
return nil | ||
} | ||
|
||
func (cg *certificateGetter) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { | ||
cg.RLock() | ||
defer cg.RUnlock() | ||
|
||
if cg.cert == nil { | ||
return nil, fmt.Errorf("nil certificate") | ||
} | ||
|
||
return cg.cert, nil | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May as well do
for !shutdownTriggered {
and remove the if blockThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hah, I've gotten so used to the infinite-loop-using-for-to-keep-selecting paradigm that it didn't even cross my mind.