Skip to content
This repository has been archived by the owner on Apr 19, 2024. It is now read-only.

Commit

Permalink
Fix multiple validator inconsistency (#401)
Browse files Browse the repository at this point in the history
Co-authored-by: Mislav Marohnić <git@mislav.net>
  • Loading branch information
fatihdumanli and mislav committed Mar 17, 2022
1 parent 1b28f27 commit 099a968
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 35 deletions.
68 changes: 33 additions & 35 deletions survey.go
Expand Up @@ -297,51 +297,49 @@ func Ask(qs []*Question, response interface{}, opts ...AskOpt) error {
return errors.New("cannot call Ask() with a nil reference to record the answers")
}

validate := func(q *Question, val interface{}) error {
if q.Validate != nil {
if err := q.Validate(val); err != nil {
return err
}
}
for _, v := range options.Validators {
if err := v(val); err != nil {
return err
}
}
return nil
}

// go over every question
for _, q := range qs {
// If Prompt implements controllable stdio, pass in specified stdio.
if p, ok := q.Prompt.(wantsStdio); ok {
p.WithStdio(options.Stdio)
}

// grab the user input and save it
ans, err := q.Prompt.Prompt(&options.PromptConfig)
// if there was a problem
if err != nil {
return err
}

// build up a list of validators that we have to apply to this question
validators := []Validator{}

// make sure to include the question specific one
if q.Validate != nil {
validators = append(validators, q.Validate)
}
// add any "global" validators
validators = append(validators, options.Validators...)

// apply every validator to thte response
for _, validator := range validators {
// wait for a valid response
for invalid := validator(ans); invalid != nil; invalid = validator(ans) {
err := q.Prompt.Error(&options.PromptConfig, invalid)
// if there was a problem
if err != nil {
return err
}

// ask for more input
if promptAgainer, ok := q.Prompt.(PromptAgainer); ok {
ans, err = promptAgainer.PromptAgain(&options.PromptConfig, ans, invalid)
} else {
ans, err = q.Prompt.Prompt(&options.PromptConfig)
}
// if there was a problem
if err != nil {
var ans interface{}
var validationErr error
// prompt and validation loop
for {
if validationErr != nil {
if err := q.Prompt.Error(&options.PromptConfig, validationErr); err != nil {
return err
}
}
var err error
if promptAgainer, ok := q.Prompt.(PromptAgainer); ok && validationErr != nil {
ans, err = promptAgainer.PromptAgain(&options.PromptConfig, ans, validationErr)
} else {
ans, err = q.Prompt.Prompt(&options.PromptConfig)
}
if err != nil {
return err
}
validationErr = validate(q, ans)
if validationErr == nil {
break
}
}

if q.Transform != nil {
Expand Down
70 changes: 70 additions & 0 deletions survey_test.go
@@ -1,6 +1,8 @@
package survey

import (
"errors"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -560,3 +562,71 @@ func Test_computeCursorOffset_Select(t *testing.T) {
})
}
}

func TestAsk_Validation(t *testing.T) {
p := &mockPrompt{
answers: []string{"", "company", "COM", "com"},
}

var res struct {
TLDN string
}
err := Ask([]*Question{
{
Name: "TLDN",
Prompt: p,
Validate: func(v interface{}) error {
s := v.(string)
if strings.ToLower(s) != s {
return errors.New("value contains uppercase characters")
}
return nil
},
},
}, &res, WithValidator(MinLength(1)), WithValidator(MaxLength(5)))
if err != nil {
t.Fatalf("Ask() = %v", err)
}

if res.TLDN != "com" {
t.Errorf("answer: %q, want %q", res.TLDN, "com")
}
if p.cleanups != 1 {
t.Errorf("cleanups: %d, want %d", p.cleanups, 1)
}
if err1 := p.printedErrors[0].Error(); err1 != "value is too short. Min length is 1" {
t.Errorf("printed error 1: %q, want %q", err1, "value is too short. Min length is 1")
}
if err2 := p.printedErrors[1].Error(); err2 != "value is too long. Max length is 5" {
t.Errorf("printed error 2: %q, want %q", err2, "value is too long. Max length is 5")
}
if err3 := p.printedErrors[2].Error(); err3 != "value contains uppercase characters" {
t.Errorf("printed error 2: %q, want %q", err3, "value contains uppercase characters")
}
}

type mockPrompt struct {
index int
answers []string
cleanups int
printedErrors []error
}

func (p *mockPrompt) Prompt(*PromptConfig) (interface{}, error) {
if p.index >= len(p.answers) {
return nil, errors.New("no more answers")
}
val := p.answers[p.index]
p.index++
return val, nil
}

func (p *mockPrompt) Cleanup(*PromptConfig, interface{}) error {
p.cleanups++
return nil
}

func (p *mockPrompt) Error(_ *PromptConfig, err error) error {
p.printedErrors = append(p.printedErrors, err)
return nil
}

0 comments on commit 099a968

Please sign in to comment.