Skip to content

Commit

Permalink
feat(Smuggle): fields-path param can contain method calls
Browse files Browse the repository at this point in the history
Signed-off-by: Maxime Soulé <btik-git@scoubidou.com>
  • Loading branch information
maxatome committed Apr 22, 2024
1 parent 856e80a commit 22ed66e
Show file tree
Hide file tree
Showing 3 changed files with 425 additions and 95 deletions.
137 changes: 122 additions & 15 deletions td/td_smuggle.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2018-2023, Maxime Soulé
// Copyright (c) 2018-2024, Maxime Soulé
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -72,6 +72,7 @@ var smuggleValueType = reflect.TypeOf(smuggleValue{})
type smuggleField struct {
Name string
Indexed bool
Method bool
}

func joinFieldsPath(path []smuggleField) string {
Expand All @@ -84,6 +85,9 @@ func joinFieldsPath(path []smuggleField) string {
buf.WriteByte('.')
}
buf.WriteString(part.Name)
if part.Method {
buf.WriteString("()")
}
}
}
return buf.String()
Expand All @@ -94,6 +98,7 @@ func splitFieldsPath(origPath string) ([]smuggleField, error) {
return nil, fmt.Errorf("FIELD_PATH cannot be empty")
}

privateField := ""
var res []smuggleField
for path := origPath; len(path) > 0; {
r, _ := utf8.DecodeRuneInString(path)
Expand Down Expand Up @@ -130,12 +135,33 @@ func splitFieldsPath(origPath string) ([]smuggleField, error) {
field, path = path[:end], path[end:]
}

for j, r := range field {
if !unicode.IsLetter(r) && (j == 0 || !unicode.IsNumber(r)) {
return nil, fmt.Errorf("unexpected %q in field name %q in FIELDS_PATH %q", r, field, origPath)
if strings.HasSuffix(field, "()") {
if len(field) == 2 {
return nil, fmt.Errorf("missing method name before () in FIELDS_PATH %q", origPath)
}
for j, r := range field[:len(field)-2] {
if j == 0 && !unicode.IsUpper(r) {
return nil, fmt.Errorf("method name %q is not public in FIELDS_PATH %q", field, origPath)
}
if !unicode.IsLetter(r) && !unicode.IsNumber(r) {
return nil, fmt.Errorf("unexpected %q in method name %q in FIELDS_PATH %q", r, field, origPath)
}
}
if privateField != "" {
return nil, fmt.Errorf("cannot call method %s as it is based on an unexported field %q in FIELDS_PATH %q", field, privateField, origPath)
}
res = append(res, smuggleField{Name: field[:len(field)-2], Method: true})
} else {
for j, r := range field {
if privateField == "" && j == 0 && !unicode.IsUpper(r) {
privateField = field
}
if !unicode.IsLetter(r) && (j == 0 || !unicode.IsNumber(r)) {
return nil, fmt.Errorf("unexpected %q in field name %q in FIELDS_PATH %q", r, field, origPath)
}
}
res = append(res, smuggleField{Name: field})
}
res = append(res, smuggleField{Name: field})
}
}
return res, nil
Expand All @@ -155,7 +181,63 @@ func buildFieldsPathFn(path string) (func(any) (smuggleValue, error), error) {
vgot := reflect.ValueOf(got)

for idxPart, field := range parts {
if field.Method {
var method reflect.Value
for {
method = vgot.MethodByName(field.Name)
if !method.IsValid() {
switch vgot.Kind() {
case reflect.Interface, reflect.Ptr:
if !vgot.IsNil() {
vgot = vgot.Elem()
continue
}
return smuggleValue{}, nilFieldErr(parts[:idxPart])
}
if idxPart > 0 {
return smuggleValue{}, fmt.Errorf(
"field %s (type %s) does not implement %s() method",
joinFieldsPath(parts[:idxPart]),
vgot.Type(),
field.Name)
}
return smuggleValue{}, fmt.Errorf(
"type %s has no method %s()", vgot.Type(), field.Name)
}
break
}
mt := method.Type()
if mt.NumIn() != 0 ||
(mt.NumOut() != 1 && (mt.NumOut() != 2 || mt.Out(1) != types.Error)) {
return smuggleValue{}, fmt.Errorf(
"cannot call %s, signature %s not handled, only func() A or func() (A, error) allowed",
joinFieldsPath(parts[:idxPart+1]),
method.Type())
}
var ret []reflect.Value
var panicked any
func() {
defer func() { panicked = recover() }()
ret = method.Call(nil)
}()
if panicked != nil {
return smuggleValue{}, fmt.Errorf(
"method %s panicked: %v",
joinFieldsPath(parts[:idxPart+1]),
panicked)
}
if len(ret) == 2 && !ret[1].IsNil() {
return smuggleValue{}, fmt.Errorf(
"method %s returned an error: %w",
joinFieldsPath(parts[:idxPart+1]),
ret[1].Interface().(error))
}
vgot = ret[0]
continue
}

// Resolve all interface and pointer dereferences
origKind := vgot.Kind()
for {
switch vgot.Kind() {
case reflect.Interface, reflect.Ptr:
Expand All @@ -178,13 +260,22 @@ func buildFieldsPathFn(path string) (func(any) (smuggleValue, error), error) {
}
continue
}
deref := ""
if origKind != vgot.Kind() {
deref = " (after dereferencing)"
}
if idxPart == 0 {
return smuggleValue{},
fmt.Errorf("it is a %s and should be a struct", vgot.Kind())
fmt.Errorf("it is a %s%s and should be a struct", vgot.Kind(), deref)
}
if parts[idxPart-1].Method {
return smuggleValue{}, fmt.Errorf(
"method %s returned a %s%s and should be a struct",
joinFieldsPath(parts[:idxPart]), vgot.Kind(), deref)
}
return smuggleValue{}, fmt.Errorf(
"field %q is a %s and should be a struct",
joinFieldsPath(parts[:idxPart]), vgot.Kind())
"field %q is a %s%s and should be a struct",
joinFieldsPath(parts[:idxPart]), vgot.Kind(), deref)
}

switch vgot.Kind() {
Expand Down Expand Up @@ -546,6 +637,7 @@ func buildCaster(outType reflect.Type, useString bool) reflect.Value {
// several struct layers.
//
// type A struct{ Num int }
// // func (a *A) String() string { return fmt.Sprintf("Num is %d", a.Num) }
// type B struct{ As map[string]*A }
// type C struct{ B B }
// got := C{B: B{As: map[string]*A{"foo": {Num: 12}}}}
Expand All @@ -563,15 +655,30 @@ func buildCaster(outType reflect.Type, useString bool) reflect.Value {
// // Tests that got.B.As["foo"].Num is 12
// td.Cmp(t, got, td.Smuggle("B.As[foo].Num", 12))
//
// Contrary to [JSONPointer] operator, private fields can be
// followed. Arrays, slices and maps work using the index/key inside
// square brackets (e.g. [12] or [foo]). Maps work only for simple key
// types (string or numbers), without "" when using strings
// (e.g. [foo]).
// In addition, simple public methods can also be called like in:
//
// td.Cmp(t, got, td.Smuggle("B.As[foo].String()", "Num is 12"))
//
// Allowed methods must not take any parameter and must return one
// value or a value and an error. For the latter case, if the method
// returns a non-nil error, the comparison fails. The comparison also
// fails if a panic occurs or if a method cannot be called. No private
// fields should be traversed before calling the method. For fun,
// consider a more complex example involving [reflect] and chaining
// method calls:
//
// got := reflect.Valueof(&C{B: B{As: map[string]*A{"foo": {Num: 12}}}})
// td.Cmp(t, got, td.Smuggle("Elem().Interface().B.As[foo].String()", "Num is 12"))
//
// Contrary to [JSONPointer] operator, private fields can be followed
// and public methods on public fields can be called. Arrays, slices
// and maps work using the index/key inside square brackets (e.g. [12]
// or [foo]). Maps work only for simple key types (string or numbers),
// without "" when using strings (e.g. [foo]).
//
// Behind the scenes, a temporary function is automatically created to
// achieve the same goal, but add some checks against nil values and
// auto-dereference interfaces and pointers, even on several levels,
// achieve the same goal, but adds some checks against nil values and
// auto-dereferences interfaces and pointers, even on several levels,
// like in:
//
// type A struct{ N any }
Expand Down

0 comments on commit 22ed66e

Please sign in to comment.