diff --git a/survey.go b/survey.go index 01063f05..95136ebe 100644 --- a/survey.go +++ b/survey.go @@ -297,6 +297,20 @@ func Ask(qs []*Question, response interface{}, opts ...AskOpt) error { return errors.New("cannot call Ask() with a nil reference to record the answers") } + validate := func(q *Question, val interface{}) error { + if q.Validate != nil { + if err := q.Validate(val); err != nil { + return err + } + } + for _, v := range options.Validators { + if err := v(val); err != nil { + return err + } + } + return nil + } + // go over every question for _, q := range qs { // If Prompt implements controllable stdio, pass in specified stdio. @@ -304,44 +318,28 @@ func Ask(qs []*Question, response interface{}, opts ...AskOpt) error { p.WithStdio(options.Stdio) } - // grab the user input and save it - ans, err := q.Prompt.Prompt(&options.PromptConfig) - // if there was a problem - if err != nil { - return err - } - - // build up a list of validators that we have to apply to this question - validators := []Validator{} - - // make sure to include the question specific one - if q.Validate != nil { - validators = append(validators, q.Validate) - } - // add any "global" validators - validators = append(validators, options.Validators...) - - // apply every validator to thte response - for _, validator := range validators { - // wait for a valid response - for invalid := validator(ans); invalid != nil; invalid = validator(ans) { - err := q.Prompt.Error(&options.PromptConfig, invalid) - // if there was a problem - if err != nil { - return err - } - - // ask for more input - if promptAgainer, ok := q.Prompt.(PromptAgainer); ok { - ans, err = promptAgainer.PromptAgain(&options.PromptConfig, ans, invalid) - } else { - ans, err = q.Prompt.Prompt(&options.PromptConfig) - } - // if there was a problem - if err != nil { + var ans interface{} + var validationErr error + // prompt and validation loop + for { + if validationErr != nil { + if err := q.Prompt.Error(&options.PromptConfig, validationErr); err != nil { return err } } + var err error + if promptAgainer, ok := q.Prompt.(PromptAgainer); ok && validationErr != nil { + ans, err = promptAgainer.PromptAgain(&options.PromptConfig, ans, validationErr) + } else { + ans, err = q.Prompt.Prompt(&options.PromptConfig) + } + if err != nil { + return err + } + validationErr = validate(q, ans) + if validationErr == nil { + break + } } if q.Transform != nil { diff --git a/survey_test.go b/survey_test.go index 121cf62f..68c80eeb 100644 --- a/survey_test.go +++ b/survey_test.go @@ -1,6 +1,8 @@ package survey import ( + "errors" + "strings" "testing" "time" @@ -560,3 +562,71 @@ func Test_computeCursorOffset_Select(t *testing.T) { }) } } + +func TestAsk_Validation(t *testing.T) { + p := &mockPrompt{ + answers: []string{"", "company", "COM", "com"}, + } + + var res struct { + TLDN string + } + err := Ask([]*Question{ + { + Name: "TLDN", + Prompt: p, + Validate: func(v interface{}) error { + s := v.(string) + if strings.ToLower(s) != s { + return errors.New("value contains uppercase characters") + } + return nil + }, + }, + }, &res, WithValidator(MinLength(1)), WithValidator(MaxLength(5))) + if err != nil { + t.Fatalf("Ask() = %v", err) + } + + if res.TLDN != "com" { + t.Errorf("answer: %q, want %q", res.TLDN, "com") + } + if p.cleanups != 1 { + t.Errorf("cleanups: %d, want %d", p.cleanups, 1) + } + if err1 := p.printedErrors[0].Error(); err1 != "value is too short. Min length is 1" { + t.Errorf("printed error 1: %q, want %q", err1, "value is too short. Min length is 1") + } + if err2 := p.printedErrors[1].Error(); err2 != "value is too long. Max length is 5" { + t.Errorf("printed error 2: %q, want %q", err2, "value is too long. Max length is 5") + } + if err3 := p.printedErrors[2].Error(); err3 != "value contains uppercase characters" { + t.Errorf("printed error 2: %q, want %q", err3, "value contains uppercase characters") + } +} + +type mockPrompt struct { + index int + answers []string + cleanups int + printedErrors []error +} + +func (p *mockPrompt) Prompt(*PromptConfig) (interface{}, error) { + if p.index >= len(p.answers) { + return nil, errors.New("no more answers") + } + val := p.answers[p.index] + p.index++ + return val, nil +} + +func (p *mockPrompt) Cleanup(*PromptConfig, interface{}) error { + p.cleanups++ + return nil +} + +func (p *mockPrompt) Error(_ *PromptConfig, err error) error { + p.printedErrors = append(p.printedErrors, err) + return nil +}