diff --git a/core/hash/consistenthash.go b/core/hash/consistenthash.go index 0adf3b9e5d6c..b805c8cc1f86 100644 --- a/core/hash/consistenthash.go +++ b/core/hash/consistenthash.go @@ -7,7 +7,6 @@ import ( "sync" "github.com/zeromicro/go-zero/core/lang" - "github.com/zeromicro/go-zero/core/mapping" ) const ( @@ -183,5 +182,5 @@ func innerRepr(node interface{}) string { } func repr(node interface{}) string { - return mapping.Repr(node) + return lang.Repr(node) } diff --git a/core/lang/lang.go b/core/lang/lang.go index ff7305dbe11f..3e2932fda962 100644 --- a/core/lang/lang.go +++ b/core/lang/lang.go @@ -1,5 +1,11 @@ package lang +import ( + "fmt" + "reflect" + "strconv" +) + // Placeholder is a placeholder object that can be used globally. var Placeholder PlaceholderType @@ -9,3 +15,64 @@ type ( // PlaceholderType represents a placeholder type. PlaceholderType = struct{} ) + +// Repr returns the string representation of v. +func Repr(v interface{}) string { + if v == nil { + return "" + } + + // if func (v *Type) String() string, we can't use Elem() + switch vt := v.(type) { + case fmt.Stringer: + return vt.String() + } + + val := reflect.ValueOf(v) + if val.Kind() == reflect.Ptr && !val.IsNil() { + val = val.Elem() + } + + return reprOfValue(val) +} + +func reprOfValue(val reflect.Value) string { + switch vt := val.Interface().(type) { + case bool: + return strconv.FormatBool(vt) + case error: + return vt.Error() + case float32: + return strconv.FormatFloat(float64(vt), 'f', -1, 32) + case float64: + return strconv.FormatFloat(vt, 'f', -1, 64) + case fmt.Stringer: + return vt.String() + case int: + return strconv.Itoa(vt) + case int8: + return strconv.Itoa(int(vt)) + case int16: + return strconv.Itoa(int(vt)) + case int32: + return strconv.Itoa(int(vt)) + case int64: + return strconv.FormatInt(vt, 10) + case string: + return vt + case uint: + return strconv.FormatUint(uint64(vt), 10) + case uint8: + return strconv.FormatUint(uint64(vt), 10) + case uint16: + return strconv.FormatUint(uint64(vt), 10) + case uint32: + return strconv.FormatUint(uint64(vt), 10) + case uint64: + return strconv.FormatUint(vt, 10) + case []byte: + return string(vt) + default: + return fmt.Sprint(val.Interface()) + } +} diff --git a/core/lang/lang_test.go b/core/lang/lang_test.go new file mode 100644 index 000000000000..1527defd8015 --- /dev/null +++ b/core/lang/lang_test.go @@ -0,0 +1,131 @@ +package lang + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRepr(t *testing.T) { + var ( + f32 float32 = 1.1 + f64 = 2.2 + i8 int8 = 1 + i16 int16 = 2 + i32 int32 = 3 + i64 int64 = 4 + u8 uint8 = 5 + u16 uint16 = 6 + u32 uint32 = 7 + u64 uint64 = 8 + ) + tests := []struct { + v interface{} + expect string + }{ + { + nil, + "", + }, + { + mockStringable{}, + "mocked", + }, + { + new(mockStringable), + "mocked", + }, + { + newMockPtr(), + "mockptr", + }, + { + &mockOpacity{ + val: 1, + }, + "{1}", + }, + { + true, + "true", + }, + { + false, + "false", + }, + { + f32, + "1.1", + }, + { + f64, + "2.2", + }, + { + i8, + "1", + }, + { + i16, + "2", + }, + { + i32, + "3", + }, + { + i64, + "4", + }, + { + u8, + "5", + }, + { + u16, + "6", + }, + { + u32, + "7", + }, + { + u64, + "8", + }, + { + []byte(`abcd`), + "abcd", + }, + { + mockOpacity{val: 1}, + "{1}", + }, + } + + for _, test := range tests { + t.Run(test.expect, func(t *testing.T) { + assert.Equal(t, test.expect, Repr(test.v)) + }) + } +} + +type mockStringable struct{} + +func (m mockStringable) String() string { + return "mocked" +} + +type mockPtr struct{} + +func newMockPtr() *mockPtr { + return new(mockPtr) +} + +func (m *mockPtr) String() string { + return "mockptr" +} + +type mockOpacity struct { + val int +} diff --git a/core/mapping/fieldoptions.go b/core/mapping/fieldoptions.go index 1e8a2ab628d7..14b3c84b32e5 100644 --- a/core/mapping/fieldoptions.go +++ b/core/mapping/fieldoptions.go @@ -13,6 +13,7 @@ type ( Optional bool Options []string Default string + EnvVar string Range *numberRange } @@ -106,5 +107,6 @@ func (o *fieldOptions) toOptionsWithContext(key string, m Valuer, fullName strin Optional: optional, Options: o.Options, Default: o.Default, + EnvVar: o.EnvVar, }, nil } diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index 0f456a35986a..9fcf1de5e852 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -12,6 +12,7 @@ import ( "github.com/zeromicro/go-zero/core/jsonx" "github.com/zeromicro/go-zero/core/lang" + "github.com/zeromicro/go-zero/core/proc" "github.com/zeromicro/go-zero/core/stringx" ) @@ -92,8 +93,7 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v interface{}, f rve := rv.Elem() numFields := rte.NumField() for i := 0; i < numFields; i++ { - field := rte.Field(i) - if err := u.processField(field, rve.Field(i), m, fullName); err != nil { + if err := u.processField(rte.Field(i), rve.Field(i), m, fullName); err != nil { return err } } @@ -338,6 +338,24 @@ func (u *Unmarshaler) processFieldTextUnmarshaler(field reflect.StructField, val return false, nil } +func (u *Unmarshaler) processFieldWithEnvValue(field reflect.StructField, value reflect.Value, + envVal string, opts *fieldOptionsWithContext, fullName string) error { + fieldKind := field.Type.Kind() + switch fieldKind { + case durationType.Kind(): + if err := fillDurationValue(fieldKind, value, envVal); err != nil { + return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err) + } + + return nil + case reflect.String: + value.SetString(envVal) + return nil + default: + return u.processFieldPrimitiveWithJSONNumber(field, value, json.Number(envVal), opts, fullName) + } +} + func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect.Value, m valuerWithParent, fullName string) error { key, opts, err := u.parseOptionsWithContext(field, m, fullName) @@ -346,6 +364,13 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect } fullName = join(fullName, key) + if opts != nil && len(opts.EnvVar) > 0 { + envVal := proc.Env(opts.EnvVar) + if len(envVal) > 0 { + return u.processFieldWithEnvValue(field, value, envVal, opts, fullName) + } + } + canonicalKey := key if u.opts.canonicalKey != nil { canonicalKey = u.opts.canonicalKey(key) diff --git a/core/mapping/unmarshaler_test.go b/core/mapping/unmarshaler_test.go index 9619d5da3bdf..6fdcdaacbb60 100644 --- a/core/mapping/unmarshaler_test.go +++ b/core/mapping/unmarshaler_test.go @@ -3,6 +3,7 @@ package mapping import ( "encoding/json" "fmt" + "os" "strconv" "strings" "testing" @@ -3089,6 +3090,129 @@ func TestUnmarshalValuer(t *testing.T) { assert.NotNil(t, err) } +func TestUnmarshal_EnvString(t *testing.T) { + type Value struct { + Name string `key:"name,env=TEST_NAME_STRING"` + } + + const ( + envName = "TEST_NAME_STRING" + envVal = "this is a name" + ) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) + + var v Value + assert.NoError(t, UnmarshalKey(emptyMap, &v)) + assert.Equal(t, envVal, v.Name) +} + +func TestUnmarshal_EnvStringOverwrite(t *testing.T) { + type Value struct { + Name string `key:"name,env=TEST_NAME_STRING"` + } + + const ( + envName = "TEST_NAME_STRING" + envVal = "this is a name" + ) + os.Setenv(envName, envVal) + defer os.Unsetenv(envName) + + var v Value + assert.NoError(t, UnmarshalKey(map[string]interface{}{ + "name": "local value", + }, &v)) + assert.Equal(t, envVal, v.Name) +} + +func TestUnmarshal_EnvInt(t *testing.T) { + type Value struct { + Age int `key:"age,env=TEST_NAME_INT"` + } + + const envName = "TEST_NAME_INT" + os.Setenv(envName, "123") + defer os.Unsetenv(envName) + + var v Value + assert.NoError(t, UnmarshalKey(emptyMap, &v)) + assert.Equal(t, 123, v.Age) +} + +func TestUnmarshal_EnvIntOverwrite(t *testing.T) { + type Value struct { + Age int `key:"age,env=TEST_NAME_INT"` + } + + const envName = "TEST_NAME_INT" + os.Setenv(envName, "123") + defer os.Unsetenv(envName) + + var v Value + assert.NoError(t, UnmarshalKey(map[string]interface{}{ + "age": 18, + }, &v)) + assert.Equal(t, 123, v.Age) +} + +func TestUnmarshal_EnvFloat(t *testing.T) { + type Value struct { + Age float32 `key:"name,env=TEST_NAME_FLOAT"` + } + + const envName = "TEST_NAME_FLOAT" + os.Setenv(envName, "123.45") + defer os.Unsetenv(envName) + + var v Value + assert.NoError(t, UnmarshalKey(emptyMap, &v)) + assert.Equal(t, float32(123.45), v.Age) +} + +func TestUnmarshal_EnvFloatOverwrite(t *testing.T) { + type Value struct { + Age float32 `key:"age,env=TEST_NAME_FLOAT"` + } + + const envName = "TEST_NAME_FLOAT" + os.Setenv(envName, "123.45") + defer os.Unsetenv(envName) + + var v Value + assert.NoError(t, UnmarshalKey(map[string]interface{}{ + "age": 18.5, + }, &v)) + assert.Equal(t, float32(123.45), v.Age) +} + +func TestUnmarshal_EnvDuration(t *testing.T) { + type Value struct { + Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"` + } + + const envName = "TEST_NAME_DURATION" + os.Setenv(envName, "1s") + defer os.Unsetenv(envName) + + var v Value + assert.NoError(t, UnmarshalKey(emptyMap, &v)) + assert.Equal(t, time.Second, v.Duration) +} + +func TestUnmarshal_EnvDurationBadValue(t *testing.T) { + type Value struct { + Duration time.Duration `key:"duration,env=TEST_NAME_BAD_DURATION"` + } + + const envName = "TEST_NAME_BAD_DURATION" + os.Setenv(envName, "bad") + defer os.Unsetenv(envName) + + var v Value + assert.NotNil(t, UnmarshalKey(emptyMap, &v)) +} + func BenchmarkUnmarshalString(b *testing.B) { type inner struct { Value string `key:"value"` diff --git a/core/mapping/utils.go b/core/mapping/utils.go index 8d21fd5836de..652b0c0c1ed3 100644 --- a/core/mapping/utils.go +++ b/core/mapping/utils.go @@ -10,11 +10,13 @@ import ( "strings" "sync" + "github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/stringx" ) const ( defaultOption = "default" + envOption = "env" inheritOption = "inherit" stringOption = "string" optionalOption = "optional" @@ -63,22 +65,7 @@ func Deref(t reflect.Type) reflect.Type { // Repr returns the string representation of v. func Repr(v interface{}) string { - if v == nil { - return "" - } - - // if func (v *Type) String() string, we can't use Elem() - switch vt := v.(type) { - case fmt.Stringer: - return vt.String() - } - - val := reflect.ValueOf(v) - if val.Kind() == reflect.Ptr && !val.IsNil() { - val = val.Elem() - } - - return reprOfValue(val) + return lang.Repr(v) } // ValidatePtr validates v if it's a valid pointer. @@ -354,26 +341,33 @@ func parseOption(fieldOpts *fieldOptions, fieldName, option string) error { case option == optionalOption: fieldOpts.Optional = true case strings.HasPrefix(option, optionsOption): - segs := strings.Split(option, equalToken) - if len(segs) != 2 { - return fmt.Errorf("field %s has wrong options", fieldName) + val, err := parseProperty(fieldName, optionsOption, option) + if err != nil { + return err } - fieldOpts.Options = parseOptions(segs[1]) + fieldOpts.Options = parseOptions(val) case strings.HasPrefix(option, defaultOption): - segs := strings.Split(option, equalToken) - if len(segs) != 2 { - return fmt.Errorf("field %s has wrong default option", fieldName) + val, err := parseProperty(fieldName, defaultOption, option) + if err != nil { + return err + } + + fieldOpts.Default = val + case strings.HasPrefix(option, envOption): + val, err := parseProperty(fieldName, envOption, option) + if err != nil { + return err } - fieldOpts.Default = strings.TrimSpace(segs[1]) + fieldOpts.EnvVar = val case strings.HasPrefix(option, rangeOption): - segs := strings.Split(option, equalToken) - if len(segs) != 2 { - return fmt.Errorf("field %s has wrong range", fieldName) + val, err := parseProperty(fieldName, rangeOption, option) + if err != nil { + return err } - nr, err := parseNumberRange(segs[1]) + nr, err := parseNumberRange(val) if err != nil { return err } @@ -398,6 +392,15 @@ func parseOptions(val string) []string { return strings.Split(val, optionSeparator) } +func parseProperty(field, tag, val string) (string, error) { + segs := strings.Split(val, equalToken) + if len(segs) != 2 { + return "", fmt.Errorf("field %s has wrong %s", field, tag) + } + + return strings.TrimSpace(segs[1]), nil +} + func parseSegments(val string) []string { var segments []string var escaped, grouped bool @@ -447,47 +450,6 @@ func parseSegments(val string) []string { return segments } -func reprOfValue(val reflect.Value) string { - switch vt := val.Interface().(type) { - case bool: - return strconv.FormatBool(vt) - case error: - return vt.Error() - case float32: - return strconv.FormatFloat(float64(vt), 'f', -1, 32) - case float64: - return strconv.FormatFloat(vt, 'f', -1, 64) - case fmt.Stringer: - return vt.String() - case int: - return strconv.Itoa(vt) - case int8: - return strconv.Itoa(int(vt)) - case int16: - return strconv.Itoa(int(vt)) - case int32: - return strconv.Itoa(int(vt)) - case int64: - return strconv.FormatInt(vt, 10) - case string: - return vt - case uint: - return strconv.FormatUint(uint64(vt), 10) - case uint8: - return strconv.FormatUint(uint64(vt), 10) - case uint16: - return strconv.FormatUint(uint64(vt), 10) - case uint32: - return strconv.FormatUint(uint64(vt), 10) - case uint64: - return strconv.FormatUint(vt, 10) - case []byte: - return string(vt) - default: - return fmt.Sprint(val.Interface()) - } -} - func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interface{}) error { switch kind { case reflect.Bool: diff --git a/core/mapping/utils_test.go b/core/mapping/utils_test.go index dec085846b3b..a36e5f9ac392 100644 --- a/core/mapping/utils_test.go +++ b/core/mapping/utils_test.go @@ -296,127 +296,3 @@ func TestSetValueFormatErrors(t *testing.T) { }) } } - -func TestRepr(t *testing.T) { - var ( - f32 float32 = 1.1 - f64 = 2.2 - i8 int8 = 1 - i16 int16 = 2 - i32 int32 = 3 - i64 int64 = 4 - u8 uint8 = 5 - u16 uint16 = 6 - u32 uint32 = 7 - u64 uint64 = 8 - ) - tests := []struct { - v interface{} - expect string - }{ - { - nil, - "", - }, - { - mockStringable{}, - "mocked", - }, - { - new(mockStringable), - "mocked", - }, - { - newMockPtr(), - "mockptr", - }, - { - &mockOpacity{ - val: 1, - }, - "{1}", - }, - { - true, - "true", - }, - { - false, - "false", - }, - { - f32, - "1.1", - }, - { - f64, - "2.2", - }, - { - i8, - "1", - }, - { - i16, - "2", - }, - { - i32, - "3", - }, - { - i64, - "4", - }, - { - u8, - "5", - }, - { - u16, - "6", - }, - { - u32, - "7", - }, - { - u64, - "8", - }, - { - []byte(`abcd`), - "abcd", - }, - { - mockOpacity{val: 1}, - "{1}", - }, - } - - for _, test := range tests { - t.Run(test.expect, func(t *testing.T) { - assert.Equal(t, test.expect, Repr(test.v)) - }) - } -} - -type mockStringable struct{} - -func (m mockStringable) String() string { - return "mocked" -} - -type mockPtr struct{} - -func newMockPtr() *mockPtr { - return new(mockPtr) -} - -func (m *mockPtr) String() string { - return "mockptr" -} - -type mockOpacity struct { - val int -} diff --git a/core/stringx/strings.go b/core/stringx/strings.go index 850eaebfbb64..c0c4510c7e33 100644 --- a/core/stringx/strings.go +++ b/core/stringx/strings.go @@ -69,6 +69,33 @@ func HasEmpty(args ...string) bool { return false } +// Join joins any number of elements into a single string, separating them with given sep. +// Empty elements are ignored. However, if the argument list is empty or all its elements are empty, +// Join returns an empty string. +func Join(sep byte, elem ...string) string { + var size int + for _, e := range elem { + size += len(e) + } + if size == 0 { + return "" + } + + buf := make([]byte, 0, size+len(elem)-1) + for _, e := range elem { + if len(e) == 0 { + continue + } + + if len(buf) > 0 { + buf = append(buf, sep) + } + buf = append(buf, e...) + } + + return string(buf) +} + // NotEmpty checks if all strings are not empty in args. func NotEmpty(args ...string) bool { return !HasEmpty(args...) diff --git a/core/stringx/strings_test.go b/core/stringx/strings_test.go index 1e88b829e213..6ef8755e5842 100644 --- a/core/stringx/strings_test.go +++ b/core/stringx/strings_test.go @@ -147,6 +147,42 @@ func TestFirstN(t *testing.T) { } } +func TestJoin(t *testing.T) { + tests := []struct { + name string + input []string + expect string + }{ + { + name: "all blanks", + input: []string{"", ""}, + expect: "", + }, + { + name: "two values", + input: []string{"012", "abc"}, + expect: "012.abc", + }, + { + name: "last blank", + input: []string{"abc", ""}, + expect: "abc", + }, + { + name: "first blank", + input: []string{"", "abc"}, + expect: "abc", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expect, Join('.', test.input...)) + }) + } +} + func TestRemove(t *testing.T) { cases := []struct { input []string