Skip to content

Commit

Permalink
Decapsulate by component rather than string
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo committed Mar 27, 2024
1 parent b6caabb commit 1bb15f9
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 15 deletions.
49 changes: 34 additions & 15 deletions multiaddr.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"log"
"strings"

"golang.org/x/exp/slices"
)
Expand Down Expand Up @@ -150,26 +149,46 @@ func (m *multiaddr) Encapsulate(o Multiaddr) Multiaddr {
}

// Decapsulate unwraps Multiaddr up until the given Multiaddr is found.
func (m *multiaddr) Decapsulate(o Multiaddr) Multiaddr {
s1 := m.String()
s2 := o.String()
i := strings.LastIndex(s1, s2)
if i < 0 {
// if multiaddr not contained, returns a copy.
cpy := make([]byte, len(m.bytes))
copy(cpy, m.bytes)
return &multiaddr{bytes: cpy}
func (m *multiaddr) Decapsulate(right Multiaddr) Multiaddr {
if right == nil {
return m
}

if i == 0 {
leftParts := Split(m)
rightParts := Split(right)

lastIndex := -1
for i := range leftParts {
foundMatch := false
for j, rightC := range rightParts {
if len(leftParts) <= i+j {
foundMatch = false
break
}

foundMatch = rightC.Equal(leftParts[i+j])
if !foundMatch {
break
}
}

if foundMatch {
lastIndex = i
}
}

if lastIndex == 0 {
return nil
}

ma, err := NewMultiaddr(s1[:i])
if err != nil {
panic("Multiaddr.Decapsulate incorrect byte boundaries.")
if lastIndex < 0 {
// if multiaddr not contained, returns a copy.
cpy := make([]byte, len(m.bytes))
copy(cpy, m.bytes)
return &multiaddr{bytes: cpy}
}
return ma

return Join(leftParts[:lastIndex]...)
}

var ErrProtocolNotFound = fmt.Errorf("protocol not found in multiaddr")
Expand Down
148 changes: 148 additions & 0 deletions multiaddr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package multiaddr

import (
"bytes"
"encoding/binary"
"encoding/hex"
"fmt"
"math"
Expand Down Expand Up @@ -454,6 +455,47 @@ func TestDecapsulateComment(t *testing.T) {
require.Nil(t, rest, "expected a nil multiaddr if we decapsulate everything")
}

func TestDecapsulate(t *testing.T) {
t.Run("right is nil", func(t *testing.T) {
left := StringCast("/ip4/1.2.3.4/tcp/1")
var right Multiaddr
left.Decapsulate(right)
})

testcases := []struct {
left, right, expected string
}{
{"/ip4/1.2.3.4/tcp/1234", "/ip4/1.2.3.4", ""},
{"/ip4/1.2.3.4", "/ip4/1.2.3.4/tcp/1234", "/ip4/1.2.3.4"},
{"/ip4/1.2.3.5/tcp/1234", "/ip4/5.3.2.1", "/ip4/1.2.3.5/tcp/1234"},
{"/ip4/1.2.3.5/udp/1234/quic-v1", "/udp/1234", "/ip4/1.2.3.5"},
{"/ip4/1.2.3.6/udp/1234/quic-v1", "/udp/1234/quic-v1", "/ip4/1.2.3.6"},
{"/ip4/1.2.3.7/tcp/1234", "/ws", "/ip4/1.2.3.7/tcp/1234"},
{"/dnsaddr/wss.com/tcp/4001", "/ws", "/dnsaddr/wss.com/tcp/4001"},
{"/dnsaddr/wss.com/tcp/4001/ws", "/wss", "/dnsaddr/wss.com/tcp/4001/ws"},
{"/dnsaddr/wss.com/ws", "/wss", "/dnsaddr/wss.com/ws"},
{"/dnsaddr/wss.com/ws", "/dnsaddr/wss.com", ""},
{"/dnsaddr/wss.com/tcp/4001/wss", "/wss", "/dnsaddr/wss.com/tcp/4001"},
}

for _, tc := range testcases {
t.Run(tc.left, func(t *testing.T) {
left := StringCast(tc.left)
right := StringCast(tc.right)
actualMa := left.Decapsulate(right)

if tc.expected == "" {
require.Nil(t, actualMa, "expected nil")
return
}

actual := actualMa.String()
expected := StringCast(tc.expected).String()
require.Equal(t, expected, actual)
})
}
}

func assertValueForProto(t *testing.T, a Multiaddr, p int, exp string) {
t.Logf("checking for %s in %s", ProtocolWithCode(p).Name, a)
fv, err := a.ValueForProtocol(p)
Expand Down Expand Up @@ -882,3 +924,109 @@ func TestDNS(t *testing.T) {
t.Fatal("expected equality")
}
}

// Returns a nil component and nil error if there is not enough data to generate a component.
// Returns the unconsumed portion of the input data.
func generateComponent(d []byte) (*Component, []byte, error) {
if len(d) <= 2 {
return nil, d, nil
}

// Pick a random protocol
idx := int(binary.LittleEndian.Uint16(d[:2]))
p := Protocols[idx%len(Protocols)]
d = d[2:]

if p.Size == 0 {
// No arg
return newComponent(p, nil), d, nil
} else if p.Size > 0 {
// Fixed size
if len(d) < (p.Size / 8) {
return nil, d, nil
}
return newComponent(p, d[:p.Size/8]), d, nil
} else {
// Varint

// Limit the value size to 1k bytes. Not sure if this is enforced anywhere?
if len(d) <= 2 {
return nil, d, nil
}
valSize := int(binary.LittleEndian.Uint16(d[:2]))
d = d[2:]
valSize = valSize % 1024

if len(d) < valSize {
return nil, d, nil
}

return newComponent(p, d[:valSize]), d[valSize:], nil
}
}

func validateComponent(c *Component) bool {
// TODO: we are validating the component by converting into a multiaddr first.
// Do users ever parse bytes as a component directly? That would open up other issues.
// There's a public api for this with .UnmarshalBinary.
_, err := NewMultiaddrBytes(c.Bytes())
return err == nil
}

func generateMultiaddr(d []byte) (Multiaddr, []byte, error) {
if len(d) == 0 {
return nil, d, nil
}

componentCount := int(uint8(d[0]))
d = d[1:]

if componentCount == 0 {
return nil, d, nil
}

parts := make([]Multiaddr, 0, componentCount)
var err error
var c *Component
for i := 0; i < componentCount; i++ {
c, d, err = generateComponent(d)
if err != nil {
return nil, d, err
}
if c == nil {
return nil, d, nil
}
if !validateComponent(c) {
return nil, d, nil
}

parts = append(parts, c)
}
return Join(parts...), d, nil
}

func FuzzDecapsulate(f *testing.F) {
f.Fuzz(func(t *testing.T, left, right []byte) {
leftMa, _, err := generateMultiaddr(left)
if err != nil {
t.Fatal(err)
}
if leftMa == nil {
return
}

rightMa, _, err := generateMultiaddr(right)
if err != nil {
t.Fatal(err)
}
if rightMa == nil {
return
}

_ = leftMa.Decapsulate(rightMa)

// --------------------
// TODO test invariants
// --------------------
})
}

0 comments on commit 1bb15f9

Please sign in to comment.