Skip to content
This repository has been archived by the owner on Apr 19, 2024. It is now read-only.

Fix multiple validator inconsistency #401

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 33 additions & 35 deletions survey.go
Expand Up @@ -297,51 +297,49 @@ func Ask(qs []*Question, response interface{}, opts ...AskOpt) error {
return errors.New("cannot call Ask() with a nil reference to record the answers")
}

validate := func(q *Question, val interface{}) error {
if q.Validate != nil {
if err := q.Validate(val); err != nil {
return err
}
}
for _, v := range options.Validators {
if err := v(val); err != nil {
return err
}
}
return nil
}

// go over every question
for _, q := range qs {
// If Prompt implements controllable stdio, pass in specified stdio.
if p, ok := q.Prompt.(wantsStdio); ok {
p.WithStdio(options.Stdio)
}

// grab the user input and save it
ans, err := q.Prompt.Prompt(&options.PromptConfig)
// if there was a problem
if err != nil {
return err
}

// build up a list of validators that we have to apply to this question
validators := []Validator{}

// make sure to include the question specific one
if q.Validate != nil {
validators = append(validators, q.Validate)
}
// add any "global" validators
validators = append(validators, options.Validators...)

// apply every validator to thte response
for _, validator := range validators {
// wait for a valid response
for invalid := validator(ans); invalid != nil; invalid = validator(ans) {
err := q.Prompt.Error(&options.PromptConfig, invalid)
// if there was a problem
if err != nil {
return err
}

// ask for more input
if promptAgainer, ok := q.Prompt.(PromptAgainer); ok {
ans, err = promptAgainer.PromptAgain(&options.PromptConfig, ans, invalid)
} else {
ans, err = q.Prompt.Prompt(&options.PromptConfig)
}
// if there was a problem
if err != nil {
var ans interface{}
var validationErr error
// prompt and validation loop
for {
if validationErr != nil {
if err := q.Prompt.Error(&options.PromptConfig, validationErr); err != nil {
return err
}
}
var err error
if promptAgainer, ok := q.Prompt.(PromptAgainer); ok && validationErr != nil {
ans, err = promptAgainer.PromptAgain(&options.PromptConfig, ans, validationErr)
} else {
ans, err = q.Prompt.Prompt(&options.PromptConfig)
}
if err != nil {
return err
}
validationErr = validate(q, ans)
if validationErr == nil {
break
}
}

if q.Transform != nil {
Expand Down
70 changes: 70 additions & 0 deletions survey_test.go
@@ -1,6 +1,8 @@
package survey

import (
"errors"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -560,3 +562,71 @@ func Test_computeCursorOffset_Select(t *testing.T) {
})
}
}

func TestAsk_Validation(t *testing.T) {
p := &mockPrompt{
answers: []string{"", "company", "COM", "com"},
}

var res struct {
TLDN string
}
err := Ask([]*Question{
{
Name: "TLDN",
Prompt: p,
Validate: func(v interface{}) error {
s := v.(string)
if strings.ToLower(s) != s {
return errors.New("value contains uppercase characters")
}
return nil
},
},
}, &res, WithValidator(MinLength(1)), WithValidator(MaxLength(5)))
if err != nil {
t.Fatalf("Ask() = %v", err)
}

if res.TLDN != "com" {
t.Errorf("answer: %q, want %q", res.TLDN, "com")
}
if p.cleanups != 1 {
t.Errorf("cleanups: %d, want %d", p.cleanups, 1)
}
if err1 := p.printedErrors[0].Error(); err1 != "value is too short. Min length is 1" {
t.Errorf("printed error 1: %q, want %q", err1, "value is too short. Min length is 1")
}
if err2 := p.printedErrors[1].Error(); err2 != "value is too long. Max length is 5" {
t.Errorf("printed error 2: %q, want %q", err2, "value is too long. Max length is 5")
}
if err3 := p.printedErrors[2].Error(); err3 != "value contains uppercase characters" {
t.Errorf("printed error 2: %q, want %q", err3, "value contains uppercase characters")
}
}

type mockPrompt struct {
index int
answers []string
cleanups int
printedErrors []error
}

func (p *mockPrompt) Prompt(*PromptConfig) (interface{}, error) {
if p.index >= len(p.answers) {
return nil, errors.New("no more answers")
}
val := p.answers[p.index]
p.index++
return val, nil
}

func (p *mockPrompt) Cleanup(*PromptConfig, interface{}) error {
p.cleanups++
return nil
}

func (p *mockPrompt) Error(_ *PromptConfig, err error) error {
p.printedErrors = append(p.printedErrors, err)
return nil
}