Skip to content

Commit

Permalink
New parser
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg committed Jul 21, 2022
1 parent 7766e43 commit e046101
Show file tree
Hide file tree
Showing 5 changed files with 819 additions and 98 deletions.
74 changes: 62 additions & 12 deletions aconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
type Loader struct {
config Config
dst interface{}
parser *structParser
fields []*fieldData
fsys fs.FS
flagSet *flag.FlagSet
Expand All @@ -32,6 +33,11 @@ type Loader struct {

// Config to configure configuration loader.
type Config struct {
// NewParser set to true enables a new and better struct parser.
// Default is false because there might be bugs.
// In the future new parser will be enabled by default.
NewParser bool

SkipDefaults bool // SkipDefaults set to true will not load config from 'default' tag.
SkipFiles bool // SkipFiles set to true will not load config from files.
SkipEnv bool // SkipEnv set to true will not load config from environment variables.
Expand Down Expand Up @@ -168,24 +174,37 @@ func (l *Loader) init() {
l.config.Args = os.Args[1:]
}

l.fields = l.getFields(l.dst)
if l.config.NewParser {
l.parser = newStructParser(l.config)
if err := l.parser.parseStruct(l.dst); err != nil {
l.errInit = err
return
}
} else {
l.fields = l.getFields(l.dst)
}

l.flagSet = flag.NewFlagSet(l.config.FlagPrefix, flag.ContinueOnError)
if !l.config.SkipFlags {
names := make(map[string]bool, len(l.fields))
for _, field := range l.fields {
flagName := l.fullTag(l.config.FlagPrefix, field, flagNameTag)
if flagName == "" {
continue
}
if names[flagName] && !l.config.AllowDuplicates {
l.errInit = fmt.Errorf("duplicate flag %q", flagName)
return
if l.config.NewParser {
l.flagSet = l.parser.flagSet
} else {
for _, field := range l.fields {
flagName := l.fullTag(l.config.FlagPrefix, field, flagNameTag)
if flagName == "" {
continue
}
if names[flagName] && !l.config.AllowDuplicates {
l.errInit = fmt.Errorf("duplicate flag %q", flagName)
return
}
names[flagName] = true
l.flagSet.String(flagName, field.Tag(defaultValueTag), field.Tag(usageTag))
}
names[flagName] = true
l.flagSet.String(flagName, field.Tag(defaultValueTag), field.Tag(usageTag))
}
}

if l.config.FileFlag != "" {
// TODO: should be prefixed ?
l.flagSet.String(l.config.FileFlag, "", "config file param")
Expand Down Expand Up @@ -261,6 +280,12 @@ func (l *Loader) loadSources() error {
return fmt.Errorf("load flags: %w", err)
}
}

if l.config.NewParser {
if err := l.parser.apply(l.dst); err != nil {
return fmt.Errorf("apply: %w", err)
}
}
return nil
}

Expand All @@ -277,6 +302,10 @@ func (l *Loader) checkRequired() error {
}

func (l *Loader) loadDefaults() error {
if l.config.NewParser {
return nil
}

for _, field := range l.fields {
defaultValue := field.Tag(defaultValueTag)
if err := l.setFieldData(field, defaultValue); err != nil {
Expand Down Expand Up @@ -327,6 +356,13 @@ func (l *Loader) loadFile(file string) error {

tag := decoder.Format()

if l.config.NewParser {
if err := l.parser.applyLevel(tag, actualFields); err != nil {
return fmt.Errorf("apply %s: %w", tag, err)
}
return nil
}

for _, field := range l.fields {
name := l.fullTag("", field, tag)
if name == "" {
Expand Down Expand Up @@ -379,6 +415,13 @@ func (l *Loader) loadEnvironment() error {
actualEnvs := getEnv()
dupls := make(map[string]struct{})

if l.config.NewParser {
if err := l.parser.applyFlat("env", actualEnvs); err != nil {
return fmt.Errorf("apply env: %w", err)
}
return nil
}

for _, field := range l.fields {
envName := l.fullTag(l.config.EnvPrefix, field, envNameTag)
if envName == "" {
Expand Down Expand Up @@ -410,6 +453,13 @@ func (l *Loader) loadFlags() error {
actualFlags := getFlags(l.flagSet)
dupls := make(map[string]struct{})

if l.config.NewParser {
if err := l.parser.applyFlat("flag", actualFlags); err != nil {
return fmt.Errorf("apply flag: %w", err)
}
return nil
}

for _, field := range l.fields {
flagName := l.fullTag(l.config.FlagPrefix, field, flagNameTag)
if flagName == "" {
Expand All @@ -430,7 +480,7 @@ func (l *Loader) postFlagCheck(values map[string]interface{}, dupls map[string]s
delete(values, name)
}
for flag, value := range values {
if strings.HasPrefix(flag, l.config.EnvPrefix) {
if strings.HasPrefix(flag, l.config.FlagPrefix) {
return fmt.Errorf("unknown flag %s=%v (see AllowUnknownFlags config param)", flag, value)
}
}
Expand Down

0 comments on commit e046101

Please sign in to comment.