diff --git a/config/http_config.go b/config/http_config.go index bcda953a..e8013289 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -526,7 +526,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT return newRT(tlsConfig) } - return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, newRT) + return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.CAFile, cfg.TLSConfig.CertFile, cfg.TLSConfig.KeyFile, newRT) } type authorizationCredentialsRoundTripper struct { @@ -695,7 +695,7 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro if len(rt.config.TLSConfig.CAFile) == 0 { t, _ = tlsTransport(tlsConfig) } else { - t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, tlsTransport) + t, err = NewTLSRoundTripper(tlsConfig, rt.config.TLSConfig.CAFile, rt.config.TLSConfig.CertFile, rt.config.TLSConfig.KeyFile, tlsTransport) if err != nil { return nil, err } @@ -824,12 +824,39 @@ func (c *TLSConfig) SetDirectory(dir string) { c.KeyFile = JoinDir(dir, c.KeyFile) } +// UnmarshalYAML implements the yaml.Unmarshaler interface. +func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + type plain TLSConfig + return unmarshal((*plain)(c)) +} + +// readCertAndKey reads the cert and key files from the disk. +func readCertAndKey(certFile, keyFile string) ([]byte, []byte, error) { + certData, err := ioutil.ReadFile(certFile) + if err != nil { + return nil, nil, err + } + + keyData, err := ioutil.ReadFile(keyFile) + if err != nil { + return nil, nil, err + } + + return certData, keyData, nil +} + // getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate. -func (c *TLSConfig) getClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile) +func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { + certData, keyData, err := readCertAndKey(c.CertFile, c.KeyFile) + if err != nil { + return nil, fmt.Errorf("unable to read specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err) + } + + cert, err := tls.X509KeyPair(certData, keyData) if err != nil { return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err) } + return &cert, nil } @@ -855,23 +882,30 @@ func updateRootCA(cfg *tls.Config, b []byte) bool { // tlsRoundTripper is a RoundTripper that updates automatically its TLS // configuration whenever the content of the CA file changes. type tlsRoundTripper struct { - caFile string + caFile string + certFile string + keyFile string + // newRT returns a new RoundTripper. newRT func(*tls.Config) (http.RoundTripper, error) - mtx sync.RWMutex - rt http.RoundTripper - hashCAFile []byte - tlsConfig *tls.Config + mtx sync.RWMutex + rt http.RoundTripper + hashCAFile []byte + hashCertFile []byte + hashKeyFile []byte + tlsConfig *tls.Config } func NewTLSRoundTripper( cfg *tls.Config, - caFile string, + caFile, certFile, keyFile string, newRT func(*tls.Config) (http.RoundTripper, error), ) (http.RoundTripper, error) { t := &tlsRoundTripper{ caFile: caFile, + certFile: certFile, + keyFile: keyFile, newRT: newRT, tlsConfig: cfg, } @@ -881,7 +915,7 @@ func NewTLSRoundTripper( return nil, err } t.rt = rt - _, t.hashCAFile, err = t.getCAWithHash() + _, t.hashCAFile, t.hashCertFile, t.hashKeyFile, err = t.getTLSFilesWithHash() if err != nil { return nil, err } @@ -889,25 +923,36 @@ func NewTLSRoundTripper( return t, nil } -func (t *tlsRoundTripper) getCAWithHash() ([]byte, []byte, error) { - b, err := readCAFile(t.caFile) +func (t *tlsRoundTripper) getTLSFilesWithHash() ([]byte, []byte, []byte, []byte, error) { + b1, err := readCAFile(t.caFile) if err != nil { - return nil, nil, err + return nil, nil, nil, nil, err + } + h1 := sha256.Sum256(b1) + + var h2, h3 [32]byte + if t.certFile != "" { + b2, b3, err := readCertAndKey(t.certFile, t.keyFile) + if err != nil { + return nil, nil, nil, nil, err + } + h2, h3 = sha256.Sum256(b2), sha256.Sum256(b3) } - h := sha256.Sum256(b) - return b, h[:], nil + return b1, h1[:], h2[:], h3[:], nil } // RoundTrip implements the http.RoundTrip interface. func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - b, h, err := t.getCAWithHash() + caData, caHash, certHash, keyHash, err := t.getTLSFilesWithHash() if err != nil { return nil, err } t.mtx.RLock() - equal := bytes.Equal(h[:], t.hashCAFile) + equal := bytes.Equal(caHash[:], t.hashCAFile) && + bytes.Equal(certHash[:], t.hashCertFile) && + bytes.Equal(keyHash[:], t.hashKeyFile) rt := t.rt t.mtx.RUnlock() if equal { @@ -916,8 +961,10 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { } // Create a new RoundTripper. + // The cert and key files are read separately by the client + // using GetClientCertificate. tlsConfig := t.tlsConfig.Clone() - if !updateRootCA(tlsConfig, b) { + if !updateRootCA(tlsConfig, caData) { return nil, fmt.Errorf("unable to use specified CA cert %s", t.caFile) } rt, err = t.newRT(tlsConfig) @@ -928,7 +975,9 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { t.mtx.Lock() t.rt = rt - t.hashCAFile = h[:] + t.hashCAFile = caHash[:] + t.hashCertFile = certHash[:] + t.hashKeyFile = keyHash[:] t.mtx.Unlock() return rt.RoundTrip(req) diff --git a/config/http_config_test.go b/config/http_config_test.go index 66b5141a..4ab623c6 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -718,7 +718,7 @@ func TestTLSConfigInvalidCA(t *testing.T) { KeyFile: ClientKeyNoPassPath, ServerName: "", InsecureSkipVerify: false}, - errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath), + errorMessage: fmt.Sprintf("unable to read specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath), }, { configTLSConfig: TLSConfig{ CAFile: "", @@ -726,7 +726,7 @@ func TestTLSConfigInvalidCA(t *testing.T) { KeyFile: MissingKey, ServerName: "", InsecureSkipVerify: false}, - errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey), + errorMessage: fmt.Sprintf("unable to read specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey), }, } @@ -1532,3 +1532,116 @@ func TestOAuth2Proxy(t *testing.T) { t.Errorf("Error loading OAuth2 client config: %v", err) } } + +func TestModifyTLSCertificates(t *testing.T) { + bs := getCertificateBlobs(t) + + tmpDir, err := ioutil.TempDir("", "modifytlscertificates") + if err != nil { + t.Fatal("Failed to create tmp dir", err) + } + defer os.RemoveAll(tmpDir) + ca, cert, key := filepath.Join(tmpDir, "ca"), filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key") + + handler := func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, ExpectedMessage) + } + testServer, err := newTestServer(handler) + if err != nil { + t.Fatal(err.Error()) + } + defer testServer.Close() + + tests := []struct { + ca string + cert string + key string + + errMsg string + + modification func() + }{ + { + ca: ClientCertificatePath, + cert: ClientCertificatePath, + key: ClientKeyNoPassPath, + + errMsg: "certificate signed by unknown authority", + + modification: func() { writeCertificate(bs, TLSCAChainPath, ca) }, + }, + { + ca: TLSCAChainPath, + cert: WrongClientCertPath, + key: ClientKeyNoPassPath, + + errMsg: "private key does not match public key", + + modification: func() { writeCertificate(bs, ClientCertificatePath, cert) }, + }, + { + ca: TLSCAChainPath, + cert: ClientCertificatePath, + key: WrongClientCertPath, + + errMsg: "found a certificate rather than a key in the PEM for the private key", + + modification: func() { writeCertificate(bs, ClientKeyNoPassPath, key) }, + }, + } + + cfg := HTTPClientConfig{ + TLSConfig: TLSConfig{ + CAFile: ca, + CertFile: cert, + KeyFile: key, + InsecureSkipVerify: false}, + } + + var c *http.Client + for i, tc := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + writeCertificate(bs, tc.ca, ca) + writeCertificate(bs, tc.cert, cert) + writeCertificate(bs, tc.key, key) + if c == nil { + c, err = NewClientFromConfig(cfg, "test") + if err != nil { + t.Fatalf("Error creating HTTP Client: %v", err) + } + } + + req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) + if err != nil { + t.Fatalf("Error creating HTTP request: %v", err) + } + + r, err := c.Do(req) + if err == nil { + r.Body.Close() + t.Fatalf("Could connect to the test server.") + } + if !strings.Contains(err.Error(), tc.errMsg) { + t.Fatalf("Expected error message to contain %q, got %q", tc.errMsg, err) + } + + tc.modification() + + r, err = c.Do(req) + if err != nil { + t.Fatalf("Expected no error, got %q", err) + } + + b, err := ioutil.ReadAll(r.Body) + r.Body.Close() + if err != nil { + t.Errorf("Can't read the server response body") + } + + got := strings.TrimSpace(string(b)) + if ExpectedMessage != got { + t.Errorf("The expected message %q differs from the obtained message %q", ExpectedMessage, got) + } + }) + } +}