/
DerReader.kt
341 lines (293 loc) · 10.4 KB
/
DerReader.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
/*
* Copyright (C) 2020 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3.tls.internal.der
import java.math.BigInteger
import java.net.ProtocolException
import okio.Buffer
import okio.BufferedSource
import okio.ByteString
import okio.ForwardingSource
import okio.Source
import okio.buffer
/**
* Streaming decoder of data encoded following Abstract Syntax Notation One (ASN.1). There are
* multiple variants of ASN.1, including:
*
* * DER: Distinguished Encoding Rules. This further constrains ASN.1 for deterministic encoding.
* * BER: Basic Encoding Rules.
*
* This class was implemented according to the [X.690 spec][[x690]], and under the advice of
* [Lets Encrypt's ASN.1 and DER][asn1_and_der] guide.
*
* [x690]: https://www.itu.int/rec/T-REC-X.690
* [asn1_and_der]: https://letsencrypt.org/docs/a-warm-welcome-to-asn1-and-der/
*/
internal class DerReader(source: Source) {
private val countingSource: CountingSource = CountingSource(source)
private val source: BufferedSource = countingSource.buffer()
/** Total bytes read thus far. */
private val byteCount: Long
get() = countingSource.bytesRead - source.buffer.size
/** How many bytes to read before [peekHeader] should return false, or -1L for no limit. */
private var limit = -1L
/** Type hints scoped to the call stack, manipulated with [withTypeHint]. */
private val typeHintStack = mutableListOf<Any?>()
/**
* The type hint for the current object. Used to pick adapters based on other fields, such as
* in extensions which have different types depending on their extension ID.
*/
var typeHint: Any?
get() = typeHintStack.lastOrNull()
set(value) {
typeHintStack[typeHintStack.size - 1] = value
}
/** Names leading to the current location in the ASN.1 document. */
private val path = mutableListOf<String>()
private var constructed = false
private var peekedHeader: DerHeader? = null
private val bytesLeft: Long
get() = if (limit == -1L) -1L else (limit - byteCount)
fun hasNext(): Boolean = peekHeader() != null
/**
* Returns the next header to process unless this scope is exhausted.
*
* This returns null if:
*
* * The stream is exhausted.
* * We've read all of the bytes of an object whose length is known.
* * We've reached the [DerHeader.TAG_END_OF_CONTENTS] of an object whose length is unknown.
*/
fun peekHeader(): DerHeader? {
var result = peekedHeader
if (result == null) {
result = readHeader()
peekedHeader = result
}
if (result.isEndOfData) return null
return result
}
/**
* Consume the next header in the stream and return it. If there is no header to read because we
* have reached a limit, this returns [END_OF_DATA].
*/
internal fun readHeader(): DerHeader {
require(peekedHeader == null)
// We've hit a local limit.
if (byteCount == limit) return END_OF_DATA
// We've exhausted the source stream.
if (limit == -1L && source.exhausted()) return END_OF_DATA
// Read the tag.
val tagAndClass = source.readByte().toInt() and 0xff
val tagClass = tagAndClass and 0b1100_0000
val constructed = (tagAndClass and 0b0010_0000) == 0b0010_0000
val tag = when (val tag0 = tagAndClass and 0b0001_1111) {
0b0001_1111 -> readVariableLengthLong()
else -> tag0.toLong()
}
// Read the length.
val length0 = source.readByte().toInt() and 0xff
val length = when {
length0 == 0b1000_0000 -> {
throw ProtocolException("indefinite length not permitted for DER")
}
(length0 and 0b1000_0000) == 0b1000_0000 -> {
// Length specified over multiple bytes.
val lengthBytes = length0 and 0b0111_1111
if (lengthBytes > 8) {
throw ProtocolException("length encoded with more than 8 bytes is not supported")
}
var lengthBits = source.readByte().toLong() and 0xff
if (lengthBits == 0L || lengthBytes == 1 && lengthBits and 0b1000_0000 == 0L) {
throw ProtocolException("invalid encoding for length")
}
for (i in 1 until lengthBytes) {
lengthBits = lengthBits shl 8
lengthBits += source.readByte().toInt() and 0xff
}
if (lengthBits < 0) throw ProtocolException("length > Long.MAX_VALUE")
lengthBits
}
else -> {
// Length is 127 or fewer bytes.
(length0 and 0b0111_1111).toLong()
}
}
// Note that this may be be an encoded "end of data" header.
return DerHeader(tagClass, tag, constructed, length)
}
/**
* Consume a header and execute [block], which should consume the entire value described by the
* header. It is an error to not consume a full value in [block].
*/
internal inline fun <T> read(name: String?, block: (DerHeader) -> T): T {
if (!hasNext()) throw ProtocolException("expected a value")
val header = peekedHeader!!
peekedHeader = null
val pushedLimit = limit
val pushedConstructed = constructed
val newLimit = if (header.length != -1L) byteCount + header.length else -1L
if (pushedLimit != -1L && newLimit > pushedLimit) {
throw ProtocolException("enclosed object too large")
}
limit = newLimit
constructed = header.constructed
if (name != null) path += name
try {
val result = block(header)
// The object processed bytes beyond its range.
if (newLimit != -1L && byteCount > newLimit) {
throw ProtocolException("unexpected byte count at $this")
}
return result
} finally {
peekedHeader = null
limit = pushedLimit
constructed = pushedConstructed
if (name != null) path.removeAt(path.size - 1)
}
}
/**
* Execute [block] with a new namespace for type hints. Type hints from the enclosing type are no
* longer usable by the current type's members.
*/
fun <T> withTypeHint(block: () -> T): T {
typeHintStack.add(null)
try {
return block()
} finally {
typeHintStack.removeAt(typeHintStack.size - 1)
}
}
fun readBoolean(): Boolean {
if (bytesLeft != 1L) throw ProtocolException("unexpected length: $bytesLeft at $this")
return source.readByte().toInt() != 0
}
fun readBigInteger(): BigInteger {
if (bytesLeft == 0L) throw ProtocolException("unexpected length: $bytesLeft at $this")
val byteArray = source.readByteArray(bytesLeft)
return BigInteger(byteArray)
}
fun readLong(): Long {
if (bytesLeft !in 1..8) throw ProtocolException("unexpected length: $bytesLeft at $this")
var result = source.readByte().toLong() // No "and 0xff" because this is a signed value!
while (byteCount < limit) {
result = result shl 8
result += source.readByte().toInt() and 0xff
}
return result
}
fun readBitString(): BitString {
if (bytesLeft == -1L || constructed) {
throw ProtocolException("constructed bit strings not supported for DER")
}
if (bytesLeft < 1) {
throw ProtocolException("malformed bit string")
}
val unusedBitCount = source.readByte().toInt() and 0xff
val byteString = source.readByteString(bytesLeft)
return BitString(byteString, unusedBitCount)
}
fun readOctetString(): ByteString {
if (bytesLeft == -1L || constructed) {
throw ProtocolException("constructed octet strings not supported for DER")
}
return source.readByteString(bytesLeft)
}
fun readUtf8String(): String {
if (bytesLeft == -1L || constructed) {
throw ProtocolException("constructed strings not supported for DER")
}
return source.readUtf8(bytesLeft)
}
fun readObjectIdentifier(): String {
val result = Buffer()
val dot = '.'.code.toByte().toInt()
when (val xy = readVariableLengthLong()) {
in 0L until 40L -> {
result.writeDecimalLong(0)
result.writeByte(dot)
result.writeDecimalLong(xy)
}
in 40L until 80L -> {
result.writeDecimalLong(1)
result.writeByte(dot)
result.writeDecimalLong(xy - 40L)
}
else -> {
result.writeDecimalLong(2)
result.writeByte(dot)
result.writeDecimalLong(xy - 80L)
}
}
while (byteCount < limit) {
result.writeByte(dot)
result.writeDecimalLong(readVariableLengthLong())
}
return result.readUtf8()
}
fun readRelativeObjectIdentifier(): String {
val result = Buffer()
val dot = '.'.code.toByte().toInt()
while (byteCount < limit) {
if (result.size > 0) {
result.writeByte(dot)
}
result.writeDecimalLong(readVariableLengthLong())
}
return result.readUtf8()
}
/** Used for tags and subidentifiers. */
private fun readVariableLengthLong(): Long {
// TODO(jwilson): detect overflow.
var result = 0L
while (true) {
val byteN = source.readByte().toLong() and 0xff
if ((byteN and 0b1000_0000L) == 0b1000_0000L) {
result = (result + (byteN and 0b0111_1111)) shl 7
} else {
return result + byteN
}
}
}
/** Read a value as bytes without interpretation of its contents. */
fun readUnknown(): ByteString {
return source.readByteString(bytesLeft)
}
override fun toString(): String = path.joinToString(separator = " / ")
companion object {
/**
* A synthetic value that indicates there's no more bytes. Values with equivalent data may also
* show up in ASN.1 streams to also indicate the end of SEQUENCE, SET or other constructed
* value.
*/
private val END_OF_DATA = DerHeader(
tagClass = DerHeader.TAG_CLASS_UNIVERSAL,
tag = DerHeader.TAG_END_OF_CONTENTS,
constructed = false,
length = -1L
)
}
/** A source that keeps track of how many bytes it's consumed. */
private class CountingSource(source: Source) : ForwardingSource(source) {
var bytesRead = 0L
override fun read(sink: Buffer, byteCount: Long): Long {
val result = delegate.read(sink, byteCount)
if (result == -1L) return -1L
bytesRead += result
return result
}
}
}