From 93a06df4bee36dc43b7d906ca395b0ac0d3229f3 Mon Sep 17 00:00:00 2001 From: Vsevolod Tolstopyatov Date: Fri, 24 Jun 2022 16:36:32 +0200 Subject: [PATCH] Do not use tree-based decoding for fast-path polymorphism (#1919) Do not use tree-based decoding for fast-path polymorphism and try to optimistically read it as very first key and then silently skip Fixes #1839 --- benchmark/build.gradle | 5 +- .../json/PolymorphismOverheadBenchmark.kt | 54 ++++++++++++++++ build.gradle | 3 +- .../src/kotlinx/serialization/json/Json.kt | 2 +- .../serialization/json/internal/JsonPath.kt | 2 +- .../json/internal/Polymorphic.kt | 7 ++- .../json/internal/StreamingJsonDecoder.kt | 61 +++++++++++++++++-- .../json/internal/lexer/AbstractJsonLexer.kt | 2 + .../json/internal/lexer/StringJsonLexer.kt | 26 ++++++-- .../DefaultPolymorphicSerializerTest.kt | 35 +++++++++++ .../serialization/json/JsonTestBase.kt | 2 +- .../kotlinx/serialization/json/JvmStreams.kt | 2 +- .../json/internal/JsonIterator.kt | 4 +- .../json/internal/JsonLexerJvm.kt | 11 ++-- .../features/JsonJvmStreamsTest.kt | 45 +++++++++++++- 15 files changed, 233 insertions(+), 28 deletions(-) create mode 100644 benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/PolymorphismOverheadBenchmark.kt create mode 100644 formats/json/commonTest/src/kotlinx/serialization/features/DefaultPolymorphicSerializerTest.kt diff --git a/benchmark/build.gradle b/benchmark/build.gradle index 8e0e4927b..0935e5a35 100644 --- a/benchmark/build.gradle +++ b/benchmark/build.gradle @@ -6,13 +6,12 @@ apply plugin: 'java' apply plugin: 'kotlin' apply plugin: 'kotlinx-serialization' apply plugin: 'idea' -apply plugin: 'net.ltgt.apt' apply plugin: 'com.github.johnrengelman.shadow' -apply plugin: 'me.champeau.gradle.jmh' +apply plugin: 'me.champeau.jmh' sourceCompatibility = 1.8 targetCompatibility = 1.8 -jmh.jmhVersion = 1.22 +jmh.jmhVersion = "1.22" jmhJar { baseName 'benchmarks' diff --git a/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/PolymorphismOverheadBenchmark.kt b/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/PolymorphismOverheadBenchmark.kt new file mode 100644 index 000000000..b272bae6a --- /dev/null +++ b/benchmark/src/jmh/kotlin/kotlinx/benchmarks/json/PolymorphismOverheadBenchmark.kt @@ -0,0 +1,54 @@ +package kotlinx.benchmarks.json + +import kotlinx.serialization.* +import kotlinx.serialization.json.* +import kotlinx.serialization.modules.* +import org.openjdk.jmh.annotations.* +import java.util.concurrent.* + +@Warmup(iterations = 7, time = 1) +@Measurement(iterations = 5, time = 1) +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +@Fork(1) +open class PolymorphismOverheadBenchmark { + + @Serializable + @JsonClassDiscriminator("poly") + data class PolymorphicWrapper(val i: @Polymorphic Poly, val i2: Impl) // amortize the cost a bit + + @Serializable + data class BaseWrapper(val i: Impl, val i2: Impl) + + @JsonClassDiscriminator("poly") + interface Poly + + @Serializable + @JsonClassDiscriminator("poly") + class Impl(val a: Int, val b: String) : Poly + + private val impl = Impl(239, "average_size_string") + private val module = SerializersModule { + polymorphic(Poly::class) { + subclass(Impl.serializer()) + } + } + + private val json = Json { serializersModule = module } + private val implString = json.encodeToString(impl) + private val polyString = json.encodeToString(impl) + private val serializer = serializer() + + // 5000 + @Benchmark + fun base() = json.decodeFromString(Impl.serializer(), implString) + + // As of 1.3.x + // Baseline -- 1500 + // v1, no skip -- 2000 + // v2, with skip -- 3000 [withdrawn] + @Benchmark + fun poly() = json.decodeFromString(serializer, polyString) + +} diff --git a/build.gradle b/build.gradle index 69aa68dd1..60b7e2733 100644 --- a/build.gradle +++ b/build.gradle @@ -74,8 +74,7 @@ buildscript { // Various benchmarking stuff classpath "com.github.jengelman.gradle.plugins:shadow:4.0.2" - classpath "me.champeau.gradle:jmh-gradle-plugin:0.5.3" - classpath "net.ltgt.gradle:gradle-apt-plugin:0.21" + classpath "me.champeau.jmh:jmh-gradle-plugin:0.6.6" } } diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/Json.kt b/formats/json/commonMain/src/kotlinx/serialization/json/Json.kt index 4afe9e74c..8f1f02fd6 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/Json.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/Json.kt @@ -96,7 +96,7 @@ public sealed class Json( */ public final override fun decodeFromString(deserializer: DeserializationStrategy, string: String): T { val lexer = StringJsonLexer(string) - val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor) + val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor, null) val result = input.decodeSerializableValue(deserializer) lexer.expectEof() return result diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/JsonPath.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/JsonPath.kt index 4e055b234..14e70a425 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/JsonPath.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/JsonPath.kt @@ -24,7 +24,7 @@ internal class JsonPath { // Tombstone indicates that we are within a map, but the map key is currently being decoded. // It is also used to overwrite a previous map key to avoid memory leaks and misattribution. - object Tombstone + private object Tombstone /* * Serial descriptor, map key or the tombstone for map key diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/Polymorphic.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/Polymorphic.kt index ea65c48ac..c1c91264f 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/Polymorphic.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/Polymorphic.kt @@ -9,6 +9,7 @@ import kotlinx.serialization.* import kotlinx.serialization.descriptors.* import kotlinx.serialization.internal.* import kotlinx.serialization.json.* +import kotlin.jvm.* @Suppress("UNCHECKED_CAST") internal inline fun JsonEncoder.encodePolymorphically( @@ -55,12 +56,13 @@ internal fun checkKind(kind: SerialKind) { } internal fun JsonDecoder.decodeSerializableValuePolymorphic(deserializer: DeserializationStrategy): T { + // NB: changes in this method should be reflected in StreamingJsonDecoder#decodeSerializableValue if (deserializer !is AbstractPolymorphicSerializer<*> || json.configuration.useArrayPolymorphism) { return deserializer.deserialize(this) } + val discriminator = deserializer.descriptor.classDiscriminator(json) val jsonTree = cast(decodeJsonElement(), deserializer.descriptor) - val discriminator = deserializer.descriptor.classDiscriminator(json) val type = jsonTree[discriminator]?.jsonPrimitive?.content val actualSerializer = deserializer.findPolymorphicSerializerOrNull(this, type) ?: throwSerializerNotFound(type, jsonTree) @@ -69,7 +71,8 @@ internal fun JsonDecoder.decodeSerializableValuePolymorphic(deserializer: De return json.readPolymorphicJson(discriminator, jsonTree, actualSerializer as DeserializationStrategy) } -private fun throwSerializerNotFound(type: String?, jsonTree: JsonObject): Nothing { +@JvmName("throwSerializerNotFound") +internal fun throwSerializerNotFound(type: String?, jsonTree: JsonObject): Nothing { val suffix = if (type == null) "missing class discriminator ('null')" else "class discriminator '$type'" diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt index bf2290440..403e90deb 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/StreamingJsonDecoder.kt @@ -9,6 +9,7 @@ import kotlinx.serialization.descriptors.* import kotlinx.serialization.encoding.* import kotlinx.serialization.encoding.CompositeDecoder.Companion.DECODE_DONE import kotlinx.serialization.encoding.CompositeDecoder.Companion.UNKNOWN_NAME +import kotlinx.serialization.internal.* import kotlinx.serialization.json.* import kotlinx.serialization.modules.* import kotlin.jvm.* @@ -21,11 +22,27 @@ internal open class StreamingJsonDecoder( final override val json: Json, private val mode: WriteMode, @JvmField internal val lexer: AbstractJsonLexer, - descriptor: SerialDescriptor + descriptor: SerialDescriptor, + discriminatorHolder: DiscriminatorHolder? ) : JsonDecoder, AbstractDecoder() { + // A mutable reference to the discriminator that have to be skipped when in optimistic phase + // of polymorphic serialization, see `decodeSerializableValue` + internal class DiscriminatorHolder(@JvmField var discriminatorToSkip: String?) + + private fun DiscriminatorHolder?.trySkip(unknownKey: String): Boolean { + if (this == null) return false + if (discriminatorToSkip == unknownKey) { + discriminatorToSkip = null + return true + } + return false + } + + override val serializersModule: SerializersModule = json.serializersModule private var currentIndex = -1 + private var discriminatorHolder: DiscriminatorHolder? = discriminatorHolder private val configuration = json.configuration private val elementMarker: JsonElementMarker? = if (configuration.explicitNulls) null else JsonElementMarker(descriptor) @@ -35,7 +52,40 @@ internal open class StreamingJsonDecoder( @Suppress("INVISIBLE_MEMBER", "INVISIBLE_REFERENCE") override fun decodeSerializableValue(deserializer: DeserializationStrategy): T { try { - return decodeSerializableValuePolymorphic(deserializer) + /* + * This is an optimized path over decodeSerializableValuePolymorphic(deserializer): + * dSVP reads the very next JSON tree into a memory as JsonElement and then runs TreeJsonDecoder over it + * in order to deal with an arbitrary order of keys, but with the price of additional memory pressure + * and CPU consumption. + * We would like to provide best possible performance for data produced by kotlinx.serialization + * itself, for that we do the following optimistic optimization: + * + * 0) Remember current position in the string + * 1) Read the very next key of JSON structure + * 2) If it matches* the descriminator key, read the value, remember current position + * 3) Return the value, recover an initial position + * (*) -- if it doesn't match, fallback to dSVP method. + */ + if (deserializer !is AbstractPolymorphicSerializer<*> || json.configuration.useArrayPolymorphism) { + return deserializer.deserialize(this) + } + + val discriminator = deserializer.descriptor.classDiscriminator(json) + val type = lexer.consumeLeadingMatchingValue(discriminator, configuration.isLenient) + var actualSerializer: DeserializationStrategy? = null + if (type != null) { + actualSerializer = deserializer.findPolymorphicSerializerOrNull(this, type) + } + if (actualSerializer == null) { + // Fallback if we haven't found discriminator or serializer + return decodeSerializableValuePolymorphic(deserializer as DeserializationStrategy) + } + + discriminatorHolder = DiscriminatorHolder(discriminator) + @Suppress("UNCHECKED_CAST") + val result = actualSerializer.deserialize(this) as T + return result + } catch (e: MissingFieldException) { throw MissingFieldException(e.message + " at path: " + lexer.path.getPath(), e) } @@ -52,12 +102,13 @@ internal open class StreamingJsonDecoder( json, newMode, lexer, - descriptor + descriptor, + discriminatorHolder ) else -> if (mode == newMode && json.configuration.explicitNulls) { this } else { - StreamingJsonDecoder(json, newMode, lexer, descriptor) + StreamingJsonDecoder(json, newMode, lexer, descriptor, discriminatorHolder) } } } @@ -193,7 +244,7 @@ internal open class StreamingJsonDecoder( } private fun handleUnknown(key: String): Boolean { - if (configuration.ignoreUnknownKeys) { + if (configuration.ignoreUnknownKeys || discriminatorHolder.trySkip(key)) { lexer.skipElement(configuration.isLenient) } else { // Here we cannot properly update json path indicies diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/AbstractJsonLexer.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/AbstractJsonLexer.kt index 173e54a84..977347a55 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/AbstractJsonLexer.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/AbstractJsonLexer.kt @@ -283,6 +283,8 @@ internal abstract class AbstractJsonLexer { return current } + abstract fun consumeLeadingMatchingValue(keyToMatch: String, isLenient: Boolean): String? + fun peekString(isLenient: Boolean): String? { val token = peekNextToken() val string = if (isLenient) { diff --git a/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/StringJsonLexer.kt b/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/StringJsonLexer.kt index 0ff980a2c..9ccfbcc1d 100644 --- a/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/StringJsonLexer.kt +++ b/formats/json/commonMain/src/kotlinx/serialization/json/internal/lexer/StringJsonLexer.kt @@ -78,10 +78,10 @@ internal class StringJsonLexer(override val source: String) : AbstractJsonLexer( override fun consumeKeyString(): String { /* - * For strings we assume that escaped symbols are rather an exception, so firstly - * we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf', - * than do our pessimistic check for backslash and fallback to slow-path if necessary. - */ + * For strings we assume that escaped symbols are rather an exception, so firstly + * we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf', + * than do our pessimistic check for backslash and fallback to slow-path if necessary. + */ consumeNextToken(STRING) val current = currentPosition val closingQuote = source.indexOf('"', current) @@ -96,4 +96,22 @@ internal class StringJsonLexer(override val source: String) : AbstractJsonLexer( this.currentPosition = closingQuote + 1 return source.substring(current, closingQuote) } + + override fun consumeLeadingMatchingValue(keyToMatch: String, isLenient: Boolean): String? { + val positionSnapshot = currentPosition + try { + // Malformed JSON, bailout + if (consumeNextToken() != TC_BEGIN_OBJ) return null + val firstKey = if (isLenient) consumeKeyString() else consumeStringLenientNotNull() + if (firstKey == keyToMatch) { + if (consumeNextToken() != TC_COLON) return null + val result = if (isLenient) consumeString() else consumeStringLenientNotNull() + return result + } + return null + } finally { + // Restore the position + currentPosition = positionSnapshot + } + } } diff --git a/formats/json/commonTest/src/kotlinx/serialization/features/DefaultPolymorphicSerializerTest.kt b/formats/json/commonTest/src/kotlinx/serialization/features/DefaultPolymorphicSerializerTest.kt new file mode 100644 index 000000000..d2f09f06b --- /dev/null +++ b/formats/json/commonTest/src/kotlinx/serialization/features/DefaultPolymorphicSerializerTest.kt @@ -0,0 +1,35 @@ +/* + * Copyright 2017-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ +package kotlinx.serialization.features + +import kotlinx.serialization.* +import kotlinx.serialization.json.* +import kotlinx.serialization.modules.* +import kotlin.test.* + +class DefaultPolymorphicSerializerTest : JsonTestBase() { + + @Serializable + abstract class Project { + abstract val name: String + } + + @Serializable + data class DefaultProject(override val name: String, val type: String): Project() + + val module = SerializersModule { + polymorphic(Project::class) { + defaultDeserializer { DefaultProject.serializer() } + } + } + + private val json = Json { serializersModule = module } + + @Test + fun test() = parametrizedTest { + assertEquals(DefaultProject("example", "unknown"), + json.decodeFromString(""" {"type":"unknown","name":"example"}""", it)) + } + +} diff --git a/formats/json/commonTest/src/kotlinx/serialization/json/JsonTestBase.kt b/formats/json/commonTest/src/kotlinx/serialization/json/JsonTestBase.kt index a88e264f5..4352aa6bf 100644 --- a/formats/json/commonTest/src/kotlinx/serialization/json/JsonTestBase.kt +++ b/formats/json/commonTest/src/kotlinx/serialization/json/JsonTestBase.kt @@ -67,7 +67,7 @@ abstract class JsonTestBase { } JsonTestingMode.TREE -> { val lexer = StringJsonLexer(source) - val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor) + val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor, null) val tree = input.decodeJsonElement() lexer.expectEof() readJson(tree, deserializer) diff --git a/formats/json/jvmMain/src/kotlinx/serialization/json/JvmStreams.kt b/formats/json/jvmMain/src/kotlinx/serialization/json/JvmStreams.kt index 3b83299c1..be3a64db4 100644 --- a/formats/json/jvmMain/src/kotlinx/serialization/json/JvmStreams.kt +++ b/formats/json/jvmMain/src/kotlinx/serialization/json/JvmStreams.kt @@ -61,7 +61,7 @@ public fun Json.decodeFromStream( stream: InputStream ): T { val lexer = ReaderJsonLexer(stream) - val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor) + val input = StreamingJsonDecoder(this, WriteMode.OBJ, lexer, deserializer.descriptor, null) val result = input.decodeSerializableValue(deserializer) lexer.expectEof() return result diff --git a/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonIterator.kt b/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonIterator.kt index 790030825..3929c840a 100644 --- a/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonIterator.kt +++ b/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonIterator.kt @@ -56,7 +56,7 @@ private class JsonIteratorWsSeparated( private val deserializer: DeserializationStrategy ) : Iterator { override fun next(): T = - StreamingJsonDecoder(json, WriteMode.OBJ, lexer, deserializer.descriptor) + StreamingJsonDecoder(json, WriteMode.OBJ, lexer, deserializer.descriptor, null) .decodeSerializableValue(deserializer) override fun hasNext(): Boolean = lexer.isNotEof() @@ -75,7 +75,7 @@ private class JsonIteratorArrayWrapped( } else { lexer.consumeNextToken(COMMA) } - val input = StreamingJsonDecoder(json, WriteMode.OBJ, lexer, deserializer.descriptor) + val input = StreamingJsonDecoder(json, WriteMode.OBJ, lexer, deserializer.descriptor, null) return input.decodeSerializableValue(deserializer) } diff --git a/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonLexerJvm.kt b/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonLexerJvm.kt index eabfd0886..28ec2cfc3 100644 --- a/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonLexerJvm.kt +++ b/formats/json/jvmMain/src/kotlinx/serialization/json/internal/JsonLexerJvm.kt @@ -133,10 +133,10 @@ internal class ReaderJsonLexer( override fun consumeKeyString(): String { /* - * For strings we assume that escaped symbols are rather an exception, so firstly - * we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf', - * than do our pessimistic check for backslash and fallback to slow-path if necessary. - */ + * For strings we assume that escaped symbols are rather an exception, so firstly + * we optimistically scan for closing quote via intrinsified and blazing-fast 'indexOf', + * than do our pessimistic check for backslash and fallback to slow-path if necessary. + */ consumeNextToken(STRING) var current = currentPosition val closingQuote = indexOf('"', current) @@ -174,4 +174,7 @@ internal class ReaderJsonLexer( override fun appendRange(fromIndex: Int, toIndex: Int) { escapedString.append(_source, fromIndex, toIndex - fromIndex) } + + // Can be carefully implemented but postponed for now + override fun consumeLeadingMatchingValue(keyToMatch: String, isLenient: Boolean): String? = null } diff --git a/formats/json/jvmTest/src/kotlinx/serialization/features/JsonJvmStreamsTest.kt b/formats/json/jvmTest/src/kotlinx/serialization/features/JsonJvmStreamsTest.kt index b576a2c1b..0de89d9c5 100644 --- a/formats/json/jvmTest/src/kotlinx/serialization/features/JsonJvmStreamsTest.kt +++ b/formats/json/jvmTest/src/kotlinx/serialization/features/JsonJvmStreamsTest.kt @@ -4,11 +4,11 @@ package kotlinx.serialization.features -import kotlinx.serialization.SerializationException -import kotlinx.serialization.StringData +import kotlinx.serialization.* import kotlinx.serialization.builtins.serializer import kotlinx.serialization.json.* import kotlinx.serialization.json.internal.BATCH_SIZE +import kotlinx.serialization.modules.* import kotlinx.serialization.test.* import org.junit.Test import java.io.ByteArrayInputStream @@ -85,4 +85,45 @@ class JsonJvmStreamsTest { } } + interface Poly + + @Serializable + @SerialName("Impl") + data class Impl(val str: String) : Poly + + @Test + fun testPolymorphismWhenCrossingBatchSizeNonLeadingKey() { + val json = Json { + serializersModule = SerializersModule { + polymorphic(Poly::class) { + subclass(Impl::class, Impl.serializer()) + } + } + } + + val longString = "a".repeat(BATCH_SIZE - 5) + val string = """{"str":"$longString", "type":"Impl"}""" + val golden = Impl(longString) + + val deserialized = json.decodeViaStream(serializer(), string) + assertEquals(golden, deserialized as Impl) + } + + @Test + fun testPolymorphismWhenCrossingBatchSize() { + val json = Json { + serializersModule = SerializersModule { + polymorphic(Poly::class) { + subclass(Impl::class, Impl.serializer()) + } + } + } + + val aLotOfWhiteSpaces = " ".repeat(BATCH_SIZE - 5) + val string = """{$aLotOfWhiteSpaces"type":"Impl", "str":"value"}""" + val golden = Impl("value") + + val deserialized = json.decodeViaStream(serializer(), string) + assertEquals(golden, deserialized as Impl) + } }