Skip to content

Commit

Permalink
KTOR-4323: WebSockets use custom serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
hfhbd committed May 15, 2022
1 parent ebaf9f1 commit f9c64e2
Show file tree
Hide file tree
Showing 15 changed files with 156 additions and 44 deletions.
9 changes: 4 additions & 5 deletions ktor-client/ktor-client-core/api/ktor-client-core.api
Expand Up @@ -745,15 +745,13 @@ public final class io/ktor/client/plugins/websocket/BuildersKt {
public static synthetic fun wss$default (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
}

public final class io/ktor/client/plugins/websocket/ClientSessionsKt {
public static final fun getConverter (Lio/ktor/client/plugins/websocket/DefaultClientWebSocketSession;)Lio/ktor/serialization/WebsocketContentConverter;
}

public abstract interface class io/ktor/client/plugins/websocket/ClientWebSocketSession : io/ktor/websocket/WebSocketSession {
public abstract interface class io/ktor/client/plugins/websocket/ClientWebSocketSession : io/ktor/serialization/SerializableWebSocketSession {
public abstract fun getCall ()Lio/ktor/client/call/HttpClientCall;
public abstract fun getConverter ()Lio/ktor/serialization/WebsocketContentConverter;
}

public final class io/ktor/client/plugins/websocket/ClientWebSocketSession$DefaultImpls {
public static fun getConverter (Lio/ktor/client/plugins/websocket/ClientWebSocketSession;)Lio/ktor/serialization/WebsocketContentConverter;
public static fun send (Lio/ktor/client/plugins/websocket/ClientWebSocketSession;Lio/ktor/websocket/Frame;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

Expand All @@ -762,6 +760,7 @@ public final class io/ktor/client/plugins/websocket/DefaultClientWebSocketSessio
public fun flush (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun getCall ()Lio/ktor/client/call/HttpClientCall;
public fun getCloseReason ()Lkotlinx/coroutines/Deferred;
public fun getConverter ()Lio/ktor/serialization/WebsocketContentConverter;
public fun getCoroutineContext ()Lkotlin/coroutines/CoroutineContext;
public fun getExtensions ()Ljava/util/List;
public fun getIncoming ()Lkotlinx/coroutines/channels/ReceiveChannel;
Expand Down
Expand Up @@ -15,11 +15,17 @@ import io.ktor.websocket.serialization.*
/**
* Client specific [WebSocketSession].
*/
public interface ClientWebSocketSession : WebSocketSession {
public interface ClientWebSocketSession : SerializableWebSocketSession {
/**
* [HttpClientCall] associated with session.
*/
public val call: HttpClientCall

/**
* Converter for web socket session, if plugin [WebSockets] is installed
*/
public override val converter: WebsocketContentConverter?
get() = call.client.pluginOrNull(WebSockets)?.contentConverter
}

/**
Expand All @@ -35,12 +41,6 @@ internal class DelegatingClientWebSocketSession(
session: WebSocketSession
) : ClientWebSocketSession, WebSocketSession by session

/**
* Converter for web socket session
*/
public val DefaultClientWebSocketSession.converter: WebsocketContentConverter?
get() = call.client.pluginOrNull(WebSockets)?.contentConverter

/**
* Serializes [data] to a frame and enqueues this frame.
* May suspend if the outgoing queue is full.
Expand Down
2 changes: 1 addition & 1 deletion ktor-client/ktor-client-tests/build.gradle.kts
Expand Up @@ -9,7 +9,7 @@ import java.net.*
description = "Common tests for client"

plugins {
id("kotlinx-serialization")
kotlin("plugin.serialization")
}

open class KtorTestServer : DefaultTask() {
Expand Down
Expand Up @@ -10,10 +10,14 @@ import io.ktor.client.plugins.websocket.*
import io.ktor.client.tests.utils.*
import io.ktor.http.*
import io.ktor.serialization.*
import io.ktor.serialization.kotlinx.*
import io.ktor.test.dispatcher.*
import io.ktor.util.reflect.*
import io.ktor.utils.io.charsets.*
import io.ktor.websocket.*
import kotlinx.serialization.*
import kotlinx.serialization.Serializer
import kotlinx.serialization.json.*
import kotlin.test.*

private const val TEST_SIZE: Int = 100
Expand Down Expand Up @@ -142,6 +146,29 @@ class WebSocketTest : ClientLoader(100000) {
}
}

@Serializer(forClass = Data::class)
object DataSerializer

@Test
fun testWebSocketSerializationWithCustomSerializer() = clientTests(listOf("Android", "Apache", "Curl")) {
config {
WebSockets {
contentConverter = KotlinxWebsocketSerializationConverter(Json)
}
}

test { client ->
client.webSocket("$TEST_WEBSOCKET_SERVER/websockets/echo") {
repeat(TEST_SIZE) {
val originalData = Data("hello")
sendSerialized(originalData, DataSerializer)
val actual = receiveDeserialized(DataSerializer)
assertEquals(originalData, actual)
}
}
}
}

@Test
fun testSerializationWithNoConverter() = clientTests(listOf("Android", "Apache", "Curl")) {
config {
Expand Down
Expand Up @@ -2,6 +2,7 @@ public abstract interface class io/ktor/server/websocket/DefaultWebSocketServerS
}

public final class io/ktor/server/websocket/DefaultWebSocketServerSession$DefaultImpls {
public static fun getConverter (Lio/ktor/server/websocket/DefaultWebSocketServerSession;)Lio/ktor/serialization/WebsocketContentConverter;
public static fun send (Lio/ktor/server/websocket/DefaultWebSocketServerSession;Lio/ktor/websocket/Frame;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

Expand Down Expand Up @@ -42,17 +43,18 @@ public final class io/ktor/server/websocket/RoutingKt {
public static synthetic fun webSocketRaw$default (Lio/ktor/server/routing/Route;Ljava/lang/String;ZLkotlin/jvm/functions/Function2;ILjava/lang/Object;)V
}

public abstract interface class io/ktor/server/websocket/WebSocketServerSession : io/ktor/websocket/WebSocketSession {
public abstract interface class io/ktor/server/websocket/WebSocketServerSession : io/ktor/serialization/SerializableWebSocketSession {
public abstract fun getCall ()Lio/ktor/server/application/ApplicationCall;
public abstract fun getConverter ()Lio/ktor/serialization/WebsocketContentConverter;
}

public final class io/ktor/server/websocket/WebSocketServerSession$DefaultImpls {
public static fun getConverter (Lio/ktor/server/websocket/WebSocketServerSession;)Lio/ktor/serialization/WebsocketContentConverter;
public static fun send (Lio/ktor/server/websocket/WebSocketServerSession;Lio/ktor/websocket/Frame;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class io/ktor/server/websocket/WebSocketServerSessionKt {
public static final fun getApplication (Lio/ktor/server/websocket/WebSocketServerSession;)Lio/ktor/server/application/Application;
public static final fun getConverter (Lio/ktor/server/websocket/WebSocketServerSession;)Lio/ktor/serialization/WebsocketContentConverter;
}

public final class io/ktor/server/websocket/WebSocketUpgrade : io/ktor/http/content/OutgoingContent$ProtocolUpgrade {
Expand Down
@@ -1,5 +1,9 @@
description = ""

plugins {
kotlin("plugin.serialization")
}

kotlin.sourceSets {
jvmAndNixMain {
dependencies {
Expand All @@ -11,6 +15,7 @@ kotlin.sourceSets {
jvmAndNixTest {
dependencies {
api(project(":ktor-server:ktor-server-plugins:ktor-server-content-negotiation"))
api(project(":ktor-shared:ktor-serialization:ktor-serialization-kotlinx:ktor-serialization-kotlinx-json"))
api(project(":ktor-server:ktor-server-cio"))
}
}
Expand Down
Expand Up @@ -14,11 +14,17 @@ import io.ktor.websocket.serialization.*
/**
* Represents a server-side web socket session
*/
public interface WebSocketServerSession : WebSocketSession {
public interface WebSocketServerSession : SerializableWebSocketSession {
/**
* Associated received [call] that originating this session
*/
public val call: ApplicationCall

/**
* Converter for web socket session, if plugin [WebSockets] is installed
*/
public override val converter: WebsocketContentConverter?
get() = application.plugin(WebSockets).contentConverter
}

/**
Expand All @@ -33,12 +39,6 @@ public interface DefaultWebSocketServerSession : DefaultWebSocketSession, WebSoc
*/
public val WebSocketServerSession.application: Application get() = call.application

/**
* Converter for web socket session
*/
public val WebSocketServerSession.converter: WebsocketContentConverter?
get() = application.plugin(WebSockets).contentConverter

/**
* Serializes [data] to a frame and enqueues this frame.
* May suspend if the outgoing queue is full.
Expand Down
Expand Up @@ -4,11 +4,52 @@

package io.ktor.tests.websocket

import io.ktor.serialization.kotlinx.*
import io.ktor.server.application.*
import io.ktor.server.cio.*
import io.ktor.server.routing.*
import io.ktor.server.websocket.*
import io.ktor.util.*
import io.ktor.utils.io.*
import io.ktor.websocket.*
import kotlinx.serialization.Serializer
import kotlinx.serialization.json.*
import kotlin.test.*

@InternalAPI
class CIOWebSocketTest : WebSocketEngineSuite<CIOApplicationEngine, CIOApplicationEngine.Configuration>(CIO) {
init {
enableSsl = false
enableHttp2 = false
}

override fun plugins(application: Application, routingConfigurer: Routing.() -> Unit) {
application.install(WebSockets) {
contentConverter = KotlinxWebsocketSerializationConverter(Json)
}
super.plugins(application, routingConfigurer)
}

data class Data(val s: Int)

@Serializer(forClass = Data::class)
object DataSerializer

@Test
fun testWebSocketCustomSerializer() = runTest {
createAndStartServer {
webSocket("/") {
assertEquals(Data(42), receiveDeserialized(DataSerializer))
println("CALLED")
}
}

useSocket {
negotiateHttpWebSocket()
output.writeFrame(Frame.Text(Json.encodeToString(DataSerializer, Data(42))), masking = false)
output.writeFrame(Frame.Close(), false)
output.flush()
assertCloseFrame()
}
}
}
Expand Up @@ -34,7 +34,7 @@ abstract class WebSocketEngineSuite<TEngine : ApplicationEngine, TConfiguration
override val timeout = 30.seconds

override fun plugins(application: Application, routingConfigurer: Routing.() -> Unit) {
application.install(WebSockets)
application.pluginOrNull(WebSockets) ?: application.install(WebSockets)
super.plugins(application, routingConfigurer)
}

Expand Down Expand Up @@ -580,7 +580,7 @@ abstract class WebSocketEngineSuite<TEngine : ApplicationEngine, TConfiguration
}
}

private suspend fun Connection.negotiateHttpWebSocket() {
internal suspend fun Connection.negotiateHttpWebSocket() {
// send upgrade request
output.apply {
writeFully(
Expand Down Expand Up @@ -610,7 +610,7 @@ abstract class WebSocketEngineSuite<TEngine : ApplicationEngine, TConfiguration
assertEquals("websocket", headers[HttpHeaders.Upgrade])
}

private suspend fun Connection.assertCloseFrame(
internal suspend fun Connection.assertCloseFrame(
closeCode: Short = CloseReason.Codes.NORMAL.code,
replyCloseFrame: Boolean = true
) {
Expand Down Expand Up @@ -667,7 +667,7 @@ abstract class WebSocketEngineSuite<TEngine : ApplicationEngine, TConfiguration
}
}

private suspend inline fun useSocket(block: Connection.() -> Unit) {
internal suspend inline fun useSocket(block: Connection.() -> Unit) {
SelectorManager().use {
aSocket(it).tcp().connect("localhost", port) {
noDelay = true
Expand Down
8 changes: 8 additions & 0 deletions ktor-shared/ktor-serialization/api/ktor-serialization.api
Expand Up @@ -28,6 +28,14 @@ public final class io/ktor/serialization/JsonConvertException : io/ktor/serializ
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/Throwable;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
}

public abstract interface class io/ktor/serialization/SerializableWebSocketSession : io/ktor/websocket/WebSocketSession {
public abstract fun getConverter ()Lio/ktor/serialization/WebsocketContentConverter;
}

public final class io/ktor/serialization/SerializableWebSocketSession$DefaultImpls {
public static fun send (Lio/ktor/serialization/SerializableWebSocketSession;Lio/ktor/websocket/Frame;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public class io/ktor/serialization/WebsocketContentConvertException : io/ktor/serialization/ContentConvertException {
public fun <init> (Ljava/lang/String;Ljava/lang/Throwable;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/Throwable;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
Expand Down
@@ -0,0 +1,14 @@
/*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.serialization

import io.ktor.websocket.*

public interface SerializableWebSocketSession: WebSocketSession {
/**
* Converter for web socket session, if plugin [WebSockets] is installed
*/
public val converter: WebsocketContentConverter?
}
Expand Up @@ -16,3 +16,8 @@ public final class io/ktor/serialization/kotlinx/KotlinxWebsocketSerializationCo
public fun serialize (Ljava/nio/charset/Charset;Lio/ktor/util/reflect/TypeInfo;Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class io/ktor/serialization/kotlinx/KotlinxWebsocketSerializationConverterKt {
public static final fun receiveDeserialized (Lio/ktor/serialization/SerializableWebSocketSession;Lkotlinx/serialization/DeserializationStrategy;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun sendSerialized (Lio/ktor/serialization/SerializableWebSocketSession;Ljava/lang/Object;Lkotlinx/serialization/SerializationStrategy;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

Expand Up @@ -38,7 +38,6 @@ internal abstract class KotlinxSerializationBase<T>(
}

internal open class SerializationParameters(
open val format: SerialFormat,
open val value: Any,
open val typeInfo: TypeInfo,
open val charset: Charset
Expand All @@ -47,9 +46,8 @@ internal open class SerializationParameters(
}

internal class SerializationNegotiationParameters(
override val format: SerialFormat,
override val value: Any,
override val typeInfo: TypeInfo,
override val charset: Charset,
val contentType: ContentType
) : SerializationParameters(format, value, typeInfo, charset)
) : SerializationParameters(value, typeInfo, charset)
Expand Up @@ -38,7 +38,6 @@ public class KotlinxSerializationConverter(
): OutgoingContent {
return serializationBase.serialize(
SerializationNegotiationParameters(
format,
value,
typeInfo,
charset,
Expand Down Expand Up @@ -71,7 +70,6 @@ public class KotlinxSerializationConverter(
}
return serializeContent(
parameters.serializer,
parameters.format,
parameters.value,
parameters.contentType,
parameters.charset
Expand All @@ -81,7 +79,6 @@ public class KotlinxSerializationConverter(

private fun serializeContent(
serializer: KSerializer<*>,
format: SerialFormat,
value: Any,
contentType: ContentType,
charset: Charset
Expand Down

0 comments on commit f9c64e2

Please sign in to comment.