Skip to content

Commit

Permalink
KTOR-5216 Parse header with multiple challenges (#3277)
Browse files Browse the repository at this point in the history
Parse header with multiple challenges
  • Loading branch information
marychatte committed Dec 6, 2022
1 parent bfde300 commit dab18c0
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 46 deletions.
35 changes: 35 additions & 0 deletions buildSrc/src/main/kotlin/test/server/tests/Auth.kt
Expand Up @@ -152,6 +152,41 @@ internal fun Application.authTestServer() {
call.respond("OK")
}
}

route("multiple") {
get("header") {
val token = call.request.headers[HttpHeaders.Authorization]

if (token.isNullOrEmpty() || token.contains("Invalid")) {
call.response.header(
HttpHeaders.WWWAuthenticate,
"Basic realm=\"TestServer\", charset=UTF-8, Digest, Bearer realm=\"my-server\""
)
call.respond(HttpStatusCode.Unauthorized)
return@get
}

call.respond("OK")
}
get("headers") {
val token = call.request.headers[HttpHeaders.Authorization]

if (token.isNullOrEmpty() || token.contains("Invalid")) {
call.response.header(
HttpHeaders.WWWAuthenticate,
"Basic realm=\"TestServer\", charset=UTF-8, Digest"
)
call.response.header(
HttpHeaders.WWWAuthenticate,
"Bearer realm=\"my-server\""
)
call.respond(HttpStatusCode.Unauthorized)
return@get
}

call.respond("OK")
}
}
}
}
}
Expand Down
Expand Up @@ -54,14 +54,26 @@ public class Auth private constructor(
val candidateProviders = HashSet(plugin.providers)

while (call.response.status == HttpStatusCode.Unauthorized) {
val headerValue = call.response.headers[HttpHeaders.WWWAuthenticate]
val headerValues = call.response.headers.getAll(HttpHeaders.WWWAuthenticate)
val authHeaders = headerValues?.map { parseAuthorizationHeaders(it) }?.flatten() ?: emptyList()

val authHeader = headerValue?.let { parseAuthorizationHeader(headerValue) }
val provider = when {
authHeader == null && candidateProviders.size == 1 -> candidateProviders.first()
authHeader == null -> return@intercept call
else -> candidateProviders.find { it.isApplicable(authHeader) } ?: return@intercept call
var providerOrNull: AuthProvider? = null
var authHeader: HttpAuthHeader? = null

when {
authHeaders.isEmpty() && candidateProviders.size == 1 -> {
providerOrNull = candidateProviders.first()
}

authHeaders.isEmpty() -> return@intercept call

else -> authHeader = authHeaders.find { header ->
providerOrNull = candidateProviders.find { it.isApplicable(header) }
providerOrNull != null
}
}
val provider = providerOrNull ?: return@intercept call

if (!provider.refreshToken(call.response)) return@intercept call

candidateProviders.remove(provider)
Expand Down
Expand Up @@ -535,4 +535,73 @@ class AuthTest : ClientLoader() {
assertEquals(2, loadCount)
}
}

@Test
fun testMultipleChallengesInHeader() = clientTests {
config {
install(Auth) {
basic {
credentials { BasicAuthCredentials("Invalid", "Invalid") }
}
bearer {
loadTokens { BearerTokens("test", "test") }
}
}
}
test { client ->
val responseOneHeader = client.get("$TEST_SERVER/auth/multiple/header").bodyAsText()
assertEquals("OK", responseOneHeader)
}
}

@Test
fun testMultipleChallengesInHeaders() = clientTests {
config {
install(Auth) {
basic {
credentials { BasicAuthCredentials("Invalid", "Invalid") }
}
bearer {
loadTokens { BearerTokens("test", "test") }
}
}
}
test { client ->
val responseMultipleHeaders = client.get("$TEST_SERVER/auth/multiple/headers").bodyAsText()
assertEquals("OK", responseMultipleHeaders)
}
}

@Test
fun testMultipleChallengesInHeaderUnauthorized() = clientTests {
test { client ->
val response = client.get("$TEST_SERVER/auth/multiple/header")
assertEquals(HttpStatusCode.Unauthorized, response.status)
response.headers[HttpHeaders.WWWAuthenticate]?.also {
assertTrue { it.contains("Bearer") }
assertTrue { it.contains("Basic") }
assertTrue { it.contains("Digest") }
} ?: run {
fail("Expected WWWAuthenticate header")
}
}
}

@Test
fun testMultipleChallengesInMultipleHeadersUnauthorized() = clientTests(listOf("Js")) {
test { client ->
val response = client.get("$TEST_SERVER/auth/multiple/headers")
assertEquals(HttpStatusCode.Unauthorized, response.status)
response.headers.getAll(HttpHeaders.WWWAuthenticate)?.let {
assertEquals(2, it.size)
it.joinToString().let { header ->
assertTrue { header.contains("Basic") }
assertTrue { header.contains("Digest") }
assertTrue { header.contains("Bearer") }
}
} ?: run {
fail("Expected WWWAuthenticate header")
}
}
}
}
1 change: 1 addition & 0 deletions ktor-http/api/ktor-http.api
Expand Up @@ -1114,6 +1114,7 @@ public final class io/ktor/http/auth/HttpAuthHeader$Single : io/ktor/http/auth/H

public final class io/ktor/http/auth/HttpAuthHeaderKt {
public static final fun parseAuthorizationHeader (Ljava/lang/String;)Lio/ktor/http/auth/HttpAuthHeader;
public static final fun parseAuthorizationHeaders (Ljava/lang/String;)Ljava/util/List;
}

public final class io/ktor/http/content/ByteArrayContent : io/ktor/http/content/OutgoingContent$ByteArrayContent {
Expand Down
141 changes: 110 additions & 31 deletions ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt
Expand Up @@ -8,7 +8,6 @@ import io.ktor.http.*
import io.ktor.http.parsing.*
import io.ktor.util.*
import io.ktor.utils.io.charsets.*
import kotlin.native.concurrent.*

private val TOKEN_EXTRA = setOf('!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~')
private val TOKEN68_EXTRA = setOf('-', '.', '_', '~', '+', '/')
Expand All @@ -19,20 +18,21 @@ private val escapeRegex: Regex = "\\\\.".toRegex()
* Parses an authorization header [headerValue] into a [HttpAuthHeader].
* @return [HttpAuthHeader] or `null` if argument string is blank.
* @throws [ParseException] on invalid header
*
* @see [parseAuthorizationHeaders]
*/
public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? {
var index = 0
index = headerValue.skipSpaces(index)

var tokenStartIndex = index
val tokenStartIndex = index
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}

// Auth scheme
val authScheme = headerValue.substring(tokenStartIndex until index)
index = headerValue.skipSpaces(index)
tokenStartIndex = index

if (authScheme.isBlank()) {
return null
Expand All @@ -42,44 +42,130 @@ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? {
return HttpAuthHeader.Parameterized(authScheme, emptyList())
}

val token68 = matchToken68(headerValue, index)
if (token68 != null) {
return HttpAuthHeader.Single(authScheme, token68)
val token68EndIndex = matchToken68(headerValue, index)
val token68 = headerValue.substring(index until token68EndIndex).trim()
if (token68.isNotEmpty()) {
if (token68EndIndex == headerValue.length) {
return HttpAuthHeader.Single(authScheme, token68)
}
}

val parameters = matchParameters(headerValue, tokenStartIndex)
return HttpAuthHeader.Parameterized(authScheme, parameters)
val parameters = mutableMapOf<String, String>()
val endIndex = matchParameters(headerValue, index, parameters)
return if (endIndex == -1) HttpAuthHeader.Parameterized(authScheme, parameters) else
throw ParseException("Function parseAuthorizationHeader can parse only one header")
}

/**
* Parses an authorization header [headerValue] into a list of [HttpAuthHeader].
* @return a list of [HttpAuthHeader]
* @throws [ParseException] on invalid header
*/
@InternalAPI
public fun parseAuthorizationHeaders(headerValue: String): List<HttpAuthHeader> {
var index = 0
val headers = mutableListOf<HttpAuthHeader>()
while (index != -1) {
index = parseAuthorizationHeader(headerValue, index, headers)
}
return headers
}

private fun matchParameters(headerValue: String, startIndex: Int): Map<String, String> {
val result = mutableMapOf<String, String>()
private fun parseAuthorizationHeader(
headerValue: String,
startIndex: Int,
headers: MutableList<HttpAuthHeader>
): Int {
var index = headerValue.skipSpaces(startIndex)

// Auth scheme
val schemeStartIndex = index
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}
val authScheme = headerValue.substring(schemeStartIndex until index)

if (authScheme.isBlank()) {
throw ParseException("Invalid authScheme value: it should be token, can't be blank")
}
index = headerValue.skipSpaces(index)

nextChallengeIndex(headers, HttpAuthHeader.Parameterized(authScheme, emptyList()), index, headerValue)?.let {
return it
}

val token68EndIndex = matchToken68(headerValue, index)
val token68 = headerValue.substring(index until token68EndIndex).trim()
if (token68.isNotEmpty()) {
nextChallengeIndex(headers, HttpAuthHeader.Single(authScheme, token68), token68EndIndex, headerValue)?.let {
return it
}
}

val parameters = mutableMapOf<String, String>()
val nextIndexChallenge = matchParameters(headerValue, index, parameters)
headers.add(HttpAuthHeader.Parameterized(authScheme, parameters))
return nextIndexChallenge
}

/**
* Check for the ending of the current challenge in a header
* @return -1 if at the end of the header
* @return null if the challenge is not ended
* @return a positive number - the index of the beginning of the next challenge
*/
private fun nextChallengeIndex(
headers: MutableList<HttpAuthHeader>,
header: HttpAuthHeader,
index: Int,
headerValue: String
): Int? {
if (index == headerValue.length || headerValue[index] == ',') {
headers.add(header)
return when {
index == headerValue.length -> -1
headerValue[index] == ',' -> index + 1
else -> error("") // unreachable code
}
}
return null
}

private fun matchParameters(headerValue: String, startIndex: Int, parameters: MutableMap<String, String>): Int {
var index = startIndex
while (index > 0 && index < headerValue.length) {
index = matchParameter(headerValue, index, result)
index = headerValue.skipDelimiter(index, ',')
val nextIndex = matchParameter(headerValue, index, parameters)
if (nextIndex == index) {
return index
} else {
index = headerValue.skipDelimiter(nextIndex, ',')
}
}

return result
return index
}

private fun matchParameter(headerValue: String, startIndex: Int, parameters: MutableMap<String, String>): Int {
private fun matchParameter(
headerValue: String,
startIndex: Int,
parameters: MutableMap<String, String>
): Int {
val keyStart = headerValue.skipSpaces(startIndex)
var index = keyStart

// Take key
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}

val key = headerValue.substring(keyStart until index)

// Take '='
// Check if new challenge
index = headerValue.skipSpaces(index)
if (index >= headerValue.length || headerValue[index] != '=') {
throw ParseException("Expected `=` after parameter key '$key': $headerValue")
if (index == headerValue.length || headerValue[index] != '=') {
return startIndex
}

// Take '='
index++
index = headerValue.skipSpaces(index)

Expand Down Expand Up @@ -116,8 +202,8 @@ private fun matchParameter(headerValue: String, startIndex: Int, parameters: Mut
return index
}

private fun matchToken68(headerValue: String, startIndex: Int): String? {
var index = startIndex
private fun matchToken68(headerValue: String, startIndex: Int): Int {
var index = headerValue.skipSpaces(startIndex)

while (index < headerValue.length && headerValue[index].isToken68()) {
index++
Expand All @@ -127,12 +213,7 @@ private fun matchToken68(headerValue: String, startIndex: Int): String? {
index++
}

val onlySpaceRemaining = (index until headerValue.length).all { headerValue[it] == ' ' }
if (onlySpaceRemaining) {
return headerValue.substring(startIndex until index)
}

return null
return headerValue.skipSpaces(index)
}

/**
Expand Down Expand Up @@ -355,13 +436,11 @@ private fun String.unescaped() = replace(escapeRegex) { it.value.takeLast(1) }
private fun String.skipDelimiter(startIndex: Int, delimiter: Char): Int {
var index = skipSpaces(startIndex)

while (index < length && this[index] != delimiter) {
index++
}

if (index == length) return -1
index++
if (this[index] != delimiter)
throw ParseException("Expected delimiter $delimiter at position $index, but found ${this[index]}")

index++
return skipSpaces(index)
}

Expand Down

0 comments on commit dab18c0

Please sign in to comment.