diff --git a/README.md b/README.md index 4b4b4e1..e5e1b39 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,7 @@ There are following supported types: - maps (of any other supported type); - `time.Duration`; - `time.Time` (layout by default is RFC3339, may be overridden by `env-layout`); +- `*time.Location` (time zone parsing [depends](https://pkg.go.dev/time#LoadLocation) on running machine) - any type implementing `cleanenv.Setter` interface. diff --git a/cleanenv.go b/cleanenv.go index 226c104..5dd5e7b 100644 --- a/cleanenv.go +++ b/cleanenv.go @@ -5,7 +5,6 @@ import ( "flag" "fmt" "io" - "math" "net/url" "os" "path/filepath" @@ -29,18 +28,25 @@ const ( const ( // Name of the environment variable or a list of names TagEnv = "env" + // Value parsing layout (for types like time.Time) TagEnvLayout = "env-layout" + // Default value TagEnvDefault = "env-default" + // Custom list and map separator TagEnvSeparator = "env-separator" + // Environment variable description TagEnvDescription = "env-description" + // Flag to mark a field as updatable TagEnvUpd = "env-upd" + // Flag to mark a field as required TagEnvRequired = "env-required" + // Flag to specify prefix for structure fields TagEnvPrefix = "env-prefix" ) @@ -49,15 +55,15 @@ const ( // // To implement a custom value setter you need to add a SetValue function to your type that will receive a string raw value: // -// type MyField string +// type MyField string // -// func (f *MyField) SetValue(s string) error { -// if s == "" { -// return fmt.Errorf("field value can't be empty") -// } -// *f = MyField("my field is: " + s) -// return nil -// } +// func (f *MyField) SetValue(s string) error { +// if s == "" { +// return fmt.Errorf("field value can't be empty") +// } +// *f = MyField("my field is: " + s) +// return nil +// } type Setter interface { SetValue(string) error } @@ -72,20 +78,20 @@ type Updater interface { // // Example: // -// type ConfigDatabase struct { -// Port string `yaml:"port" env:"PORT" env-default:"5432"` -// Host string `yaml:"host" env:"HOST" env-default:"localhost"` -// Name string `yaml:"name" env:"NAME" env-default:"postgres"` -// User string `yaml:"user" env:"USER" env-default:"user"` -// Password string `yaml:"password" env:"PASSWORD"` -// } +// type ConfigDatabase struct { +// Port string `yaml:"port" env:"PORT" env-default:"5432"` +// Host string `yaml:"host" env:"HOST" env-default:"localhost"` +// Name string `yaml:"name" env:"NAME" env-default:"postgres"` +// User string `yaml:"user" env:"USER" env-default:"user"` +// Password string `yaml:"password" env:"PASSWORD"` +// } // -// var cfg ConfigDatabase +// var cfg ConfigDatabase // -// err := cleanenv.ReadConfig("config.yml", &cfg) -// if err != nil { -// ... -// } +// err := cleanenv.ReadConfig("config.yml", &cfg) +// if err != nil { +// ... +// } func ReadConfig(path string, cfg interface{}) error { err := parseFile(path, cfg) if err != nil { @@ -178,11 +184,58 @@ func parseENV(r io.Reader, _ interface{}) error { } for env, val := range vars { - os.Setenv(env, val) + if err = os.Setenv(env, val); err != nil { + return fmt.Errorf("set environment: %w", err) + } } + return nil } +// parseSlice parses value into a slice of given type +func parseSlice(valueType reflect.Type, value string, sep string, layout *string) (*reflect.Value, error) { + sliceValue := reflect.MakeSlice(valueType, 0, 0) + if valueType.Elem().Kind() == reflect.Uint8 { + sliceValue = reflect.ValueOf([]byte(value)) + } else if len(strings.TrimSpace(value)) != 0 { + values := strings.Split(value, sep) + sliceValue = reflect.MakeSlice(valueType, len(values), len(values)) + + for i, val := range values { + if err := parseValue(sliceValue.Index(i), val, sep, layout); err != nil { + return nil, err + } + } + } + return &sliceValue, nil +} + +// parseMap parses value into a map of given type +func parseMap(valueType reflect.Type, value string, sep string, layout *string) (*reflect.Value, error) { + mapValue := reflect.MakeMap(valueType) + if len(strings.TrimSpace(value)) != 0 { + pairs := strings.Split(value, sep) + for _, pair := range pairs { + kvPair := strings.SplitN(pair, ":", 2) + if len(kvPair) != 2 { + return nil, fmt.Errorf("invalid map item: %q", pair) + } + k := reflect.New(valueType.Key()).Elem() + err := parseValue(k, kvPair[0], sep, layout) + if err != nil { + return nil, err + } + v := reflect.New(valueType.Elem()).Elem() + err = parseValue(v, kvPair[1], sep, layout) + if err != nil { + return nil, err + } + mapValue.SetMapIndex(k, v) + } + } + return &mapValue, nil +} + // structMeta is a structure metadata entity type structMeta struct { envList []string @@ -198,7 +251,7 @@ type structMeta struct { // isFieldValueZero determines if fieldValue empty or not func (sm *structMeta) isFieldValueZero() bool { - return isZero(sm.fieldValue) + return sm.fieldValue.IsZero() } // parseFunc custom value parser function @@ -206,6 +259,7 @@ type parseFunc func(*reflect.Value, string, *string) error // Any specific supported struct can be added here var validStructs = map[reflect.Type]parseFunc{ + reflect.TypeOf(time.Time{}): func(field *reflect.Value, value string, layout *string) error { var l string if layout != nil { @@ -220,6 +274,7 @@ var validStructs = map[reflect.Type]parseFunc{ field.Set(reflect.ValueOf(val)) return nil }, + reflect.TypeOf(url.URL{}): func(field *reflect.Value, value string, _ *string) error { val, err := url.Parse(value) if err != nil { @@ -228,6 +283,16 @@ var validStructs = map[reflect.Type]parseFunc{ field.Set(reflect.ValueOf(*val)) return nil }, + + reflect.TypeOf(&time.Location{}): func(field *reflect.Value, value string, _ *string) error { + loc, err := time.LoadLocation(value) + if err != nil { + return err + } + + field.Set(reflect.ValueOf(loc)) + return nil + }, } // readStructMetadata reads structure metadata (types, tags, etc.) @@ -268,12 +333,14 @@ func readStructMetadata(cfgRoot interface{}) ([]structMeta, error) { // process nested structure (except of supported ones) if fld := s.Field(idx); fld.Kind() == reflect.Struct { + // add structure to parsing stack if _, found := validStructs[fld.Type()]; !found { prefix, _ := fType.Tag.Lookup(TagEnvPrefix) cfgStack = append(cfgStack, cfgNode{fld.Addr().Interface(), sPrefix + prefix}) continue } + // process time.Time if l, ok := fType.Tag.Lookup(TagEnvLayout); ok { layout = &l @@ -357,9 +424,10 @@ func readEnvVars(cfg interface{}, update bool) error { } if rawValue == nil && meta.required && meta.isFieldValueZero() { - err := fmt.Errorf("field %q is required but the value is not provided", - meta.fieldName) - return err + return fmt.Errorf( + "field %q is required but the value is not provided", + meta.fieldName, + ) } if rawValue == nil && meta.isFieldValueZero() { @@ -392,8 +460,8 @@ func parseValue(field reflect.Value, value, sep string, layout *string) error { } valueType := field.Type() - switch valueType.Kind() { + // parse string value case reflect.String: field.SetString(value) @@ -406,16 +474,22 @@ func parseValue(field reflect.Value, value, sep string, layout *string) error { } field.SetBool(b) - // parse integer (or time) value - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if field.Kind() == reflect.Int64 && valueType.PkgPath() == "time" && valueType.Name() == "Duration" { + // parse integer + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: + number, err := strconv.ParseInt(value, 0, valueType.Bits()) + if err != nil { + return err + } + field.SetInt(number) + + case reflect.Int64: + if valueType == reflect.TypeOf(time.Duration(0)) { // try to parse time d, err := time.ParseDuration(value) if err != nil { return err } field.SetInt(int64(d)) - } else { // parse regular integer number, err := strconv.ParseInt(value, 0, valueType.Bits()) @@ -459,62 +533,18 @@ func parseValue(field reflect.Value, value, sep string, layout *string) error { field.Set(*mapValue) - case reflect.Struct: + default: + // look for supported struct parser if structParser, found := validStructs[valueType]; found { return structParser(&field, value, layout) } - default: return fmt.Errorf("unsupported type %s.%s", valueType.PkgPath(), valueType.Name()) } return nil } -// parseSlice parses value into a slice of given type -func parseSlice(valueType reflect.Type, value string, sep string, layout *string) (*reflect.Value, error) { - sliceValue := reflect.MakeSlice(valueType, 0, 0) - if valueType.Elem().Kind() == reflect.Uint8 { - sliceValue = reflect.ValueOf([]byte(value)) - } else if len(strings.TrimSpace(value)) != 0 { - values := strings.Split(value, sep) - sliceValue = reflect.MakeSlice(valueType, len(values), len(values)) - - for i, val := range values { - if err := parseValue(sliceValue.Index(i), val, sep, layout); err != nil { - return nil, err - } - } - } - return &sliceValue, nil -} - -// parseMap parses value into a map of given type -func parseMap(valueType reflect.Type, value string, sep string, layout *string) (*reflect.Value, error) { - mapValue := reflect.MakeMap(valueType) - if len(strings.TrimSpace(value)) != 0 { - pairs := strings.Split(value, sep) - for _, pair := range pairs { - kvPair := strings.SplitN(pair, ":", 2) - if len(kvPair) != 2 { - return nil, fmt.Errorf("invalid map item: %q", pair) - } - k := reflect.New(valueType.Key()).Elem() - err := parseValue(k, kvPair[0], sep, layout) - if err != nil { - return nil, err - } - v := reflect.New(valueType.Elem()).Elem() - err = parseValue(v, kvPair[1], sep, layout) - if err != nil { - return nil, err - } - mapValue.SetMapIndex(k, v) - } - } - return &mapValue, nil -} - // GetDescription returns a description of environment variables. // You can provide a custom header text. func GetDescription(cfg interface{}, headerText *string) (string, error) { @@ -583,42 +613,3 @@ func FUsage(w io.Writer, cfg interface{}, headerText *string, usageFuncs ...func fmt.Fprintln(w, text) } } - -// isZero is a backport of reflect.Value.IsZero() -func isZero(v reflect.Value) bool { - switch v.Kind() { - case reflect.Bool: - return !v.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return v.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return v.Uint() == 0 - case reflect.Float32, reflect.Float64: - return math.Float64bits(v.Float()) == 0 - case reflect.Complex64, reflect.Complex128: - c := v.Complex() - return math.Float64bits(real(c)) == 0 && math.Float64bits(imag(c)) == 0 - case reflect.Array: - for i := 0; i < v.Len(); i++ { - if !isZero(v.Index(i)) { - return false - } - } - return true - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer: - return v.IsNil() - case reflect.String: - return v.Len() == 0 - case reflect.Struct: - for i := 0; i < v.NumField(); i++ { - if !isZero(v.Field(i)) { - return false - } - } - return true - default: - // This should never happens, but will act as a safeguard for - // later, as a default value doesn't makes sense here. - panic(fmt.Sprintf("Value.IsZero: %v", v.Kind())) - } -} diff --git a/cleanenv_test.go b/cleanenv_test.go index e368a0c..b83cce0 100644 --- a/cleanenv_test.go +++ b/cleanenv_test.go @@ -50,13 +50,15 @@ func TestReadEnvVars(t *testing.T) { } type AllTypes struct { - Integer int64 `env:"TEST_INTEGER"` - UnsInteger uint64 `env:"TEST_UNSINTEGER"` - Float float64 `env:"TEST_FLOAT"` - Boolean bool `env:"TEST_BOOLEAN"` - String string `env:"TEST_STRING"` - Duration time.Duration `env:"TEST_DURATION"` - Time time.Time `env:"TEST_TIME"` + Integer int64 `env:"TEST_INTEGER"` + UnsInteger uint64 `env:"TEST_UNSINTEGER"` + Float float64 `env:"TEST_FLOAT"` + Boolean bool `env:"TEST_BOOLEAN"` + String string `env:"TEST_STRING"` + Duration time.Duration `env:"TEST_DURATION"` + Time time.Time `env:"TEST_TIME"` + // Location depends on the system, so we test it with time.UTC + Location *time.Location `env:"TEST_LOCATION"` ArrayInt []int `env:"TEST_ARRAYINT"` ArrayString []string `env:"TEST_ARRAYSTRING"` MapStringInt map[string]int `env:"TEST_MAPSTRINGINT"` @@ -104,13 +106,15 @@ func TestReadEnvVars(t *testing.T) { { name: "all types", env: map[string]string{ - "TEST_INTEGER": "-5", - "TEST_UNSINTEGER": "5", - "TEST_FLOAT": "5.5", - "TEST_BOOLEAN": "true", - "TEST_STRING": "test", - "TEST_DURATION": "1h5m10s", - "TEST_TIME": "2012-04-23T18:25:43.511Z", + "TEST_INTEGER": "-5", + "TEST_UNSINTEGER": "5", + "TEST_FLOAT": "5.5", + "TEST_BOOLEAN": "true", + "TEST_STRING": "test", + "TEST_DURATION": "1h5m10s", + "TEST_TIME": "2012-04-23T18:25:43.511Z", + // Location depends on the system, so we test it with time.UTC + "TEST_LOCATION": "UTC", "TEST_ARRAYINT": "1,2,3", "TEST_ARRAYSTRING": "a,b,c", "TEST_MAPSTRINGINT": "a:1,b:2,c:3", @@ -125,6 +129,7 @@ func TestReadEnvVars(t *testing.T) { String: "test", Duration: durationFunc("1h5m10s"), Time: timeFunc("2012-04-23T18:25:43.511Z", time.RFC3339), + Location: time.UTC, ArrayInt: []int{1, 2, 3}, ArrayString: []string{"a", "b", "c"}, MapStringInt: map[string]int{ @@ -373,7 +378,7 @@ func TestReadEnvVarsURL(t *testing.T) { } if !reflect.DeepEqual(tt.cfg, tt.want) { fmt.Println(tt.cfg.(*WithURL).DatabaseURL) - t.Errorf("wrong data %v, want %v", tt.cfg, tt.want) + t.Errorf("wrong data: got %v, want %v", tt.cfg, tt.want) } }) } @@ -428,7 +433,6 @@ func TestReadEnvVarsTime(t *testing.T) { }) } } - func TestReadEnvVarsWithPrefix(t *testing.T) { type Logging struct { Debug bool `env:"DEBUG"` @@ -630,9 +634,7 @@ number = 1 float = 2.3 string = "test" boolean = true - array = [1, 2, 3] - [object] one = 1 two = 2`, @@ -1163,3 +1165,25 @@ no-env: this }) } } + +// TestTimeLocation tests *time.Location parse. It is a pointer type, +// so we need to compare it with pointer manually, +// because reflect.DeepEqual() compares only pointer values, not their structs +func TestTimeLocation(t *testing.T) { + want := time.UTC + + var S struct { + Location *time.Location `env:"TEST_LOCATION"` + } + + os.Setenv("TEST_LOCATION", "UTC") + defer os.Clearenv() + + if err := ReadEnv(&S); err != nil { + t.Fatal("cannot read env:", err) + } + + if want != S.Location { + t.Errorf("wrong location pointers: got %p, want %p", S.Location, want) + } +}