diff --git a/p2p/host/peerstore/pstoreds/addr_book.go b/p2p/host/peerstore/pstoreds/addr_book.go index 12cb0814e8..88082102bb 100644 --- a/p2p/host/peerstore/pstoreds/addr_book.go +++ b/p2p/host/peerstore/pstoreds/addr_book.go @@ -207,6 +207,9 @@ func (ab *dsAddrBook) Close() error { // // If the cache argument is true, the record is inserted in the cache when loaded from the datastore. func (ab *dsAddrBook) loadRecord(id peer.ID, cache bool, update bool) (pr *addrsRecord, err error) { + if err := id.Validate(); err != nil { + return nil, err + } if e, ok := ab.cache.Get(id); ok { pr = e.(*addrsRecord) pr.Lock() @@ -421,6 +424,11 @@ func (ab *dsAddrBook) AddrStream(ctx context.Context, p peer.ID) <-chan ma.Multi // ClearAddrs will delete all known addresses for a peer ID. func (ab *dsAddrBook) ClearAddrs(p peer.ID) { + if err := p.Validate(); err != nil { + // nothing to do + return + } + ab.cache.Remove(p) key := addrBookBase.ChildString(b32.RawStdEncoding.EncodeToString([]byte(p))) diff --git a/p2p/host/peerstore/pstoreds/metadata.go b/p2p/host/peerstore/pstoreds/metadata.go index 4d285af789..bf7655231c 100644 --- a/p2p/host/peerstore/pstoreds/metadata.go +++ b/p2p/host/peerstore/pstoreds/metadata.go @@ -41,6 +41,9 @@ func NewPeerMetadata(_ context.Context, store ds.Datastore, _ Options) (*dsPeerM } func (pm *dsPeerMetadata) Get(p peer.ID, key string) (interface{}, error) { + if err := p.Validate(); err != nil { + return nil, err + } k := pmBase.ChildString(base32.RawStdEncoding.EncodeToString([]byte(p))).ChildString(key) value, err := pm.ds.Get(k) if err != nil { @@ -58,6 +61,9 @@ func (pm *dsPeerMetadata) Get(p peer.ID, key string) (interface{}, error) { } func (pm *dsPeerMetadata) Put(p peer.ID, key string, val interface{}) error { + if err := p.Validate(); err != nil { + return err + } k := pmBase.ChildString(base32.RawStdEncoding.EncodeToString([]byte(p))).ChildString(key) var buf pool.Buffer if err := gob.NewEncoder(&buf).Encode(&val); err != nil { diff --git a/p2p/host/peerstore/pstoreds/protobook.go b/p2p/host/peerstore/pstoreds/protobook.go index 27e70f556d..84d546b912 100644 --- a/p2p/host/peerstore/pstoreds/protobook.go +++ b/p2p/host/peerstore/pstoreds/protobook.go @@ -39,6 +39,10 @@ func NewProtoBook(meta pstore.PeerMetadata) *dsProtoBook { } func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error { + if err := p.Validate(); err != nil { + return err + } + s := pb.segments.get(p) s.Lock() defer s.Unlock() @@ -52,6 +56,10 @@ func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error { } func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error { + if err := p.Validate(); err != nil { + return err + } + s := pb.segments.get(p) s.Lock() defer s.Unlock() @@ -69,6 +77,10 @@ func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error { } func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) { + if err := p.Validate(); err != nil { + return nil, err + } + s := pb.segments.get(p) s.RLock() defer s.RUnlock() @@ -87,6 +99,10 @@ func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) { } func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) { + if err := p.Validate(); err != nil { + return nil, err + } + s := pb.segments.get(p) s.RLock() defer s.RUnlock() @@ -107,6 +123,10 @@ func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, } func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { + if err := p.Validate(); err != nil { + return err + } + s := pb.segments.get(p) s.Lock() defer s.Unlock() diff --git a/p2p/host/peerstore/pstoremem/addr_book.go b/p2p/host/peerstore/pstoremem/addr_book.go index cb7340f34b..48f7f44996 100644 --- a/p2p/host/peerstore/pstoremem/addr_book.go +++ b/p2p/host/peerstore/pstoremem/addr_book.go @@ -196,6 +196,11 @@ func (mab *memoryAddrBook) ConsumePeerRecord(recordEnvelope *record.Envelope, tt } func (mab *memoryAddrBook) addAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration, signed bool) { + if err := p.Validate(); err != nil { + log.Warningf("tried to set addrs for invalid peer ID %s: %s", p, err) + return + } + // if ttl is zero, exit. nothing to do. if ttl <= 0 { return @@ -244,12 +249,22 @@ func (mab *memoryAddrBook) addAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Du // SetAddr calls mgr.SetAddrs(p, addr, ttl) func (mab *memoryAddrBook) SetAddr(p peer.ID, addr ma.Multiaddr, ttl time.Duration) { + if err := p.Validate(); err != nil { + log.Warningf("tried to set addrs for invalid peer ID %s: %s", p, err) + return + } + mab.SetAddrs(p, []ma.Multiaddr{addr}, ttl) } // SetAddrs sets the ttl on addresses. This clears any TTL there previously. // This is used when we receive the best estimate of the validity of an address. func (mab *memoryAddrBook) SetAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Duration) { + if err := p.Validate(); err != nil { + log.Warningf("tried to set addrs for invalid peer ID %s: %s", p, err) + return + } + s := mab.segments.get(p) s.Lock() defer s.Unlock() @@ -287,6 +302,11 @@ func (mab *memoryAddrBook) SetAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Du // UpdateAddrs updates the addresses associated with the given peer that have // the given oldTTL to have the given newTTL. func (mab *memoryAddrBook) UpdateAddrs(p peer.ID, oldTTL time.Duration, newTTL time.Duration) { + if err := p.Validate(); err != nil { + log.Warningf("tried to set addrs for invalid peer ID %s: %s", p, err) + return + } + s := mab.segments.get(p) s.Lock() defer s.Unlock() @@ -310,6 +330,11 @@ func (mab *memoryAddrBook) UpdateAddrs(p peer.ID, oldTTL time.Duration, newTTL t // Addrs returns all known (and valid) addresses for a given peer func (mab *memoryAddrBook) Addrs(p peer.ID) []ma.Multiaddr { + if err := p.Validate(); err != nil { + // invalid peer ID = no addrs + return nil + } + s := mab.segments.get(p) s.RLock() defer s.RUnlock() @@ -336,6 +361,11 @@ func validAddrs(amap map[string]*expiringAddr) []ma.Multiaddr { // given peer id, if one exists. // Returns nil if no signed PeerRecord exists for the peer. func (mab *memoryAddrBook) GetPeerRecord(p peer.ID) *record.Envelope { + if err := p.Validate(); err != nil { + // invalid peer ID = no addrs + return nil + } + s := mab.segments.get(p) s.RLock() defer s.RUnlock() @@ -356,6 +386,11 @@ func (mab *memoryAddrBook) GetPeerRecord(p peer.ID) *record.Envelope { // ClearAddrs removes all previously stored addresses func (mab *memoryAddrBook) ClearAddrs(p peer.ID) { + if err := p.Validate(); err != nil { + // nothing to clear + return + } + s := mab.segments.get(p) s.Lock() defer s.Unlock() @@ -367,6 +402,13 @@ func (mab *memoryAddrBook) ClearAddrs(p peer.ID) { // AddrStream returns a channel on which all new addresses discovered for a // given peer ID will be published. func (mab *memoryAddrBook) AddrStream(ctx context.Context, p peer.ID) <-chan ma.Multiaddr { + if err := p.Validate(); err != nil { + log.Warningf("tried to get addrs for invalid peer ID %s: %s", p, err) + ch := make(chan ma.Multiaddr) + close(ch) + return ch + } + s := mab.segments.get(p) s.RLock() defer s.RUnlock() diff --git a/p2p/host/peerstore/pstoremem/metadata.go b/p2p/host/peerstore/pstoremem/metadata.go index 7ded769192..7bd0fb3d72 100644 --- a/p2p/host/peerstore/pstoremem/metadata.go +++ b/p2p/host/peerstore/pstoremem/metadata.go @@ -35,6 +35,9 @@ func NewPeerMetadata() *memoryPeerMetadata { } func (ps *memoryPeerMetadata) Put(p peer.ID, key string, val interface{}) error { + if err := p.Validate(); err != nil { + return err + } ps.dslock.Lock() defer ps.dslock.Unlock() if vals, ok := val.(string); ok && internKeys[key] { @@ -49,6 +52,9 @@ func (ps *memoryPeerMetadata) Put(p peer.ID, key string, val interface{}) error } func (ps *memoryPeerMetadata) Get(p peer.ID, key string) (interface{}, error) { + if err := p.Validate(); err != nil { + return nil, err + } ps.dslock.RLock() defer ps.dslock.RUnlock() i, ok := ps.ds[metakey{p, key}] diff --git a/p2p/host/peerstore/pstoremem/protobook.go b/p2p/host/peerstore/pstoremem/protobook.go index 04d8ec47ae..af44c6fe38 100644 --- a/p2p/host/peerstore/pstoremem/protobook.go +++ b/p2p/host/peerstore/pstoremem/protobook.go @@ -67,6 +67,10 @@ func (pb *memoryProtoBook) internProtocol(proto string) string { } func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...string) error { + if err := p.Validate(); err != nil { + return err + } + s := pb.segments.get(p) s.Lock() defer s.Unlock() @@ -82,6 +86,10 @@ func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...string) error { } func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...string) error { + if err := p.Validate(); err != nil { + return err + } + s := pb.segments.get(p) s.Lock() defer s.Unlock() @@ -100,6 +108,10 @@ func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...string) error { } func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) { + if err := p.Validate(); err != nil { + return nil, err + } + s := pb.segments.get(p) s.RLock() defer s.RUnlock() @@ -113,6 +125,10 @@ func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) { } func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { + if err := p.Validate(); err != nil { + return err + } + s := pb.segments.get(p) s.Lock() defer s.Unlock() @@ -130,6 +146,10 @@ func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { } func (pb *memoryProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) { + if err := p.Validate(); err != nil { + return nil, err + } + s := pb.segments.get(p) s.RLock() defer s.RUnlock() diff --git a/p2p/host/peerstore/test/peerstore_suite.go b/p2p/host/peerstore/test/peerstore_suite.go index 8a629624be..5797297aec 100644 --- a/p2p/host/peerstore/test/peerstore_suite.go +++ b/p2p/host/peerstore/test/peerstore_suite.go @@ -279,6 +279,29 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) { if !reflect.DeepEqual(supported, protos[2:]) { t.Fatal("expected only one protocol to remain") } + + // test bad peer IDs + badp := peer.ID("") + + err = ps.AddProtocols(badp, protos...) + if err == nil { + t.Fatal("expected error when using a bad peer ID") + } + + _, err = ps.GetProtocols(badp) + if err == nil || err == pstore.ErrNotFound { + t.Fatal("expected error when using a bad peer ID") + } + + _, err = ps.SupportsProtocols(badp, "q", "w", "a", "y", "b") + if err == nil || err == pstore.ErrNotFound { + t.Fatal("expected error when using a bad peer ID") + } + + err = ps.RemoveProtocols(badp) + if err == nil || err == pstore.ErrNotFound { + t.Fatal("expected error when using a bad peer ID") + } } } @@ -309,6 +332,10 @@ func testBasicPeerstore(ps pstore.Peerstore) func(t *testing.T) { if !pinfo.Addrs[0].Equal(addrs[0]) { t.Fatal("stored wrong address") } + + // should fail silently... + ps.AddAddrs("", addrs, pstore.PermanentAddrTTL) + ps.Addrs("") } } @@ -355,6 +382,12 @@ func testMetadata(ps pstore.Peerstore) func(t *testing.T) { continue } } + if err := ps.Put("", "foobar", "thing"); err == nil { + t.Errorf("expected error for bad peer ID") + } + if _, err := ps.Get("", "foobar"); err == nil || err == pstore.ErrNotFound { + t.Errorf("expected error for bad peer ID") + } } }