Skip to content

Commit

Permalink
limit concurrent client requests
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 21, 2023
1 parent e1c362a commit acb1c88
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 59 deletions.
4 changes: 2 additions & 2 deletions p2p/protocol/autonatv2/autonat.go
Expand Up @@ -75,7 +75,7 @@ type AutoNAT struct {
cancel context.CancelFunc
wg sync.WaitGroup

srv *Server
srv *server
cli *client

mx sync.Mutex
Expand Down Expand Up @@ -123,7 +123,7 @@ func New(h host.Host, dialer host.Host, opts ...AutoNATOption) (*AutoNAT, error)
ctx: ctx,
cancel: cancel,
sub: sub,
srv: NewServer(h, dialer, s),
srv: newServer(h, dialer, s),
cli: newClient(h),
allowAllAddrs: s.allowAllAddrs,
peers: newPeersMap(),
Expand Down
8 changes: 4 additions & 4 deletions p2p/protocol/autonatv2/autonat_test.go
Expand Up @@ -84,7 +84,7 @@ func TestAutoNATPrivateAddr(t *testing.T) {
}

func TestClientRequest(t *testing.T) {
an := newAutoNAT(t, nil, allowAll)
an := newAutoNAT(t, nil, allowAllAddrs)

addrs := an.host.Addrs()

Expand Down Expand Up @@ -127,7 +127,7 @@ func TestClientRequest(t *testing.T) {
}

func TestClientServerError(t *testing.T) {
an := newAutoNAT(t, nil, allowAll)
an := newAutoNAT(t, nil, allowAllAddrs)
addrs := an.host.Addrs()

p := bhost.NewBlankHost(swarmt.GenSwarm(t))
Expand Down Expand Up @@ -179,7 +179,7 @@ func TestClientServerError(t *testing.T) {
}

func TestClientDataRequest(t *testing.T) {
an := newAutoNAT(t, nil, allowAll)
an := newAutoNAT(t, nil, allowAllAddrs)
addrs := an.host.Addrs()

p := bhost.NewBlankHost(swarmt.GenSwarm(t))
Expand Down Expand Up @@ -275,7 +275,7 @@ func TestClientDataRequest(t *testing.T) {
}

func TestClientDialBacks(t *testing.T) {
an := newAutoNAT(t, nil, allowAll)
an := newAutoNAT(t, nil, allowAllAddrs)
addrs := an.host.Addrs()

p := bhost.NewBlankHost(swarmt.GenSwarm(t))
Expand Down
14 changes: 4 additions & 10 deletions p2p/protocol/autonatv2/options.go
Expand Up @@ -2,6 +2,7 @@ package autonatv2

import "time"

// autoNATSettings is used to configure AutoNAT
type autoNATSettings struct {
allowAllAddrs bool
serverRPM int
Expand All @@ -25,11 +26,6 @@ func defaultSettings() *autoNATSettings {

type AutoNATOption func(s *autoNATSettings) error

func allowAll(s *autoNATSettings) error {
s.allowAllAddrs = true
return nil
}

func WithServerRateLimit(rpm, perPeerRPM, dialDataRPM int) AutoNATOption {
return func(s *autoNATSettings) error {
s.serverRPM = rpm
Expand All @@ -46,9 +42,7 @@ func WithDataRequestPolicy(drp dataRequestPolicyFunc) AutoNATOption {
}
}

func WithNow(now func() time.Time) AutoNATOption {
return func(s *autoNATSettings) error {
s.now = now
return nil
}
func allowAllAddrs(s *autoNATSettings) error {
s.allowAllAddrs = true
return nil
}
104 changes: 67 additions & 37 deletions p2p/protocol/autonatv2/server.go
Expand Up @@ -20,18 +20,26 @@ import (

type dataRequestPolicyFunc = func(s network.Stream, dialAddr ma.Multiaddr) bool

type Server struct {
dialer host.Host
host host.Host
// server implements the AutoNATv2 server.
//
// It rate limits requests on a global level, per peer level and on whether the request requires dial data.
type server struct {
host host.Host
dialerHost host.Host
limiter *rateLimiter

// dialDataRequestPolicy is used to determine whether dialing the address requires receiving dial data.
// It is set to amplification attack prevention by default.
dialDataRequestPolicy dataRequestPolicyFunc
allowAllAddrs bool
limiter *rateLimiter
now func() time.Time // for tests

// for tests
now func() time.Time
allowAllAddrs bool
}

func NewServer(host, dialer host.Host, s *autoNATSettings) *Server {
return &Server{
dialer: dialer,
func newServer(host, dialer host.Host, s *autoNATSettings) *server {
return &server{
dialerHost: dialer,
host: host,
dialDataRequestPolicy: s.dataRequestPolicy,
allowAllAddrs: s.allowAllAddrs,
Expand All @@ -45,15 +53,15 @@ func NewServer(host, dialer host.Host, s *autoNATSettings) *Server {
}
}

func (as *Server) Enable() {
func (as *server) Enable() {
as.host.SetStreamHandler(DialProtocol, as.handleDialRequest)
}

func (as *Server) Disable() {
func (as *server) Disable() {
as.host.RemoveStreamHandler(DialProtocol)
}

func (as *Server) handleDialRequest(s network.Stream) {
func (as *server) handleDialRequest(s network.Stream) {
if err := s.Scope().SetService(ServiceName); err != nil {
s.Reset()
log.Debugf("failed to attach stream to service %s: %w", ServiceName, err)
Expand Down Expand Up @@ -85,6 +93,7 @@ func (as *Server) handleDialRequest(s network.Stream) {
}

nonce := msg.GetDialRequest().Nonce
// parse peer's addresses
var dialAddr ma.Multiaddr
var addrIdx int
for i, ab := range msg.GetDialRequest().GetAddrs() {
Expand All @@ -96,7 +105,7 @@ func (as *Server) handleDialRequest(s network.Stream) {
continue
}
if (!as.allowAllAddrs && !manet.IsPublicAddr(a)) ||
(!as.dialer.Network().CanDial(a)) {
(!as.dialerHost.Network().CanDial(a)) {
continue
}
dialAddr = a
Expand All @@ -105,6 +114,7 @@ func (as *Server) handleDialRequest(s network.Stream) {
}

w := pbio.NewDelimitedWriter(s)
// No dialable address
if dialAddr == nil {
msg = pb.Message{
Msg: &pb.Message_DialResponse{
Expand All @@ -122,6 +132,7 @@ func (as *Server) handleDialRequest(s network.Stream) {
}

isDialDataRequired := as.dialDataRequestPolicy(s, dialAddr)

if !as.limiter.Accept(p, isDialDataRequired) {
msg = pb.Message{
Msg: &pb.Message_DialResponse{
Expand All @@ -138,21 +149,23 @@ func (as *Server) handleDialRequest(s network.Stream) {
log.Debugf("rejecting request from %s: rate limit exceeded", p)
return
}
defer as.limiter.CompleteRequest(p)

if isDialDataRequired {
err := getDialData(w, r, &msg, addrIdx)
if err != nil {
if err := getDialData(w, r, &msg, addrIdx); err != nil {
s.Reset()
log.Debugf("%s refused dial data request: %s", p, err)
return
}
}
status := as.dialBack(s.Conn().RemotePeer(), dialAddr, nonce)

dialStatus := as.dialBack(s.Conn().RemotePeer(), dialAddr, nonce)

msg = pb.Message{
Msg: &pb.Message_DialResponse{
DialResponse: &pb.DialResponse{
Status: pb.DialResponse_OK,
DialStatus: status,
DialStatus: dialStatus,
AddrIdx: uint32(addrIdx),
},
},
Expand All @@ -164,17 +177,7 @@ func (as *Server) handleDialRequest(s network.Stream) {
}
}

// amplificationAttackPrevention requests data when the peer's observed IP address is different
// from the dial back IP address
func amplificationAttackPrevention(s network.Stream, dialAddr ma.Multiaddr) bool {
connIP, err := manet.ToIP(s.Conn().RemoteMultiaddr())
if err != nil {
return true
}
dialIP, _ := manet.ToIP(s.Conn().LocalMultiaddr()) // must be an IP multiaddr
return !connIP.Equal(dialIP)
}

// getDialData gets data from the client for dialing the address
func getDialData(w pbio.Writer, r pbio.Reader, msg *pb.Message, addrIdx int) error {
numBytes := minHandshakeSizeBytes + rand.Intn(maxHandshakeSizeBytes-minHandshakeSizeBytes)
*msg = pb.Message{
Expand All @@ -200,16 +203,16 @@ func getDialData(w pbio.Writer, r pbio.Reader, msg *pb.Message, addrIdx int) err
return nil
}

func (as *Server) dialBack(p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialStatus {
func (as *server) dialBack(p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialStatus {
ctx, cancel := context.WithTimeout(context.Background(), dialBackDialTimeout)
as.dialer.Peerstore().AddAddr(p, addr, peerstore.TempAddrTTL)
as.dialerHost.Peerstore().AddAddr(p, addr, peerstore.TempAddrTTL)
defer func() {
cancel()
as.dialer.Network().ClosePeer(p)
as.dialer.Peerstore().ClearAddrs(p)
as.dialer.Peerstore().RemovePeer(p)
as.dialerHost.Network().ClosePeer(p)
as.dialerHost.Peerstore().ClearAddrs(p)
as.dialerHost.Peerstore().RemovePeer(p)
}()
s, err := as.dialer.NewStream(ctx, p, DialBackProtocol)
s, err := as.dialerHost.NewStream(ctx, p, DialBackProtocol)
if err != nil {
return pb.DialStatus_E_DIAL_ERROR
}
Expand All @@ -223,7 +226,7 @@ func (as *Server) dialBack(p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialSt
}

// Since the underlying connection is on a separate dialer, it'll be closed after this function returns.
// Connection close will drop all the queued writes. To ensure message delivery, do a close write and
// Connection close will drop all the queued writes. To ensure message delivery, do a CloseWrite and
// wait a second for the peer to Close its end of the stream.
s.CloseWrite()
s.SetDeadline(as.now().Add(1 * time.Second))
Expand All @@ -233,7 +236,8 @@ func (as *Server) dialBack(p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialSt
return pb.DialStatus_OK
}

// rateLimiter implements a sliding window rate limit of requests per minute.
// rateLimiter implements a sliding window rate limit of requests per minute. It allows 1 concurrent request
// per peer. It rate limits requests globally, at a peer level and depending on whether it requires dial data.
type rateLimiter struct {
PerPeerRPM int
RPM int
Expand All @@ -243,25 +247,33 @@ type rateLimiter struct {
reqs []time.Time
peerReqs map[peer.ID][]time.Time
dialDataReqs []time.Time
now func() time.Time // for tests
ongoingReqs map[peer.ID]struct{}

now func() time.Time // for tests
}

func (r *rateLimiter) Accept(p peer.ID, requiresData bool) bool {
r.mu.Lock()
defer r.mu.Unlock()
if r.peerReqs == nil {
r.peerReqs = make(map[peer.ID][]time.Time)
r.ongoingReqs = make(map[peer.ID]struct{})
}

nw := r.now()
r.cleanup(p, nw)

if _, ok := r.ongoingReqs[p]; ok {
return false
}
if len(r.reqs) >= r.RPM || len(r.peerReqs[p]) >= r.PerPeerRPM {
return false
}
if requiresData && len(r.dialDataReqs) >= r.DialDataRPM {
return false
}

r.ongoingReqs[p] = struct{}{}
r.reqs = append(r.reqs, nw)
r.peerReqs[p] = append(r.peerReqs[p], nw)
if requiresData {
Expand Down Expand Up @@ -302,3 +314,21 @@ func (r *rateLimiter) cleanup(p peer.ID, now time.Time) {
}
r.peerReqs[p] = r.peerReqs[p][idx:]
}

func (r *rateLimiter) CompleteRequest(p peer.ID) {
r.mu.Lock()
defer r.mu.Unlock()

delete(r.ongoingReqs, p)
}

// amplificationAttackPrevention is a dialDataRequestPolicy which requests data when the peer's observed
// IP address is different from the dial back IP address
func amplificationAttackPrevention(s network.Stream, dialAddr ma.Multiaddr) bool {
connIP, err := manet.ToIP(s.Conn().RemoteMultiaddr())
if err != nil {
return true
}
dialIP, _ := manet.ToIP(s.Conn().LocalMultiaddr()) // must be an IP multiaddr
return !connIP.Equal(dialIP)
}
19 changes: 13 additions & 6 deletions p2p/protocol/autonatv2/server_test.go
Expand Up @@ -24,12 +24,12 @@ func newTestRequests(addrs []ma.Multiaddr, sendDialData bool) (reqs []Request) {

func TestServerAllAddrsInvalid(t *testing.T) {
dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableQUIC, swarmt.OptDisableTCP))
an := newAutoNAT(t, dialer, allowAll)
an := newAutoNAT(t, dialer, allowAllAddrs)
defer an.Close()
defer an.host.Close()
an.srv.Enable()

c := newAutoNAT(t, nil, allowAll)
c := newAutoNAT(t, nil, allowAllAddrs)
defer c.Close()
defer c.host.Close()

Expand All @@ -46,7 +46,7 @@ func TestServerPrivateRejected(t *testing.T) {
defer an.host.Close()
an.srv.Enable()

c := newAutoNAT(t, nil, allowAll)
c := newAutoNAT(t, nil, allowAllAddrs)
defer c.Close()
defer c.host.Close()

Expand All @@ -59,7 +59,7 @@ func TestServerPrivateRejected(t *testing.T) {

func TestServerDataRequest(t *testing.T) {
dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP))
an := newAutoNAT(t, dialer, allowAll, WithDataRequestPolicy(
an := newAutoNAT(t, dialer, allowAllAddrs, WithDataRequestPolicy(
func(s network.Stream, dialAddr ma.Multiaddr) bool {
if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil {
return true
Expand Down Expand Up @@ -102,11 +102,11 @@ func TestServerDataRequest(t *testing.T) {
}

func TestServerDial(t *testing.T) {
an := newAutoNAT(t, nil, WithServerRateLimit(10, 10, 10), allowAll)
an := newAutoNAT(t, nil, WithServerRateLimit(10, 10, 10), allowAllAddrs)
defer an.host.Close()
an.srv.Enable()

c := newAutoNAT(t, nil, allowAll)
c := newAutoNAT(t, nil, allowAllAddrs)
defer c.Close()
defer c.host.Close()

Expand Down Expand Up @@ -141,22 +141,29 @@ func TestRateLimiter(t *testing.T) {
require.True(t, r.Accept("peer1", false))

cl.AdvanceBy(10 * time.Second)
require.False(t, r.Accept("peer1", false)) // first request is still active
r.CompleteRequest("peer1")

require.True(t, r.Accept("peer1", false))
r.CompleteRequest("peer1")

cl.AdvanceBy(10 * time.Second)
require.False(t, r.Accept("peer1", false))

cl.AdvanceBy(10 * time.Second)
require.True(t, r.Accept("peer2", false))
r.CompleteRequest("peer2")

cl.AdvanceBy(10 * time.Second)
require.False(t, r.Accept("peer3", false))

cl.AdvanceBy(21 * time.Second) // first request expired
require.True(t, r.Accept("peer1", false))
r.CompleteRequest("peer1")

cl.AdvanceBy(10 * time.Second)
require.True(t, r.Accept("peer3", true))
r.CompleteRequest("peer3")

cl.AdvanceBy(50 * time.Second)
require.False(t, r.Accept("peer3", true))
Expand Down

0 comments on commit acb1c88

Please sign in to comment.