Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support encoding/decoding ByteArrays in parameterized types as ByteStrings. #2383

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -4,7 +4,9 @@ import kotlinx.serialization.*

/**
* Specifies that a [ByteArray] shall be encoded/decoded as CBOR major type 2: a byte string.
* For types other than [ByteArray], [ByteString] will have no effect.
* When annotating a parameterized type such as [List] or [Map], this will also apply to any wrapped [ByteArray] types,
* encoding/decoding them as CBOR major type 2: a byte string. For types other than [ByteArray], [ByteString] will have
* no effect.
*
* Example usage:
*
Expand Down
Expand Up @@ -195,28 +195,38 @@ internal class CborEncoder(private val output: ByteArrayOutput) {
}
}

private class CborMapReader(cbor: Cbor, decoder: CborDecoder) : CborListReader(cbor, decoder) {
private class CborMapReader(
cbor: Cbor,
decoder: CborDecoder,
decodeByteArrayAsByteString: Boolean,
) : CborListReader(cbor, decoder, decodeByteArrayAsByteString) {
override fun skipBeginToken() = setSize(decoder.startMap() * 2)
}

private open class CborListReader(cbor: Cbor, decoder: CborDecoder) : CborReader(cbor, decoder) {
private open class CborListReader(
cbor: Cbor,
decoder: CborDecoder,
decodeByteArrayAsByteString: Boolean,
) : CborReader(cbor, decoder, decodeByteArrayAsByteString) {
private var ind = 0

override fun skipBeginToken() = setSize(decoder.startArray())

override fun decodeElementIndex(descriptor: SerialDescriptor) = if (!finiteMode && decoder.isEnd() || (finiteMode && ind >= size)) CompositeDecoder.DECODE_DONE else ind++
}

internal open class CborReader(private val cbor: Cbor, protected val decoder: CborDecoder) : AbstractDecoder() {
internal open class CborReader(
private val cbor: Cbor,
protected val decoder: CborDecoder,
private var decodeByteArrayAsByteString: Boolean = false,
) : AbstractDecoder() {

protected var size = -1
private set
protected var finiteMode = false
private set
private var readProperties: Int = 0

private var decodeByteArrayAsByteString = false

protected fun setSize(size: Int) {
if (size >= 0) {
finiteMode = true
Expand All @@ -232,9 +242,9 @@ internal open class CborReader(private val cbor: Cbor, protected val decoder: Cb
@OptIn(ExperimentalSerializationApi::class)
override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
val re = when (descriptor.kind) {
StructureKind.LIST, is PolymorphicKind -> CborListReader(cbor, decoder)
StructureKind.MAP -> CborMapReader(cbor, decoder)
else -> CborReader(cbor, decoder)
StructureKind.LIST, is PolymorphicKind -> CborListReader(cbor, decoder, decodeByteArrayAsByteString)
StructureKind.MAP -> CborMapReader(cbor, decoder, decodeByteArrayAsByteString)
else -> CborReader(cbor, decoder, decodeByteArrayAsByteString)
}
re.skipBeginToken()
return re
Expand Down
Expand Up @@ -139,6 +139,65 @@ class CborReaderTest {
)
}

@Test
fun testReadByteStringWhenRepeated() {
/* A1 # map(1)
* 6B # text(11)
* 62797465537472696E6773 # "byteStrings"
* 81 # array(1)
* 44 # bytes(4)
* 01020304 # "\x01\x02\x03\x04"
*/
assertEquals(
expected = RepeatedByteString(listOf(byteArrayOf(1, 2, 3, 4))),
actual = Cbor.decodeFromHexString(
deserializer = RepeatedByteString.serializer(),
hex = "a16b62797465537472696e6773814401020304"
)
)

/* A1 # map(1)
* 6B # text(11)
* 62797465537472696E6773 # "byteStrings"
* 80 # array(0)
*/
assertEquals(
expected = RepeatedByteString(listOf()),
actual = Cbor.decodeFromHexString(
deserializer = RepeatedByteString.serializer(),
hex = "a16b62797465537472696e677380"
)
)
}

@Test
fun testReadRepeatedByteStringWithByteArray() {
/* A2 # map(2)
* 6B # text(11)
* 62797465537472696E6773 # "byteStrings"
* 81 # array(1)
* 44 # bytes(4)
* 01020304 # "\x01\x02\x03\x04"
* 69 # text(9)
* 627974654172726179 # "byteArray"
* 84 # array(4)
* 05 # unsigned(5)
* 06 # unsigned(6)
* 07 # unsigned(7)
* 08 # unsigned(8)
*/
assertEquals(
expected = RepeatedByteStringWithByteArray(
byteStrings = listOf(byteArrayOf(1, 2, 3, 4)),
byteArray = byteArrayOf(5, 6, 7, 8),
),
actual = Cbor.decodeFromHexString(
deserializer = RepeatedByteStringWithByteArray.serializer(),
hex = "a26b62797465537472696e6773814401020304696279746541727261798405060708"
)
)
}

/**
* CBOR hex data represents serialized versions of [TypesUmbrella] (which does **not** have a root property 'a') so
* decoding to [Simple] (which has the field 'a') is expected to fail.
Expand Down
Expand Up @@ -90,6 +90,53 @@ data class NullableByteString(
}
}

@Serializable
data class RepeatedByteString(@ByteString val byteStrings: List<ByteArray>) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false

other as RepeatedByteString

return byteStrings.toTypedArray().contentDeepEquals(other.byteStrings.toTypedArray())
}

override fun hashCode(): Int {
return byteStrings.hashCode()
}

override fun toString(): String {
return "RepeatedByteString(byteStrings=${byteStrings.map { it.contentToString() }}"
}
}

@Serializable
data class RepeatedByteStringWithByteArray(
@ByteString val byteStrings: List<ByteArray>,
val byteArray: ByteArray,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false

other as RepeatedByteStringWithByteArray

if (!byteStrings.toTypedArray().contentDeepEquals(other.byteStrings.toTypedArray())) return false
return byteArray.contentEquals(other.byteArray)
}

override fun hashCode(): Int {
var result = byteStrings.hashCode()
result = 31 * result + byteArray.contentHashCode()
return result
}

override fun toString(): String {
return "RepeatedByteStringWithByteArray(byteStrings=${byteStrings.map { it.contentToString() }}, " +
"byteArray=${byteArray.contentToString()})"
}
}

@Serializable(with = CustomByteStringSerializer::class)
data class CustomByteString(val a: Byte, val b: Byte, val c: Byte)

Expand Down