Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basichost: don't wait for Identify #2551

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
71 changes: 40 additions & 31 deletions p2p/host/basic/basic_host.go
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"time"

"github.com/libp2p/go-libp2p/core/connmgr"
Expand Down Expand Up @@ -646,24 +647,32 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
return nil, fmt.Errorf("failed to open stream: %w", err)
}

// Wait for any in-progress identifies on the connection to finish. This
// is faster than negotiating.
//
// If the other side doesn't support identify, that's fine. This will
// just be a no-op.
select {
case <-h.ids.IdentifyWait(s.Conn()):
case <-ctx.Done():
_ = s.Reset()
return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err())
}
// If pids contains only a single protocol, optimistically use that protocol (i.e. don't wait for
// multistream negotiation).
Comment on lines +650 to +651
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a summary of the algorithm to the documentation for the method NewStream

var pref protocol.ID
if len(pids) == 1 {
pref = pids[0]
} else if len(pids) > 1 {
// Wait for any in-progress identifies on the connection to finish.
// This is faster than negotiating.
// If the other side doesn't support identify, that's fine. This will just be a no-op.
select {
case <-h.ids.IdentifyWait(s.Conn()):
case <-ctx.Done():
_ = s.Reset()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Is just s.Reset() better?

return nil, fmt.Errorf("identify failed to complete: %w", ctx.Err())
}

pref, err := h.preferredProtocol(p, pids)
if err != nil {
_ = s.Reset()
return nil, err
// If Identify has finished, we know which protocols the peer supports.
// We don't need to do a multistream negotiation.
// Instead, we just pick the first supported protocol.
var err error
pref, err = h.preferredProtocol(p, pids)
if err != nil {
_ = s.Reset()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Is just s.Reset() better?

return nil, err
}
}

if pref != "" {
if err := s.SetProtocol(pref); err != nil {
return nil, err
Expand Down Expand Up @@ -736,22 +745,10 @@ func (h *BasicHost) Connect(ctx context.Context, pi peer.AddrInfo) error {
// the connection once it has been opened.
func (h *BasicHost) dialPeer(ctx context.Context, p peer.ID) error {
log.Debugf("host %s dialing %s", h.ID(), p)
c, err := h.Network().DialPeer(ctx, p)
if err != nil {
if _, err := h.Network().DialPeer(ctx, p); err != nil {
return fmt.Errorf("failed to dial: %w", err)
}

// TODO: Consider removing this? On one hand, it's nice because we can
// assume that things like the agent version are usually set when this
// returns. On the other hand, we don't _really_ need to wait for this.
//
// This is mostly here to preserve existing behavior.
select {
case <-h.ids.IdentifyWait(c):
case <-ctx.Done():
return fmt.Errorf("identify failed to complete: %w", ctx.Err())
}

log.Debugf("host %s finished dialing %s", h.ID(), p)
return nil
}
Expand Down Expand Up @@ -1049,14 +1046,26 @@ func (h *BasicHost) Close() error {
type streamWrapper struct {
network.Stream
rw io.ReadWriteCloser

calledRead atomic.Bool
Comment on lines +1049 to +1050
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we document why we need this? Maybe with a link to multiformats/go-multistream#20

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this streamWrapper?
All it does is wrap CloseWrite, which can be concurrent with Read and Write.
@MarcoPolo @Stebalien

}

func (s *streamWrapper) Read(b []byte) (int, error) {
return s.rw.Read(b)
n, err := s.rw.Read(b)
if s.calledRead.CompareAndSwap(false, true) {
if errors.Is(err, network.ErrReset) {
return n, msmux.ErrNotSupported[protocol.ID]{Protos: []protocol.ID{s.Protocol()}}
}
}
return n, err
}

func (s *streamWrapper) Write(b []byte) (int, error) {
return s.rw.Write(b)
n, err := s.rw.Write(b)
if s.calledRead.Load() && errors.Is(err, network.ErrReset) {
return n, msmux.ErrNotSupported[protocol.ID]{Protos: []protocol.ID{s.Protocol()}}
Comment on lines +1065 to +1066
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be !s.calledRead.Load()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a race condition here, but I think it doesn't matter becuase I'm not sure if you can use streams in the way that is required to trigger this condition.

The race is:
goroutine1:
does successful read and is going to do CompareAndSwap

then:
goroutine2:
does write and receives StreamReset for some reason(can it?).

now goroutine2 does the if s.calledRead.Load() && errors.Is(err, network.ErrReset) before goroutine1 could do CompareAndSwap.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the race condition. I think this should be !s.calledRead.Load(). My assumption is that the goal here is that either Read or Write return a msmux.ErrNotSupported on this specific type of error, and that it's okay if both return a msmux.ErrNotSupported. I don't think you can enter a race condition where neither return `msmux.ErrNotSupported, but you can enter one where both return it (which is okay).

Copy link
Member

@sukunrt sukunrt Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now think this whole logic should be a part of multistream.lazyClientConn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I agree.

}
return n, err
}

func (s *streamWrapper) Close() error {
Expand Down
16 changes: 12 additions & 4 deletions p2p/host/basic/basic_host_test.go
Expand Up @@ -535,8 +535,17 @@ func TestProtoDowngrade(t *testing.T) {
// This is _almost_ instantaneous, but this test fails once every ~1k runs without this.
time.Sleep(time.Millisecond)

sub, err := h1.EventBus().Subscribe(&event.EvtPeerIdentificationCompleted{})
require.NoError(t, err)
defer sub.Close()

h2pi := h2.Peerstore().PeerInfo(h2.ID())
require.NoError(t, h1.Connect(ctx, h2pi))
select {
case <-sub.Out():
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}

s2, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing")
require.NoError(t, err)
Expand Down Expand Up @@ -704,13 +713,12 @@ func TestHostAddrChangeDetection(t *testing.T) {
}

func TestNegotiationCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

h1, h2 := getHostPair(t)
defer h1.Close()
defer h2.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// pre-negotiation so we can make the negotiation hang.
h2.Network().SetStreamHandler(func(s network.Stream) {
<-ctx.Done() // wait till the test is done.
Expand All @@ -722,7 +730,7 @@ func TestNegotiationCancel(t *testing.T) {

errCh := make(chan error, 1)
go func() {
s, err := h1.NewStream(ctx2, h2.ID(), "/testing")
s, err := h1.NewStream(ctx2, h2.ID(), "/testing", "/testing2")
if s != nil {
errCh <- fmt.Errorf("expected to fail negotiation")
return
Expand Down
2 changes: 1 addition & 1 deletion p2p/protocol/circuitv2/client/reservation.go
Expand Up @@ -89,7 +89,7 @@ func Reserve(ctx context.Context, h host.Host, ai peer.AddrInfo) (*Reservation,

if err := rd.ReadMsg(&msg); err != nil {
s.Reset()
return nil, ReservationError{Status: pbv2.Status_CONNECTION_FAILED, Reason: "error reading reservation response message: %w", err: err}
return nil, ReservationError{Status: pbv2.Status_CONNECTION_FAILED, Reason: "error reading reservation response message", err: err}
}

if msg.GetType() != pbv2.HopMessage_STATUS {
Expand Down
11 changes: 7 additions & 4 deletions p2p/protocol/holepunch/holepunch_test.go
Expand Up @@ -123,7 +123,8 @@ func TestDirectDialWorks(t *testing.T) {
require.Empty(t, h1.Network().ConnsToPeer(h2.ID()))
require.NoError(t, h1ps.DirectConnect(h2.ID()))
require.GreaterOrEqual(t, len(h1.Network().ConnsToPeer(h2.ID())), 1)
require.GreaterOrEqual(t, len(h2.Network().ConnsToPeer(h1.ID())), 1)
// h1 might finish the handshake first, but h2 should see the connection shortly after
require.Eventually(t, func() bool { return len(h2.Network().ConnsToPeer(h1.ID())) > 0 }, time.Second, 25*time.Millisecond)
events := tr.getEvents()
require.Len(t, events, 1)
require.Equal(t, holepunch.DirectDialEvtT, events[0].Type)
Expand Down Expand Up @@ -340,9 +341,10 @@ func TestFailuresOnResponder(t *testing.T) {
defer relay.Close()

s, err := h2.NewStream(network.WithUseTransient(context.Background(), "holepunch"), h1.ID(), holepunch.Protocol)
require.NoError(t, err)

go tc.initiator(s)
// h1 will reset the stream. This might or might not happen before multistream has finished.
if err == nil {
go tc.initiator(s)
}

getTracerError := func(tr *mockEventTracer) []string {
var errs []string
Expand Down Expand Up @@ -487,6 +489,7 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc
ID: h2.ID(),
Addrs: []ma.Multiaddr{raddr},
}))
require.Eventually(t, func() bool { return len(h2.Network().ConnsToPeer(h1.ID())) > 0 }, time.Second, 50*time.Millisecond)
return
}

Expand Down
24 changes: 12 additions & 12 deletions p2p/protocol/identify/id_test.go
Expand Up @@ -473,25 +473,25 @@ func TestUserAgent(t *testing.T) {
defer cancel()

h1, err := libp2p.New(libp2p.UserAgent("foo"), libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer h1.Close()

h2, err := libp2p.New(libp2p.UserAgent("bar"), libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer h2.Close()

err = h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
if err != nil {
t.Fatal(err)
sub, err := h1.EventBus().Subscribe(&event.EvtPeerIdentificationCompleted{})
require.NoError(t, err)
defer sub.Close()

require.NoError(t, h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}))
select {
case <-sub.Out():
case <-time.After(time.Second):
t.Fatal("timeout")
}
av, err := h1.Peerstore().Get(h2.ID(), "AgentVersion")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
if ver, ok := av.(string); !ok || ver != "bar" {
t.Errorf("expected agent version %q, got %q", "bar", av)
}
Expand Down
2 changes: 2 additions & 0 deletions p2p/test/quic/quic_test.go
Expand Up @@ -61,6 +61,7 @@ func TestQUICAndWebTransport(t *testing.T) {
)
require.NoError(t, err)
require.NoError(t, h2.Connect(ctx, peer.AddrInfo{ID: h1.ID(), Addrs: h1.Addrs()}))
require.Eventually(t, func() bool { return len(h1.Network().ConnsToPeer(h2.ID())) > 0 }, time.Second, 25*time.Millisecond)
for _, conns := range [][]network.Conn{h2.Network().ConnsToPeer(h1.ID()), h1.Network().ConnsToPeer(h2.ID())} {
require.Len(t, conns, 1)
if _, err := conns[0].LocalMultiaddr().ValueForProtocol(ma.P_WEBTRANSPORT); err == nil {
Expand All @@ -78,6 +79,7 @@ func TestQUICAndWebTransport(t *testing.T) {
)
require.NoError(t, err)
require.NoError(t, h3.Connect(ctx, peer.AddrInfo{ID: h1.ID(), Addrs: h1.Addrs()}))
require.Eventually(t, func() bool { return len(h1.Network().ConnsToPeer(h3.ID())) > 0 }, time.Second, 25*time.Millisecond)
for _, conns := range [][]network.Conn{h3.Network().ConnsToPeer(h1.ID()), h1.Network().ConnsToPeer(h3.ID())} {
require.Len(t, conns, 1)
if _, err := conns[0].LocalMultiaddr().ValueForProtocol(ma.P_WEBTRANSPORT); err != nil {
Expand Down
12 changes: 11 additions & 1 deletion p2p/test/swarm/swarm_test.go
Expand Up @@ -193,6 +193,7 @@ func TestLimitStreamsWhenHangingHandlers(t *testing.T) {

// Open streamLimit streams
success := 0
errCnt := 0
// we make a lot of tries because identify and identify push take up a few streams
for i := 0; i < 1000 && success < streamLimit; i++ {
mgr, err = rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(rcmgr.InfiniteLimits))
Expand All @@ -206,6 +207,7 @@ func TestLimitStreamsWhenHangingHandlers(t *testing.T) {

s, err := sender.NewStream(context.Background(), receiver.ID(), pid)
if err != nil {
errCnt++
continue
}

Expand All @@ -227,7 +229,11 @@ func TestLimitStreamsWhenHangingHandlers(t *testing.T) {

sender.Peerstore().AddAddrs(receiver.ID(), receiver.Addrs(), peerstore.PermanentAddrTTL)

_, err = sender.NewStream(context.Background(), receiver.ID(), pid)
s, err := sender.NewStream(context.Background(), receiver.ID(), pid)
// stream is not received by the peer before the first write or read
require.NoError(t, err)
var b [1]byte
_, err = io.ReadFull(s, b[:])
require.Error(t, err)

// Close the open streams
Expand All @@ -236,6 +242,10 @@ func TestLimitStreamsWhenHangingHandlers(t *testing.T) {
// Next call should succeed
require.Eventually(t, func() bool {
s, err := sender.NewStream(context.Background(), receiver.ID(), pid)
// stream is not received by the peer before the first write or read
require.NoError(t, err)
var b [1]byte
_, err = io.ReadFull(s, b[:])
if err == nil {
s.Close()
return true
Expand Down
9 changes: 6 additions & 3 deletions p2p/test/transport/gating_test.go
Expand Up @@ -181,7 +181,8 @@ func TestInterceptAccept(t *testing.T) {
}

h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
// use two protocols here, so we actually enter multistream negotiation
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID)
require.Error(t, err)
if _, err := h2.Addrs()[0].ValueForProtocol(ma.P_WEBRTC_DIRECT); err != nil {
// WebRTC rejects connection attempt before an error can be sent to the client.
Expand Down Expand Up @@ -218,7 +219,8 @@ func TestInterceptSecuredIncoming(t *testing.T) {
}),
)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
// use two protocols here, so we actually enter multistream negotiation
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID)
require.Error(t, err)
require.NotErrorIs(t, err, context.DeadlineExceeded)
})
Expand Down Expand Up @@ -254,7 +256,8 @@ func TestInterceptUpgradedIncoming(t *testing.T) {
}),
)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
// use two protocols here, so we actually enter multistream negotiation
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID, protocol.TestingID)
require.Error(t, err)
require.NotErrorIs(t, err, context.DeadlineExceeded)
})
Expand Down
13 changes: 9 additions & 4 deletions p2p/test/transport/transport_test.go
Expand Up @@ -549,14 +549,19 @@ func TestListenerStreamResets(t *testing.T) {
}))

h1.SetStreamHandler("reset", func(s network.Stream) {
// Make sure the multistream negotiation actually succeeds before resetting.
// This is necessary because we don't have stream error codes yet.
s.Read(make([]byte, 4))
s.Write([]byte("pong"))
s.Read(make([]byte, 4))
s.Reset()
})

s, err := h2.NewStream(context.Background(), h1.ID(), "reset")
if err != nil {
require.ErrorIs(t, err, network.ErrReset)
return
}
require.NoError(t, err)
s.Write([]byte("ping"))
s.Read(make([]byte, 4))
s.Write([]byte("ping"))

_, err = s.Read([]byte{0})
require.ErrorIs(t, err, network.ErrReset)
Expand Down