Skip to content

Commit

Permalink
basichost: append certhash for webrtc addresses provided via address …
Browse files Browse the repository at this point in the history
…factory (#2774)
  • Loading branch information
sukunrt committed Apr 23, 2024
1 parent e61c36f commit 0385ec9
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 6 deletions.
7 changes: 5 additions & 2 deletions p2p/host/basic/basic_host.go
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
"github.com/libp2p/go-libp2p/p2p/protocol/identify"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
"github.com/prometheus/client_golang/prometheus"

Expand Down Expand Up @@ -786,15 +787,17 @@ func (h *BasicHost) Addrs() []ma.Multiaddr {
copy(addrs, addrsOld)

for i, addr := range addrs {
if ok, n := libp2pwebtransport.IsWebtransportMultiaddr(addr); ok && n == 0 {
wtOK, wtN := libp2pwebtransport.IsWebtransportMultiaddr(addr)
webrtcOK, webrtcN := libp2pwebrtc.IsWebRTCDirectMultiaddr(addr)
if (wtOK && wtN == 0) || (webrtcOK && webrtcN == 0) {
t := s.TransportForListening(addr)
tpt, ok := t.(addCertHasher)
if !ok {
continue
}
addrWithCerthash, added := tpt.AddCertHashes(addr)
if !added {
log.Debug("Couldn't add certhashes to webtransport multiaddr because we aren't listening on webtransport")
log.Debugf("Couldn't add certhashes to multiaddr: %s", addr)
continue
}
addrs[i] = addrWithCerthash
Expand Down
41 changes: 41 additions & 0 deletions p2p/test/basichost/basic_host_test.go
Expand Up @@ -3,6 +3,7 @@ package basichost
import (
"context"
"fmt"
"strings"
"testing"
"time"

Expand All @@ -12,6 +13,8 @@ import (
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -158,3 +161,41 @@ func TestNewStreamTransientConnection(t *testing.T) {
<-done
<-done
}

func TestAddrFactorCertHashAppend(t *testing.T) {
wtAddr := "/ip4/1.2.3.4/udp/1/quic-v1/webtransport"
webrtcAddr := "/ip4/1.2.3.4/udp/2/webrtc-direct"
addrsFactory := func(addrs []ma.Multiaddr) []ma.Multiaddr {
return append(addrs,
ma.StringCast(wtAddr),
ma.StringCast(webrtcAddr),
)
}
h, err := libp2p.New(
libp2p.AddrsFactory(addrsFactory),
libp2p.Transport(libp2pwebrtc.New),
libp2p.Transport(libp2pwebtransport.New),
libp2p.ListenAddrStrings(
"/ip4/0.0.0.0/udp/0/quic-v1/webtransport",
"/ip4/0.0.0.0/udp/0/webrtc-direct",
),
)
require.NoError(t, err)
require.Eventually(t, func() bool {
addrs := h.Addrs()
var hasWebRTC, hasWebTransport bool
for _, addr := range addrs {
if strings.HasPrefix(addr.String(), webrtcAddr) {
if _, err := addr.ValueForProtocol(ma.P_CERTHASH); err == nil {
hasWebRTC = true
}
}
if strings.HasPrefix(addr.String(), wtAddr) {
if _, err := addr.ValueForProtocol(ma.P_CERTHASH); err == nil {
hasWebTransport = true
}
}
}
return hasWebRTC && hasWebTransport
}, 5*time.Second, 100*time.Millisecond)
}
56 changes: 52 additions & 4 deletions p2p/transport/webrtc/transport.go
Expand Up @@ -40,16 +40,13 @@ import (
"github.com/libp2p/go-msgio"

ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multihash"

"github.com/pion/datachannel"
"github.com/pion/webrtc/v3"
)

var dialMatcher = mafmt.And(mafmt.UDP, mafmt.Base(ma.P_WEBRTC_DIRECT), mafmt.Base(ma.P_CERTHASH))

var webrtcComponent *ma.Component

func init() {
Expand Down Expand Up @@ -179,7 +176,8 @@ func (t *WebRTCTransport) Proxy() bool {
}

func (t *WebRTCTransport) CanDial(addr ma.Multiaddr) bool {
return dialMatcher.Matches(addr)
isValid, n := IsWebRTCDirectMultiaddr(addr)
return isValid && n > 0
}

// Listen returns a listener for addr.
Expand Down Expand Up @@ -514,6 +512,24 @@ func (t *WebRTCTransport) noiseHandshake(ctx context.Context, pc *webrtc.PeerCon
return secureConn.RemotePublicKey(), nil
}

func (t *WebRTCTransport) AddCertHashes(addr ma.Multiaddr) (ma.Multiaddr, bool) {
listenerFingerprint, err := t.getCertificateFingerprint()
if err != nil {
return nil, false
}

encodedLocalFingerprint, err := encodeDTLSFingerprint(listenerFingerprint)
if err != nil {
return nil, false
}

certComp, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, encodedLocalFingerprint)
if err != nil {
return nil, false
}
return addr.Encapsulate(certComp), true
}

type netConnWrapper struct {
*stream
}
Expand Down Expand Up @@ -601,3 +617,35 @@ func newWebRTCConnection(settings webrtc.SettingEngine, config webrtc.Configurat
IncomingDataChannels: incomingDataChannels,
}, nil
}

// IsWebRTCDirectMultiaddr returns whether addr is a /webrtc-direct multiaddr with the count of certhashes
// in addr
func IsWebRTCDirectMultiaddr(addr ma.Multiaddr) (bool, int) {
var foundUDP, foundWebRTC bool
certHashCount := 0
ma.ForEach(addr, func(c ma.Component) bool {
if !foundUDP {
if c.Protocol().Code == ma.P_UDP {
foundUDP = true
}
return true
}
if !foundWebRTC && foundUDP {
// protocol after udp must be webrtc-direct
if c.Protocol().Code != ma.P_WEBRTC_DIRECT {
return false
}
foundWebRTC = true
return true
}
if foundWebRTC {
if c.Protocol().Code == ma.P_CERTHASH {
certHashCount++
} else {
return false
}
}
return true
})
return foundUDP && foundWebRTC, certHashCount
}
67 changes: 67 additions & 0 deletions p2p/transport/webrtc/transport_test.go
Expand Up @@ -40,6 +40,58 @@ func getTransport(t *testing.T, opts ...Option) (*WebRTCTransport, peer.ID) {
return transport, peerID
}

func TestIsWebRTCDirectMultiaddr(t *testing.T) {
invalid := []string{
"/ip4/1.2.3.4/tcp/10/",
"/ip6/1::3/udp/100/quic-v1/",
"/ip4/1.2.3.4/udp/1/quic-v1/webrtc-direct",
}

valid := []struct {
addr string
count int
}{
{
addr: "/ip4/1.2.3.4/udp/1234/webrtc-direct",
count: 0,
},
{
addr: "/dns/test.test/udp/1234/webrtc-direct",
count: 0,
},
{
addr: "/ip4/1.2.3.4/udp/1234/webrtc-direct/certhash/uEiAsGPzpiPGQzSlVHRXrUCT5EkTV7YFrV4VZ3hpEKTd_zg",
count: 1,
},
{
addr: "/ip6/0:0:0:0:0:0:0:1/udp/1234/webrtc-direct/certhash/uEiAsGPzpiPGQzSlVHRXrUCT5EkTV7YFrV4VZ3hpEKTd_zg",
count: 1,
},
{
addr: "/dns/test.test/udp/1234/webrtc-direct/certhash/uEiAsGPzpiPGQzSlVHRXrUCT5EkTV7YFrV4VZ3hpEKTd_zg",
count: 1,
},
{
addr: "/dns/test.test/udp/1234/webrtc-direct/certhash/uEiAsGPzpiPGQzSlVHRXrUCT5EkTV7YFrV4VZ3hpEKTd_zg/certhash/uEiAsGPzpiPGQzSlVHRXrUCT5EkTV7ZGrV4VZ3hpEKTd_zg",
count: 2,
},
}

for _, addr := range invalid {
a := ma.StringCast(addr)
isValid, n := IsWebRTCDirectMultiaddr(a)
require.Equal(t, 0, n)
require.False(t, isValid)
}

for _, tc := range valid {
a := ma.StringCast(tc.addr)
isValid, n := IsWebRTCDirectMultiaddr(a)
require.Equal(t, tc.count, n)
require.True(t, isValid)
}
}

func TestTransportWebRTC_CanDial(t *testing.T) {
tr, _ := getTransport(t)
invalid := []string{
Expand All @@ -65,6 +117,21 @@ func TestTransportWebRTC_CanDial(t *testing.T) {
}
}

func TestTransportAddCertHasher(t *testing.T) {
tr, _ := getTransport(t)
addrs := []string{
"/ip4/1.2.3.4/udp/1/webrtc-direct",
"/ip6/1::3/udp/2/webrtc-direct",
}
for _, a := range addrs {
addr, added := tr.AddCertHashes(ma.StringCast(a))
require.True(t, added)
_, err := addr.ValueForProtocol(ma.P_CERTHASH)
require.NoError(t, err)
require.True(t, strings.HasPrefix(addr.String(), a))
}
}

func TestTransportWebRTC_ListenFailsOnNonWebRTCMultiaddr(t *testing.T) {
tr, _ := getTransport(t)
testAddrs := []string{
Expand Down

0 comments on commit 0385ec9

Please sign in to comment.