diff --git a/ktor-client/ktor-client-cio/jvmAndNix/src/io/ktor/client/engine/cio/CIOEngine.kt b/ktor-client/ktor-client-cio/jvmAndNix/src/io/ktor/client/engine/cio/CIOEngine.kt index a2aa49ea077..0f4ba1c93d1 100644 --- a/ktor-client/ktor-client-cio/jvmAndNix/src/io/ktor/client/engine/cio/CIOEngine.kt +++ b/ktor-client/ktor-client-cio/jvmAndNix/src/io/ktor/client/engine/cio/CIOEngine.kt @@ -33,7 +33,11 @@ internal class CIOEngine( private val selectorManager: SelectorManager by lazy { SelectorManager(dispatcher) } - private val connectionFactory = ConnectionFactory(selectorManager, config.maxConnectionsCount) + private val connectionFactory = ConnectionFactory( + selectorManager, + config.maxConnectionsCount, + config.endpoint.maxConnectionsPerRoute + ) private val requestsJob: CoroutineContext @@ -42,6 +46,7 @@ internal class CIOEngine( private val proxy: ProxyConfig? = when (val type = config.proxy?.type) { ProxyType.SOCKS, null -> null + ProxyType.HTTP -> config.proxy else -> throw IllegalStateException("CIO engine does not currently support $type proxies.") } diff --git a/ktor-client/ktor-client-cio/jvmAndNix/src/io/ktor/client/engine/cio/ConnectionFactory.kt b/ktor-client/ktor-client-cio/jvmAndNix/src/io/ktor/client/engine/cio/ConnectionFactory.kt index 3b392745349..8b114df285a 100644 --- a/ktor-client/ktor-client-cio/jvmAndNix/src/io/ktor/client/engine/cio/ConnectionFactory.kt +++ b/ktor-client/ktor-client-cio/jvmAndNix/src/io/ktor/client/engine/cio/ConnectionFactory.kt @@ -6,29 +6,35 @@ package io.ktor.client.engine.cio import io.ktor.network.selector.* import io.ktor.network.sockets.* +import io.ktor.util.collections.* import kotlinx.coroutines.sync.* internal class ConnectionFactory( private val selector: SelectorManager, - maxConnectionsCount: Int + connectionsLimit: Int, + private val addressConnectionsLimit: Int ) { - private val semaphore = Semaphore(maxConnectionsCount) + private val limit = Semaphore(connectionsLimit) + private val addressLimit = ConcurrentMap() suspend fun connect( address: InetSocketAddress, configuration: SocketOptions.TCPClientSocketOptions.() -> Unit = {} ): Socket { - semaphore.acquire() + limit.acquire() + addressLimit.computeIfAbsent(address) { Semaphore(addressConnectionsLimit) }.acquire() + return try { aSocket(selector).tcpNoDelay().tcp().connect(address, configuration) } catch (cause: Throwable) { // a failure or cancellation - semaphore.release() + limit.release() throw cause } } - fun release() { - semaphore.release() + fun release(address: InetSocketAddress) { + addressLimit[address]!!.release() + limit.release() } } diff --git a/ktor-client/ktor-client-cio/jvmAndNix/src/io/ktor/client/engine/cio/Endpoint.kt b/ktor-client/ktor-client-cio/jvmAndNix/src/io/ktor/client/engine/cio/Endpoint.kt index c6e95c800f9..22cbe2401ee 100644 --- a/ktor-client/ktor-client-cio/jvmAndNix/src/io/ktor/client/engine/cio/Endpoint.kt +++ b/ktor-client/ktor-client-cio/jvmAndNix/src/io/ktor/client/engine/cio/Endpoint.kt @@ -188,7 +188,7 @@ internal class Endpoint( } catch (_: Throwable) { } - connectionFactory.release() + connectionFactory.release(address) throw cause } } @@ -229,7 +229,8 @@ internal class Endpoint( } private fun releaseConnection() { - connectionFactory.release() + val address = InetSocketAddress(host, port) + connectionFactory.release(address) connections.decrementAndGet() } diff --git a/ktor-client/ktor-client-mock/common/src/io/ktor/client/engine/mock/MockUtils.kt b/ktor-client/ktor-client-mock/common/src/io/ktor/client/engine/mock/MockUtils.kt index c96501da33e..c7efda3aead 100644 --- a/ktor-client/ktor-client-mock/common/src/io/ktor/client/engine/mock/MockUtils.kt +++ b/ktor-client/ktor-client-mock/common/src/io/ktor/client/engine/mock/MockUtils.kt @@ -31,8 +31,8 @@ public suspend fun OutgoingContent.toByteArray(): ByteArray = when (this) { else -> ByteArray(0) } -@OptIn(DelicateCoroutinesApi::class) @Suppress("KDocMissingDocumentation") +@OptIn(DelicateCoroutinesApi::class) public suspend fun OutgoingContent.toByteReadPacket(): ByteReadPacket = when (this) { is OutgoingContent.ByteArrayContent -> ByteReadPacket(bytes()) is OutgoingContent.ReadChannelContent -> readFrom().readRemaining()