diff --git a/config/http_config.go b/config/http_config.go index 37cb37ce..063edde0 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -36,25 +36,87 @@ import ( "gopkg.in/yaml.v2" ) -// DefaultHTTPClientConfig is the default HTTP client configuration. -var DefaultHTTPClientConfig = HTTPClientConfig{ - FollowRedirects: true, - EnableHTTP2: true, -} +var ( + // DefaultHTTPClientConfig is the default HTTP client configuration. + DefaultHTTPClientConfig = HTTPClientConfig{ + FollowRedirects: true, + EnableHTTP2: true, + } -// defaultHTTPClientOptions holds the default HTTP client options. -var defaultHTTPClientOptions = httpClientOptions{ - keepAlivesEnabled: true, - http2Enabled: true, - // 5 minutes is typically above the maximum sane scrape interval. So we can - // use keepalive for all configurations. - idleConnTimeout: 5 * time.Minute, -} + // defaultHTTPClientOptions holds the default HTTP client options. + defaultHTTPClientOptions = httpClientOptions{ + keepAlivesEnabled: true, + http2Enabled: true, + // 5 minutes is typically above the maximum sane scrape interval. So we can + // use keepalive for all configurations. + idleConnTimeout: 5 * time.Minute, + } +) type closeIdler interface { CloseIdleConnections() } +type TLSVersion uint16 + +var TLSVersions = map[string]TLSVersion{ + "TLS13": (TLSVersion)(tls.VersionTLS13), + "TLS12": (TLSVersion)(tls.VersionTLS12), + "TLS11": (TLSVersion)(tls.VersionTLS11), + "TLS10": (TLSVersion)(tls.VersionTLS10), +} + +func (tv *TLSVersion) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + err := unmarshal((*string)(&s)) + if err != nil { + return err + } + if v, ok := TLSVersions[s]; ok { + *tv = v + return nil + } + return fmt.Errorf("unknown TLS version: %s", s) +} + +func (tv *TLSVersion) MarshalYAML() (interface{}, error) { + if tv != nil || *tv == 0 { + return []byte("null"), nil + } + for s, v := range TLSVersions { + if *tv == v { + return s, nil + } + } + return nil, fmt.Errorf("unknown TLS version: %d", tv) +} + +// MarshalJSON implements the json.Unmarshaler interface for TLSVersion. +func (tv *TLSVersion) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + if v, ok := TLSVersions[s]; ok { + *tv = v + return nil + } + return fmt.Errorf("unknown TLS version: %s", s) +} + +// MarshalJSON implements the json.Marshaler interface for TLSVersion. +func (tv *TLSVersion) MarshalJSON() ([]byte, error) { + if tv != nil || *tv == 0 { + return []byte("null"), nil + } + for s, v := range TLSVersions { + if *tv == v { + return []byte(s), nil + } + } + return nil, fmt.Errorf("unknown TLS version: %d", tv) +} + // BasicAuth contains basic HTTP authentication credentials. type BasicAuth struct { Username string `yaml:"username" json:"username"` @@ -669,7 +731,10 @@ func cloneRequest(r *http.Request) *http.Request { // NewTLSConfig creates a new tls.Config from the given TLSConfig. func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) { - tlsConfig := &tls.Config{InsecureSkipVerify: cfg.InsecureSkipVerify} + tlsConfig := &tls.Config{ + InsecureSkipVerify: cfg.InsecureSkipVerify, + MinVersion: uint16(cfg.MinVersion), + } // If a CA cert is provided then let's read it in so we can validate the // scrape target's certificate properly. @@ -714,6 +779,8 @@ type TLSConfig struct { ServerName string `yaml:"server_name,omitempty" json:"server_name,omitempty"` // Disable target certificate validation. InsecureSkipVerify bool `yaml:"insecure_skip_verify" json:"insecure_skip_verify"` + // Minimum TLS version. + MinVersion TLSVersion `yaml:"min_version,omitempty" json:"min_version,omitempty"` } // SetDirectory joins any relative file paths with dir. @@ -726,12 +793,6 @@ func (c *TLSConfig) SetDirectory(dir string) { c.KeyFile = JoinDir(dir, c.KeyFile) } -// UnmarshalYAML implements the yaml.Unmarshaler interface. -func (c *TLSConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - type plain TLSConfig - return unmarshal((*plain)(c)) -} - // getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate. func (c *TLSConfig) getClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) { cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile) diff --git a/config/http_config_test.go b/config/http_config_test.go index 884f344c..06eb6d04 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -627,7 +627,8 @@ func TestTLSConfig(t *testing.T) { CertFile: ClientCertificatePath, KeyFile: ClientKeyNoPassPath, ServerName: "localhost", - InsecureSkipVerify: false} + InsecureSkipVerify: false, + } tlsCAChain, err := ioutil.ReadFile(TLSCAChainPath) if err != nil { @@ -640,7 +641,8 @@ func TestTLSConfig(t *testing.T) { expectedTLSConfig := &tls.Config{ RootCAs: rootCAs, ServerName: configTLSConfig.ServerName, - InsecureSkipVerify: configTLSConfig.InsecureSkipVerify} + InsecureSkipVerify: configTLSConfig.InsecureSkipVerify, + } tlsConfig, err := NewTLSConfig(&configTLSConfig) if err != nil { diff --git a/config/testdata/tls_config.empty.good.json b/config/testdata/tls_config.empty.good.json new file mode 100644 index 00000000..0967ef42 --- /dev/null +++ b/config/testdata/tls_config.empty.good.json @@ -0,0 +1 @@ +{} diff --git a/config/testdata/tls_config.insecure.good.json b/config/testdata/tls_config.insecure.good.json new file mode 100644 index 00000000..57fcd98f --- /dev/null +++ b/config/testdata/tls_config.insecure.good.json @@ -0,0 +1 @@ +{"insecure_skip_verify": true} diff --git a/config/testdata/tls_config.tlsversion.good.json b/config/testdata/tls_config.tlsversion.good.json new file mode 100644 index 00000000..c8d41cde --- /dev/null +++ b/config/testdata/tls_config.tlsversion.good.json @@ -0,0 +1 @@ +{"min_version": "TLS11"} diff --git a/config/testdata/tls_config.tlsversion.good.yml b/config/testdata/tls_config.tlsversion.good.yml new file mode 100644 index 00000000..ee24ee67 --- /dev/null +++ b/config/testdata/tls_config.tlsversion.good.yml @@ -0,0 +1 @@ +min_version: TLS11 diff --git a/config/tls_config_test.go b/config/tls_config_test.go index 2b965ea6..c759b86e 100644 --- a/config/tls_config_test.go +++ b/config/tls_config_test.go @@ -14,23 +14,39 @@ package config import ( + "bytes" "crypto/tls" + "fmt" "io/ioutil" + "path/filepath" "reflect" "testing" + "encoding/json" + "gopkg.in/yaml.v2" ) -// LoadTLSConfig parses the given YAML file into a tls.Config. +// LoadTLSConfig parses the given file into a tls.Config. func LoadTLSConfig(filename string) (*tls.Config, error) { content, err := ioutil.ReadFile(filename) if err != nil { return nil, err } cfg := TLSConfig{} - if err = yaml.UnmarshalStrict(content, &cfg); err != nil { - return nil, err + switch filepath.Ext(filename) { + case ".yml": + if err = yaml.UnmarshalStrict(content, &cfg); err != nil { + return nil, err + } + case ".json": + decoder := json.NewDecoder(bytes.NewReader(content)) + decoder.DisallowUnknownFields() + if err = decoder.Decode(&cfg); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("Unknown extension: %s", filepath.Ext(filename)) } return NewTLSConfig(&cfg) } @@ -39,12 +55,25 @@ var expectedTLSConfigs = []struct { filename string config *tls.Config }{ + { + filename: "tls_config.empty.good.json", + config: &tls.Config{}, + }, { + filename: "tls_config.insecure.good.json", + config: &tls.Config{InsecureSkipVerify: true}, + }, { + filename: "tls_config.tlsversion.good.json", + config: &tls.Config{MinVersion: tls.VersionTLS11}, + }, { filename: "tls_config.empty.good.yml", config: &tls.Config{}, }, { filename: "tls_config.insecure.good.yml", config: &tls.Config{InsecureSkipVerify: true}, + }, { + filename: "tls_config.tlsversion.good.yml", + config: &tls.Config{MinVersion: tls.VersionTLS11}, }, } @@ -52,7 +81,7 @@ func TestValidTLSConfig(t *testing.T) { for _, cfg := range expectedTLSConfigs { got, err := LoadTLSConfig("testdata/" + cfg.filename) if err != nil { - t.Errorf("Error parsing %s: %s", cfg.filename, err) + t.Fatalf("Error parsing %s: %s", cfg.filename, err) } // non-nil functions are never equal. got.GetClientCertificate = nil