diff --git a/internal/btf/btf.go b/internal/btf/btf.go index d960335c5..47eb2d9d2 100644 --- a/internal/btf/btf.go +++ b/internal/btf/btf.go @@ -525,7 +525,7 @@ func fixupDatasec(rawTypes []rawType, rawStrings *stringTable, sectionSizes map[ // Copy creates a copy of Spec. func (s *Spec) Copy() *Spec { - types, _ := copyTypes(s.types, nil) + types := copyTypes(s.types, nil) namedTypes := make(map[essentialName][]Type) for _, typ := range types { diff --git a/internal/btf/core.go b/internal/btf/core.go index 61c916629..417f1e8ff 100644 --- a/internal/btf/core.go +++ b/internal/btf/core.go @@ -269,19 +269,13 @@ var errImpossibleRelocation = errors.New("impossible relocation") // the better the target is. func coreCalculateFixups(byteOrder binary.ByteOrder, local Type, targets []Type, relos CORERelos) ([]COREFixup, error) { localID := local.ID() - local, err := copyType(local, skipQualifiersAndTypedefs) - if err != nil { - return nil, err - } + local = Copy(local, UnderlyingType) bestScore := len(relos) var bestFixups []COREFixup for i := range targets { targetID := targets[i].ID() - target, err := copyType(targets[i], skipQualifiersAndTypedefs) - if err != nil { - return nil, err - } + target := Copy(targets[i], UnderlyingType) score := 0 // lower is better fixups := make([]COREFixup, 0, len(relos)) @@ -1009,31 +1003,3 @@ func coreAreMembersCompatible(localType Type, targetType Type) error { return fmt.Errorf("type %s: %w", localType, ErrNotSupported) } } - -func skipQualifiersAndTypedefs(typ Type) (Type, error) { - result := typ - for depth := 0; depth <= maxTypeDepth; depth++ { - switch v := (result).(type) { - case qualifier: - result = v.qualify() - case *Typedef: - result = v.Type - default: - return result, nil - } - } - return nil, errors.New("exceeded type depth") -} - -func skipQualifiers(typ Type) (Type, error) { - result := typ - for depth := 0; depth <= maxTypeDepth; depth++ { - switch v := (result).(type) { - case qualifier: - result = v.qualify() - default: - return result, nil - } - } - return nil, errors.New("exceeded type depth") -} diff --git a/internal/btf/core_test.go b/internal/btf/core_test.go index 3b23719c6..6957551ab 100644 --- a/internal/btf/core_test.go +++ b/internal/btf/core_test.go @@ -584,8 +584,9 @@ func TestCORECopyWithoutQualifiers(t *testing.T) { root := &Volatile{} root.Type = test.fn(root) - _, err := copyType(root, skipQualifiersAndTypedefs) - qt.Assert(t, err, qt.Not(qt.IsNil)) + cycle, ok := Copy(root, UnderlyingType).(*cycle) + qt.Assert(t, ok, qt.IsTrue) + qt.Assert(t, cycle.root, qt.Equals, root) }) } @@ -595,8 +596,7 @@ func TestCORECopyWithoutQualifiers(t *testing.T) { v := a.fn(&Pointer{Target: b.fn(&Int{Name: "z"})}) want := &Pointer{Target: &Int{Name: "z"}} - got, err := copyType(v, skipQualifiersAndTypedefs) - qt.Assert(t, err, qt.IsNil) + got := Copy(v, UnderlyingType) qt.Assert(t, got, qt.DeepEquals, want) }) } @@ -611,8 +611,7 @@ func TestCORECopyWithoutQualifiers(t *testing.T) { t.Log(q.name) } - got, err := copyType(v, skipQualifiersAndTypedefs) - qt.Assert(t, err, qt.IsNil) + got := Copy(v, UnderlyingType) qt.Assert(t, got, qt.DeepEquals, root) }) } diff --git a/internal/btf/format.go b/internal/btf/format.go index 159319c33..3acb9b38d 100644 --- a/internal/btf/format.go +++ b/internal/btf/format.go @@ -63,12 +63,7 @@ func (gf *GoFormatter) writeTypeDecl(name string, typ Type) error { return fmt.Errorf("need a name for type %s", typ) } - typ, err := skipQualifiers(typ) - if err != nil { - return err - } - - switch v := typ.(type) { + switch v := skipQualifiers(typ).(type) { case *Enum: fmt.Fprintf(&gf.w, "type %s int32", name) if len(v.Values) == 0 { @@ -83,10 +78,11 @@ func (gf *GoFormatter) writeTypeDecl(name string, typ Type) error { gf.w.WriteString(")") return nil - } - fmt.Fprintf(&gf.w, "type %s ", name) - return gf.writeTypeLit(typ, 0) + default: + fmt.Fprintf(&gf.w, "type %s ", name) + return gf.writeTypeLit(v, 0) + } } // writeType outputs the name of a named type or a literal describing the type. @@ -96,10 +92,7 @@ func (gf *GoFormatter) writeTypeDecl(name string, typ Type) error { // foo (if foo is a named type) // uint32 func (gf *GoFormatter) writeType(typ Type, depth int) error { - typ, err := skipQualifiers(typ) - if err != nil { - return err - } + typ = skipQualifiers(typ) name := gf.Names[typ] if name != "" { @@ -124,12 +117,8 @@ func (gf *GoFormatter) writeTypeLit(typ Type, depth int) error { return errNestedTooDeep } - typ, err := skipQualifiers(typ) - if err != nil { - return err - } - - switch v := typ.(type) { + var err error + switch v := skipQualifiers(typ).(type) { case *Int: gf.writeIntLit(v) @@ -154,7 +143,7 @@ func (gf *GoFormatter) writeTypeLit(typ Type, depth int) error { err = gf.writeDatasecLit(v, depth) default: - return fmt.Errorf("type %s: %w", typ, ErrNotSupported) + return fmt.Errorf("type %T: %w", v, ErrNotSupported) } if err != nil { @@ -302,3 +291,16 @@ func (gf *GoFormatter) writePadding(bytes uint32) { fmt.Fprintf(&gf.w, "_ [%d]byte; ", bytes) } } + +func skipQualifiers(typ Type) Type { + result := typ + for depth := 0; depth <= maxTypeDepth; depth++ { + switch v := (result).(type) { + case qualifier: + result = v.qualify() + default: + return result + } + } + return &cycle{typ} +} diff --git a/internal/btf/types.go b/internal/btf/types.go index 198e60189..2e6bc41ed 100644 --- a/internal/btf/types.go +++ b/internal/btf/types.go @@ -547,6 +547,20 @@ func (f *Float) copy() Type { return &cpy } +// cycle is a type which had to be elided since it exceeded maxTypeDepth. +type cycle struct { + root Type +} + +func (c *cycle) ID() TypeID { return math.MaxUint32 } +func (c *cycle) Format(fs fmt.State, verb rune) { formatType(fs, verb, c, "root=", c.root) } +func (c *cycle) TypeName() string { return "" } +func (c *cycle) walk(*typeDeque) {} +func (c *cycle) copy() Type { + cpy := *c + return &cpy +} + type sizer interface { size() uint32 } @@ -626,12 +640,7 @@ func Sizeof(typ Type) (int, error) { // // Currently only supports the subset of types necessary for bitfield relocations. func alignof(typ Type) (int, error) { - typ, err := skipQualifiersAndTypedefs(typ) - if err != nil { - return 0, err - } - - switch t := typ.(type) { + switch t := UnderlyingType(typ).(type) { case *Enum: return int(t.size()), nil case *Int: @@ -641,44 +650,40 @@ func alignof(typ Type) (int, error) { } } -// Copy a Type recursively. -func Copy(typ Type) Type { - typ, _ = copyType(typ, nil) - return typ -} - -// copy a Type recursively. +// Transformer modifies a given Type and returns the result. // -// typ may form a cycle. +// For example, UnderlyingType removes any qualifiers or typedefs from a type. +// See the example on Copy for how to use a transform. +type Transformer func(Type) Type + +// Copy a Type recursively. // -// Returns any errors from transform verbatim. -func copyType(typ Type, transform func(Type) (Type, error)) (Type, error) { +// typ may form a cycle. If transform is not nil, it is called with the +// to be copied type, and the returned value is copied instead. +func Copy(typ Type, transform Transformer) Type { copies := make(copier) - return typ, copies.copy(&typ, transform) + copies.copy(&typ, transform) + return typ } // copy a slice of Types recursively. // -// Types may form a cycle. -// -// Returns any errors from transform verbatim. -func copyTypes(types []Type, transform func(Type) (Type, error)) ([]Type, error) { +// See Copy for the semantics. +func copyTypes(types []Type, transform Transformer) []Type { result := make([]Type, len(types)) copy(result, types) copies := make(copier) for i := range result { - if err := copies.copy(&result[i], transform); err != nil { - return nil, err - } + copies.copy(&result[i], transform) } - return result, nil + return result } type copier map[Type]Type -func (c copier) copy(typ *Type, transform func(Type) (Type, error)) error { +func (c copier) copy(typ *Type, transform Transformer) { var work typeDeque for t := typ; t != nil; t = work.pop() { // *t is the identity of the type. @@ -689,11 +694,7 @@ func (c copier) copy(typ *Type, transform func(Type) (Type, error)) error { var cpy Type if transform != nil { - tf, err := transform(*t) - if err != nil { - return fmt.Errorf("copy %s: %w", *t, err) - } - cpy = tf.copy() + cpy = transform(*t).copy() } else { cpy = (*t).copy() } @@ -704,8 +705,6 @@ func (c copier) copy(typ *Type, transform func(Type) (Type, error)) error { // Mark any nested types for copying. cpy.walk(&work) } - - return nil } // typeDeque keeps track of pointers to types which still @@ -1040,9 +1039,6 @@ func newEssentialName(name string) essentialName { } // UnderlyingType skips qualifiers and Typedefs. -// -// May return typ verbatim if too many types have to be skipped to protect against -// circular Types. func UnderlyingType(typ Type) Type { result := typ for depth := 0; depth <= maxTypeDepth; depth++ { @@ -1055,8 +1051,7 @@ func UnderlyingType(typ Type) Type { return result } } - // Return the original argument, since we can't find an underlying type. - return typ + return &cycle{typ} } type formatState struct { @@ -1132,7 +1127,7 @@ func formatType(f fmt.State, verb rune, t formattableType, extra ...interface{}) switch v := arg.(type) { case string: _, _ = io.WriteString(f, v) - wantSpace = v[len(v)-1] != '=' + wantSpace = len(v) > 0 && v[len(v)-1] != '=' continue case formattableType: diff --git a/internal/btf/types_test.go b/internal/btf/types_test.go index f6f3d5770..2d46bb7ee 100644 --- a/internal/btf/types_test.go +++ b/internal/btf/types_test.go @@ -35,11 +35,11 @@ func TestSizeof(t *testing.T) { } } -func TestCopyType(t *testing.T) { - _, _ = copyType((*Void)(nil), nil) +func TestCopy(t *testing.T) { + _ = Copy((*Void)(nil), nil) in := &Int{Size: 4} - out, _ := copyType(in, nil) + out := Copy(in, nil) in.Size = 8 if size := out.(*Int).Size; size != 4 { @@ -47,13 +47,13 @@ func TestCopyType(t *testing.T) { } t.Run("cyclical", func(t *testing.T) { - _, _ = copyType(newCyclicalType(2), nil) + _ = Copy(newCyclicalType(2), nil) }) t.Run("identity", func(t *testing.T) { u16 := &Int{Size: 2} - out, _ := copyType(&Struct{ + out := Copy(&Struct{ Members: []Member{ {Name: "a", Type: u16}, {Name: "b", Type: u16}, @@ -65,6 +65,17 @@ func TestCopyType(t *testing.T) { }) } +func BenchmarkCopy(b *testing.B) { + typ := newCyclicalType(10) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + Copy(typ, nil) + } +} + // The following are valid Types. // // There currently is no better way to document which @@ -123,6 +134,7 @@ func TestType(t *testing.T) { Vars: []VarSecinfo{{Type: &Void{}}}, } }, + func() Type { return &cycle{&Void{}} }, } compareTypes := cmp.Comparer(func(a, b *Type) bool { @@ -244,6 +256,8 @@ func TestFormatType(t *testing.T) { t2 := &testFormattableType{"foo", []interface{}{t1}} + t3 := &testFormattableType{extra: []interface{}{""}} + tests := []struct { t formattableType fmt string @@ -262,6 +276,8 @@ func TestFormatType(t *testing.T) { {t2, "%v", []string{goType, t2.name}, []string{"extra"}}, // %1v does print nested types' extra. {t2, "%1v", []string{goType, t2.name, "extra"}, nil}, + // empty strings in extra don't emit anything. + {t3, "%v", []string{"[]"}, nil}, } for _, test := range tests { @@ -322,8 +338,9 @@ func TestUnderlyingType(t *testing.T) { root := &Volatile{} root.Type = test.fn(root) - got := UnderlyingType(root) - qt.Assert(t, got, qt.Equals, root) + got, ok := UnderlyingType(root).(*cycle) + qt.Assert(t, ok, qt.IsTrue) + qt.Assert(t, got.root, qt.Equals, root) }) } @@ -357,3 +374,12 @@ func BenchmarkUnderlyingType(b *testing.B) { } }) } + +// Copy can be used with UnderlyingType to strip qualifiers from a type graph. +func ExampleCopy_stripQualifiers() { + a := &Volatile{Type: &Pointer{Target: &Typedef{Name: "foo", Type: &Int{Size: 2}}}} + b := Copy(a, UnderlyingType) + // b has Volatile and Typedef removed. + fmt.Printf("%3v\n", b) + // Output: Pointer[target=Int[unsigned size=16]] +} diff --git a/internal/cmd/gentypes/main.go b/internal/cmd/gentypes/main.go index b8f0903bd..5d91e4bf0 100644 --- a/internal/cmd/gentypes/main.go +++ b/internal/cmd/gentypes/main.go @@ -479,7 +479,7 @@ import ( } func outputPatchedStruct(gf *btf.GoFormatter, w *bytes.Buffer, id string, s *btf.Struct, patches []patch) error { - s = btf.Copy(s).(*btf.Struct) + s = btf.Copy(s, nil).(*btf.Struct) for i, p := range patches { if err := p(s); err != nil {