Skip to content

Commit

Permalink
Add time.Location support
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyakaznacheev committed Oct 11, 2022
2 parents 130ef83 + badbae6 commit 3bb055c
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 134 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -192,6 +192,7 @@ There are following supported types:
- maps (of any other supported type);
- `time.Duration`;
- `time.Time` (layout by default is RFC3339, may be overridden by `env-layout`);
- `*time.Location` (time zone parsing [depends](https://pkg.go.dev/time#LoadLocation) on running machine)
- any type implementing `cleanenv.Setter` interface.


Expand Down
223 changes: 107 additions & 116 deletions cleanenv.go
Expand Up @@ -5,7 +5,6 @@ import (
"flag"
"fmt"
"io"
"math"
"net/url"
"os"
"path/filepath"
Expand All @@ -29,18 +28,25 @@ const (
const (
// Name of the environment variable or a list of names
TagEnv = "env"

// Value parsing layout (for types like time.Time)
TagEnvLayout = "env-layout"

// Default value
TagEnvDefault = "env-default"

// Custom list and map separator
TagEnvSeparator = "env-separator"

// Environment variable description
TagEnvDescription = "env-description"

// Flag to mark a field as updatable
TagEnvUpd = "env-upd"

// Flag to mark a field as required
TagEnvRequired = "env-required"

// Flag to specify prefix for structure fields
TagEnvPrefix = "env-prefix"
)
Expand All @@ -49,15 +55,15 @@ const (
//
// To implement a custom value setter you need to add a SetValue function to your type that will receive a string raw value:
//
// type MyField string
// type MyField string
//
// func (f *MyField) SetValue(s string) error {
// if s == "" {
// return fmt.Errorf("field value can't be empty")
// }
// *f = MyField("my field is: " + s)
// return nil
// }
// func (f *MyField) SetValue(s string) error {
// if s == "" {
// return fmt.Errorf("field value can't be empty")
// }
// *f = MyField("my field is: " + s)
// return nil
// }
type Setter interface {
SetValue(string) error
}
Expand All @@ -72,20 +78,20 @@ type Updater interface {
//
// Example:
//
// type ConfigDatabase struct {
// Port string `yaml:"port" env:"PORT" env-default:"5432"`
// Host string `yaml:"host" env:"HOST" env-default:"localhost"`
// Name string `yaml:"name" env:"NAME" env-default:"postgres"`
// User string `yaml:"user" env:"USER" env-default:"user"`
// Password string `yaml:"password" env:"PASSWORD"`
// }
// type ConfigDatabase struct {
// Port string `yaml:"port" env:"PORT" env-default:"5432"`
// Host string `yaml:"host" env:"HOST" env-default:"localhost"`
// Name string `yaml:"name" env:"NAME" env-default:"postgres"`
// User string `yaml:"user" env:"USER" env-default:"user"`
// Password string `yaml:"password" env:"PASSWORD"`
// }
//
// var cfg ConfigDatabase
// var cfg ConfigDatabase
//
// err := cleanenv.ReadConfig("config.yml", &cfg)
// if err != nil {
// ...
// }
// err := cleanenv.ReadConfig("config.yml", &cfg)
// if err != nil {
// ...
// }
func ReadConfig(path string, cfg interface{}) error {
err := parseFile(path, cfg)
if err != nil {
Expand Down Expand Up @@ -178,11 +184,58 @@ func parseENV(r io.Reader, _ interface{}) error {
}

for env, val := range vars {
os.Setenv(env, val)
if err = os.Setenv(env, val); err != nil {
return fmt.Errorf("set environment: %w", err)
}
}

return nil
}

// parseSlice parses value into a slice of given type
func parseSlice(valueType reflect.Type, value string, sep string, layout *string) (*reflect.Value, error) {
sliceValue := reflect.MakeSlice(valueType, 0, 0)
if valueType.Elem().Kind() == reflect.Uint8 {
sliceValue = reflect.ValueOf([]byte(value))
} else if len(strings.TrimSpace(value)) != 0 {
values := strings.Split(value, sep)
sliceValue = reflect.MakeSlice(valueType, len(values), len(values))

for i, val := range values {
if err := parseValue(sliceValue.Index(i), val, sep, layout); err != nil {
return nil, err
}
}
}
return &sliceValue, nil
}

// parseMap parses value into a map of given type
func parseMap(valueType reflect.Type, value string, sep string, layout *string) (*reflect.Value, error) {
mapValue := reflect.MakeMap(valueType)
if len(strings.TrimSpace(value)) != 0 {
pairs := strings.Split(value, sep)
for _, pair := range pairs {
kvPair := strings.SplitN(pair, ":", 2)
if len(kvPair) != 2 {
return nil, fmt.Errorf("invalid map item: %q", pair)
}
k := reflect.New(valueType.Key()).Elem()
err := parseValue(k, kvPair[0], sep, layout)
if err != nil {
return nil, err
}
v := reflect.New(valueType.Elem()).Elem()
err = parseValue(v, kvPair[1], sep, layout)
if err != nil {
return nil, err
}
mapValue.SetMapIndex(k, v)
}
}
return &mapValue, nil
}

// structMeta is a structure metadata entity
type structMeta struct {
envList []string
Expand All @@ -198,14 +251,15 @@ type structMeta struct {

// isFieldValueZero determines if fieldValue empty or not
func (sm *structMeta) isFieldValueZero() bool {
return isZero(sm.fieldValue)
return sm.fieldValue.IsZero()
}

// parseFunc custom value parser function
type parseFunc func(*reflect.Value, string, *string) error

// Any specific supported struct can be added here
var validStructs = map[reflect.Type]parseFunc{

reflect.TypeOf(time.Time{}): func(field *reflect.Value, value string, layout *string) error {
var l string
if layout != nil {
Expand All @@ -220,6 +274,7 @@ var validStructs = map[reflect.Type]parseFunc{
field.Set(reflect.ValueOf(val))
return nil
},

reflect.TypeOf(url.URL{}): func(field *reflect.Value, value string, _ *string) error {
val, err := url.Parse(value)
if err != nil {
Expand All @@ -228,6 +283,16 @@ var validStructs = map[reflect.Type]parseFunc{
field.Set(reflect.ValueOf(*val))
return nil
},

reflect.TypeOf(&time.Location{}): func(field *reflect.Value, value string, _ *string) error {
loc, err := time.LoadLocation(value)
if err != nil {
return err
}

field.Set(reflect.ValueOf(loc))
return nil
},
}

// readStructMetadata reads structure metadata (types, tags, etc.)
Expand Down Expand Up @@ -268,12 +333,14 @@ func readStructMetadata(cfgRoot interface{}) ([]structMeta, error) {

// process nested structure (except of supported ones)
if fld := s.Field(idx); fld.Kind() == reflect.Struct {

// add structure to parsing stack
if _, found := validStructs[fld.Type()]; !found {
prefix, _ := fType.Tag.Lookup(TagEnvPrefix)
cfgStack = append(cfgStack, cfgNode{fld.Addr().Interface(), sPrefix + prefix})
continue
}

// process time.Time
if l, ok := fType.Tag.Lookup(TagEnvLayout); ok {
layout = &l
Expand Down Expand Up @@ -357,9 +424,10 @@ func readEnvVars(cfg interface{}, update bool) error {
}

if rawValue == nil && meta.required && meta.isFieldValueZero() {
err := fmt.Errorf("field %q is required but the value is not provided",
meta.fieldName)
return err
return fmt.Errorf(
"field %q is required but the value is not provided",
meta.fieldName,
)
}

if rawValue == nil && meta.isFieldValueZero() {
Expand Down Expand Up @@ -392,8 +460,8 @@ func parseValue(field reflect.Value, value, sep string, layout *string) error {
}

valueType := field.Type()

switch valueType.Kind() {

// parse string value
case reflect.String:
field.SetString(value)
Expand All @@ -406,16 +474,22 @@ func parseValue(field reflect.Value, value, sep string, layout *string) error {
}
field.SetBool(b)

// parse integer (or time) value
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if field.Kind() == reflect.Int64 && valueType.PkgPath() == "time" && valueType.Name() == "Duration" {
// parse integer
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
number, err := strconv.ParseInt(value, 0, valueType.Bits())
if err != nil {
return err
}
field.SetInt(number)

case reflect.Int64:
if valueType == reflect.TypeOf(time.Duration(0)) {
// try to parse time
d, err := time.ParseDuration(value)
if err != nil {
return err
}
field.SetInt(int64(d))

} else {
// parse regular integer
number, err := strconv.ParseInt(value, 0, valueType.Bits())
Expand Down Expand Up @@ -459,62 +533,18 @@ func parseValue(field reflect.Value, value, sep string, layout *string) error {

field.Set(*mapValue)

case reflect.Struct:
default:
// look for supported struct parser
if structParser, found := validStructs[valueType]; found {
return structParser(&field, value, layout)
}

default:
return fmt.Errorf("unsupported type %s.%s", valueType.PkgPath(), valueType.Name())
}

return nil
}

// parseSlice parses value into a slice of given type
func parseSlice(valueType reflect.Type, value string, sep string, layout *string) (*reflect.Value, error) {
sliceValue := reflect.MakeSlice(valueType, 0, 0)
if valueType.Elem().Kind() == reflect.Uint8 {
sliceValue = reflect.ValueOf([]byte(value))
} else if len(strings.TrimSpace(value)) != 0 {
values := strings.Split(value, sep)
sliceValue = reflect.MakeSlice(valueType, len(values), len(values))

for i, val := range values {
if err := parseValue(sliceValue.Index(i), val, sep, layout); err != nil {
return nil, err
}
}
}
return &sliceValue, nil
}

// parseMap parses value into a map of given type
func parseMap(valueType reflect.Type, value string, sep string, layout *string) (*reflect.Value, error) {
mapValue := reflect.MakeMap(valueType)
if len(strings.TrimSpace(value)) != 0 {
pairs := strings.Split(value, sep)
for _, pair := range pairs {
kvPair := strings.SplitN(pair, ":", 2)
if len(kvPair) != 2 {
return nil, fmt.Errorf("invalid map item: %q", pair)
}
k := reflect.New(valueType.Key()).Elem()
err := parseValue(k, kvPair[0], sep, layout)
if err != nil {
return nil, err
}
v := reflect.New(valueType.Elem()).Elem()
err = parseValue(v, kvPair[1], sep, layout)
if err != nil {
return nil, err
}
mapValue.SetMapIndex(k, v)
}
}
return &mapValue, nil
}

// GetDescription returns a description of environment variables.
// You can provide a custom header text.
func GetDescription(cfg interface{}, headerText *string) (string, error) {
Expand Down Expand Up @@ -583,42 +613,3 @@ func FUsage(w io.Writer, cfg interface{}, headerText *string, usageFuncs ...func
fmt.Fprintln(w, text)
}
}

// isZero is a backport of reflect.Value.IsZero()
func isZero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return math.Float64bits(v.Float()) == 0
case reflect.Complex64, reflect.Complex128:
c := v.Complex()
return math.Float64bits(real(c)) == 0 && math.Float64bits(imag(c)) == 0
case reflect.Array:
for i := 0; i < v.Len(); i++ {
if !isZero(v.Index(i)) {
return false
}
}
return true
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
return v.IsNil()
case reflect.String:
return v.Len() == 0
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
if !isZero(v.Field(i)) {
return false
}
}
return true
default:
// This should never happens, but will act as a safeguard for
// later, as a default value doesn't makes sense here.
panic(fmt.Sprintf("Value.IsZero: %v", v.Kind()))
}
}

0 comments on commit 3bb055c

Please sign in to comment.