Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-2036 Fix CIO connection limit #3140

Merged
merged 3 commits into from Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -5,6 +5,7 @@
package io.ktor.client.engine.cio

import io.ktor.client.call.*
import io.ktor.client.network.sockets.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
Expand Down Expand Up @@ -147,13 +148,19 @@ class CIORequestTest : TestWithKtor() {
}

test { client ->
var fail: Throwable? = null
for (i in 0..1000) {
try {
client.get("http://something.wrong").body<String>()
} catch (cause: UnresolvedAddressException) {
// ignore
} catch (cause: Throwable) {
fail = cause
}
}

assertNotNull(fail)
if (fail !is ConnectTimeoutException && fail !is UnresolvedAddressException) {
fail("Expected ConnectTimeoutException or UnresolvedAddressException, got $fail", fail)
}
}
}
}
Expand Up @@ -33,7 +33,11 @@ internal class CIOEngine(

private val selectorManager: SelectorManager by lazy { SelectorManager(dispatcher) }

private val connectionFactory = ConnectionFactory(selectorManager, config.maxConnectionsCount)
private val connectionFactory = ConnectionFactory(
selectorManager,
config.maxConnectionsCount,
config.endpoint.maxConnectionsPerRoute
)

private val requestsJob: CoroutineContext

Expand All @@ -42,6 +46,7 @@ internal class CIOEngine(
private val proxy: ProxyConfig? = when (val type = config.proxy?.type) {
ProxyType.SOCKS,
null -> null

ProxyType.HTTP -> config.proxy
else -> throw IllegalStateException("CIO engine does not currently support $type proxies.")
}
Expand Down
Expand Up @@ -6,29 +6,37 @@ package io.ktor.client.engine.cio

import io.ktor.network.selector.*
import io.ktor.network.sockets.*
import io.ktor.util.collections.*
import kotlinx.coroutines.sync.*

internal class ConnectionFactory(
private val selector: SelectorManager,
maxConnectionsCount: Int
connectionsLimit: Int,
private val addressConnectionsLimit: Int
) {
private val semaphore = Semaphore(maxConnectionsCount)
private val limit = Semaphore(connectionsLimit)
private val addressLimit = ConcurrentMap<InetSocketAddress, Semaphore>()

suspend fun connect(
address: InetSocketAddress,
configuration: SocketOptions.TCPClientSocketOptions.() -> Unit = {}
): Socket {
semaphore.acquire()
limit.acquire()
val addressSemaphore = addressLimit.computeIfAbsent(address) { Semaphore(addressConnectionsLimit) }
addressSemaphore.acquire()

return try {
aSocket(selector).tcpNoDelay().tcp().connect(address, configuration)
} catch (cause: Throwable) {
// a failure or cancellation
semaphore.release()
addressSemaphore.release()
limit.release()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you release addressLimit here too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, thanks

throw cause
}
}

fun release() {
semaphore.release()
fun release(address: InetSocketAddress) {
addressLimit[address]!!.release()
limit.release()
}
}
Expand Up @@ -188,7 +188,7 @@ internal class Endpoint(
} catch (_: Throwable) {
}

connectionFactory.release()
connectionFactory.release(address)
throw cause
}
}
Expand Down Expand Up @@ -229,7 +229,8 @@ internal class Endpoint(
}

private fun releaseConnection() {
connectionFactory.release()
val address = InetSocketAddress(host, port)
connectionFactory.release(address)
connections.decrementAndGet()
}

Expand Down
Expand Up @@ -31,8 +31,8 @@ public suspend fun OutgoingContent.toByteArray(): ByteArray = when (this) {
else -> ByteArray(0)
}

@OptIn(DelicateCoroutinesApi::class)
@Suppress("KDocMissingDocumentation")
@OptIn(DelicateCoroutinesApi::class)
public suspend fun OutgoingContent.toByteReadPacket(): ByteReadPacket = when (this) {
is OutgoingContent.ByteArrayContent -> ByteReadPacket(bytes())
is OutgoingContent.ReadChannelContent -> readFrom().readRemaining()
Expand Down