Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if TLS certificate and key file have been modified #345

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
89 changes: 69 additions & 20 deletions config/http_config.go
Expand Up @@ -526,7 +526,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
return newRT(tlsConfig)
}

return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, newRT)
return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, cfg.TLSConfig.CertFile, cfg.TLSConfig.KeyFile, newRT)
}

type authorizationCredentialsRoundTripper struct {
Expand Down Expand Up @@ -695,7 +695,7 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
if len(rt.config.TLSConfig.CAFile) == 0 {
t, _ = tlsTransport(tlsConfig)
} else {
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, tlsTransport)
t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, rt.config.TLSConfig.CertFile, rt.config.TLSConfig.KeyFile, tlsTransport)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -824,12 +824,39 @@ func (c *TLSConfig) SetDirectory(dir string) {
c.KeyFile = JoinDir(dir, c.KeyFile)
}

// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain TLSConfig
return unmarshal((*plain)(c))
}

// readCertAndKey reads the cert and key files from the disk.
func readCertAndKey(certFile, keyFile string) ([]byte, []byte, error) {
certData, err := ioutil.ReadFile(certFile)
if err != nil {
return nil, nil, err
}

keyData, err := ioutil.ReadFile(keyFile)
if err != nil {
return nil, nil, err
}

return certData, keyData, nil
}

// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate.
func (c *TLSConfig) getClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
certData, keyData, err := readCertAndKey(c.CertFile, c.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to read specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
}

cert, err := tls.X509KeyPair(certData, keyData)
if err != nil {
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
}

return &cert, nil
}

Expand All @@ -855,23 +882,30 @@ func updateRootCA(cfg *tls.Config, b []byte) bool {
// tlsRoundTripper is a RoundTripper that updates automatically its TLS
// configuration whenever the content of the CA file changes.
type tlsRoundTripper struct {
caFile string
caFile string
certFile string
keyFile string

// newRT returns a new RoundTripper.
newRT func(*tls.Config) (http.RoundTripper, error)

mtx sync.RWMutex
rt http.RoundTripper
hashCAFile []byte
tlsConfig *tls.Config
mtx sync.RWMutex
rt http.RoundTripper
hashCAFile []byte
hashCertFile []byte
hashKeyFile []byte
tlsConfig *tls.Config
}

func NewTLSRoundTripper(
cfg *tls.Config,
caFile string,
caFile, certFile, keyFile string,
newRT func(*tls.Config) (http.RoundTripper, error),
) (http.RoundTripper, error) {
t := &tlsRoundTripper{
caFile: caFile,
certFile: certFile,
keyFile: keyFile,
newRT: newRT,
tlsConfig: cfg,
}
Expand All @@ -881,33 +915,44 @@ func NewTLSRoundTripper(
return nil, err
}
t.rt = rt
_, t.hashCAFile, err = t.getCAWithHash()
_, t.hashCAFile, t.hashCertFile, t.hashKeyFile, err = t.getTLSFilesWithHash()
if err != nil {
return nil, err
}

return t, nil
}

func (t *tlsRoundTripper) getCAWithHash() ([]byte, []byte, error) {
b, err := readCAFile(t.caFile)
func (t *tlsRoundTripper) getTLSFilesWithHash() ([]byte, []byte, []byte, []byte, error) {
b1, err := readCAFile(t.caFile)
if err != nil {
return nil, nil, err
return nil, nil, nil, nil, err
}
h1 := sha256.Sum256(b1)

var h2, h3 [32]byte
if t.certFile != "" {
b2, b3, err := readCertAndKey(t.certFile, t.keyFile)
if err != nil {
return nil, nil, nil, nil, err
}
h2, h3 = sha256.Sum256(b2), sha256.Sum256(b3)
}
h := sha256.Sum256(b)
return b, h[:], nil

return b1, h1[:], h2[:], h3[:], nil
}

// RoundTrip implements the http.RoundTrip interface.
func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
b, h, err := t.getCAWithHash()
caData, caHash, certHash, keyHash, err := t.getTLSFilesWithHash()
if err != nil {
return nil, err
}

t.mtx.RLock()
equal := bytes.Equal(h[:], t.hashCAFile)
equal := bytes.Equal(caHash[:], t.hashCAFile) &&
bytes.Equal(certHash[:], t.hashCertFile) &&
bytes.Equal(keyHash[:], t.hashKeyFile)
rt := t.rt
t.mtx.RUnlock()
if equal {
Expand All @@ -916,8 +961,10 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
}

// Create a new RoundTripper.
// The cert and key files are read separately by the client
// using GetClientCertificate.
tlsConfig := t.tlsConfig.Clone()
if !updateRootCA(tlsConfig, b) {
if !updateRootCA(tlsConfig, caData) {
return nil, fmt.Errorf("unable to use specified CA cert %s", t.caFile)
}
rt, err = t.newRT(tlsConfig)
Expand All @@ -928,7 +975,9 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {

t.mtx.Lock()
t.rt = rt
t.hashCAFile = h[:]
t.hashCAFile = caHash[:]
t.hashCertFile = certHash[:]
t.hashKeyFile = keyHash[:]
t.mtx.Unlock()

return rt.RoundTrip(req)
Expand Down
117 changes: 115 additions & 2 deletions config/http_config_test.go
Expand Up @@ -718,15 +718,15 @@ func TestTLSConfigInvalidCA(t *testing.T) {
KeyFile: ClientKeyNoPassPath,
ServerName: "",
InsecureSkipVerify: false},
errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath),
errorMessage: fmt.Sprintf("unable to read specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath),
}, {
configTLSConfig: TLSConfig{
CAFile: "",
CertFile: ClientCertificatePath,
KeyFile: MissingKey,
ServerName: "",
InsecureSkipVerify: false},
errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey),
errorMessage: fmt.Sprintf("unable to read specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey),
},
}

Expand Down Expand Up @@ -1532,3 +1532,116 @@ func TestOAuth2Proxy(t *testing.T) {
t.Errorf("Error loading OAuth2 client config: %v", err)
}
}

func TestModifyTLSCertificates(t *testing.T) {
bs := getCertificateBlobs(t)

tmpDir, err := ioutil.TempDir("", "modifytlscertificates")
if err != nil {
t.Fatal("Failed to create tmp dir", err)
}
defer os.RemoveAll(tmpDir)
ca, cert, key := filepath.Join(tmpDir, "ca"), filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key")

handler := func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, ExpectedMessage)
}
testServer, err := newTestServer(handler)
if err != nil {
t.Fatal(err.Error())
}
defer testServer.Close()

tests := []struct {
ca string
cert string
key string

errMsg string

modification func()
}{
{
ca: ClientCertificatePath,
cert: ClientCertificatePath,
key: ClientKeyNoPassPath,

errMsg: "certificate signed by unknown authority",

modification: func() { writeCertificate(bs, TLSCAChainPath, ca) },
},
{
ca: TLSCAChainPath,
cert: WrongClientCertPath,
key: ClientKeyNoPassPath,

errMsg: "private key does not match public key",

modification: func() { writeCertificate(bs, ClientCertificatePath, cert) },
},
{
ca: TLSCAChainPath,
cert: ClientCertificatePath,
key: WrongClientCertPath,

errMsg: "found a certificate rather than a key in the PEM for the private key",

modification: func() { writeCertificate(bs, ClientKeyNoPassPath, key) },
},
}

cfg := HTTPClientConfig{
TLSConfig: TLSConfig{
CAFile: ca,
CertFile: cert,
KeyFile: key,
InsecureSkipVerify: false},
}

var c *http.Client
for i, tc := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
writeCertificate(bs, tc.ca, ca)
writeCertificate(bs, tc.cert, cert)
writeCertificate(bs, tc.key, key)
if c == nil {
c, err = NewClientFromConfig(cfg, "test")
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
}

req, err := http.NewRequest(http.MethodGet, testServer.URL, nil)
if err != nil {
t.Fatalf("Error creating HTTP request: %v", err)
}

r, err := c.Do(req)
if err == nil {
r.Body.Close()
t.Fatalf("Could connect to the test server.")
}
if !strings.Contains(err.Error(), tc.errMsg) {
t.Fatalf("Expected error message to contain %q, got %q", tc.errMsg, err)
}

tc.modification()

r, err = c.Do(req)
if err != nil {
t.Fatalf("Expected no error, got %q", err)
}

b, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
t.Errorf("Can't read the server response body")
}

got := strings.TrimSpace(string(b))
if ExpectedMessage != got {
t.Errorf("The expected message %q differs from the obtained message %q", ExpectedMessage, got)
}
})
}
}