diff --git a/btf/core.go b/btf/core.go index 54c911ab7..f952b654e 100644 --- a/btf/core.go +++ b/btf/core.go @@ -855,7 +855,7 @@ func coreAreTypesCompatible(localType Type, targetType Type) error { depth = 0 ) - for ; l != nil && t != nil; l, t = localTs.shift(), targetTs.shift() { + for ; l != nil && t != nil; l, t = localTs.Shift(), targetTs.Shift() { if depth >= maxTypeDepth { return errors.New("types are nested too deep") } @@ -873,8 +873,8 @@ func coreAreTypesCompatible(localType Type, targetType Type) error { case *Pointer, *Array: depth++ - walkType(localType, localTs.push) - walkType(targetType, targetTs.push) + walkType(localType, localTs.Push) + walkType(targetType, targetTs.Push) case *FuncProto: tv := targetType.(*FuncProto) @@ -883,8 +883,8 @@ func coreAreTypesCompatible(localType Type, targetType Type) error { } depth++ - walkType(localType, localTs.push) - walkType(targetType, targetTs.push) + walkType(localType, localTs.Push) + walkType(targetType, targetTs.Push) default: return fmt.Errorf("unsupported type %T", localType) diff --git a/btf/types.go b/btf/types.go index eba4a9632..81980e8d8 100644 --- a/btf/types.go +++ b/btf/types.go @@ -4,11 +4,11 @@ import ( "fmt" "io" "math" - "math/bits" "reflect" "strings" "github.com/cilium/ebpf/asm" + "github.com/cilium/ebpf/internal" ) const maxTypeDepth = 32 @@ -678,7 +678,7 @@ type copier map[Type]Type func (c copier) copy(typ *Type, transform Transformer) { var work typeDeque - for t := typ; t != nil; t = work.pop() { + for t := typ; t != nil; t = work.Pop() { // *t is the identity of the type. if cpy := c[*t]; cpy != nil { *t = cpy @@ -696,97 +696,11 @@ func (c copier) copy(typ *Type, transform Transformer) { *t = cpy // Mark any nested types for copying. - walkType(cpy, work.push) + walkType(cpy, work.Push) } } -type typeDeque = deque[*Type] - -// deque implements a double ended queue. -type deque[T any] struct { - elems []T - read, write uint64 - mask uint64 -} - -func (dq *deque[T]) empty() bool { - return dq.read == dq.write -} - -func (dq *deque[T]) remainingCap() int { - return len(dq.elems) - int(dq.write-dq.read) -} - -// push adds an element to the end. -func (dq *deque[T]) push(e T) { - if dq.remainingCap() >= 1 { - dq.elems[dq.write&dq.mask] = e - dq.write++ - return - } - - elems := dq.linearise(1) - elems = append(elems, e) - - dq.elems = elems[:cap(elems)] - dq.mask = uint64(cap(elems)) - 1 - dq.read, dq.write = 0, uint64(len(elems)) -} - -// shift returns the first element or the zero value. -func (dq *deque[T]) shift() T { - var zero T - - if dq.empty() { - return zero - } - - index := dq.read & dq.mask - t := dq.elems[index] - dq.elems[index] = zero - dq.read++ - return t -} - -// pop returns the last element or the zero value. -func (dq *deque[T]) pop() T { - var zero T - - if dq.empty() { - return zero - } - - dq.write-- - index := dq.write & dq.mask - t := dq.elems[index] - dq.elems[index] = zero - return t -} - -// linearise the contents of the deque. -// -// The returned slice has space for at least n more elements and has power -// of two capacity. -func (dq *deque[T]) linearise(n int) []T { - length := dq.write - dq.read - need := length + uint64(n) - if need < length { - panic("overflow") - } - - // Round up to the new power of two which is at least 8. - // See https://jameshfisher.com/2018/03/30/round-up-power-2/ - capacity := 1 << (64 - bits.LeadingZeros64(need-1)) - if capacity < 8 { - capacity = 8 - } - - types := make([]T, length, capacity) - pivot := dq.read & dq.mask - copied := copy(types, dq.elems[pivot:]) - copy(types[copied:], dq.elems[:pivot]) - return types -} +type typeDeque = internal.Deque[*Type] // inflateRawTypes takes a list of raw btf types linked via type IDs, and turns // it into a graph of Types connected via pointers. diff --git a/btf/types_test.go b/btf/types_test.go index b46d8d2f1..569d32bd5 100644 --- a/btf/types_test.go +++ b/btf/types_test.go @@ -193,83 +193,6 @@ func countChildren(t *testing.T, typ reflect.Type) int { return n } -func TestDeque(t *testing.T) { - t.Run("pop", func(t *testing.T) { - var dq deque[int] - dq.push(1) - dq.push(2) - - if dq.pop() != 2 { - t.Error("Didn't pop 2 first") - } - - if dq.pop() != 1 { - t.Error("Didn't pop 1 second") - } - - if dq.pop() != 0 { - t.Error("Didn't pop zero") - } - }) - - t.Run("shift", func(t *testing.T) { - var td deque[int] - td.push(1) - td.push(2) - - if td.shift() != 1 { - t.Error("Didn't shift 1 first") - } - - if td.shift() != 2 { - t.Error("Didn't shift b second") - } - - if td.shift() != 0 { - t.Error("Didn't shift zero") - } - }) - - t.Run("push", func(t *testing.T) { - var td deque[int] - td.push(1) - td.push(2) - td.shift() - - for i := 1; i <= 12; i++ { - td.push(i) - } - - if td.shift() != 2 { - t.Error("Didn't shift 2 first") - } - for i := 1; i <= 12; i++ { - if v := td.shift(); v != i { - t.Fatalf("Shifted %d at pos %d", v, i) - } - } - }) - - t.Run("linearise", func(t *testing.T) { - var td deque[int] - td.push(1) - td.push(2) - - all := td.linearise(0) - if len(all) != 2 { - t.Fatal("Expected 2 elements, got", len(all)) - } - - if cap(all)&(cap(all)-1) != 0 { - t.Fatalf("Capacity %d is not a power of two", cap(all)) - } - - if all[0] != 1 || all[1] != 2 { - t.Fatal("Elements don't match") - } - }) -} - type testFormattableType struct { name string extra []interface{} @@ -471,7 +394,7 @@ func BenchmarkWalk(b *testing.B) { for i := 0; i < b.N; i++ { var dq typeDeque - walkType(typ, dq.push) + walkType(typ, dq.Push) } }) } diff --git a/internal/deque.go b/internal/deque.go new file mode 100644 index 000000000..1abc9a9ba --- /dev/null +++ b/internal/deque.go @@ -0,0 +1,89 @@ +package internal + +import "math/bits" + +// Deque implements a double ended queue. +type Deque[T any] struct { + elems []T + read, write uint64 + mask uint64 +} + +func (dq *Deque[T]) Empty() bool { + return dq.read == dq.write +} + +func (dq *Deque[T]) remainingCap() int { + return len(dq.elems) - int(dq.write-dq.read) +} + +// Push adds an element to the end. +func (dq *Deque[T]) Push(e T) { + if dq.remainingCap() >= 1 { + dq.elems[dq.write&dq.mask] = e + dq.write++ + return + } + + elems := dq.linearise(1) + elems = append(elems, e) + + dq.elems = elems[:cap(elems)] + dq.mask = uint64(cap(elems)) - 1 + dq.read, dq.write = 0, uint64(len(elems)) +} + +// Shift returns the first element or the zero value. +func (dq *Deque[T]) Shift() T { + var zero T + + if dq.Empty() { + return zero + } + + index := dq.read & dq.mask + t := dq.elems[index] + dq.elems[index] = zero + dq.read++ + return t +} + +// Pop returns the last element or the zero value. +func (dq *Deque[T]) Pop() T { + var zero T + + if dq.Empty() { + return zero + } + + dq.write-- + index := dq.write & dq.mask + t := dq.elems[index] + dq.elems[index] = zero + return t +} + +// linearise the contents of the deque. +// +// The returned slice has space for at least n more elements and has power +// of two capacity. +func (dq *Deque[T]) linearise(n int) []T { + length := dq.write - dq.read + need := length + uint64(n) + if need < length { + panic("overflow") + } + + // Round up to the new power of two which is at least 8. + // See https://jameshfisher.com/2018/03/30/round-up-power-2/ + capacity := 1 << (64 - bits.LeadingZeros64(need-1)) + if capacity < 8 { + capacity = 8 + } + + types := make([]T, length, capacity) + pivot := dq.read & dq.mask + copied := copy(types, dq.elems[pivot:]) + copy(types[copied:], dq.elems[:pivot]) + return types +} diff --git a/internal/deque_test.go b/internal/deque_test.go new file mode 100644 index 000000000..d611c0719 --- /dev/null +++ b/internal/deque_test.go @@ -0,0 +1,80 @@ +package internal + +import "testing" + +func TestDeque(t *testing.T) { + t.Run("pop", func(t *testing.T) { + var dq Deque[int] + dq.Push(1) + dq.Push(2) + + if dq.Pop() != 2 { + t.Error("Didn't pop 2 first") + } + + if dq.Pop() != 1 { + t.Error("Didn't pop 1 second") + } + + if dq.Pop() != 0 { + t.Error("Didn't pop zero") + } + }) + + t.Run("shift", func(t *testing.T) { + var td Deque[int] + td.Push(1) + td.Push(2) + + if td.Shift() != 1 { + t.Error("Didn't shift 1 first") + } + + if td.Shift() != 2 { + t.Error("Didn't shift b second") + } + + if td.Shift() != 0 { + t.Error("Didn't shift zero") + } + }) + + t.Run("push", func(t *testing.T) { + var td Deque[int] + td.Push(1) + td.Push(2) + td.Shift() + + for i := 1; i <= 12; i++ { + td.Push(i) + } + + if td.Shift() != 2 { + t.Error("Didn't shift 2 first") + } + for i := 1; i <= 12; i++ { + if v := td.Shift(); v != i { + t.Fatalf("Shifted %d at pos %d", v, i) + } + } + }) + + t.Run("linearise", func(t *testing.T) { + var td Deque[int] + td.Push(1) + td.Push(2) + + all := td.linearise(0) + if len(all) != 2 { + t.Fatal("Expected 2 elements, got", len(all)) + } + + if cap(all)&(cap(all)-1) != 0 { + t.Fatalf("Capacity %d is not a power of two", cap(all)) + } + + if all[0] != 1 || all[1] != 2 { + t.Fatal("Elements don't match") + } + }) +}