Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-5199 Support WebSockets in Curl engine #3950

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions ktor-client/ktor-client-curl/build.gradle.kts
Expand Up @@ -50,6 +50,7 @@ kotlin {
dependencies {
api(project(":ktor-client:ktor-client-plugins:ktor-client-logging"))
api(project(":ktor-client:ktor-client-plugins:ktor-client-json"))
implementation(libs.kotlinx.serialization.json)
}
}
}
Expand Down
@@ -0,0 +1,84 @@
#ifndef CURLINC_WEBSOCKETS_H
#define CURLINC_WEBSOCKETS_H
/***************************************************************************
* _ _ ____ _
* Project ___| | | | _ \| |
* / __| | | | |_) | |
* | (__| |_| | _ <| |___
* \___|\___/|_| \_\_____|
*
* Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
*
* This software is licensed as described in the file COPYING, which
* you should have received as part of this distribution. The terms
* are also available at https://curl.se/docs/copyright.html.
*
* You may opt to use, copy, modify, merge, publish, distribute and/or sell
* copies of the Software, and permit persons to whom the Software is
* furnished to do so, under the terms of the COPYING file.
*
* This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
* KIND, either express or implied.
*
* SPDX-License-Identifier: curl
*
***************************************************************************/

#ifdef __cplusplus
extern "C" {
#endif

struct curl_ws_frame {
int age; /* zero */
int flags; /* See the CURLWS_* defines */
curl_off_t offset; /* the offset of this data into the frame */
curl_off_t bytesleft; /* number of pending bytes left of the payload */
size_t len; /* size of the current data chunk */
};

/* flag bits */
#define CURLWS_TEXT (1<<0)
#define CURLWS_BINARY (1<<1)
#define CURLWS_CONT (1<<2)
#define CURLWS_CLOSE (1<<3)
#define CURLWS_PING (1<<4)
#define CURLWS_OFFSET (1<<5)

/*
* NAME curl_ws_recv()
*
* DESCRIPTION
*
* Receives data from the websocket connection. Use after successful
* curl_easy_perform() with CURLOPT_CONNECT_ONLY option.
*/
CURL_EXTERN CURLcode curl_ws_recv(CURL *curl, void *buffer, size_t buflen,
size_t *recv,
const struct curl_ws_frame **metap);

/* flags for curl_ws_send() */
#define CURLWS_PONG (1<<6)

/*
* NAME curl_ws_send()
*
* DESCRIPTION
*
* Sends data over the websocket connection. Use after successful
* curl_easy_perform() with CURLOPT_CONNECT_ONLY option.
*/
CURL_EXTERN CURLcode curl_ws_send(CURL *curl, const void *buffer,
size_t buflen, size_t *sent,
curl_off_t fragsize,
unsigned int flags);

/* bits for the CURLOPT_WS_OPTIONS bitmask: */
#define CURLWS_RAW_MODE (1<<0)

CURL_EXTERN const struct curl_ws_frame *curl_ws_meta(CURL *curl);

#ifdef __cplusplus
}
#endif

#endif /* CURLINC_WEBSOCKETS_H */
Expand Up @@ -8,10 +8,10 @@ import io.ktor.client.engine.*
import io.ktor.client.engine.curl.internal.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.sse.*
import io.ktor.client.plugins.websocket.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.http.cio.*
import io.ktor.util.*
import io.ktor.util.date.*
import io.ktor.utils.io.*
import kotlinx.coroutines.*
Expand All @@ -21,7 +21,7 @@ internal class CurlClientEngine(
) : HttpClientEngineBase("ktor-curl") {
override val dispatcher = Dispatchers.Unconfined

override val supportedCapabilities = setOf(HttpTimeoutCapability, SSECapability)
override val supportedCapabilities = setOf(HttpTimeoutCapability, WebSocketCapability, SSECapability)

private val curlProcessor = CurlProcessor(coroutineContext)

Expand All @@ -46,9 +46,15 @@ internal class CurlClientEngine(
rawHeaders.release()

val responseBody: Any = if (needToProcessSSE(data, status, headers)) {
DefaultClientSSESession(data.body as SSEClientContent, bodyChannel, callContext)
val content = data.body as SSEClientContent
val body = responseBody as CurlHttpResponseBody
DefaultClientSSESession(content, body.bodyChannel, callContext)
} else if (data.isUpgradeRequest()) {
val websocket = responseBody as CurlWebSocketResponseBody
CurlWebSocketSession(websocket, callContext)
bjhham marked this conversation as resolved.
Show resolved Hide resolved
} else {
bodyChannel
val body = responseBody as CurlHttpResponseBody
body.bodyChannel
}

HttpResponseData(
Expand Down
Expand Up @@ -18,6 +18,12 @@ internal class RequestContainer(
val completionHandler: CompletableDeferred<CurlSuccess>
)

/**
* A class responsible for processing requests asynchronously.
*
* It holds a dispatcher interacting with curl multi interface API,
* which requires API calls from single thread.
*/
internal class CurlProcessor(coroutineContext: CoroutineContext) {
@OptIn(InternalAPI::class)
private val curlDispatcher: CloseableCoroutineDispatcher =
Expand All @@ -28,6 +34,8 @@ internal class CurlProcessor(coroutineContext: CoroutineContext) {

private val curlScope = CoroutineScope(coroutineContext + curlDispatcher)
private val requestQueue: Channel<RequestContainer> = Channel(Channel.UNLIMITED)
private val requestCounter = atomic(0L)
private val curlProtocols by lazy { getCurlProtocols() }

init {
val init = curlScope.launch {
Expand All @@ -42,9 +50,14 @@ internal class CurlProcessor(coroutineContext: CoroutineContext) {
}

suspend fun executeRequest(request: CurlRequestData): CurlSuccess {
if (request.isUpgradeRequest && !curlProtocols.contains(request.protocol)) {
error("WebSockets are supported in experimental libcurl 7.86 and greater")
}

val result = CompletableDeferred<CurlSuccess>()
requestQueue.send(RequestContainer(request, result))
curlApi!!.wakeup()
nextRequest {
requestQueue.send(RequestContainer(request, result))
}
return result.await()
}

Expand All @@ -54,7 +67,7 @@ internal class CurlProcessor(coroutineContext: CoroutineContext) {
val api = curlApi!!
while (!requestQueue.isClosedForReceive) {
drainRequestQueue(api)
api.perform()
api.perform(requestCounter)
}
}
}
Expand Down Expand Up @@ -86,6 +99,8 @@ internal class CurlProcessor(coroutineContext: CoroutineContext) {
if (!closed.compareAndSet(false, true)) return

requestQueue.close()
nextRequest()

GlobalScope.launch(curlDispatcher) {
curlScope.coroutineContext[Job]!!.join()
curlApi!!.close()
Expand All @@ -100,4 +115,10 @@ internal class CurlProcessor(coroutineContext: CoroutineContext) {
curlApi!!.cancelRequest(easyHandle, cause)
}
}

private inline fun nextRequest(body: (Long) -> Unit = {}) = try {
body(requestCounter.incrementAndGet())
} finally {
curlApi!!.wakeup()
}
}
Expand Up @@ -20,6 +20,17 @@ internal typealias EasyHandle = COpaquePointer
@OptIn(ExperimentalForeignApi::class)
internal typealias MultiHandle = COpaquePointer

/**
* Curl manages websocket headers internally:
* @see <a href="https://github.com/curl/curl/blob/f0986c6e18417865f49e725201a5224d9b5af849/lib/ws.c#L684">List of headers</a>
*/
internal val DISALLOWED_WEBSOCKET_HEADERS = setOf(
HttpHeaders.Upgrade,
HttpHeaders.Connection,
HttpHeaders.SecWebSocketVersion,
HttpHeaders.SecWebSocketKey
)

@OptIn(ExperimentalForeignApi::class)
internal fun CURLMcode.verify() {
if (this != CURLM_OK) {
Expand Down Expand Up @@ -71,6 +82,7 @@ internal fun HttpRequestData.headersToCurl(): CPointer<curl_slist> {
var result: CPointer<curl_slist>? = null

mergeHeaders(headers, body) { key, value ->
if (isUpgradeRequest() && DISALLOWED_WEBSOCKET_HEADERS.contains(key)) return@mergeHeaders
val header = "$key: $value"
result = curl_slist_append(result, header)
}
Expand All @@ -87,3 +99,16 @@ internal fun UInt.fromCurl(): HttpProtocolVersion = when (this) {
/* old curl fallback */
else -> HttpProtocolVersion.HTTP_1_1
}

/**
* Retrieves the supported protocols for the current version of cURL.
*
* @return The list of supported protocols as strings, e.g. [ftp, http, ws]
*/
@OptIn(ExperimentalForeignApi::class)
internal fun getCurlProtocols(): List<String> {
bjhham marked this conversation as resolved.
Show resolved Hide resolved
val currentVersion = CURLversion.values().first { it.value == CURLVERSION_NOW.toUInt() }
val versionInfoPtr = curl_version_info(currentVersion)
val versionInfo = versionInfoPtr!!.reinterpret<curl_version_info_data>().pointed
return versionInfo.protocols?.toKStringList().orEmpty()
}
Expand Up @@ -6,26 +6,43 @@ package io.ktor.client.engine.curl.internal

import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import kotlinx.atomicfu.*
import kotlinx.cinterop.*
import kotlinx.coroutines.*
import libcurl.*
import platform.posix.*
import kotlin.coroutines.*

/**
* The callback is getting called on each completely parser header line.
*/
@OptIn(ExperimentalForeignApi::class)
internal fun onHeadersReceived(
buffer: CPointer<ByteVar>,
size: size_t,
count: size_t,
userdata: COpaquePointer
): Long {
val packet = userdata.fromCPointer<CurlResponseBuilder>().headersBytes
val response = userdata.fromCPointer<CurlResponseBuilder>()
val packet = response.headersBytes
val chunkSize = (size * count).toLong()
packet.writeFully(buffer, 0, chunkSize)

if (isFinalHeaderLine(chunkSize, buffer) && !response.bodyStartedReceiving.isCompleted) {
response.bodyStartedReceiving.complete(Unit)
}

return chunkSize
}

/**
* Checks if the given header represents the final header line (CR LF).
*
* @see <a href="https://curl.se/libcurl/c/CURLOPT_HEADERFUNCTION.html">Description.</a>
*/
@OptIn(ExperimentalForeignApi::class)
private fun isFinalHeaderLine(chunkSize: Long, buffer: CPointer<ByteVar>) =
chunkSize == 2L && buffer[0] == 0x0D.toByte() && buffer[1] == 0x0A.toByte()

@OptIn(ExperimentalForeignApi::class)
internal fun onBodyChunkReceived(
buffer: CPointer<ByteVar>,
Expand All @@ -34,44 +51,7 @@ internal fun onBodyChunkReceived(
userdata: COpaquePointer
): Int {
val wrapper = userdata.fromCPointer<CurlResponseBodyData>()
if (!wrapper.bodyStartedReceiving.isCompleted) {
wrapper.bodyStartedReceiving.complete(Unit)
}

val body = wrapper.body
if (body.isClosedForWrite) {
return if (body.closedCause != null) -1 else 0
}

val chunkSize = (size * count).toInt()

// TODO: delete `runBlocking` with fix of https://youtrack.jetbrains.com/issue/KTOR-6030/Migrate-to-new-kotlinx.io-library
val written = try {
runBlocking {
body.writeFully(buffer, 0, chunkSize)
}
chunkSize
} catch (cause: Throwable) {
return -1
}
if (written > 0) {
wrapper.bytesWritten += written
}
if (wrapper.bytesWritten.value == chunkSize) {
wrapper.bytesWritten.value = 0
return chunkSize
}

CoroutineScope(wrapper.callContext).launch {
try {
body.awaitFreeSpace()
} catch (_: Throwable) {
// no op, error will be handled on next write on cURL thread
} finally {
wrapper.onUnpause()
}
}
return CURL_WRITEFUNC_PAUSE
return wrapper.onBodyChunkReceived(buffer, size, count)
}

@OptIn(ExperimentalForeignApi::class)
Expand Down Expand Up @@ -118,11 +98,8 @@ internal class CurlRequestBodyData(
val onUnpause: () -> Unit
)

internal class CurlResponseBodyData(
val bodyStartedReceiving: CompletableDeferred<Unit>,
val body: ByteWriteChannel,
val callContext: CoroutineContext,
val onUnpause: () -> Unit
) {
internal val bytesWritten = atomic(0)
internal interface CurlResponseBodyData {
@OptIn(ExperimentalForeignApi::class)
fun onBodyChunkReceived(buffer: CPointer<ByteVar>, size: size_t, count: size_t): Int
fun close(cause: Throwable? = null)
}