Skip to content

Commit

Permalink
rename CheckReachability to GetReachability
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Apr 25, 2024
1 parent 3afc95f commit af0ac51
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 46 deletions.
27 changes: 8 additions & 19 deletions p2p/protocol/autonatv2/autonat.go
Expand Up @@ -82,11 +82,12 @@ type AutoNAT struct {
mx sync.Mutex
peers *peersMap

allowAllAddrs bool // for testing
// allowAllAddrs enables using private and localhost addresses for reachability checks.
// This is only useful for testing.
allowAllAddrs bool
}

// New returns a new AutoNAT instance. The returned instance runs the server when the provided host
// is publicly reachable.
// New returns a new AutoNAT instance.
// host and dialerHost should have the same dialing capabilities. In case the host doesn't support
// a transport, dial back requests for address for that transport will be ignored.
func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) {
Expand All @@ -99,19 +100,12 @@ func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT,
// We are listening on event.EvtPeerProtocolsUpdated, event.EvtPeerConnectednessChanged
// event.EvtPeerIdentificationCompleted to maintain our set of autonat supporting peers.
//
// We listen on event.EvtLocalReachabilityChanged to Disable the server if we are not
// publicly reachable. Currently this event is sent by the AutoNAT v1 module. During the
// transition period from AutoNAT v1 to v2, there won't be enough v2 servers on the network
// and most clients will be unable to discover a peer which supports AutoNAT v2. So, we use
// v1 to determine reachability for the transition period.
//
// Once there are enough v2 servers on the network for nodes to determine their reachability
// using AutoNAT v2, we'll use Address Pipeline
// (https://github.com/libp2p/go-libp2p/issues/2229)(to be implemented in a future release)
// to determine reachability using v2 client and send this event from Address Pipeline, if
// we are publicly reachable.
sub, err := host.EventBus().Subscribe([]interface{}{
new(event.EvtLocalReachabilityChanged),
new(event.EvtPeerProtocolsUpdated),
new(event.EvtPeerConnectednessChanged),
new(event.EvtPeerIdentificationCompleted),
Expand All @@ -132,6 +126,7 @@ func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT,
peers: newPeersMap(),
}
an.cli.RegisterDialBack()
an.srv.Enable()

an.wg.Add(1)
go an.background()
Expand All @@ -149,12 +144,6 @@ func (an *AutoNAT) background() {
return
case e := <-an.sub.Out():
switch evt := e.(type) {
case event.EvtLocalReachabilityChanged:
if evt.Reachability == network.ReachabilityPrivate {
an.srv.Disable()
} else {
an.srv.Enable()
}
case event.EvtPeerProtocolsUpdated:
an.updatePeer(evt.Peer)
case event.EvtPeerConnectednessChanged:
Expand All @@ -171,8 +160,8 @@ func (an *AutoNAT) Close() {
an.wg.Wait()
}

// CheckReachability makes a single dial request for checking reachability for requested addresses
func (an *AutoNAT) CheckReachability(ctx context.Context, reqs []Request) (Result, error) {
// GetReachability makes a single dial request for checking reachability for requested addresses
func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, error) {
if !an.allowAllAddrs {
for _, r := range reqs {
if !manet.IsPublicAddr(r.Addr) {
Expand All @@ -185,7 +174,7 @@ func (an *AutoNAT) CheckReachability(ctx context.Context, reqs []Request) (Resul
return Result{}, ErrNoValidPeers
}

res, err := an.cli.CheckReachability(ctx, p, reqs)
res, err := an.cli.GetReachability(ctx, p, reqs)
if err != nil {
log.Debugf("reachability check with %s failed, err: %s", p, err)
return Result{}, fmt.Errorf("reachability check with %s failed: %w", p, err)
Expand Down
10 changes: 5 additions & 5 deletions p2p/protocol/autonatv2/autonat_test.go
Expand Up @@ -80,7 +80,7 @@ func idAndWait(t *testing.T, cli *AutoNAT, srv *AutoNAT) {

func TestAutoNATPrivateAddr(t *testing.T) {
an := newAutoNAT(t, nil)
res, err := an.CheckReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}})
res, err := an.GetReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}})
require.Equal(t, res, Result{})
require.Contains(t, err.Error(), "private address cannot be verified by autonatv2")
}
Expand Down Expand Up @@ -112,7 +112,7 @@ func TestClientRequest(t *testing.T) {
s.Reset()
})

res, err := an.CheckReachability(context.Background(), []Request{
res, err := an.GetReachability(context.Background(), []Request{
{Addr: addrs[0], SendDialData: true}, {Addr: addrs[1]},
})
require.Equal(t, res, Result{})
Expand Down Expand Up @@ -167,7 +167,7 @@ func TestClientServerError(t *testing.T) {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
b.SetStreamHandler(DialProtocol, tc.handler)
addrs := an.host.Addrs()
res, err := an.CheckReachability(
res, err := an.GetReachability(
context.Background(),
newTestRequests(addrs, false))
require.Equal(t, res, Result{})
Expand Down Expand Up @@ -280,7 +280,7 @@ func TestClientDataRequest(t *testing.T) {
b.SetStreamHandler(DialProtocol, tc.handler)
addrs := an.host.Addrs()

res, err := an.CheckReachability(
res, err := an.GetReachability(
context.Background(),
[]Request{
{Addr: addrs[0], SendDialData: true},
Expand Down Expand Up @@ -489,7 +489,7 @@ func TestClientDialBacks(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
addrs := an.host.Addrs()
b.SetStreamHandler(DialProtocol, tc.handler)
res, err := an.CheckReachability(
res, err := an.GetReachability(
context.Background(),
[]Request{
{Addr: addrs[0], SendDialData: true},
Expand Down
4 changes: 2 additions & 2 deletions p2p/protocol/autonatv2/client.go
Expand Up @@ -38,8 +38,8 @@ func (ac *client) RegisterDialBack() {
ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack)
}

// CheckReachability verifies address reachability with a AutoNAT v2 server p.
func (ac *client) CheckReachability(ctx context.Context, p peer.ID, reqs []Request) (Result, error) {
// GetReachability verifies address reachability with a AutoNAT v2 server p.
func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request) (Result, error) {
ctx, cancel := context.WithTimeout(ctx, streamTimeout)
defer cancel()

Expand Down
28 changes: 15 additions & 13 deletions p2p/protocol/autonatv2/server.go
Expand Up @@ -31,8 +31,8 @@ type server struct {
dialerHost host.Host
limiter *rateLimiter

// dialDataRequestPolicy is used to determine whether dialing the address requires receiving dial data.
// It is set to amplification attack prevention by default.
// dialDataRequestPolicy is used to determine whether dialing the address requires receiving
// dial data. It is set to amplification attack prevention by default.
dialDataRequestPolicy dataRequestPolicyFunc

// for tests
Expand Down Expand Up @@ -98,7 +98,7 @@ func (as *server) handleDialRequest(s network.Stream) {
}
if msg.GetDialRequest() == nil {
s.Reset()
log.Debugf("invalid message type from %s: %T", p, msg.Msg)
log.Debugf("invalid message type from %s: %T expected: DialRequest", p, msg.Msg)
return
}

Expand All @@ -119,7 +119,7 @@ func (as *server) handleDialRequest(s network.Stream) {
continue
}
// Check if the host can dial the address. This check ensures that we do not
// attempt dialing an IPv6 address if we have no IPv6 connectivity as the host dialer's
// attempt dialing an IPv6 address if we have no IPv6 connectivity as the host's
// black hole detector is likely to be more accurate.
if as.host.Network().CanDial(p, a) != network.DialabilityDialable {
continue
Expand All @@ -141,14 +141,13 @@ func (as *server) handleDialRequest(s network.Stream) {
}
if err := w.WriteMsg(&msg); err != nil {
s.Reset()
log.Debugf("failed to write response to %s: %s", p, err)
log.Debugf("failed to write dial refused response to %s: %s", p, err)
return
}
return
}

isDialDataRequired := as.dialDataRequestPolicy(s, dialAddr)

if !as.limiter.Accept(p, isDialDataRequired) {
msg = pb.Message{
Msg: &pb.Message_DialResponse{
Expand All @@ -159,10 +158,10 @@ func (as *server) handleDialRequest(s network.Stream) {
}
if err := w.WriteMsg(&msg); err != nil {
s.Reset()
log.Debugf("failed to write response to %s: %s", p, err)
log.Debugf("failed to write request rejected response to %s: %s", p, err)
return
}
log.Debugf("rejecting request from %s: rate limit exceeded", p)
log.Debugf("rejected request from %s: rate limit exceeded", p)
return
}
defer as.limiter.CompleteRequest(p)
Expand Down Expand Up @@ -248,12 +247,14 @@ func (as *server) dialBack(p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialSt
return pb.DialStatus_E_DIAL_BACK_ERROR
}

// Since the underlying connection is on a separate dialer, it'll be closed after this function returns.
// Connection close will drop all the queued writes. To ensure message delivery, do a CloseWrite and
// wait a second for the peer to Close its end of the stream.
// Since the underlying connection is on a separate dialer, it'll be closed after this
// function returns. Connection close will drop all the queued writes.
// To ensure message delivery, do a CloseWrite and read a byte from the stream. The peer
// actually sends a DialDataResponse back but we only care about the fact that the DialBack
// message has reached the peer. So we ignore that message on the read side.
s.CloseWrite()
s.SetDeadline(as.now().Add(1 * time.Second))
b := make([]byte, 1) // Read 1 byte here because 0 len reads are free to return (0, nil) immediately
s.SetDeadline(as.now().Add(5 * time.Second)) // 5 is a magic number
b := make([]byte, 1) // Read 1 byte here because 0 len reads are free to return (0, nil) immediately
s.Read(b)

return pb.DialStatus_OK
Expand All @@ -275,6 +276,7 @@ type rateLimiter struct {
dialDataReqs []time.Time
// ongoingReqs tracks in progress requests. This is used to disallow multiple concurrent requests by the
// same peer
// TODO: Should we allow a few concurrent requests per peer?
ongoingReqs map[peer.ID]struct{}

now func() time.Time // for tests
Expand Down
14 changes: 7 additions & 7 deletions p2p/protocol/autonatv2/server_test.go
Expand Up @@ -35,7 +35,7 @@ func TestServerInvalidAddrsRejected(t *testing.T) {

idAndWait(t, c, an)

res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), true))
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true))
require.ErrorIs(t, err, ErrDialRefused)
require.Equal(t, Result{}, res)
})
Expand All @@ -47,7 +47,7 @@ func TestServerInvalidAddrsRejected(t *testing.T) {

idAndWait(t, c, an)

res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), true))
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true))
require.ErrorIs(t, err, ErrDialRefused)
require.Equal(t, Result{}, res)
})
Expand Down Expand Up @@ -84,10 +84,10 @@ func TestServerDataRequest(t *testing.T) {
}
}

_, err := c.CheckReachability(context.Background(), []Request{{Addr: tcpAddr, SendDialData: true}, {Addr: quicAddr}})
_, err := c.GetReachability(context.Background(), []Request{{Addr: tcpAddr, SendDialData: true}, {Addr: quicAddr}})
require.Error(t, err)

res, err := c.CheckReachability(context.Background(), []Request{{Addr: quicAddr, SendDialData: true}, {Addr: tcpAddr}})
res, err := c.GetReachability(context.Background(), []Request{{Addr: quicAddr, SendDialData: true}, {Addr: tcpAddr}})
require.NoError(t, err)

require.Equal(t, Result{
Expand All @@ -113,7 +113,7 @@ func TestServerDial(t *testing.T) {
hostAddrs := c.host.Addrs()

t.Run("unreachable addr", func(t *testing.T) {
res, err := c.CheckReachability(context.Background(),
res, err := c.GetReachability(context.Background(),
append([]Request{{Addr: unreachableAddr, SendDialData: true}}, newTestRequests(hostAddrs, false)...))
require.NoError(t, err)
require.Equal(t, Result{
Expand All @@ -125,7 +125,7 @@ func TestServerDial(t *testing.T) {
})

t.Run("reachable addr", func(t *testing.T) {
res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), false))
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), false))
require.NoError(t, err)
require.Equal(t, Result{
Idx: 0,
Expand All @@ -137,7 +137,7 @@ func TestServerDial(t *testing.T) {

t.Run("dialback error", func(t *testing.T) {
c.host.RemoveStreamHandler(DialBackProtocol)
res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), false))
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), false))
require.NoError(t, err)
require.Equal(t, Result{
Idx: 0,
Expand Down

0 comments on commit af0ac51

Please sign in to comment.