diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala index 3380308b65..1a333c8c78 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala @@ -58,7 +58,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co val wireLog = new BusLogging(context.system.eventStream, "", classOf[Diagnostics], context.system.asInstanceOf[ExtendedActorSystem].logFilter) with DiagnosticLoggingAdapter - def diag(message: T, direction: String) = { + def diag(message: T, direction: String): Unit = { require(direction == "IN" || direction == "OUT") val channelId_opt = Logs.channelId(message) wireLog.mdc(Logs.mdc(LogCategory(message), remoteNodeId_opt, channelId_opt)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala index a5c8a3ae30..961603998f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala @@ -105,7 +105,10 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A d.transport ! TransportHandler.Listener(self) Metrics.PeerConnectionsConnecting.withTag(Tags.ConnectionState, Tags.ConnectionStates.Initializing).increment() log.info(s"using features=$localFeatures") - val localInit = protocol.Init(localFeatures, TlvStream(InitTlv.Networks(chainHash :: Nil))) + val localInit = NodeAddress.extractIPAddress(d.pendingAuth.address) match { + case Some(remoteAddress) if !d.pendingAuth.outgoing => protocol.Init(localFeatures, TlvStream(InitTlv.Networks(chainHash :: Nil), InitTlv.RemoteAddress(remoteAddress))) + case _ => protocol.Init(localFeatures, TlvStream(InitTlv.Networks(chainHash :: Nil))) + } d.transport ! localInit startSingleTimer(INIT_TIMER, InitTimeout, conf.initTimeout) goto(INITIALIZING) using InitializingData(chainHash, d.pendingAuth, d.remoteNodeId, d.transport, peer, localInit, doSync, d.isPersistent) @@ -118,6 +121,7 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A d.transport ! TransportHandler.ReadAck(remoteInit) log.info(s"peer is using features=${remoteInit.features}, networks=${remoteInit.networks.mkString(",")}") + remoteInit.remoteAddress_opt.foreach(address => log.info("peer reports that our IP address is {} (public={})", address.socketAddress.toString, NodeAddress.isPublicIPAddress(address))) val featureGraphErr_opt = Features.validateFeatureGraph(remoteInit.features) if (remoteInit.networks.nonEmpty && remoteInit.networks.intersect(d.localInit.networks).isEmpty) { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala index 3974851076..1c5e58f01e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala @@ -21,7 +21,7 @@ import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Satoshi} import fr.acinq.eclair.blockchain.fee.FeeratePerKw import fr.acinq.eclair.channel.ChannelType -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, ShortChannelId, TimestampMilli, TimestampSecond, UInt64} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, ShortChannelId, TimestampSecond, UInt64} import scodec.bits.ByteVector import java.net.{Inet4Address, Inet6Address, InetAddress, InetSocketAddress} @@ -49,6 +49,7 @@ sealed trait HtlcSettlementMessage extends UpdateMessage { def id: Long } // <- case class Init(features: Features, tlvStream: TlvStream[InitTlv] = TlvStream.empty) extends SetupMessage { val networks = tlvStream.get[InitTlv.Networks].map(_.chainHashes).getOrElse(Nil) + val remoteAddress_opt = tlvStream.get[InitTlv.RemoteAddress].map(_.address) } case class Warning(channelId: ByteVector32, data: ByteVector, tlvStream: TlvStream[WarningTlv] = TlvStream.empty) extends SetupMessage with HasChannelId { @@ -215,28 +216,52 @@ case class Color(r: Byte, g: Byte, b: Byte) { // @formatter:off sealed trait NodeAddress { def socketAddress: InetSocketAddress } sealed trait OnionAddress extends NodeAddress +sealed trait IPAddress extends NodeAddress +// @formatter:on + object NodeAddress { /** - * Creates a NodeAddress from a host and port. - * - * Note that non-onion hosts will be resolved. - * - * We don't attempt to resolve onion addresses (it will be done by the tor proxy), so we just recognize them based on - * the .onion TLD and rely on their length to separate v2/v3. - */ + * Creates a NodeAddress from a host and port. + * + * Note that non-onion hosts will be resolved. + * + * We don't attempt to resolve onion addresses (it will be done by the tor proxy), so we just recognize them based on + * the .onion TLD and rely on their length to separate v2/v3. + */ def fromParts(host: String, port: Int): Try[NodeAddress] = Try { host match { case _ if host.endsWith(".onion") && host.length == 22 => Tor2(host.dropRight(6), port) case _ if host.endsWith(".onion") && host.length == 62 => Tor3(host.dropRight(6), port) - case _ => InetAddress.getByName(host) match { + case _ => InetAddress.getByName(host) match { case a: Inet4Address => IPv4(a, port) case a: Inet6Address => IPv6(a, port) } } } + + def extractIPAddress(socketAddress: InetSocketAddress): Option[IPAddress] = { + val address = socketAddress.getAddress + address match { + case address: Inet4Address => Some(IPv4(address, socketAddress.getPort)) + case address: Inet6Address => Some(IPv6(address, socketAddress.getPort)) + case _ => None + } + } + + private def isPrivate(address: InetAddress): Boolean = address.isAnyLocalAddress || address.isLoopbackAddress || address.isLinkLocalAddress || address.isSiteLocalAddress + + def isPublicIPAddress(address: NodeAddress): Boolean = { + address match { + case IPv4(ipv4, _) if !isPrivate(ipv4) => true + case IPv6(ipv6, _) if !isPrivate(ipv6) => true + case _ => false + } + } } -case class IPv4(ipv4: Inet4Address, port: Int) extends NodeAddress { override def socketAddress = new InetSocketAddress(ipv4, port) } -case class IPv6(ipv6: Inet6Address, port: Int) extends NodeAddress { override def socketAddress = new InetSocketAddress(ipv6, port) } + +// @formatter:off +case class IPv4(ipv4: Inet4Address, port: Int) extends IPAddress { override def socketAddress = new InetSocketAddress(ipv4, port) } +case class IPv6(ipv6: Inet6Address, port: Int) extends IPAddress { override def socketAddress = new InetSocketAddress(ipv6, port) } case class Tor2(tor2: String, port: Int) extends OnionAddress { override def socketAddress = InetSocketAddress.createUnresolved(tor2 + ".onion", port) } case class Tor3(tor3: String, port: Int) extends OnionAddress { override def socketAddress = InetSocketAddress.createUnresolved(tor3 + ".onion", port) } // @formatter:on diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/SetupAndControlTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/SetupAndControlTlv.scala index ebcb8a9252..881077fdbf 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/SetupAndControlTlv.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/SetupAndControlTlv.scala @@ -35,6 +35,12 @@ object InitTlv { /** The chains the node is interested in. */ case class Networks(chainHashes: List[ByteVector32]) extends InitTlv + /** + * When receiving an incoming connection, we can send back the public address our peer is connecting from. + * This lets our peer discover if its public IP has changed from within its local network. + */ + case class RemoteAddress(address: NodeAddress) extends InitTlv + } object InitTlvCodecs { @@ -42,9 +48,11 @@ object InitTlvCodecs { import InitTlv._ private val networks: Codec[Networks] = variableSizeBytesLong(varintoverflow, list(bytes32)).as[Networks] + private val remoteAddress: Codec[RemoteAddress] = variableSizeBytesLong(varintoverflow, nodeaddress).as[RemoteAddress] val initTlvCodec = tlvStream(discriminated[InitTlv].by(varint) .typecase(UInt64(1), networks) + .typecase(UInt64(3), remoteAddress) ) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala index 07a1ce48be..3e802efc57 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala @@ -18,8 +18,8 @@ package fr.acinq.eclair.io import akka.actor.PoisonPill import akka.testkit.{TestFSMRef, TestProbe} -import fr.acinq.bitcoin.{Block, ByteVector32} import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} +import fr.acinq.bitcoin.{Block, ByteVector32} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features.{BasicMultiPartPayment, ChannelRangeQueries, PaymentSecret, VariableLengthOnion} import fr.acinq.eclair.TestConstants._ @@ -93,6 +93,20 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) } + test("send incoming connection's remote address in init") { f => + import f._ + val probe = TestProbe() + val incomingConnection = PeerConnection.PendingAuth(connection.ref, None, fakeIPAddress.socketAddress, origin_opt = None, transport_opt = Some(transport.ref), isPersistent = true) + assert(!incomingConnection.outgoing) + probe.send(peerConnection, incomingConnection) + transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId)) + switchboard.expectMsg(PeerConnection.Authenticated(peerConnection, remoteNodeId)) + probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref, nodeParams.chainHash, nodeParams.features, doSync = false)) + transport.expectMsgType[TransportHandler.Listener] + val localInit = transport.expectMsgType[protocol.Init] + assert(localInit.remoteAddress_opt === Some(fakeIPAddress)) + } + test("handle connection closed during authentication") { f => import f._ val probe = TestProbe() @@ -417,7 +431,7 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi assert(peerConnection.stateName === PeerConnection.CONNECTED) probe.send(peerConnection, FundingLocked(ByteVector32(hex"0000000000000000000000000000000000000000000000000000000000000000"), randomKey().publicKey)) peerConnection.stateData match { - case d : PeerConnection.ConnectedData => assert(d.isPersistent) + case d: PeerConnection.ConnectedData => assert(d.isPersistent) case _ => fail() } } @@ -459,5 +473,23 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi peer.send(peerConnection, message) transport.expectMsg(message) } + + test("filter private IP addresses") { _ => + val testCases = Seq( + NodeAddress.fromParts("127.0.0.1", 9735).get -> false, + NodeAddress.fromParts("0.0.0.0", 9735).get -> false, + NodeAddress.fromParts("192.168.0.1", 9735).get -> false, + NodeAddress.fromParts("140.82.121.3", 9735).get -> true, + NodeAddress.fromParts("0000:0000:0000:0000:0000:0000:0000:0001", 9735).get -> false, + NodeAddress.fromParts("b643:8bb1:c1f9:0556:487c:0acb:2ba3:3cc2", 9735).get -> true, + NodeAddress.fromParts("hsmithsxurybd7uh.onion", 9735).get -> false, + NodeAddress.fromParts("iq7zhmhck54vcax2vlrdcavq2m32wao7ekh6jyeglmnuuvv3js57r4id.onion", 9735).get -> false, + ) + for ((address, expected) <- testCases) { + val isPublic = NodeAddress.isPublicIPAddress(address) + assert(isPublic === expected) + } + } + } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecsSpec.scala index 8fcbca9e59..94ad9a2388 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecsSpec.scala @@ -30,7 +30,7 @@ import org.scalatest.funsuite.AnyFunSuite import scodec.DecodeResult import scodec.bits.{BinStringSyntax, ByteVector, HexStringSyntax} -import java.net.{Inet4Address, InetAddress} +import java.net.{Inet4Address, Inet6Address, InetAddress} /** * Created by PM on 31/05/2016. @@ -49,24 +49,28 @@ class LightningMessageCodecsSpec extends AnyFunSuite { def publicKey(fill: Byte) = PrivateKey(ByteVector.fill(32)(fill)).publicKey test("encode/decode init message") { - case class TestCase(encoded: ByteVector, rawFeatures: ByteVector, networks: List[ByteVector32], valid: Boolean, reEncoded: Option[ByteVector] = None) + case class TestCase(encoded: ByteVector, rawFeatures: ByteVector, networks: List[ByteVector32], address: Option[IPAddress], valid: Boolean, reEncoded: Option[ByteVector] = None) val chainHash1 = ByteVector32(hex"0101010101010101010101010101010101010101010101010101010101010101") val chainHash2 = ByteVector32(hex"0202020202020202020202020202020202020202020202020202020202020202") + val remoteAddress1 = IPv4(InetAddress.getByAddress(Array[Byte](140.toByte, 82.toByte, 121.toByte, 3.toByte)).asInstanceOf[Inet4Address], 9735) + val remoteAddress2 = IPv6(InetAddress.getByAddress(hex"b643 8bb1 c1f9 0556 487c 0acb 2ba3 3cc2".toArray).asInstanceOf[Inet6Address], 9736) val testCases = Seq( - TestCase(hex"0000 0000", hex"", Nil, valid = true), // no features - TestCase(hex"0000 0002088a", hex"088a", Nil, valid = true), // no global features - TestCase(hex"00020200 0000", hex"0200", Nil, valid = true, Some(hex"0000 00020200")), // no local features - TestCase(hex"00020200 0002088a", hex"0a8a", Nil, valid = true, Some(hex"0000 00020a8a")), // local and global - no conflict - same size - TestCase(hex"00020200 0003020002", hex"020202", Nil, valid = true, Some(hex"0000 0003020202")), // local and global - no conflict - different sizes - TestCase(hex"00020a02 0002088a", hex"0a8a", Nil, valid = true, Some(hex"0000 00020a8a")), // local and global - conflict - same size - TestCase(hex"00022200 000302aaa2", hex"02aaa2", Nil, valid = true, Some(hex"0000 000302aaa2")), // local and global - conflict - different sizes - TestCase(hex"0000 0002088a 03012a05022aa2", hex"088a", Nil, valid = true), // unknown odd records - TestCase(hex"0000 0002088a 03012a04022aa2", hex"088a", Nil, valid = false), // unknown even records - TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101", hex"088a", Nil, valid = false), // invalid tlv stream - TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101", hex"088a", List(chainHash1), valid = true), // single network - TestCase(hex"0000 0002088a 014001010101010101010101010101010101010101010101010101010101010101010202020202020202020202020202020202020202020202020202020202020202", hex"088a", List(chainHash1, chainHash2), valid = true), // multiple networks - TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010103012a", hex"088a", List(chainHash1), valid = true), // network and unknown odd records - TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010102012a", hex"088a", Nil, valid = false) // network and unknown even records + TestCase(hex"0000 0000", hex"", Nil, None, valid = true), // no features + TestCase(hex"0000 0002088a", hex"088a", Nil, None, valid = true), // no global features + TestCase(hex"00020200 0000", hex"0200", Nil, None, valid = true, Some(hex"0000 00020200")), // no local features + TestCase(hex"00020200 0002088a", hex"0a8a", Nil, None, valid = true, Some(hex"0000 00020a8a")), // local and global - no conflict - same size + TestCase(hex"00020200 0003020002", hex"020202", Nil, None, valid = true, Some(hex"0000 0003020202")), // local and global - no conflict - different sizes + TestCase(hex"00020a02 0002088a", hex"0a8a", Nil, None, valid = true, Some(hex"0000 00020a8a")), // local and global - conflict - same size + TestCase(hex"00022200 000302aaa2", hex"02aaa2", Nil, None, valid = true, Some(hex"0000 000302aaa2")), // local and global - conflict - different sizes + TestCase(hex"0000 0002088a 03012a05022aa2", hex"088a", Nil, None, valid = true), // unknown odd records + TestCase(hex"0000 0002088a 03012a04022aa2", hex"088a", Nil, None, valid = false), // unknown even records + TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101", hex"088a", Nil, None, valid = false), // invalid tlv stream + TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101", hex"088a", List(chainHash1), None, valid = true), // single network + TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101 0307018c5279032607", hex"088a", List(chainHash1), Some(remoteAddress1), valid = true), // single network and IPv4 address + TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101 031302b6438bb1c1f90556487c0acb2ba33cc22608", hex"088a", List(chainHash1), Some(remoteAddress2), valid = true), // single network and IPv6 address + TestCase(hex"0000 0002088a 014001010101010101010101010101010101010101010101010101010101010101010202020202020202020202020202020202020202020202020202020202020202", hex"088a", List(chainHash1, chainHash2), None, valid = true), // multiple networks + TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101 c9012a", hex"088a", List(chainHash1), None, valid = true), // network and unknown odd records + TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101 02012a", hex"088a", Nil, None, valid = false) // network and unknown even records ) for (testCase <- testCases) { @@ -74,6 +78,7 @@ class LightningMessageCodecsSpec extends AnyFunSuite { val init = initCodec.decode(testCase.encoded.bits).require.value assert(init.features.toByteVector === testCase.rawFeatures) assert(init.networks === testCase.networks) + assert(init.remoteAddress_opt === testCase.address) val encoded = initCodec.encode(init).require assert(encoded.bytes === testCase.reEncoded.getOrElse(testCase.encoded)) assert(initCodec.decode(encoded).require.value === init)