Skip to content
This repository has been archived by the owner on Apr 19, 2024. It is now read-only.

Fix multiple validator inconsistency #401

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
83 changes: 36 additions & 47 deletions survey.go
Expand Up @@ -297,53 +297,49 @@ func Ask(qs []*Question, response interface{}, opts ...AskOpt) error {
return errors.New("cannot call Ask() with a nil reference to record the answers")
}

validate := func(q *Question, val interface{}) error {
if q.Validate != nil {
if err := q.Validate(val); err != nil {
return err
}
}
for _, v := range options.Validators {
if err := v(val); err != nil {
return err
}
}
return nil
}

// go over every question
for _, q := range qs {
// If Prompt implements controllable stdio, pass in specified stdio.
if p, ok := q.Prompt.(wantsStdio); ok {
p.WithStdio(options.Stdio)
}

// grab the user input and save it
ans, err := q.Prompt.Prompt(&options.PromptConfig)
// if there was a problem
if err != nil {
return err
}

// build up a list of validators that we have to apply to this question
validators := []Validator{}

// make sure to include the question specific one
if q.Validate != nil {
validators = append(validators, q.Validate)
}
// add any "global" validators
for _, validator := range options.Validators {
validators = append(validators, validator)
}

// apply every validator to thte response
for _, validator := range validators {
// wait for a valid response
for invalid := validator(ans); invalid != nil; invalid = validator(ans) {
err := q.Prompt.Error(&options.PromptConfig, invalid)
// if there was a problem
if err != nil {
return err
}

// ask for more input
if promptAgainer, ok := q.Prompt.(PromptAgainer); ok {
ans, err = promptAgainer.PromptAgain(&options.PromptConfig, ans, invalid)
} else {
ans, err = q.Prompt.Prompt(&options.PromptConfig)
}
// if there was a problem
if err != nil {
var ans interface{}
var validationErr error
// prompt and validation loop
for {
if validationErr != nil {
if err := q.Prompt.Error(&options.PromptConfig, validationErr); err != nil {
return err
}
}
var err error
if promptAgainer, ok := q.Prompt.(PromptAgainer); ok && validationErr != nil {
ans, err = promptAgainer.PromptAgain(&options.PromptConfig, ans, validationErr)
} else {
ans, err = q.Prompt.Prompt(&options.PromptConfig)
}
if err != nil {
return err
}
validationErr = validate(q, ans)
if validationErr == nil {
break
}
}

if q.Transform != nil {
Expand All @@ -356,21 +352,14 @@ func Ask(qs []*Question, response interface{}, opts ...AskOpt) error {
}

// tell the prompt to cleanup with the validated value
q.Prompt.Cleanup(&options.PromptConfig, ans)

// if something went wrong
if err != nil {
// stop listening
if err := q.Prompt.Cleanup(&options.PromptConfig, ans); err != nil {
return err
}

// add it to the map
err = core.WriteAnswer(response, q.Name, ans)
// if something went wrong
if err != nil {
// add the answer to the response data structure
if err := core.WriteAnswer(response, q.Name, ans); err != nil {
return err
}

}

// return the response
Expand Down
69 changes: 69 additions & 0 deletions survey_test.go
@@ -1,6 +1,7 @@
package survey

import (
"errors"
"fmt"
"strings"
"testing"
Expand Down Expand Up @@ -518,3 +519,71 @@ func Test_computeCursorOffset_Select(t *testing.T) {
})
}
}

func TestAsk_Validation(t *testing.T) {
p := &mockPrompt{
answers: []string{"", "company", "COM", "com"},
}

var res struct {
TLDN string
}
err := Ask([]*Question{
{
Name: "TLDN",
Prompt: p,
Validate: func(v interface{}) error {
s := v.(string)
if strings.ToLower(s) != s {
return errors.New("value contains uppercase characters")
}
return nil
},
},
}, &res, WithValidator(MinLength(1)), WithValidator(MaxLength(5)))
if err != nil {
t.Fatalf("Ask() = %v", err)
}

if res.TLDN != "com" {
t.Errorf("answer: %q, want %q", res.TLDN, "com")
}
if p.cleanups != 1 {
t.Errorf("cleanups: %d, want %d", p.cleanups, 1)
}
if err1 := p.printedErrors[0].Error(); err1 != "value is too short. Min length is 1" {
t.Errorf("printed error 1: %q, want %q", err1, "value is too short. Min length is 1")
}
if err2 := p.printedErrors[1].Error(); err2 != "value is too long. Max length is 5" {
t.Errorf("printed error 2: %q, want %q", err2, "value is too long. Max length is 5")
}
if err3 := p.printedErrors[2].Error(); err3 != "value contains uppercase characters" {
t.Errorf("printed error 2: %q, want %q", err3, "value contains uppercase characters")
}
}

type mockPrompt struct {
index int
answers []string
cleanups int
printedErrors []error
}

func (p *mockPrompt) Prompt(*PromptConfig) (interface{}, error) {
if p.index >= len(p.answers) {
return nil, errors.New("no more answers")
}
val := p.answers[p.index]
p.index++
return val, nil
}

func (p *mockPrompt) Cleanup(*PromptConfig, interface{}) error {
p.cleanups++
return nil
}

func (p *mockPrompt) Error(_ *PromptConfig, err error) error {
p.printedErrors = append(p.printedErrors, err)
return nil
}