Skip to content

Commit

Permalink
server+runtime: add TLS cert refreshing (#4107)
Browse files Browse the repository at this point in the history
This adds a new flag to `opa run`, intended for server usage with HTTPS listeners:
`--tls-cert-refresh-period`. If used with a positive duration, such as "5m" (5 minutes),
"24h", etc, the server will track the certificate and key files' contents. When their
content changes, the certificates will be reloaded.

On an error in reloading, it will log (info) the error and try again in the next round.

Fixes #2500.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Dec 9, 2021
1 parent 58b66a0 commit cc4816e
Show file tree
Hide file tree
Showing 14 changed files with 488 additions and 34 deletions.
32 changes: 17 additions & 15 deletions cmd/run.go
Expand Up @@ -12,6 +12,7 @@ import (
"io/ioutil"
"os"
"path"
"time"

"github.com/spf13/cobra"

Expand All @@ -30,6 +31,7 @@ type runCmdParams struct {
tlsCertFile string
tlsPrivateKeyFile string
tlsCACertFile string
tlsCertRefresh time.Duration
ignore []string
serverMode bool
skipVersionCheck bool
Expand Down Expand Up @@ -170,19 +172,20 @@ To skip bundle verification, use the --skip-verify flag.
runCommand.Flags().StringVarP(&cmdParams.rt.HistoryPath, "history", "H", historyPath(), "set path of history file")
cmdParams.rt.Addrs = runCommand.Flags().StringSliceP("addr", "a", []string{defaultAddr}, "set listening address of the server (e.g., [ip]:<port> for TCP, unix://<path> for UNIX domain socket)")
cmdParams.rt.DiagnosticAddrs = runCommand.Flags().StringSlice("diagnostic-addr", []string{}, "set read-only diagnostic listening address of the server for /health and /metric APIs (e.g., [ip]:<port> for TCP, unix://<path> for UNIX domain socket)")
runCommand.Flags().BoolVarP(&cmdParams.rt.H2CEnabled, "h2c", "", false, "enable H2C for HTTP listeners")
runCommand.Flags().BoolVar(&cmdParams.rt.H2CEnabled, "h2c", false, "enable H2C for HTTP listeners")
runCommand.Flags().StringVarP(&cmdParams.rt.OutputFormat, "format", "f", "pretty", "set shell output format, i.e, pretty, json")
runCommand.Flags().BoolVarP(&cmdParams.rt.Watch, "watch", "w", false, "watch command line files for changes")
addMaxErrorsFlag(runCommand.Flags(), &cmdParams.rt.ErrorLimit)
runCommand.Flags().BoolVarP(&cmdParams.rt.PprofEnabled, "pprof", "", false, "enables pprof endpoints")
runCommand.Flags().StringVarP(&cmdParams.tlsCertFile, "tls-cert-file", "", "", "set path of TLS certificate file")
runCommand.Flags().StringVarP(&cmdParams.tlsPrivateKeyFile, "tls-private-key-file", "", "", "set path of TLS private key file")
runCommand.Flags().StringVarP(&cmdParams.tlsCACertFile, "tls-ca-cert-file", "", "", "set path of TLS CA cert file")
runCommand.Flags().VarP(cmdParams.authentication, "authentication", "", "set authentication scheme")
runCommand.Flags().VarP(cmdParams.authorization, "authorization", "", "set authorization scheme")
runCommand.Flags().VarP(cmdParams.minTLSVersion, "min-tls-version", "", "set minimum TLS version to be used by OPA's server, default is 1.2")
runCommand.Flags().BoolVar(&cmdParams.rt.PprofEnabled, "pprof", false, "enables pprof endpoints")
runCommand.Flags().StringVar(&cmdParams.tlsCertFile, "tls-cert-file", "", "set path of TLS certificate file")
runCommand.Flags().StringVar(&cmdParams.tlsPrivateKeyFile, "tls-private-key-file", "", "set path of TLS private key file")
runCommand.Flags().StringVar(&cmdParams.tlsCACertFile, "tls-ca-cert-file", "", "set path of TLS CA cert file")
runCommand.Flags().DurationVar(&cmdParams.tlsCertRefresh, "tls-cert-refresh-period", 0, "set certificate refresh period")
runCommand.Flags().Var(cmdParams.authentication, "authentication", "set authentication scheme")
runCommand.Flags().Var(cmdParams.authorization, "authorization", "set authorization scheme")
runCommand.Flags().Var(cmdParams.minTLSVersion, "min-tls-version", "set minimum TLS version to be used by OPA's server")
runCommand.Flags().VarP(cmdParams.logLevel, "log-level", "l", "set log level")
runCommand.Flags().VarP(cmdParams.logFormat, "log-format", "", "set log format")
runCommand.Flags().Var(cmdParams.logFormat, "log-format", "set log format")
runCommand.Flags().IntVar(&cmdParams.rt.GracefulShutdownPeriod, "shutdown-grace-period", 10, "set the time (in seconds) that the server will wait to gracefully shut down")
runCommand.Flags().IntVar(&cmdParams.rt.ShutdownWaitPeriod, "shutdown-wait-period", 0, "set the time (in seconds) that the server will wait before initiating shutdown")
addConfigOverrides(runCommand.Flags(), &cmdParams.rt.ConfigOverrides)
Expand Down Expand Up @@ -235,6 +238,10 @@ func initRuntime(ctx context.Context, params runCmdParams, args []string) (*runt
return nil, err
}

params.rt.CertificateFile = params.tlsCertFile
params.rt.CertificateKeyFile = params.tlsPrivateKeyFile
params.rt.CertificateRefresh = params.tlsCertRefresh

if params.tlsCACertFile != "" {
pool, err := loadCertPool(params.tlsCACertFile)
if err != nil {
Expand Down Expand Up @@ -270,12 +277,7 @@ func initRuntime(ctx context.Context, params runCmdParams, args []string) (*runt
return nil, fmt.Errorf("enable bundle mode (ie. --bundle) to verify bundle files or directories")
}

rt, err := runtime.NewRuntime(ctx, params.rt)
if err != nil {
return nil, err
}

return rt, nil
return runtime.NewRuntime(ctx, params.rt)
}

func startRuntime(ctx context.Context, rt *runtime.Runtime, serverMode bool) {
Expand Down
11 changes: 9 additions & 2 deletions docs/content/security.md
Expand Up @@ -27,6 +27,12 @@ startup:
OPA will exit immediately with a non-zero status code if only one of these flags
is specified.

The server can track the certificate and key files' contents, and reload them if necessary:

- ``--tls-cert-refresh=<duration>`` specifies how often OPA should check the TLS certificate and
private key file for changes (defaults to 0s, disabling periodic refresh). This argument accepts
any duration, such as "30s", "5m" or "24h".

Note that for using TLS-based authentication, a CA cert file can be provided:

- ``--tls-ca-cert-file=<path>`` specifies the path of the file containing the CA cert.
Expand Down Expand Up @@ -78,8 +84,9 @@ curl http://localhost:8181/v1/data
curl -k https://localhost:8181/v1/data
```

> We have to use cURL's `-k/--insecure` flag because we are using a
> self-signed certificate.
{{< info >}}
We have to use cURL's `-k/--insecure` flag because we are using a self-signed certificate.
{{< /info >}}

## Authentication and Authorization

Expand Down
9 changes: 9 additions & 0 deletions runtime/runtime.go
Expand Up @@ -96,6 +96,14 @@ type Params struct {
// is nil, the server will NOT use TLS.
Certificate *tls.Certificate

// CertificateFile and CertificateKeyFile are the paths to the cert and its
// keyfile. It'll be used to periodically reload the files from disk if they
// have changed. The server will attempt to refresh every 5 minutes, unless
// a different CertificateRefresh time.Duration is provided
CertificateFile string
CertificateKeyFile string
CertificateRefresh time.Duration

// CertPool holds the CA certs trusted by the OPA server.
CertPool *x509.CertPool

Expand Down Expand Up @@ -403,6 +411,7 @@ func (rt *Runtime) Serve(ctx context.Context) error {
WithAddresses(*rt.Params.Addrs).
WithH2CEnabled(rt.Params.H2CEnabled).
WithCertificate(rt.Params.Certificate).
WithCertificatePaths(rt.Params.CertificateFile, rt.Params.CertificateKeyFile, rt.Params.CertificateRefresh).
WithCertPool(rt.Params.CertPool).
WithAuthentication(rt.Params.Authentication).
WithAuthorization(rt.Params.Authorization).
Expand Down
76 changes: 76 additions & 0 deletions server/certs.go
@@ -0,0 +1,76 @@
// Copyright 2021 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.

package server

import (
"bytes"
"crypto/sha256"
"crypto/tls"
"io"
"os"
"time"

"github.com/open-policy-agent/opa/logging"
)

func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
s.certMtx.RLock()
defer s.certMtx.RUnlock()
return s.cert, nil
}

func (s *Server) certLoop(logger logging.Logger) Loop {
return func() error {
for range time.NewTicker(s.certRefresh).C {
certHash, err := hash(s.certFile)
if err != nil {
logger.Info("Failed to refresh server certificate: %s.", err.Error())
continue
}
certKeyHash, err := hash(s.certKeyFile)
if err != nil {
logger.Info("Failed to refresh server certificate: %s.", err.Error())
continue
}

s.certMtx.Lock()

different := !bytes.Equal(s.certFileHash, certHash) ||
!bytes.Equal(s.certKeyFileHash, certKeyHash)

if different { // load and store
newCert, err := tls.LoadX509KeyPair(s.certFile, s.certKeyFile)
if err != nil {
logger.Info("Failed to refresh server certificate: %s.", err.Error())
s.certMtx.Unlock()
continue
}
s.cert = &newCert
s.certFileHash = certHash
s.certKeyFileHash = certKeyHash
logger.Debug("Refreshed server certificate.")
}

s.certMtx.Unlock()
}

return nil
}
}

func hash(file string) ([]byte, error) {
f, err := os.Open(file)
if err != nil {
return nil, err
}
defer f.Close()

h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return nil, err
}

return h.Sum(nil), nil
}
43 changes: 36 additions & 7 deletions server/server.go
Expand Up @@ -105,6 +105,12 @@ type Server struct {
authentication AuthenticationScheme
authorization AuthorizationScheme
cert *tls.Certificate
certMtx sync.RWMutex
certFile string
certFileHash []byte
certKeyFile string
certKeyFileHash []byte
certRefresh time.Duration
certPool *x509.CertPool
minTLSVersion uint16
mtx sync.RWMutex
Expand Down Expand Up @@ -231,6 +237,15 @@ func (s *Server) WithCertificate(cert *tls.Certificate) *Server {
return s
}

// WithCertificatePaths sets the server-side certificate and keyfile paths
// that the server will periodically check for changes, and reload if necessary.
func (s *Server) WithCertificatePaths(certFile, keyFile string, refresh time.Duration) *Server {
s.certFile = certFile
s.certKeyFile = keyFile
s.certRefresh = refresh
return s
}

// WithCertPool sets the server-side cert pool that the server will use.
func (s *Server) WithCertPool(pool *x509.CertPool) *Server {
s.certPool = pool
Expand Down Expand Up @@ -332,12 +347,12 @@ func (s *Server) Listeners() ([]Loop, error) {

for t, binding := range handlerBindings {
for _, addr := range binding.addrs {
loop, listener, err := s.getListener(addr, binding.handler, t)
l, listener, err := s.getListener(addr, binding.handler, t)
if err != nil {
return nil, err
}
s.httpListeners = append(s.httpListeners, listener)
loops = append(loops, loop)
loops = append(loops, l...)
}
}

Expand Down Expand Up @@ -399,7 +414,7 @@ type httpListener interface {
Addr() string
ListenAndServe() error
ListenAndServeTLS(certFile, keyFile string) error
Shutdown(ctx context.Context) error
Shutdown(context.Context) error
Type() httpListenerType
}

Expand Down Expand Up @@ -488,26 +503,39 @@ func isMinTLSVersionSupported(TLSVersion uint16) bool {
return false
}

func (s *Server) getListener(addr string, h http.Handler, t httpListenerType) (Loop, httpListener, error) {
func (s *Server) getListener(addr string, h http.Handler, t httpListenerType) ([]Loop, httpListener, error) {
parsedURL, err := parseURL(addr, s.cert != nil)
if err != nil {
return nil, nil, err
}

var loops []Loop
var loop Loop
var listener httpListener
switch parsedURL.Scheme {
case "unix":
loop, listener, err = s.getListenerForUNIXSocket(parsedURL, h, t)
loops = []Loop{loop}
case "http":
loop, listener, err = s.getListenerForHTTPServer(parsedURL, h, t)
loops = []Loop{loop}
case "https":
loop, listener, err = s.getListenerForHTTPSServer(parsedURL, h, t)
logger := s.manager.Logger().WithFields(map[string]interface{}{
"cert-file": s.certFile,
"cert-key-file": s.certKeyFile,
})
if s.certRefresh > 0 {
certLoop := s.certLoop(logger)
loops = []Loop{loop, certLoop}
} else {
loops = []Loop{loop}
}
default:
err = fmt.Errorf("invalid url scheme %q", parsedURL.Scheme)
}

return loop, listener, err
return loops, listener, err
}

func (s *Server) getListenerForHTTPServer(u *url.URL, h http.Handler, t httpListenerType) (Loop, httpListener, error) {
Expand All @@ -521,6 +549,7 @@ func (s *Server) getListenerForHTTPServer(u *url.URL, h http.Handler, t httpList
}

l := newHTTPListener(&h1s, t)

return l.ListenAndServe, l, nil
}

Expand All @@ -534,8 +563,8 @@ func (s *Server) getListenerForHTTPSServer(u *url.URL, h http.Handler, t httpLis
Addr: u.Host,
Handler: h,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{*s.cert},
ClientCAs: s.certPool,
GetCertificate: s.getCertificate,
ClientCAs: s.certPool,
},
}
if s.authentication == AuthenticationTLS {
Expand Down

0 comments on commit cc4816e

Please sign in to comment.