/
embedded_walker.go
144 lines (129 loc) · 4.5 KB
/
embedded_walker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package errcheck
import (
"fmt"
"go/types"
)
// walkThroughEmbeddedInterfaces returns a slice of Interfaces that
// we need to walk through in order to reach the actual definition,
// in an Interface, of the method selected by the given selection.
//
// false will be returned in the second return value if:
// - the right side of the selection is not a function
// - the actual definition of the function is not in an Interface
//
// The returned slice will contain all the interface types that need
// to be walked through to reach the actual definition.
//
// For example, say we have:
//
// type Inner interface {Method()}
// type Middle interface {Inner}
// type Outer interface {Middle}
// type T struct {Outer}
// type U struct {T}
// type V struct {U}
//
// And then the selector:
//
// V.Method
//
// We'll return [Outer, Middle, Inner] by first walking through the embedded structs
// until we reach the Outer interface, then descending through the embedded interfaces
// until we find the one that actually explicitly defines Method.
func walkThroughEmbeddedInterfaces(sel *types.Selection) ([]types.Type, bool) {
fn, ok := sel.Obj().(*types.Func)
if !ok {
return nil, false
}
// Start off at the receiver.
currentT := sel.Recv()
// First, we can walk through any Struct fields provided
// by the selection Index() method. We ignore the last
// index because it would give the method itself.
indexes := sel.Index()
for _, fieldIndex := range indexes[:len(indexes)-1] {
currentT = getTypeAtFieldIndex(currentT, fieldIndex)
}
// Now currentT is either a type implementing the actual function,
// an Invalid type (if the receiver is a package), or an interface.
//
// If it's not an Interface, then we're done, as this function
// only cares about Interface-defined functions.
//
// If it is an Interface, we potentially need to continue digging until
// we find the Interface that actually explicitly defines the function.
interfaceT, ok := maybeUnname(currentT).(*types.Interface)
if !ok {
return nil, false
}
// The first interface we pass through is this one we've found. We return the possibly
// wrapping types.Named because it is more useful to work with for callers.
result := []types.Type{currentT}
// If this interface itself explicitly defines the given method
// then we're done digging.
for !explicitlyDefinesMethod(interfaceT, fn) {
// Otherwise, we find which of the embedded interfaces _does_
// define the method, add it to our list, and loop.
namedInterfaceT, ok := getEmbeddedInterfaceDefiningMethod(interfaceT, fn)
if !ok {
// This should be impossible as long as we type-checked: either the
// interface or one of its embedded ones must implement the method...
panic(fmt.Sprintf("either %v or one of its embedded interfaces must implement %v", currentT, fn))
}
result = append(result, namedInterfaceT)
interfaceT = namedInterfaceT.Underlying().(*types.Interface)
}
return result, true
}
func getTypeAtFieldIndex(startingAt types.Type, fieldIndex int) types.Type {
t := maybeUnname(maybeDereference(startingAt))
s, ok := t.(*types.Struct)
if !ok {
panic(fmt.Sprintf("cannot get Field of a type that is not a struct, got a %T", t))
}
return s.Field(fieldIndex).Type()
}
// getEmbeddedInterfaceDefiningMethod searches through any embedded interfaces of the
// passed interface searching for one that defines the given function. If found, the
// types.Named wrapping that interface will be returned along with true in the second value.
//
// If no such embedded interface is found, nil and false are returned.
func getEmbeddedInterfaceDefiningMethod(interfaceT *types.Interface, fn *types.Func) (*types.Named, bool) {
for i := 0; i < interfaceT.NumEmbeddeds(); i++ {
embedded := interfaceT.Embedded(i)
if definesMethod(embedded.Underlying().(*types.Interface), fn) {
return embedded, true
}
}
return nil, false
}
func explicitlyDefinesMethod(interfaceT *types.Interface, fn *types.Func) bool {
for i := 0; i < interfaceT.NumExplicitMethods(); i++ {
if interfaceT.ExplicitMethod(i) == fn {
return true
}
}
return false
}
func definesMethod(interfaceT *types.Interface, fn *types.Func) bool {
for i := 0; i < interfaceT.NumMethods(); i++ {
if interfaceT.Method(i) == fn {
return true
}
}
return false
}
func maybeDereference(t types.Type) types.Type {
p, ok := t.(*types.Pointer)
if ok {
return p.Elem()
}
return t
}
func maybeUnname(t types.Type) types.Type {
n, ok := t.(*types.Named)
if ok {
return n.Underlying()
}
return t
}