diff --git a/multiaddr.go b/multiaddr.go index af26d44..4b3a360 100644 --- a/multiaddr.go +++ b/multiaddr.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "log" - "strings" "golang.org/x/exp/slices" ) @@ -150,26 +149,46 @@ func (m *multiaddr) Encapsulate(o Multiaddr) Multiaddr { } // Decapsulate unwraps Multiaddr up until the given Multiaddr is found. -func (m *multiaddr) Decapsulate(o Multiaddr) Multiaddr { - s1 := m.String() - s2 := o.String() - i := strings.LastIndex(s1, s2) - if i < 0 { - // if multiaddr not contained, returns a copy. - cpy := make([]byte, len(m.bytes)) - copy(cpy, m.bytes) - return &multiaddr{bytes: cpy} +func (m *multiaddr) Decapsulate(right Multiaddr) Multiaddr { + if right == nil { + return m } - if i == 0 { + leftParts := Split(m) + rightParts := Split(right) + + lastIndex := -1 + for i := range leftParts { + foundMatch := false + for j, rightC := range rightParts { + if len(leftParts) <= i+j { + foundMatch = false + break + } + + foundMatch = rightC.Equal(leftParts[i+j]) + if !foundMatch { + break + } + } + + if foundMatch { + lastIndex = i + } + } + + if lastIndex == 0 { return nil } - ma, err := NewMultiaddr(s1[:i]) - if err != nil { - panic("Multiaddr.Decapsulate incorrect byte boundaries.") + if lastIndex < 0 { + // if multiaddr not contained, returns a copy. + cpy := make([]byte, len(m.bytes)) + copy(cpy, m.bytes) + return &multiaddr{bytes: cpy} } - return ma + + return Join(leftParts[:lastIndex]...) } var ErrProtocolNotFound = fmt.Errorf("protocol not found in multiaddr") diff --git a/multiaddr_test.go b/multiaddr_test.go index 33a0663..b33e94d 100644 --- a/multiaddr_test.go +++ b/multiaddr_test.go @@ -454,6 +454,47 @@ func TestDecapsulateComment(t *testing.T) { require.Nil(t, rest, "expected a nil multiaddr if we decapsulate everything") } +func TestDecapsulate(t *testing.T) { + t.Run("right is nil", func(t *testing.T) { + left := StringCast("/ip4/1.2.3.4/tcp/1") + var right Multiaddr + left.Decapsulate(right) + }) + + testcases := []struct { + left, right, expected string + }{ + {"/ip4/1.2.3.4/tcp/1234", "/ip4/1.2.3.4", ""}, + {"/ip4/1.2.3.4", "/ip4/1.2.3.4/tcp/1234", "/ip4/1.2.3.4"}, + {"/ip4/1.2.3.5/tcp/1234", "/ip4/5.3.2.1", "/ip4/1.2.3.5/tcp/1234"}, + {"/ip4/1.2.3.5/udp/1234/quic-v1", "/udp/1234", "/ip4/1.2.3.5"}, + {"/ip4/1.2.3.6/udp/1234/quic-v1", "/udp/1234/quic-v1", "/ip4/1.2.3.6"}, + {"/ip4/1.2.3.7/tcp/1234", "/ws", "/ip4/1.2.3.7/tcp/1234"}, + {"/dnsaddr/wss.com/tcp/4001", "/ws", "/dnsaddr/wss.com/tcp/4001"}, + {"/dnsaddr/wss.com/tcp/4001/ws", "/wss", "/dnsaddr/wss.com/tcp/4001/ws"}, + {"/dnsaddr/wss.com/ws", "/wss", "/dnsaddr/wss.com/ws"}, + {"/dnsaddr/wss.com/ws", "/dnsaddr/wss.com", ""}, + {"/dnsaddr/wss.com/tcp/4001/wss", "/wss", "/dnsaddr/wss.com/tcp/4001"}, + } + + for _, tc := range testcases { + t.Run(tc.left, func(t *testing.T) { + left := StringCast(tc.left) + right := StringCast(tc.right) + actualMa := left.Decapsulate(right) + + if tc.expected == "" { + require.Nil(t, actualMa, "expected nil") + return + } + + actual := actualMa.String() + expected := StringCast(tc.expected).String() + require.Equal(t, expected, actual) + }) + } +} + func assertValueForProto(t *testing.T, a Multiaddr, p int, exp string) { t.Logf("checking for %s in %s", ProtocolWithCode(p).Name, a) fv, err := a.ValueForProtocol(p)