diff --git a/config/http_config.go b/config/http_config.go index eedb2169..063edde0 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -80,12 +80,41 @@ func (tv *TLSVersion) UnmarshalYAML(unmarshal func(interface{}) error) error { } func (tv *TLSVersion) MarshalYAML() (interface{}, error) { + if tv != nil || *tv == 0 { + return []byte("null"), nil + } for s, v := range TLSVersions { if *tv == v { return s, nil } } - return fmt.Sprintf("%v", tv), nil + return nil, fmt.Errorf("unknown TLS version: %d", tv) +} + +// MarshalJSON implements the json.Unmarshaler interface for TLSVersion. +func (tv *TLSVersion) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + if v, ok := TLSVersions[s]; ok { + *tv = v + return nil + } + return fmt.Errorf("unknown TLS version: %s", s) +} + +// MarshalJSON implements the json.Marshaler interface for TLSVersion. +func (tv *TLSVersion) MarshalJSON() ([]byte, error) { + if tv != nil || *tv == 0 { + return []byte("null"), nil + } + for s, v := range TLSVersions { + if *tv == v { + return []byte(s), nil + } + } + return nil, fmt.Errorf("unknown TLS version: %d", tv) } // BasicAuth contains basic HTTP authentication credentials. @@ -751,7 +780,7 @@ type TLSConfig struct { // Disable target certificate validation. InsecureSkipVerify bool `yaml:"insecure_skip_verify" json:"insecure_skip_verify"` // Minimum TLS version. - MinVersion TLSVersion `yaml:"min_version,omitempty"` + MinVersion TLSVersion `yaml:"min_version,omitempty" json:"min_version,omitempty"` } // SetDirectory joins any relative file paths with dir. diff --git a/config/testdata/tls_config.empty.good.json b/config/testdata/tls_config.empty.good.json new file mode 100644 index 00000000..0967ef42 --- /dev/null +++ b/config/testdata/tls_config.empty.good.json @@ -0,0 +1 @@ +{} diff --git a/config/testdata/tls_config.insecure.good.json b/config/testdata/tls_config.insecure.good.json new file mode 100644 index 00000000..57fcd98f --- /dev/null +++ b/config/testdata/tls_config.insecure.good.json @@ -0,0 +1 @@ +{"insecure_skip_verify": true} diff --git a/config/testdata/tls_config.tlsversion.good.json b/config/testdata/tls_config.tlsversion.good.json new file mode 100644 index 00000000..c8d41cde --- /dev/null +++ b/config/testdata/tls_config.tlsversion.good.json @@ -0,0 +1 @@ +{"min_version": "TLS11"} diff --git a/config/tls_config_test.go b/config/tls_config_test.go index b164a2f3..c759b86e 100644 --- a/config/tls_config_test.go +++ b/config/tls_config_test.go @@ -14,23 +14,39 @@ package config import ( + "bytes" "crypto/tls" + "fmt" "io/ioutil" + "path/filepath" "reflect" "testing" + "encoding/json" + "gopkg.in/yaml.v2" ) -// LoadTLSConfig parses the given YAML file into a tls.Config. +// LoadTLSConfig parses the given file into a tls.Config. func LoadTLSConfig(filename string) (*tls.Config, error) { content, err := ioutil.ReadFile(filename) if err != nil { return nil, err } cfg := TLSConfig{} - if err = yaml.UnmarshalStrict(content, &cfg); err != nil { - return nil, err + switch filepath.Ext(filename) { + case ".yml": + if err = yaml.UnmarshalStrict(content, &cfg); err != nil { + return nil, err + } + case ".json": + decoder := json.NewDecoder(bytes.NewReader(content)) + decoder.DisallowUnknownFields() + if err = decoder.Decode(&cfg); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("Unknown extension: %s", filepath.Ext(filename)) } return NewTLSConfig(&cfg) } @@ -39,6 +55,16 @@ var expectedTLSConfigs = []struct { filename string config *tls.Config }{ + { + filename: "tls_config.empty.good.json", + config: &tls.Config{}, + }, { + filename: "tls_config.insecure.good.json", + config: &tls.Config{InsecureSkipVerify: true}, + }, { + filename: "tls_config.tlsversion.good.json", + config: &tls.Config{MinVersion: tls.VersionTLS11}, + }, { filename: "tls_config.empty.good.yml", config: &tls.Config{}, @@ -55,7 +81,7 @@ func TestValidTLSConfig(t *testing.T) { for _, cfg := range expectedTLSConfigs { got, err := LoadTLSConfig("testdata/" + cfg.filename) if err != nil { - t.Errorf("Error parsing %s: %s", cfg.filename, err) + t.Fatalf("Error parsing %s: %s", cfg.filename, err) } // non-nil functions are never equal. got.GetClientCertificate = nil