diff --git a/network/p2p/middleware.go b/network/p2p/middleware.go index ba9fd508eea..1243a87e112 100644 --- a/network/p2p/middleware.go +++ b/network/p2p/middleware.go @@ -13,6 +13,7 @@ import ( libp2pnetwork "github.com/libp2p/go-libp2p-core/network" "github.com/rs/zerolog" + "github.com/onflow/flow-go/crypto" "github.com/onflow/flow-go/engine" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" @@ -315,7 +316,7 @@ func (m *Middleware) handleIncomingStream(s libp2pnetwork.Stream) { log.Info().Msg("incoming stream received") //create a new readConnection with the context of the middleware - conn := newReadConnection(m.ctx, s, m.processMessage, log, m.metrics, LargeMsgMaxUnicastMsgSize) + conn := newReadConnection(m.ctx, s, m.processUnauthenticatedMessage, log, m.metrics, LargeMsgMaxUnicastMsgSize) // kick off the receive loop to continuously receive messages m.wg.Add(1) @@ -359,8 +360,28 @@ func (m *Middleware) Unsubscribe(channel network.Channel) error { return nil } -// processMessage processes a message and eventually passes it to the overlay -func (m *Middleware) processMessage(msg *message.Message) { +// processMessage processes a message and a source (indicated by its PublicKey) and eventually passes it to the overlay +func (m *Middleware) processMessage(msg *message.Message, originKey crypto.PublicKey) { + identities, err := m.ov.Identity() + if err != nil { + m.log.Error().Err(err).Msg("failed to retrieve identities list while delivering a message") + return + } + + // check the origin of the message corresponds to the one claimed in the OriginID + originID := flow.HashToID(msg.OriginID) + + originIdentity, found := identities[originID] + if !found || !originIdentity.NetworkPubKey.Equals(originKey) { + m.log.Warn().Msgf("message claiming to be from nodeID %v with key %x was actually signed by %x and dropped", originID, originIdentity.NetworkPubKey, originKey) + return + } + + m.processUnauthenticatedMessage(msg) +} + +// processUnAuthenticatedMessage processes a message and eventually passes it to the overlay +func (m *Middleware) processUnauthenticatedMessage(msg *message.Message) { // run through all the message validators for _, v := range m.validators { diff --git a/network/p2p/readSubscription.go b/network/p2p/readSubscription.go index f81756e3def..5e643bf2e55 100644 --- a/network/p2p/readSubscription.go +++ b/network/p2p/readSubscription.go @@ -5,9 +5,12 @@ import ( "strings" "sync" + "github.com/libp2p/go-libp2p-core/peer" pubsub "github.com/libp2p/go-libp2p-pubsub" + "github.com/rs/zerolog" + "github.com/onflow/flow-go/crypto" "github.com/onflow/flow-go/module" "github.com/onflow/flow-go/network/message" _ "github.com/onflow/flow-go/utils/binstat" @@ -20,13 +23,13 @@ type readSubscription struct { log zerolog.Logger sub *pubsub.Subscription metrics module.NetworkMetrics - callback func(msg *message.Message) + callback func(msg *message.Message, pubKey crypto.PublicKey) } // newReadSubscription reads the messages coming in on the subscription func newReadSubscription(ctx context.Context, sub *pubsub.Subscription, - callback func(msg *message.Message), + callback func(msg *message.Message, pubKey crypto.PublicKey), log zerolog.Logger, metrics module.NetworkMetrics) *readSubscription { @@ -87,10 +90,29 @@ func (r *readSubscription) receiveLoop(wg *sync.WaitGroup) { return } + // if pubsub.WithMessageSigning(true) and pubsub.WithStrictSignatureVerification(true), + // the emitter is authenticated + emitter, err := peer.IDFromBytes(rawMsg.From) + if err != nil { + r.log.Err(err).Msgf("failed to unmarshal peerID %v of a message", rawMsg.From) + return + } + // we use ECDSA, so the PeerID should be an identity multihash and this should succeed + pk, err := emitter.ExtractPublicKey() + if err != nil { + r.log.Err(err).Msgf("failed to extract public key of peerID %v", emitter.String()) + return + } + flowKey, err := FlowPublicKeyFromLibP2P(pk) + if err != nil { + r.log.Err(err).Msgf("failed to extract flow public key of libp2p key %v", err) + return + } + // log metrics r.metrics.NetworkMessageReceived(msg.Size(), msg.ChannelID, msg.Type) // call the callback - r.callback(&msg) + r.callback(&msg, flowKey) } } diff --git a/network/test/meshengine_test.go b/network/test/meshengine_test.go index f31e246d85e..603b6167179 100644 --- a/network/test/meshengine_test.go +++ b/network/test/meshengine_test.go @@ -169,7 +169,7 @@ func (suite *MeshEngineTestSuite) allToAllScenario(send ConduitSendWrapperFunc) for i := 0; i < pubsub.GossipSubD*count; i++ { select { case <-suite.obs: - case <-time.After(2 * time.Second): + case <-time.After(5 * time.Second): assert.FailNow(suite.T(), "could not receive pubsub tag indicating mesh formed") } } diff --git a/network/test/middleware_test.go b/network/test/middleware_test.go index 5584f2d0ee5..7d68f53cda0 100644 --- a/network/test/middleware_test.go +++ b/network/test/middleware_test.go @@ -268,6 +268,45 @@ func (m *MiddlewareTestSuite) TestEcho() { } } +// TestSpoofedPubSubHello evaluates checking the originID of the message w.r.t. its libp2p network ID on PubSub +// we check a pubsub message with a spoofed OriginID does not get delivered +// This would be doubled with cryptographic verification of the libp2p network ID in production (see message signing options in pubSub initialization) +func (m *MiddlewareTestSuite) TestSpoofedPubSubHello() { + first := 0 + last := m.size - 1 + firstNode := m.ids[first].NodeID + lastNode := m.ids[last].NodeID + + // initially subscribe the nodes to the channel + for _, mw := range m.mws { + err := mw.Subscribe(testChannel) + require.NoError(m.Suite.T(), err) + } + + // set up waiting for m.size pubsub tags indicating a mesh has formed + for i := 0; i < m.size; i++ { + select { + case <-m.obs: + case <-time.After(2 * time.Second): + assert.FailNow(m.T(), "could not receive pubsub tag indicating mesh formed") + } + } + + var spoofedID flow.Identifier + copy(spoofedID[:16], firstNode[:16]) + copy(spoofedID[16:], lastNode[16:]) + + message1 := createMessage(spoofedID, lastNode, "hello1") + + err := m.mws[first].Publish(message1, testChannel) + assert.NoError(m.T(), err) + + // assert that the spoofed message is not received by the target node + assert.Never(m.T(), func() bool { + return !m.ov[last].AssertNumberOfCalls(m.T(), "Receive", 0) + }, 2*time.Second, time.Millisecond) +} + // TestMaxMessageSize_SendDirect evaluates that invoking SendDirect method of the middleware on a message // size beyond the permissible unicast message size returns an error. func (m *MiddlewareTestSuite) TestMaxMessageSize_SendDirect() { @@ -395,7 +434,7 @@ func (m *MiddlewareTestSuite) TestUnsubscribe() { for i := 0; i < m.size; i++ { select { case <-m.obs: - case <-time.After(2 * time.Second): + case <-time.After(5 * time.Second): assert.FailNow(m.T(), "could not receive pubsub tag indicating mesh formed") } }