Skip to content
This repository has been archived by the owner on Apr 19, 2024. It is now read-only.

Add WriteAnswer support for promoted fields #366

Merged
merged 1 commit into from Aug 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 36 additions & 18 deletions core/write.go
Expand Up @@ -24,6 +24,11 @@ type OptionAnswer struct {
Index int
}

type reflectField struct {
value reflect.Value
fieldType reflect.StructField
}

func OptionAnswerList(incoming []string) []OptionAnswer {
list := []OptionAnswer{}
for i, opt := range incoming {
Expand Down Expand Up @@ -63,13 +68,12 @@ func WriteAnswer(t interface{}, name string, v interface{}) (err error) {
}

// get the name of the field that matches the string we were given
fieldIndex, err := findFieldIndex(elem, name)
field, _, err := findField(elem, name)
// if something went wrong
if err != nil {
// bubble up
return err
}
field := elem.Field(fieldIndex)
// handle references to the Settable interface aswell
if s, ok := field.Interface().(Settable); ok {
// use the interface method
Expand Down Expand Up @@ -156,37 +160,51 @@ func IsFieldNotMatch(err error) (string, bool) {

// BUG(AlecAivazis): the current implementation might cause weird conflicts if there are
// two fields with same name that only differ by casing.
func findFieldIndex(s reflect.Value, name string) (int, error) {
// the type of the value
sType := s.Type()
func findField(s reflect.Value, name string) (reflect.Value, reflect.StructField, error) {

// first look for matching tags so we can overwrite matching field names
for i := 0; i < sType.NumField(); i++ {
// the field we are current scanning
field := sType.Field(i)
fields := flattenFields(s)

// first look for matching tags so we can overwrite matching field names
for _, f := range fields {
// the value of the survey tag
tag := field.Tag.Get(tagName)
tag := f.fieldType.Tag.Get(tagName)
// if the tag matches the name we are looking for
if tag != "" && tag == name {
// then we found our index
return i, nil
return f.value, f.fieldType, nil
}
}

// then look for matching names
for i := 0; i < sType.NumField(); i++ {
// the field we are current scanning
field := sType.Field(i)

for _, f := range fields {
// if the name of the field matches what we're looking for
if strings.ToLower(field.Name) == strings.ToLower(name) {
return i, nil
if strings.ToLower(f.fieldType.Name) == strings.ToLower(name) {
return f.value, f.fieldType, nil
}
}

// we didn't find the field
return -1, errFieldNotMatch{name}
return reflect.Value{}, reflect.StructField{}, errFieldNotMatch{name}
}

func flattenFields(s reflect.Value) []reflectField {
sType := s.Type()
numField := sType.NumField()
fields := make([]reflectField, 0, numField)
for i := 0; i < numField; i++ {
fieldType := sType.Field(i)
field := s.Field(i)

if field.Kind() == reflect.Struct && fieldType.Anonymous {
// field is a promoted structure
for _, f := range flattenFields(field) {
fields = append(fields, f)
}
continue
}
fields = append(fields, reflectField{field, fieldType})
}
return fields
}

// isList returns true if the element is something we can Len()
Expand Down
178 changes: 157 additions & 21 deletions core/write_test.go
Expand Up @@ -305,12 +305,12 @@ func TestWriteAnswer_returnsErrWhenFieldNotFound(t *testing.T) {
}
}

func TestFindFieldIndex_canFindExportedField(t *testing.T) {
func TestFindField_canFindExportedField(t *testing.T) {
// create a reflective wrapper over the struct to look through
val := reflect.ValueOf(struct{ Name string }{})
val := reflect.ValueOf(struct{ Name string }{Name: "Jack"})

// find the field matching "name"
fieldIndex, err := findFieldIndex(val, "name")
field, fieldType, err := findField(val, "name")
// if something went wrong
if err != nil {
// the test failed
Expand All @@ -319,20 +319,28 @@ func TestFindFieldIndex_canFindExportedField(t *testing.T) {
}

// make sure we got the right value
if val.Type().Field(fieldIndex).Name != "Name" {
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", val.Type().Field(fieldIndex).Name)
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right field type
if fieldType.Name != "Name" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", fieldType.Name)
}
}

func TestFindFieldIndex_canFindTaggedField(t *testing.T) {
func TestFindField_canFindTaggedField(t *testing.T) {
// the struct to look through
val := reflect.ValueOf(struct {
Username string `survey:"name"`
}{})
}{
Username: "Jack",
})

// find the field matching "name"
fieldIndex, err := findFieldIndex(val, "name")
field, fieldType, err := findField(val, "name")
// if something went wrong
if err != nil {
// the test failed
Expand All @@ -341,52 +349,180 @@ func TestFindFieldIndex_canFindTaggedField(t *testing.T) {
}

// make sure we got the right value
if val.Type().Field(fieldIndex).Name != "Username" {
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", val.Type().Field(fieldIndex).Name)
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right fieldType
if fieldType.Name != "Username" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name)
}
}

func TestFindFieldIndex_canHandleCapitalAnswerNames(t *testing.T) {
func TestFindField_canHandleCapitalAnswerNames(t *testing.T) {
// create a reflective wrapper over the struct to look through
val := reflect.ValueOf(struct{ Name string }{})
val := reflect.ValueOf(struct{ Name string }{Name: "Jack"})

// find the field matching "name"
fieldIndex, err := findFieldIndex(val, "Name")
field, fieldType, err := findField(val, "Name")
// if something went wrong
if err != nil {
// the test failed
t.Error(err.Error())
return
}

// make sure we got the right value
if val.Type().Field(fieldIndex).Name != "Name" {
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right fieldType
if fieldType.Name != "Name" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", val.Type().Field(fieldIndex).Name)
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", fieldType.Name)
}
}

func TestFindFieldIndex_tagOverwriteFieldName(t *testing.T) {
func TestFindField_tagOverwriteFieldName(t *testing.T) {
// the struct to look through
val := reflect.ValueOf(struct {
Name string
Username string `survey:"name"`
}{})
}{
Name: "Ralf",
Username: "Jack",
})

// find the field matching "name"
field, fieldType, err := findField(val, "name")
// if something went wrong
if err != nil {
// the test failed
t.Error(err.Error())
return
}

// make sure we got the right value
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right fieldType
if fieldType.Name != "Username" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name)
}
}

func TestFindField_supportsPromotedFields(t *testing.T) {
// create a reflective wrapper over the struct to look through
type Common struct {
Name string
}

type Strct struct {
Common // Name field added by composition
Username string
}

val := reflect.ValueOf(Strct{Common: Common{Name: "Jack"}})

// find the field matching "name"
field, fieldType, err := findField(val, "Name")
// if something went wrong
if err != nil {
// the test failed
t.Error(err.Error())
return
}
// make sure we got the right value
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right fieldType
if fieldType.Name != "Name" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", fieldType.Name)
}
}

func TestFindField_promotedFieldsWithTag(t *testing.T) {
// create a reflective wrapper over the struct to look through
type Common struct {
Username string `survey:"name"`
}

type Strct struct {
Common // Name field added by composition
Name string
}

val := reflect.ValueOf(Strct{
Common: Common{Username: "Jack"},
Name: "Ralf",
})

// find the field matching "name"
fieldIndex, err := findFieldIndex(val, "name")
field, fieldType, err := findField(val, "name")
// if something went wrong
if err != nil {
// the test failed
t.Error(err.Error())
return
}
// make sure we got the right value
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right fieldType
if fieldType.Name != "Username" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name)
}
}

func TestFindField_promotedFieldsDontHavePriorityOverTags(t *testing.T) {
// create a reflective wrapper over the struct to look through
type Common struct {
Name string
}

type Strct struct {
Common // Name field added by composition
Username string `survey:"name"`
}

val := reflect.ValueOf(Strct{
Common: Common{Name: "Ralf"},
Username: "Jack",
})

// find the field matching "name"
field, fieldType, err := findField(val, "name")
// if something went wrong
if err != nil {
// the test failed
t.Error(err.Error())
return
}
// make sure we got the right value
if val.Type().Field(fieldIndex).Name != "Username" {
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right fieldType
if fieldType.Name != "Username" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", val.Type().Field(fieldIndex).Name)
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name)
}
}

Expand Down