Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-5216 Parse header with multiple challenges #3277

Merged
merged 3 commits into from Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -54,13 +54,13 @@ public class Auth private constructor(
val candidateProviders = HashSet(plugin.providers)

while (call.response.status == HttpStatusCode.Unauthorized) {
val headerValue = call.response.headers[HttpHeaders.WWWAuthenticate]
val headerValues = call.response.headers.getAll(HttpHeaders.WWWAuthenticate)
marychatte marked this conversation as resolved.
Show resolved Hide resolved
val authHeaders = headerValues?.map { parseAuthorizationHeaders(it) }?.flatten() ?: emptyList()

val authHeader = headerValue?.let { parseAuthorizationHeader(headerValue) }
val provider = when {
authHeader == null && candidateProviders.size == 1 -> candidateProviders.first()
authHeader == null -> return@intercept call
else -> candidateProviders.find { it.isApplicable(authHeader) } ?: return@intercept call
val (provider, authHeader) = when {
authHeaders.isEmpty() && candidateProviders.size == 1 -> candidateProviders.first() to null
authHeaders.isEmpty() -> return@intercept call
else -> findProviderAndHeader(candidateProviders, authHeaders) ?: return@intercept call
}
if (!provider.refreshToken(call.response)) return@intercept call

Expand All @@ -76,6 +76,21 @@ public class Auth private constructor(
return@intercept call
}
}

private fun findProviderAndHeader(
providers: Collection<AuthProvider>,
authHeaders: List<HttpAuthHeader>
): Pair<AuthProvider, HttpAuthHeader>? {
marychatte marked this conversation as resolved.
Show resolved Hide resolved
authHeaders.forEach { header ->
providers.forEach { provider ->
if (provider.isApplicable(header)) {
return provider to header
}
}
}

return null
}
}
}

Expand Down
1 change: 1 addition & 0 deletions ktor-http/api/ktor-http.api
Expand Up @@ -1114,6 +1114,7 @@ public final class io/ktor/http/auth/HttpAuthHeader$Single : io/ktor/http/auth/H

public final class io/ktor/http/auth/HttpAuthHeaderKt {
public static final fun parseAuthorizationHeader (Ljava/lang/String;)Lio/ktor/http/auth/HttpAuthHeader;
public static final fun parseAuthorizationHeaders (Ljava/lang/String;)Ljava/util/List;
}

public final class io/ktor/http/content/ByteArrayContent : io/ktor/http/content/OutgoingContent$ByteArrayContent {
Expand Down
124 changes: 96 additions & 28 deletions ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt
Expand Up @@ -8,7 +8,6 @@ import io.ktor.http.*
import io.ktor.http.parsing.*
import io.ktor.util.*
import io.ktor.utils.io.charsets.*
import kotlin.native.concurrent.*

private val TOKEN_EXTRA = setOf('!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~')
private val TOKEN68_EXTRA = setOf('-', '.', '_', '~', '+', '/')
Expand All @@ -19,20 +18,21 @@ private val escapeRegex: Regex = "\\\\.".toRegex()
* Parses an authorization header [headerValue] into a [HttpAuthHeader].
* @return [HttpAuthHeader] or `null` if argument string is blank.
* @throws [ParseException] on invalid header
*
* @see [parseAuthorizationHeaders]
*/
public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? {
var index = 0
index = headerValue.skipSpaces(index)

var tokenStartIndex = index
val tokenStartIndex = index
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}

// Auth scheme
val authScheme = headerValue.substring(tokenStartIndex until index)
index = headerValue.skipSpaces(index)
tokenStartIndex = index

if (authScheme.isBlank()) {
return null
Expand All @@ -42,44 +42,104 @@ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? {
return HttpAuthHeader.Parameterized(authScheme, emptyList())
}

val token68 = matchToken68(headerValue, index)
val (indexAfterToken68, token68) = matchToken68(headerValue, index)
if (token68 != null) {
return HttpAuthHeader.Single(authScheme, token68)
return checkSingleHeader(indexAfterToken68, HttpAuthHeader.Single(authScheme, token68))
}

val parameters = matchParameters(headerValue, tokenStartIndex)
return HttpAuthHeader.Parameterized(authScheme, parameters)
val (endIndex, parameters) = matchParameters(headerValue, index)
return checkSingleHeader(endIndex, HttpAuthHeader.Parameterized(authScheme, parameters))
}

private fun checkSingleHeader(endIndex: Int, header: HttpAuthHeader): HttpAuthHeader {
return if (endIndex == -1) header else
throw ParseException("Function parseAuthorizationHeader can parse only one header")
}

private fun matchParameters(headerValue: String, startIndex: Int): Map<String, String> {
/**
* Parses an authorization header [headerValue] into a list of [HttpAuthHeader].
* @return a list of [HttpAuthHeader]
* @throws [ParseException] on invalid header
*/
public fun parseAuthorizationHeaders(headerValue: String): List<HttpAuthHeader> {
marychatte marked this conversation as resolved.
Show resolved Hide resolved
var index = 0
val headers = mutableListOf<HttpAuthHeader>()
while (index != -1) {
val (nextIndex, header) = parseAuthorizationHeader(headerValue, index)
marychatte marked this conversation as resolved.
Show resolved Hide resolved
headers.add(header)
index = nextIndex
}
return headers
}

private fun parseAuthorizationHeader(
marychatte marked this conversation as resolved.
Show resolved Hide resolved
headerValue: String,
startIndex: Int,
): Pair<Int, HttpAuthHeader> {
var index = headerValue.skipSpaces(startIndex)

// Auth scheme
val schemeStartIndex = index
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}
val authScheme = headerValue.substring(schemeStartIndex until index)

if (authScheme.isBlank()) {
throw ParseException("Invalid authScheme value: it should be token, can't be blank")
}

val (endChallengeIndex, isEndOfChallenge) = headerValue.isEndOfChallenge(index)
if (isEndOfChallenge) {
return endChallengeIndex to HttpAuthHeader.Parameterized(authScheme, emptyList())
}

val (nextIndex, token68) = matchToken68(headerValue, endChallengeIndex)
if (token68 != null) {
return nextIndex to HttpAuthHeader.Single(authScheme, token68)
}

val (nextIndexChallenge, parameters) = matchParameters(headerValue, index)
return nextIndexChallenge to HttpAuthHeader.Parameterized(authScheme, parameters)
}

private fun matchParameters(headerValue: String, startIndex: Int): Pair<Int, Map<String, String>> {
val result = mutableMapOf<String, String>()

var index = startIndex
while (index > 0 && index < headerValue.length) {
index = matchParameter(headerValue, index, result)
index = headerValue.skipDelimiter(index, ',')
val (nextIndex, wasParameter) = matchParameter(headerValue, index, result)
if (wasParameter) {
index = headerValue.skipDelimiter(nextIndex, ',')
} else {
return nextIndex to result
}
}

return result
return index to result
}

private fun matchParameter(headerValue: String, startIndex: Int, parameters: MutableMap<String, String>): Int {
private fun matchParameter(
headerValue: String,
startIndex: Int,
parameters: MutableMap<String, String>
): Pair<Int, Boolean> {
val keyStart = headerValue.skipSpaces(startIndex)
var index = keyStart

// Take key
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}

val key = headerValue.substring(keyStart until index)

// Take '='
// Check if new challenge
index = headerValue.skipSpaces(index)
if (index >= headerValue.length || headerValue[index] != '=') {
throw ParseException("Expected `=` after parameter key '$key': $headerValue")
if (index == headerValue.length || headerValue[index] != '=') {
return keyStart to false
}

// Take '='
index++
index = headerValue.skipSpaces(index)

Expand Down Expand Up @@ -113,10 +173,10 @@ private fun matchParameter(headerValue: String, startIndex: Int, parameters: Mut
parameters[key] = if (quoted) value.unescaped() else value

if (quoted) index++
return index
return index to true
}

private fun matchToken68(headerValue: String, startIndex: Int): String? {
private fun matchToken68(headerValue: String, startIndex: Int): Pair<Int, String?> {
var index = startIndex

while (index < headerValue.length && headerValue[index].isToken68()) {
Expand All @@ -127,12 +187,14 @@ private fun matchToken68(headerValue: String, startIndex: Int): String? {
index++
}

val onlySpaceRemaining = (index until headerValue.length).all { headerValue[it] == ' ' }
if (onlySpaceRemaining) {
return headerValue.substring(startIndex until index)
}
val token68 = headerValue.substring(startIndex until index)

return null
val (endChallengeIndex, isEndOfChallenge) = headerValue.isEndOfChallenge(index)
return if (isEndOfChallenge) {
endChallengeIndex to token68
} else {
startIndex to null
}
}

/**
Expand Down Expand Up @@ -355,13 +417,11 @@ private fun String.unescaped() = replace(escapeRegex) { it.value.takeLast(1) }
private fun String.skipDelimiter(startIndex: Int, delimiter: Char): Int {
var index = skipSpaces(startIndex)

while (index < length && this[index] != delimiter) {
index++
}

if (index == length) return -1
index++
if (this[index] != delimiter)
throw ParseException("Expected delimiter $delimiter at position $index, but found ${this[index]}")

index++
return skipSpaces(index)
}

Expand All @@ -374,6 +434,14 @@ private fun String.skipSpaces(startIndex: Int): Int {
return index
}

private fun String.isEndOfChallenge(startIndex: Int): Pair<Int, Boolean> {
val index = skipSpaces(startIndex)
if (index == length) return -1 to true
if (this[index] == ',') return index + 1 to true

return index to false
}

private fun Char.isToken68(): Boolean = (this in 'a'..'z') || (this in 'A'..'Z') || isDigit() || this in TOKEN68_EXTRA

private fun Char.isToken(): Boolean = (this in 'a'..'z') || (this in 'A'..'Z') || isDigit() || this in TOKEN_EXTRA
Expand Up @@ -9,19 +9,23 @@ import kotlin.random.*
import kotlin.test.*

class AuthorizeHeaderParserTest {
@Test fun empty() {
@Test
fun empty() {
testParserParameterized("Basic", emptyMap(), "Basic")
}

@Test fun emptyWithTrailingSpaces() {
@Test
fun emptyWithTrailingSpaces() {
testParserParameterized("Basic", emptyMap(), "Basic ")
}

@Test fun singleSimple() {
@Test
fun singleSimple() {
testParserSingle("Basic", "abc==", "Basic abc==")
}

@Test fun testParameterizedSimple() {
@Test
fun testParameterizedSimple() {
testParserParameterized("Basic", mapOf("a" to "1"), "Basic a=1")
testParserParameterized("Basic", mapOf("a" to "1"), "Basic a =1")
testParserParameterized("Basic", mapOf("a" to "1"), "Basic a = 1")
Expand All @@ -30,27 +34,62 @@ class AuthorizeHeaderParserTest {
testParserParameterized("Basic", mapOf("a" to "1"), "Basic a=1 ")
}

@Test fun testParameterizedSimpleTwoParams() {
@Test
fun testParameterizedSimpleTwoParams() {
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1, b=2")
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1,b=2")
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1 ,b=2")
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1 , b=2")
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1 , b=2 ")
}

@Test fun testParameterizedQuoted() {
@Test
fun testParameterizedQuoted() {
testParserParameterized("Basic", mapOf("a" to "1 2"), "Basic a=\"1 2\"")
}

@Test fun testParameterizedQuotedEscaped() {
@Test
fun testParameterizedQuotedEscaped() {
testParserParameterized("Basic", mapOf("a" to "1 \" 2"), "Basic a=\"1 \\\" 2\"")
testParserParameterized("Basic", mapOf("a" to "1 A 2"), "Basic a=\"1 \\A 2\"")
}

@Test fun testParameterizedQuotedEscapedInTheMiddle() {
@Test
fun testParameterizedQuotedEscapedInTheMiddle() {
testParserParameterized("Basic", mapOf("a" to "1 \" 2", "b" to "2"), "Basic a=\"1 \\\" 2\", b= 2")
}

@Test
fun testMultipleChallengesParameters() {
val expected = listOf(
HttpAuthHeader.Parameterized("Digest", emptyMap()),
HttpAuthHeader.Parameterized("Bearer", mapOf("1" to "2", "3" to "4")),
HttpAuthHeader.Parameterized("Basic", emptyMap()),
)
testParserMultipleChallenges(expected, "Digest, Bearer 1 = 2, 3=4, Basic ")
}

@Test
fun testMultipleChallengesSingle() {
val expected = listOf(
HttpAuthHeader.Single("Bearer", "abc=="),
HttpAuthHeader.Parameterized("Bearer", mapOf("abc" to "def")),
HttpAuthHeader.Single("Basic", "def==="),
HttpAuthHeader.Parameterized("Digest", emptyMap())
)
testParserMultipleChallenges(expected, "Bearer abc==, Bearer abc=def, Basic def===, Digest")
}

@Test
fun testMultipleChallengesAllHeaders() {
val expected = listOf(
HttpAuthHeader.Parameterized("Basic", emptyMap()),
HttpAuthHeader.Parameterized("Bearer", mapOf("abc" to "def")),
HttpAuthHeader.Single("Digest", "abc==")
)
testParserMultipleChallenges(expected, "Basic, Bearer abc=def,Digest abc==")
}

private fun testParserSingle(scheme: String, value: String, headerValue: String) {
val actual = parseAuthorizationHeader(headerValue)!!

Expand All @@ -75,11 +114,31 @@ class AuthorizeHeaderParserTest {
}
}

private fun testParserMultipleChallenges(expected: List<HttpAuthHeader>, headerValue: String) {
val actual = parseAuthorizationHeaders(headerValue)

assertEquals(expected.size, actual.size)
(expected zip actual).forEach { (expectedHeader, actualHeader) ->
if (expectedHeader is HttpAuthHeader.Single) {
assertIs<HttpAuthHeader.Single>(actualHeader)

assertEquals(expectedHeader.blob, actualHeader.blob)
}
if (expectedHeader is HttpAuthHeader.Parameterized) {
assertIs<HttpAuthHeader.Parameterized>(actualHeader)
assertEquals(
expectedHeader.parameters.associateBy({ it.name }, { it.value }),
actualHeader.parameters.associateBy({ it.name }, { it.value })
)
}
}
}

private fun Random.nextString(
length: Int,
possible: Iterable<Char> = ('a'..'z') + ('A'..'Z') + ('0'..'9')
) = possible.toList().let { possibleElements ->
(0..length - 1).map { nextFrom(possibleElements) }.joinToString("")
(0 until length).map { nextFrom(possibleElements) }.joinToString("")
}

private fun Random.nextString(length: Int, possible: String) = nextString(length, possible.toList())
Expand Down