Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly handle buffer boundaries while decoding escape sequences from json stream #1706

Merged
merged 2 commits into from Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -140,7 +140,7 @@ internal abstract class AbstractJsonLexer {
open fun ensureHaveChars() {}

// Used as bound check in loops
abstract fun definitelyNotEof(position: Int): Int
abstract fun prefetchOrEof(position: Int): Int

abstract fun tryConsumeComma(): Boolean

Expand Down Expand Up @@ -182,7 +182,7 @@ internal abstract class AbstractJsonLexer {
val source = source
var cpos = currentPosition
while (true) {
cpos = definitelyNotEof(cpos)
cpos = prefetchOrEof(cpos)
if (cpos == -1) break // could be inline function but KT-1436
val c = source[cpos++]
if (c == ' ' || c == '\n' || c == '\r' || c == '\t') continue
Expand Down Expand Up @@ -223,7 +223,7 @@ internal abstract class AbstractJsonLexer {
val source = source
var cpos = currentPosition
while (true) {
cpos = definitelyNotEof(cpos)
cpos = prefetchOrEof(cpos)
if (cpos == -1) break
val ch = source[cpos]
if (ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t') {
Expand All @@ -244,7 +244,7 @@ internal abstract class AbstractJsonLexer {
*/
fun tryConsumeNotNull(): Boolean {
var current = skipWhitespaces()
current = definitelyNotEof(current)
current = prefetchOrEof(current)
// Cannot consume null due to EOF, maybe something else
val len = source.length - current
if (len < 4 || current == -1) return true
Expand All @@ -264,7 +264,7 @@ internal abstract class AbstractJsonLexer {
var current = currentPosition
// Skip whitespaces
while (true) {
current = definitelyNotEof(current)
current = prefetchOrEof(current)
if (current == -1) break
val c = source[current]
// Faster than char2TokenClass actually
Expand Down Expand Up @@ -317,13 +317,15 @@ internal abstract class AbstractJsonLexer {
while (char != STRING) {
if (char == STRING_ESC) {
usedAppend = true
currentPosition = appendEscape(lastPosition, currentPosition)
currentPosition = prefetchOrEof(appendEscape(lastPosition, currentPosition))
if (currentPosition == -1)
fail("EOF", currentPosition)
qwwdfsad marked this conversation as resolved.
Show resolved Hide resolved
lastPosition = currentPosition
} else if (++currentPosition >= source.length) {
usedAppend = true
// end of chunk
appendRange(lastPosition, currentPosition)
currentPosition = definitelyNotEof(currentPosition)
currentPosition = prefetchOrEof(currentPosition)
if (currentPosition == -1)
fail("EOF", currentPosition)
lastPosition = currentPosition
Expand Down Expand Up @@ -395,7 +397,7 @@ internal abstract class AbstractJsonLexer {
if (current >= source.length) {
usedAppend = true
appendRange(currentPosition, current)
val eof = definitelyNotEof(current)
val eof = prefetchOrEof(current)
if (eof == -1) {
// to handle plain lenient strings, such as top-level
currentPosition = current
Expand All @@ -421,7 +423,7 @@ internal abstract class AbstractJsonLexer {

private fun appendEsc(startPosition: Int): Int {
var currentPosition = startPosition
currentPosition = definitelyNotEof(currentPosition)
currentPosition = prefetchOrEof(currentPosition)
if (currentPosition == -1) fail("Expected escape sequence to continue, got EOF")
val currentChar = source[currentPosition++]
if (currentChar == UNICODE_ESC) {
Expand All @@ -435,7 +437,13 @@ internal abstract class AbstractJsonLexer {
}

private fun appendHex(source: CharSequence, startPos: Int): Int {
if (startPos + 4 >= source.length) fail("Unexpected EOF during unicode escape")
if (startPos + 4 >= source.length) {
currentPosition = startPos
ensureHaveChars()
if (currentPosition + 4 >= source.length)
fail("Unexpected EOF during unicode escape")
return appendHex(source, currentPosition)
}
escapedString.append(
((fromHexChar(source, startPos) shl 12) +
(fromHexChar(source, startPos + 1) shl 8) +
Expand Down Expand Up @@ -520,7 +528,7 @@ internal abstract class AbstractJsonLexer {
* that doesn't allocate and also doesn't support any radix but 10
*/
var current = skipWhitespaces()
current = definitelyNotEof(current)
current = prefetchOrEof(current)
if (current >= source.length || current == -1) fail("EOF")
val hasQuotation = if (source[current] == STRING) {
// Check it again
Expand Down Expand Up @@ -598,7 +606,7 @@ internal abstract class AbstractJsonLexer {
* in 6-th bit and we leverage this fact, our implementation consumes boolean literals
* in a case-insensitive manner.
*/
var current = definitelyNotEof(start)
var current = prefetchOrEof(start)
if (current >= source.length || current == -1) fail("EOF")
return when (source[current++].code or asciiCaseMask) {
't'.code -> {
Expand Down
@@ -1,8 +1,12 @@
/*
* Copyright 2017-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.serialization.json.internal

internal class StringJsonLexer(override val source: String) : AbstractJsonLexer() {

override fun definitelyNotEof(position: Int): Int = if (position < source.length) position else -1
override fun prefetchOrEof(position: Int): Int = if (position < source.length) position else -1

override fun consumeNextToken(): Byte {
val source = source
Expand Down
@@ -1,8 +1,11 @@
/*
* Copyright 2017-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.serialization.json

import kotlinx.serialization.*
import kotlinx.serialization.builtins.*
import kotlinx.serialization.json.internal.*
import kotlinx.serialization.test.*
import kotlin.random.*
import kotlin.test.*
Expand Down Expand Up @@ -59,7 +62,7 @@ class JsonUnicodeTest : JsonTestBase() {
@Test
fun testRandomEscapeSequences() = noJs { // Too slow on JS
repeat(10_000) {
val s = generateRandomString()
val s = generateRandomUnicodeString(Random.nextInt(1, 2047))
try {
assertSerializedAndRestored(s, String.serializer())
} catch (e: Throwable) {
Expand All @@ -68,21 +71,4 @@ class JsonUnicodeTest : JsonTestBase() {
}
}
}

private fun generateRandomString(): String {
val size = Random.nextInt(1, 2047)
return buildString(size) {
repeat(size) {
val pickEscape = Random.nextBoolean()
if (pickEscape) {
// Definitely escape symbol
// null can be appended as well, completely okay
append(ESCAPE_STRINGS.random())
} else {
// Any symbol, including escaping one
append(Char(Random.nextInt(Char.MIN_VALUE.code..Char.MAX_VALUE.code)))
}
}
}
}
}
@@ -1,10 +1,13 @@
/*
* Copyright 2017-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2017-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
package kotlinx.serialization.test

import kotlinx.serialization.*
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.json.internal.ESCAPE_STRINGS
import kotlin.random.Random
import kotlin.random.nextInt
import kotlin.test.*

fun SerialDescriptor.assertDescriptorEqualsTo(other: SerialDescriptor) {
Expand Down Expand Up @@ -40,3 +43,18 @@ inline fun assertFailsWithMissingField(block: () -> Unit) {
val e = assertFailsWith<SerializationException>(block = block)
assertTrue(e.message?.contains("but it was missing") ?: false)
}

fun generateRandomUnicodeString(size: Int): String {
return buildString(size) {
repeat(size) {
val pickEscape = Random.nextBoolean()
if (pickEscape) {
// Definitely an escape symbol
append(ESCAPE_STRINGS.random().takeIf { it != null } ?: 'N')
} else {
// Any symbol, including escaping one
append(Char(Random.nextInt(Char.MIN_VALUE.code..Char.MAX_VALUE.code)).takeIf { it.isDefined() && !it.isSurrogate()} ?: 'U')
}
}
}
}
Expand Up @@ -58,7 +58,7 @@ internal class ReaderJsonLexer(
ensureHaveChars()
var current = currentPosition
while (true) {
current = definitelyNotEof(current)
current = prefetchOrEof(current)
if (current == -1) break // could be inline function but KT-1436
val c = source[current]
// Inlined skipWhitespaces without field spill and nested loop. Also faster then char2TokenClass
Expand Down Expand Up @@ -93,7 +93,7 @@ internal class ReaderJsonLexer(
currentPosition = 0
}

override fun definitelyNotEof(position: Int): Int {
override fun prefetchOrEof(position: Int): Int {
if (position < source.length) return position
currentPosition = position
ensureHaveChars()
Expand All @@ -106,7 +106,7 @@ internal class ReaderJsonLexer(
val source = source
var cpos = currentPosition
while (true) {
cpos = definitelyNotEof(cpos)
cpos = prefetchOrEof(cpos)
if (cpos == -1) break
val ch = source[cpos++]
return when (val tc = charToTokenClass(ch)) {
Expand Down Expand Up @@ -141,7 +141,7 @@ internal class ReaderJsonLexer(
var current = currentPosition
val closingQuote = indexOf('"', current)
if (closingQuote == -1) {
current = definitelyNotEof(current)
current = prefetchOrEof(current)
if (current == -1) fail(TC_STRING)
// it's also possible just to resize buffer,
// instead of falling back to slow path,
Expand Down
Expand Up @@ -4,13 +4,15 @@

package kotlinx.serialization.features

import kotlinx.serialization.*
import kotlinx.serialization.SerializationException
import kotlinx.serialization.StringData
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.*
import kotlinx.serialization.json.internal.BATCH_SIZE
import kotlinx.serialization.test.decodeViaStream
import kotlinx.serialization.test.encodeViaStream
import kotlinx.serialization.test.*
import org.junit.Test
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith

Expand Down Expand Up @@ -65,4 +67,22 @@ class JsonJvmStreamsTest {
Json.decodeViaStream(String.serializer(), "\"")
}
}

@Test
fun testRandomEscapeSequences() {
repeat(1000) {
val s = generateRandomUnicodeString(strLen)
try {
val serializer = String.serializer()
val b = ByteArrayOutputStream()
Json.encodeToStream(serializer, s, b)
val restored = Json.decodeFromStream(serializer, ByteArrayInputStream(b.toByteArray()))
assertEquals(s, restored)
} catch (e: Throwable) {
// Not assertion error to preserve cause
throw IllegalStateException("Unexpectedly failed test, cause string: $s", e)
}
}
}

}