diff --git a/buildSrc/src/main/kotlin/test/server/tests/Auth.kt b/buildSrc/src/main/kotlin/test/server/tests/Auth.kt index 54be05b17b..4f8906ac54 100644 --- a/buildSrc/src/main/kotlin/test/server/tests/Auth.kt +++ b/buildSrc/src/main/kotlin/test/server/tests/Auth.kt @@ -152,6 +152,41 @@ internal fun Application.authTestServer() { call.respond("OK") } } + + route("multiple") { + get("header") { + val token = call.request.headers[HttpHeaders.Authorization] + + if (token.isNullOrEmpty() || token.contains("Invalid")) { + call.response.header( + HttpHeaders.WWWAuthenticate, + "Basic realm=\"TestServer\", charset=UTF-8, Digest, Bearer realm=\"my-server\"" + ) + call.respond(HttpStatusCode.Unauthorized) + return@get + } + + call.respond("OK") + } + get("headers") { + val token = call.request.headers[HttpHeaders.Authorization] + + if (token.isNullOrEmpty() || token.contains("Invalid")) { + call.response.header( + HttpHeaders.WWWAuthenticate, + "Basic realm=\"TestServer\", charset=UTF-8, Digest" + ) + call.response.header( + HttpHeaders.WWWAuthenticate, + "Bearer realm=\"my-server\"" + ) + call.respond(HttpStatusCode.Unauthorized) + return@get + } + + call.respond("OK") + } + } } } } diff --git a/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt b/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt index 2045624092..12aedd7235 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt +++ b/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt @@ -54,14 +54,26 @@ public class Auth private constructor( val candidateProviders = HashSet(plugin.providers) while (call.response.status == HttpStatusCode.Unauthorized) { - val headerValue = call.response.headers[HttpHeaders.WWWAuthenticate] + val headerValues = call.response.headers.getAll(HttpHeaders.WWWAuthenticate) + val authHeaders = headerValues?.map { parseAuthorizationHeaders(it) }?.flatten() ?: emptyList() - val authHeader = headerValue?.let { parseAuthorizationHeader(headerValue) } - val provider = when { - authHeader == null && candidateProviders.size == 1 -> candidateProviders.first() - authHeader == null -> return@intercept call - else -> candidateProviders.find { it.isApplicable(authHeader) } ?: return@intercept call + var providerOrNull: AuthProvider? = null + var authHeader: HttpAuthHeader? = null + + when { + authHeaders.isEmpty() && candidateProviders.size == 1 -> { + providerOrNull = candidateProviders.first() + } + + authHeaders.isEmpty() -> return@intercept call + + else -> authHeader = authHeaders.find { header -> + providerOrNull = candidateProviders.find { it.isApplicable(header) } + providerOrNull != null + } } + val provider = providerOrNull ?: return@intercept call + if (!provider.refreshToken(call.response)) return@intercept call candidateProviders.remove(provider) diff --git a/ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt b/ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt index d44fc1a74b..45aff885ea 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt +++ b/ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt @@ -535,4 +535,73 @@ class AuthTest : ClientLoader() { assertEquals(2, loadCount) } } + + @Test + fun testMultipleChallengesInHeader() = clientTests { + config { + install(Auth) { + basic { + credentials { BasicAuthCredentials("Invalid", "Invalid") } + } + bearer { + loadTokens { BearerTokens("test", "test") } + } + } + } + test { client -> + val responseOneHeader = client.get("$TEST_SERVER/auth/multiple/header").bodyAsText() + assertEquals("OK", responseOneHeader) + } + } + + @Test + fun testMultipleChallengesInHeaders() = clientTests { + config { + install(Auth) { + basic { + credentials { BasicAuthCredentials("Invalid", "Invalid") } + } + bearer { + loadTokens { BearerTokens("test", "test") } + } + } + } + test { client -> + val responseMultipleHeaders = client.get("$TEST_SERVER/auth/multiple/headers").bodyAsText() + assertEquals("OK", responseMultipleHeaders) + } + } + + @Test + fun testMultipleChallengesInHeaderUnauthorized() = clientTests { + test { client -> + val response = client.get("$TEST_SERVER/auth/multiple/header") + assertEquals(HttpStatusCode.Unauthorized, response.status) + response.headers[HttpHeaders.WWWAuthenticate]?.also { + assertTrue { it.contains("Bearer") } + assertTrue { it.contains("Basic") } + assertTrue { it.contains("Digest") } + } ?: run { + fail("Expected WWWAuthenticate header") + } + } + } + + @Test + fun testMultipleChallengesInMultipleHeadersUnauthorized() = clientTests(listOf("Js")) { + test { client -> + val response = client.get("$TEST_SERVER/auth/multiple/headers") + assertEquals(HttpStatusCode.Unauthorized, response.status) + response.headers.getAll(HttpHeaders.WWWAuthenticate)?.let { + assertEquals(2, it.size) + it.joinToString().let { header -> + assertTrue { header.contains("Basic") } + assertTrue { header.contains("Digest") } + assertTrue { header.contains("Bearer") } + } + } ?: run { + fail("Expected WWWAuthenticate header") + } + } + } } diff --git a/ktor-http/api/ktor-http.api b/ktor-http/api/ktor-http.api index c6bf689756..998d5f8675 100644 --- a/ktor-http/api/ktor-http.api +++ b/ktor-http/api/ktor-http.api @@ -1114,6 +1114,7 @@ public final class io/ktor/http/auth/HttpAuthHeader$Single : io/ktor/http/auth/H public final class io/ktor/http/auth/HttpAuthHeaderKt { public static final fun parseAuthorizationHeader (Ljava/lang/String;)Lio/ktor/http/auth/HttpAuthHeader; + public static final fun parseAuthorizationHeaders (Ljava/lang/String;)Ljava/util/List; } public final class io/ktor/http/content/ByteArrayContent : io/ktor/http/content/OutgoingContent$ByteArrayContent { diff --git a/ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt b/ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt index 68d8138a7e..0fa08090d5 100644 --- a/ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt +++ b/ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt @@ -8,7 +8,6 @@ import io.ktor.http.* import io.ktor.http.parsing.* import io.ktor.util.* import io.ktor.utils.io.charsets.* -import kotlin.native.concurrent.* private val TOKEN_EXTRA = setOf('!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~') private val TOKEN68_EXTRA = setOf('-', '.', '_', '~', '+', '/') @@ -19,12 +18,14 @@ private val escapeRegex: Regex = "\\\\.".toRegex() * Parses an authorization header [headerValue] into a [HttpAuthHeader]. * @return [HttpAuthHeader] or `null` if argument string is blank. * @throws [ParseException] on invalid header + * + * @see [parseAuthorizationHeaders] */ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? { var index = 0 index = headerValue.skipSpaces(index) - var tokenStartIndex = index + val tokenStartIndex = index while (index < headerValue.length && headerValue[index].isToken()) { index++ } @@ -32,7 +33,6 @@ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? { // Auth scheme val authScheme = headerValue.substring(tokenStartIndex until index) index = headerValue.skipSpaces(index) - tokenStartIndex = index if (authScheme.isBlank()) { return null @@ -42,28 +42,114 @@ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? { return HttpAuthHeader.Parameterized(authScheme, emptyList()) } - val token68 = matchToken68(headerValue, index) - if (token68 != null) { - return HttpAuthHeader.Single(authScheme, token68) + val token68EndIndex = matchToken68(headerValue, index) + val token68 = headerValue.substring(index until token68EndIndex).trim() + if (token68.isNotEmpty()) { + if (token68EndIndex == headerValue.length) { + return HttpAuthHeader.Single(authScheme, token68) + } } - val parameters = matchParameters(headerValue, tokenStartIndex) - return HttpAuthHeader.Parameterized(authScheme, parameters) + val parameters = mutableMapOf() + val endIndex = matchParameters(headerValue, index, parameters) + return if (endIndex == -1) HttpAuthHeader.Parameterized(authScheme, parameters) else + throw ParseException("Function parseAuthorizationHeader can parse only one header") +} + +/** + * Parses an authorization header [headerValue] into a list of [HttpAuthHeader]. + * @return a list of [HttpAuthHeader] + * @throws [ParseException] on invalid header + */ +@InternalAPI +public fun parseAuthorizationHeaders(headerValue: String): List { + var index = 0 + val headers = mutableListOf() + while (index != -1) { + index = parseAuthorizationHeader(headerValue, index, headers) + } + return headers } -private fun matchParameters(headerValue: String, startIndex: Int): Map { - val result = mutableMapOf() +private fun parseAuthorizationHeader( + headerValue: String, + startIndex: Int, + headers: MutableList +): Int { + var index = headerValue.skipSpaces(startIndex) + + // Auth scheme + val schemeStartIndex = index + while (index < headerValue.length && headerValue[index].isToken()) { + index++ + } + val authScheme = headerValue.substring(schemeStartIndex until index) + + if (authScheme.isBlank()) { + throw ParseException("Invalid authScheme value: it should be token, can't be blank") + } + index = headerValue.skipSpaces(index) + + nextChallengeIndex(headers, HttpAuthHeader.Parameterized(authScheme, emptyList()), index, headerValue)?.let { + return it + } + + val token68EndIndex = matchToken68(headerValue, index) + val token68 = headerValue.substring(index until token68EndIndex).trim() + if (token68.isNotEmpty()) { + nextChallengeIndex(headers, HttpAuthHeader.Single(authScheme, token68), token68EndIndex, headerValue)?.let { + return it + } + } + val parameters = mutableMapOf() + val nextIndexChallenge = matchParameters(headerValue, index, parameters) + headers.add(HttpAuthHeader.Parameterized(authScheme, parameters)) + return nextIndexChallenge +} + +/** + * Check for the ending of the current challenge in a header + * @return -1 if at the end of the header + * @return null if the challenge is not ended + * @return a positive number - the index of the beginning of the next challenge + */ +private fun nextChallengeIndex( + headers: MutableList, + header: HttpAuthHeader, + index: Int, + headerValue: String +): Int? { + if (index == headerValue.length || headerValue[index] == ',') { + headers.add(header) + return when { + index == headerValue.length -> -1 + headerValue[index] == ',' -> index + 1 + else -> error("") // unreachable code + } + } + return null +} + +private fun matchParameters(headerValue: String, startIndex: Int, parameters: MutableMap): Int { var index = startIndex while (index > 0 && index < headerValue.length) { - index = matchParameter(headerValue, index, result) - index = headerValue.skipDelimiter(index, ',') + val nextIndex = matchParameter(headerValue, index, parameters) + if (nextIndex == index) { + return index + } else { + index = headerValue.skipDelimiter(nextIndex, ',') + } } - return result + return index } -private fun matchParameter(headerValue: String, startIndex: Int, parameters: MutableMap): Int { +private fun matchParameter( + headerValue: String, + startIndex: Int, + parameters: MutableMap +): Int { val keyStart = headerValue.skipSpaces(startIndex) var index = keyStart @@ -71,15 +157,15 @@ private fun matchParameter(headerValue: String, startIndex: Int, parameters: Mut while (index < headerValue.length && headerValue[index].isToken()) { index++ } - val key = headerValue.substring(keyStart until index) - // Take '=' + // Check if new challenge index = headerValue.skipSpaces(index) - if (index >= headerValue.length || headerValue[index] != '=') { - throw ParseException("Expected `=` after parameter key '$key': $headerValue") + if (index == headerValue.length || headerValue[index] != '=') { + return startIndex } + // Take '=' index++ index = headerValue.skipSpaces(index) @@ -116,8 +202,8 @@ private fun matchParameter(headerValue: String, startIndex: Int, parameters: Mut return index } -private fun matchToken68(headerValue: String, startIndex: Int): String? { - var index = startIndex +private fun matchToken68(headerValue: String, startIndex: Int): Int { + var index = headerValue.skipSpaces(startIndex) while (index < headerValue.length && headerValue[index].isToken68()) { index++ @@ -127,12 +213,7 @@ private fun matchToken68(headerValue: String, startIndex: Int): String? { index++ } - val onlySpaceRemaining = (index until headerValue.length).all { headerValue[it] == ' ' } - if (onlySpaceRemaining) { - return headerValue.substring(startIndex until index) - } - - return null + return headerValue.skipSpaces(index) } /** @@ -355,13 +436,11 @@ private fun String.unescaped() = replace(escapeRegex) { it.value.takeLast(1) } private fun String.skipDelimiter(startIndex: Int, delimiter: Char): Int { var index = skipSpaces(startIndex) - while (index < length && this[index] != delimiter) { - index++ - } - if (index == length) return -1 - index++ + if (this[index] != delimiter) + throw ParseException("Expected delimiter $delimiter at position $index, but found ${this[index]}") + index++ return skipSpaces(index) } diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth/jvmAndNix/test/io/ktor/tests/auth/AuthorizeHeaderParserTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth/jvmAndNix/test/io/ktor/tests/auth/AuthorizeHeaderParserTest.kt index b5bce68873..2179514e38 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth/jvmAndNix/test/io/ktor/tests/auth/AuthorizeHeaderParserTest.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth/jvmAndNix/test/io/ktor/tests/auth/AuthorizeHeaderParserTest.kt @@ -5,23 +5,28 @@ package io.ktor.tests.auth import io.ktor.http.auth.* +import io.ktor.util.* import kotlin.random.* import kotlin.test.* class AuthorizeHeaderParserTest { - @Test fun empty() { + @Test + fun empty() { testParserParameterized("Basic", emptyMap(), "Basic") } - @Test fun emptyWithTrailingSpaces() { + @Test + fun emptyWithTrailingSpaces() { testParserParameterized("Basic", emptyMap(), "Basic ") } - @Test fun singleSimple() { + @Test + fun singleSimple() { testParserSingle("Basic", "abc==", "Basic abc==") } - @Test fun testParameterizedSimple() { + @Test + fun testParameterizedSimple() { testParserParameterized("Basic", mapOf("a" to "1"), "Basic a=1") testParserParameterized("Basic", mapOf("a" to "1"), "Basic a =1") testParserParameterized("Basic", mapOf("a" to "1"), "Basic a = 1") @@ -30,7 +35,8 @@ class AuthorizeHeaderParserTest { testParserParameterized("Basic", mapOf("a" to "1"), "Basic a=1 ") } - @Test fun testParameterizedSimpleTwoParams() { + @Test + fun testParameterizedSimpleTwoParams() { testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1, b=2") testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1,b=2") testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1 ,b=2") @@ -38,19 +44,53 @@ class AuthorizeHeaderParserTest { testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1 , b=2 ") } - @Test fun testParameterizedQuoted() { + @Test + fun testParameterizedQuoted() { testParserParameterized("Basic", mapOf("a" to "1 2"), "Basic a=\"1 2\"") } - @Test fun testParameterizedQuotedEscaped() { + @Test + fun testParameterizedQuotedEscaped() { testParserParameterized("Basic", mapOf("a" to "1 \" 2"), "Basic a=\"1 \\\" 2\"") testParserParameterized("Basic", mapOf("a" to "1 A 2"), "Basic a=\"1 \\A 2\"") } - @Test fun testParameterizedQuotedEscapedInTheMiddle() { + @Test + fun testParameterizedQuotedEscapedInTheMiddle() { testParserParameterized("Basic", mapOf("a" to "1 \" 2", "b" to "2"), "Basic a=\"1 \\\" 2\", b= 2") } + @Test + fun testMultipleChallengesParameters() { + val expected = listOf( + HttpAuthHeader.Parameterized("Digest", emptyMap()), + HttpAuthHeader.Parameterized("Bearer", mapOf("1" to "2", "3" to "4")), + HttpAuthHeader.Parameterized("Basic", emptyMap()), + ) + testParserMultipleChallenges(expected, "Digest, Bearer 1 = 2, 3=4, Basic ") + } + + @Test + fun testMultipleChallengesSingle() { + val expected = listOf( + HttpAuthHeader.Single("Bearer", "abc=="), + HttpAuthHeader.Parameterized("Bearer", mapOf("abc" to "def")), + HttpAuthHeader.Single("Basic", "def==="), + HttpAuthHeader.Parameterized("Digest", emptyMap()) + ) + testParserMultipleChallenges(expected, "Bearer abc==, Bearer abc=def, Basic def===, Digest") + } + + @Test + fun testMultipleChallengesAllHeaders() { + val expected = listOf( + HttpAuthHeader.Parameterized("Basic", emptyMap()), + HttpAuthHeader.Parameterized("Bearer", mapOf("abc" to "def")), + HttpAuthHeader.Single("Digest", "abc==") + ) + testParserMultipleChallenges(expected, "Basic, Bearer abc=def,Digest abc==") + } + private fun testParserSingle(scheme: String, value: String, headerValue: String) { val actual = parseAuthorizationHeader(headerValue)!! @@ -75,11 +115,32 @@ class AuthorizeHeaderParserTest { } } + @OptIn(InternalAPI::class) + private fun testParserMultipleChallenges(expected: List, headerValue: String) { + val actual = parseAuthorizationHeaders(headerValue) + + assertEquals(expected.size, actual.size) + (expected zip actual).forEach { (expectedHeader, actualHeader) -> + if (expectedHeader is HttpAuthHeader.Single) { + assertIs(actualHeader) + + assertEquals(expectedHeader.blob, actualHeader.blob) + } + if (expectedHeader is HttpAuthHeader.Parameterized) { + assertIs(actualHeader) + assertEquals( + expectedHeader.parameters.associateBy({ it.name }, { it.value }), + actualHeader.parameters.associateBy({ it.name }, { it.value }) + ) + } + } + } + private fun Random.nextString( length: Int, possible: Iterable = ('a'..'z') + ('A'..'Z') + ('0'..'9') ) = possible.toList().let { possibleElements -> - (0..length - 1).map { nextFrom(possibleElements) }.joinToString("") + (0 until length).map { nextFrom(possibleElements) }.joinToString("") } private fun Random.nextString(length: Int, possible: String) = nextString(length, possible.toList())