Skip to content

Commit

Permalink
extract messaging components from IpfsDHT into its own struct. create…
Browse files Browse the repository at this point in the history
… a new struct that manages sending DHT messages that can be used independently from the DHT.
  • Loading branch information
aschmahmann committed May 28, 2020
1 parent 5da7e89 commit eca6287
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 166 deletions.
87 changes: 3 additions & 84 deletions dht.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package dht

import (
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -32,7 +31,6 @@ import (
goprocessctx "github.com/jbenet/goprocess/context"
"github.com/multiformats/go-base32"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
"go.opencensus.io/tag"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -101,8 +99,7 @@ type IpfsDHT struct {
ctx context.Context
proc goprocess.Process

strmap map[peer.ID]*messageSender
smlk sync.Mutex
protoMessenger *ProtocolMessenger

plk sync.Mutex

Expand Down Expand Up @@ -183,6 +180,7 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error)
dht.enableValues = cfg.enableValues

dht.Validator = cfg.validator
dht.protoMessenger = NewProtocolMessenger(dht.host, dht.protocols, dht.Validator)

dht.auto = cfg.mode
switch cfg.mode {
Expand Down Expand Up @@ -273,7 +271,6 @@ func makeDHT(ctx context.Context, h host.Host, cfg config) (*IpfsDHT, error) {
selfKey: kb.ConvertPeerID(h.ID()),
peerstore: h.Peerstore(),
host: h,
strmap: make(map[peer.ID]*messageSender),
birth: time.Now(),
protocols: protocols,
protocolsStrs: protocol.ConvertToStrings(protocols),
Expand Down Expand Up @@ -477,67 +474,8 @@ func (dht *IpfsDHT) persistRTPeersInPeerStore() {
}
}

// putValueToPeer stores the given key/value pair at the peer 'p'
func (dht *IpfsDHT) putValueToPeer(ctx context.Context, p peer.ID, rec *recpb.Record) error {
pmes := pb.NewMessage(pb.Message_PUT_VALUE, rec.Key, 0)
pmes.Record = rec
rpmes, err := dht.sendRequest(ctx, p, pmes)
if err != nil {
logger.Debugw("failed to put value to peer", "to", p, "key", loggableKeyBytes(rec.Key), "error", err)
return err
}

if !bytes.Equal(rpmes.GetRecord().Value, pmes.GetRecord().Value) {
logger.Infow("value not put correctly", "put-message", pmes, "get-message", rpmes)
return errors.New("value not put correctly")
}

return nil
}

var errInvalidRecord = errors.New("received invalid record")

// getValueOrPeers queries a particular peer p for the value for
// key. It returns either the value or a list of closer peers.
// NOTE: It will update the dht's peerstore with any new addresses
// it finds for the given peer.
func (dht *IpfsDHT) getValueOrPeers(ctx context.Context, p peer.ID, key string) (*recpb.Record, []*peer.AddrInfo, error) {
pmes, err := dht.getValueSingle(ctx, p, key)
if err != nil {
return nil, nil, err
}

// Perhaps we were given closer peers
peers := pb.PBPeersToPeerInfos(pmes.GetCloserPeers())

if record := pmes.GetRecord(); record != nil {
// Success! We were given the value
logger.Debug("got value")

// make sure record is valid.
err = dht.Validator.Validate(string(record.GetKey()), record.GetValue())
if err != nil {
logger.Debug("received invalid record (discarded)")
// return a sentinal to signify an invalid record was received
err = errInvalidRecord
record = new(recpb.Record)
}
return record, peers, err
}

if len(peers) > 0 {
return nil, peers, nil
}

return nil, nil, routing.ErrNotFound
}

// getValueSingle simply performs the get value RPC with the given parameters
func (dht *IpfsDHT) getValueSingle(ctx context.Context, p peer.ID, key string) (*pb.Message, error) {
pmes := pb.NewMessage(pb.Message_GET_VALUE, []byte(key), 0)
return dht.sendRequest(ctx, p, pmes)
}

// getLocal attempts to retrieve the value from the datastore
func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) {
logger.Debugw("finding value in datastore", "key", loggableKeyString(key))
Expand Down Expand Up @@ -627,17 +565,6 @@ func (dht *IpfsDHT) FindLocal(id peer.ID) peer.AddrInfo {
}
}

// findPeerSingle asks peer 'p' if they know where the peer with id 'id' is
func (dht *IpfsDHT) findPeerSingle(ctx context.Context, p peer.ID, id peer.ID) (*pb.Message, error) {
pmes := pb.NewMessage(pb.Message_FIND_NODE, []byte(id), 0)
return dht.sendRequest(ctx, p, pmes)
}

func (dht *IpfsDHT) findProvidersSingle(ctx context.Context, p peer.ID, key multihash.Multihash) (*pb.Message, error) {
pmes := pb.NewMessage(pb.Message_GET_PROVIDERS, key, 0)
return dht.sendRequest(ctx, p, pmes)
}

// nearestPeersToQuery returns the routing tables closest peers.
func (dht *IpfsDHT) nearestPeersToQuery(pmes *pb.Message, count int) []peer.ID {
closer := dht.routingTable.NearestPeers(kb.ConvertKey(string(pmes.GetKey())), count)
Expand Down Expand Up @@ -778,15 +705,7 @@ func (dht *IpfsDHT) Host() host.Host {

// Ping sends a ping message to the passed peer and waits for a response.
func (dht *IpfsDHT) Ping(ctx context.Context, p peer.ID) error {
req := pb.NewMessage(pb.Message_PING, nil, 0)
resp, err := dht.sendRequest(ctx, p, req)
if err != nil {
return fmt.Errorf("sending request: %w", err)
}
if resp.Type != pb.Message_PING {
return fmt.Errorf("got unexpected response type: %v", resp.Type)
}
return nil
return dht.protoMessenger.Ping(ctx, p)
}

// newContextWithLocalTags returns a new context.Context with the InstanceID and
Expand Down
72 changes: 50 additions & 22 deletions dht_net.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"time"

"github.com/libp2p/go-libp2p-core/helpers"
"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/protocol"

"github.com/libp2p/go-libp2p-kad-dht/metrics"
pb "github.com/libp2p/go-libp2p-kad-dht/pb"
Expand Down Expand Up @@ -208,12 +210,38 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool {
}
}

type messageManager struct {
host host.Host // the network services we need
strmap map[peer.ID]*messageSender
smlk sync.Mutex
protocols []protocol.ID
}

func (m *messageManager) streamDisconnect(ctx context.Context, p peer.ID) {
m.smlk.Lock()
defer m.smlk.Unlock()
ms, ok := m.strmap[p]
if !ok {
return
}
delete(m.strmap, p)

// Do this asynchronously as ms.lk can block for a while.
go func() {
if err := ms.lk.Lock(ctx); err != nil {
return
}
defer ms.lk.Unlock()
ms.invalidate()
}()
}

// sendRequest sends out a request, but also makes sure to
// measure the RTT for latency measurements.
func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
func (m *messageManager) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes))

ms, err := dht.messageSenderForPeer(ctx, p)
ms, err := m.messageSenderForPeer(ctx, p)
if err != nil {
stats.Record(ctx,
metrics.SentRequests.M(1),
Expand All @@ -240,15 +268,15 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message
metrics.SentBytes.M(int64(pmes.Size())),
metrics.OutboundRequestLatency.M(float64(time.Since(start))/float64(time.Millisecond)),
)
dht.peerstore.RecordLatency(p, time.Since(start))
m.host.Peerstore().RecordLatency(p, time.Since(start))
return rpmes, nil
}

// sendMessage sends out a message
func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error {
func (m *messageManager) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error {
ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes))

ms, err := dht.messageSenderForPeer(ctx, p)
ms, err := m.messageSenderForPeer(ctx, p)
if err != nil {
stats.Record(ctx,
metrics.SentMessages.M(1),
Expand All @@ -274,30 +302,30 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message
return nil
}

func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) {
dht.smlk.Lock()
ms, ok := dht.strmap[p]
func (m *messageManager) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) {
m.smlk.Lock()
ms, ok := m.strmap[p]
if ok {
dht.smlk.Unlock()
m.smlk.Unlock()
return ms, nil
}
ms = &messageSender{p: p, dht: dht, lk: newCtxMutex()}
dht.strmap[p] = ms
dht.smlk.Unlock()
ms = &messageSender{p: p, m: m, lk: newCtxMutex()}
m.strmap[p] = ms
m.smlk.Unlock()

if err := ms.prepOrInvalidate(ctx); err != nil {
dht.smlk.Lock()
defer dht.smlk.Unlock()
m.smlk.Lock()
defer m.smlk.Unlock()

if msCur, ok := dht.strmap[p]; ok {
if msCur, ok := m.strmap[p]; ok {
// Changed. Use the new one, old one is invalid and
// not in the map so we can just throw it away.
if ms != msCur {
return msCur, nil
}
// Not changed, remove the now invalid stream from the
// map.
delete(dht.strmap, p)
delete(m.strmap, p)
}
// Invalid but not in map. Must have been removed by a disconnect.
return nil, err
Expand All @@ -307,11 +335,11 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
}

type messageSender struct {
s network.Stream
r msgio.ReadCloser
lk ctxMutex
p peer.ID
dht *IpfsDHT
s network.Stream
r msgio.ReadCloser
lk ctxMutex
p peer.ID
m *messageManager

invalid bool
singleMes int
Expand Down Expand Up @@ -352,7 +380,7 @@ func (ms *messageSender) prep(ctx context.Context) error {
// We only want to speak to peers using our primary protocols. We do not want to query any peer that only speaks
// one of the secondary "server" protocols that we happen to support (e.g. older nodes that we can respond to for
// backwards compatibility reasons).
nstr, err := ms.dht.host.NewStream(ctx, ms.p, ms.dht.protocols...)
nstr, err := ms.m.host.NewStream(ctx, ms.p, ms.m.protocols...)
if err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,14 +570,14 @@ func TestInvalidMessageSenderTracking(t *testing.T) {
defer dht.Close()

foo := peer.ID("asdasd")
_, err := dht.messageSenderForPeer(ctx, foo)
_, err := dht.protoMessenger.m.messageSenderForPeer(ctx, foo)
if err == nil {
t.Fatal("that shouldnt have succeeded")
}

dht.smlk.Lock()
mscnt := len(dht.strmap)
dht.smlk.Unlock()
dht.protoMessenger.m.smlk.Lock()
mscnt := len(dht.protoMessenger.m.strmap)
dht.protoMessenger.m.smlk.Unlock()

if mscnt > 0 {
t.Fatal("should have no message senders in map")
Expand Down
4 changes: 1 addition & 3 deletions lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/libp2p/go-libp2p-core/routing"

"github.com/ipfs/go-cid"
pb "github.com/libp2p/go-libp2p-kad-dht/pb"
kb "github.com/libp2p/go-libp2p-kbucket"
"github.com/multiformats/go-base32"
)
Expand Down Expand Up @@ -89,12 +88,11 @@ func (dht *IpfsDHT) GetClosestPeersSeeded(ctx context.Context, key string, seedP
ID: p,
})

pmes, err := dht.findPeerSingle(ctx, p, peer.ID(key))
peers, err := dht.protoMessenger.GetClosestPeers(ctx, p, peer.ID(key))
if err != nil {
logger.Debugf("error getting closer peers: %s", err)
return nil, err
}
peers := pb.PBPeersToPeerInfos(pmes.GetCloserPeers())

// For DHT query command
routing.PublishQueryEvent(ctx, &routing.QueryEvent{
Expand Down

0 comments on commit eca6287

Please sign in to comment.