Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-5216 Parse header with multiple challenges #3277

Merged
merged 3 commits into from Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 35 additions & 0 deletions buildSrc/src/main/kotlin/test/server/tests/Auth.kt
Expand Up @@ -152,6 +152,41 @@ internal fun Application.authTestServer() {
call.respond("OK")
}
}

route("multiple") {
get("header") {
val token = call.request.headers[HttpHeaders.Authorization]

if (token.isNullOrEmpty() || token.contains("Invalid")) {
call.response.header(
HttpHeaders.WWWAuthenticate,
"Basic realm=\"TestServer\", charset=UTF-8, Digest, Bearer realm=\"my-server\""
)
call.respond(HttpStatusCode.Unauthorized)
return@get
}

call.respond("OK")
}
get("headers") {
val token = call.request.headers[HttpHeaders.Authorization]

if (token.isNullOrEmpty() || token.contains("Invalid")) {
call.response.header(
HttpHeaders.WWWAuthenticate,
"Basic realm=\"TestServer\", charset=UTF-8, Digest"
)
call.response.header(
HttpHeaders.WWWAuthenticate,
"Bearer realm=\"my-server\""
)
call.respond(HttpStatusCode.Unauthorized)
return@get
}

call.respond("OK")
}
}
}
}
}
Expand Down
Expand Up @@ -54,14 +54,26 @@ public class Auth private constructor(
val candidateProviders = HashSet(plugin.providers)

while (call.response.status == HttpStatusCode.Unauthorized) {
val headerValue = call.response.headers[HttpHeaders.WWWAuthenticate]
val headerValues = call.response.headers.getAll(HttpHeaders.WWWAuthenticate)
marychatte marked this conversation as resolved.
Show resolved Hide resolved
val authHeaders = headerValues?.map { parseAuthorizationHeaders(it) }?.flatten() ?: emptyList()

val authHeader = headerValue?.let { parseAuthorizationHeader(headerValue) }
val provider = when {
authHeader == null && candidateProviders.size == 1 -> candidateProviders.first()
authHeader == null -> return@intercept call
else -> candidateProviders.find { it.isApplicable(authHeader) } ?: return@intercept call
var providerOrNull: AuthProvider? = null
var authHeader: HttpAuthHeader? = null

when {
authHeaders.isEmpty() && candidateProviders.size == 1 -> {
providerOrNull = candidateProviders.first()
}

authHeaders.isEmpty() -> return@intercept call

else -> authHeader = authHeaders.find { header ->
providerOrNull = candidateProviders.find { it.isApplicable(header) }
providerOrNull != null
}
}
val provider = providerOrNull ?: return@intercept call

if (!provider.refreshToken(call.response)) return@intercept call

candidateProviders.remove(provider)
Expand Down
Expand Up @@ -535,4 +535,73 @@ class AuthTest : ClientLoader() {
assertEquals(2, loadCount)
}
}

@Test
fun testMultipleChallengesInHeader() = clientTests {
config {
install(Auth) {
basic {
credentials { BasicAuthCredentials("Invalid", "Invalid") }
}
bearer {
loadTokens { BearerTokens("test", "test") }
}
}
}
test { client ->
val responseOneHeader = client.get("$TEST_SERVER/auth/multiple/header").bodyAsText()
assertEquals("OK", responseOneHeader)
}
}

@Test
fun testMultipleChallengesInHeaders() = clientTests {
config {
install(Auth) {
basic {
credentials { BasicAuthCredentials("Invalid", "Invalid") }
}
bearer {
loadTokens { BearerTokens("test", "test") }
}
}
}
test { client ->
val responseMultipleHeaders = client.get("$TEST_SERVER/auth/multiple/headers").bodyAsText()
assertEquals("OK", responseMultipleHeaders)
}
}

@Test
fun testMultipleChallengesInHeaderUnauthorized() = clientTests {
test { client ->
val response = client.get("$TEST_SERVER/auth/multiple/header")
assertEquals(HttpStatusCode.Unauthorized, response.status)
response.headers[HttpHeaders.WWWAuthenticate]?.also {
assertTrue { it.contains("Bearer") }
assertTrue { it.contains("Basic") }
assertTrue { it.contains("Digest") }
} ?: run {
fail("Expected WWWAuthenticate header")
}
}
}

@Test
fun testMultipleChallengesInMultipleHeadersUnauthorized() = clientTests(listOf("Js")) {
test { client ->
val response = client.get("$TEST_SERVER/auth/multiple/headers")
assertEquals(HttpStatusCode.Unauthorized, response.status)
response.headers.getAll(HttpHeaders.WWWAuthenticate)?.let {
assertEquals(2, it.size)
it.joinToString().let { header ->
assertTrue { header.contains("Basic") }
assertTrue { header.contains("Digest") }
assertTrue { header.contains("Bearer") }
}
} ?: run {
fail("Expected WWWAuthenticate header")
}
}
}
}
1 change: 1 addition & 0 deletions ktor-http/api/ktor-http.api
Expand Up @@ -1114,6 +1114,7 @@ public final class io/ktor/http/auth/HttpAuthHeader$Single : io/ktor/http/auth/H

public final class io/ktor/http/auth/HttpAuthHeaderKt {
public static final fun parseAuthorizationHeader (Ljava/lang/String;)Lio/ktor/http/auth/HttpAuthHeader;
public static final fun parseAuthorizationHeaders (Ljava/lang/String;)Ljava/util/List;
}

public final class io/ktor/http/content/ByteArrayContent : io/ktor/http/content/OutgoingContent$ByteArrayContent {
Expand Down
141 changes: 110 additions & 31 deletions ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt
Expand Up @@ -8,7 +8,6 @@ import io.ktor.http.*
import io.ktor.http.parsing.*
import io.ktor.util.*
import io.ktor.utils.io.charsets.*
import kotlin.native.concurrent.*

private val TOKEN_EXTRA = setOf('!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~')
private val TOKEN68_EXTRA = setOf('-', '.', '_', '~', '+', '/')
Expand All @@ -19,20 +18,21 @@ private val escapeRegex: Regex = "\\\\.".toRegex()
* Parses an authorization header [headerValue] into a [HttpAuthHeader].
* @return [HttpAuthHeader] or `null` if argument string is blank.
* @throws [ParseException] on invalid header
*
* @see [parseAuthorizationHeaders]
*/
public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? {
var index = 0
index = headerValue.skipSpaces(index)

var tokenStartIndex = index
val tokenStartIndex = index
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}

// Auth scheme
val authScheme = headerValue.substring(tokenStartIndex until index)
index = headerValue.skipSpaces(index)
tokenStartIndex = index

if (authScheme.isBlank()) {
return null
Expand All @@ -42,44 +42,130 @@ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? {
return HttpAuthHeader.Parameterized(authScheme, emptyList())
}

val token68 = matchToken68(headerValue, index)
if (token68 != null) {
return HttpAuthHeader.Single(authScheme, token68)
val token68EndIndex = matchToken68(headerValue, index)
val token68 = headerValue.substring(index until token68EndIndex).trim()
if (token68.isNotEmpty()) {
if (token68EndIndex == headerValue.length) {
return HttpAuthHeader.Single(authScheme, token68)
}
}

val parameters = matchParameters(headerValue, tokenStartIndex)
return HttpAuthHeader.Parameterized(authScheme, parameters)
val parameters = mutableMapOf<String, String>()
val endIndex = matchParameters(headerValue, index, parameters)
return if (endIndex == -1) HttpAuthHeader.Parameterized(authScheme, parameters) else
throw ParseException("Function parseAuthorizationHeader can parse only one header")
}

/**
* Parses an authorization header [headerValue] into a list of [HttpAuthHeader].
* @return a list of [HttpAuthHeader]
* @throws [ParseException] on invalid header
*/
@InternalAPI
public fun parseAuthorizationHeaders(headerValue: String): List<HttpAuthHeader> {
marychatte marked this conversation as resolved.
Show resolved Hide resolved
var index = 0
val headers = mutableListOf<HttpAuthHeader>()
while (index != -1) {
index = parseAuthorizationHeader(headerValue, index, headers)
}
return headers
}

private fun matchParameters(headerValue: String, startIndex: Int): Map<String, String> {
val result = mutableMapOf<String, String>()
private fun parseAuthorizationHeader(
marychatte marked this conversation as resolved.
Show resolved Hide resolved
headerValue: String,
startIndex: Int,
headers: MutableList<HttpAuthHeader>
): Int {
var index = headerValue.skipSpaces(startIndex)

// Auth scheme
val schemeStartIndex = index
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}
val authScheme = headerValue.substring(schemeStartIndex until index)

if (authScheme.isBlank()) {
throw ParseException("Invalid authScheme value: it should be token, can't be blank")
}
index = headerValue.skipSpaces(index)

nextChallengeIndex(headers, HttpAuthHeader.Parameterized(authScheme, emptyList()), index, headerValue)?.let {
return it
}

val token68EndIndex = matchToken68(headerValue, index)
val token68 = headerValue.substring(index until token68EndIndex).trim()
if (token68.isNotEmpty()) {
nextChallengeIndex(headers, HttpAuthHeader.Single(authScheme, token68), token68EndIndex, headerValue)?.let {
return it
}
}

val parameters = mutableMapOf<String, String>()
val nextIndexChallenge = matchParameters(headerValue, index, parameters)
headers.add(HttpAuthHeader.Parameterized(authScheme, parameters))
return nextIndexChallenge
}

/**
* Check for the ending of the current challenge in a header
* @return -1 if at the end of the header
* @return null if the challenge is not ended
* @return a positive number - the index of the beginning of the next challenge
*/
private fun nextChallengeIndex(
headers: MutableList<HttpAuthHeader>,
header: HttpAuthHeader,
index: Int,
headerValue: String
): Int? {
marychatte marked this conversation as resolved.
Show resolved Hide resolved
if (index == headerValue.length || headerValue[index] == ',') {
headers.add(header)
return when {
index == headerValue.length -> -1
headerValue[index] == ',' -> index + 1
else -> error("") // unreachable code
}
}
return null
}

private fun matchParameters(headerValue: String, startIndex: Int, parameters: MutableMap<String, String>): Int {
var index = startIndex
while (index > 0 && index < headerValue.length) {
index = matchParameter(headerValue, index, result)
index = headerValue.skipDelimiter(index, ',')
val nextIndex = matchParameter(headerValue, index, parameters)
if (nextIndex == index) {
return index
} else {
index = headerValue.skipDelimiter(nextIndex, ',')
}
}

return result
return index
}

private fun matchParameter(headerValue: String, startIndex: Int, parameters: MutableMap<String, String>): Int {
private fun matchParameter(
headerValue: String,
startIndex: Int,
parameters: MutableMap<String, String>
): Int {
val keyStart = headerValue.skipSpaces(startIndex)
var index = keyStart

// Take key
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}

val key = headerValue.substring(keyStart until index)

// Take '='
// Check if new challenge
index = headerValue.skipSpaces(index)
if (index >= headerValue.length || headerValue[index] != '=') {
throw ParseException("Expected `=` after parameter key '$key': $headerValue")
if (index == headerValue.length || headerValue[index] != '=') {
return startIndex
}

// Take '='
index++
index = headerValue.skipSpaces(index)

Expand Down Expand Up @@ -116,8 +202,8 @@ private fun matchParameter(headerValue: String, startIndex: Int, parameters: Mut
return index
}

private fun matchToken68(headerValue: String, startIndex: Int): String? {
var index = startIndex
private fun matchToken68(headerValue: String, startIndex: Int): Int {
var index = headerValue.skipSpaces(startIndex)

while (index < headerValue.length && headerValue[index].isToken68()) {
index++
Expand All @@ -127,12 +213,7 @@ private fun matchToken68(headerValue: String, startIndex: Int): String? {
index++
}

val onlySpaceRemaining = (index until headerValue.length).all { headerValue[it] == ' ' }
if (onlySpaceRemaining) {
return headerValue.substring(startIndex until index)
}

return null
return headerValue.skipSpaces(index)
}

/**
Expand Down Expand Up @@ -355,13 +436,11 @@ private fun String.unescaped() = replace(escapeRegex) { it.value.takeLast(1) }
private fun String.skipDelimiter(startIndex: Int, delimiter: Char): Int {
var index = skipSpaces(startIndex)

while (index < length && this[index] != delimiter) {
index++
}

if (index == length) return -1
index++
if (this[index] != delimiter)
throw ParseException("Expected delimiter $delimiter at position $index, but found ${this[index]}")

index++
return skipSpaces(index)
}

Expand Down