diff --git a/buildSrc/src/main/kotlin/test/server/tests/Auth.kt b/buildSrc/src/main/kotlin/test/server/tests/Auth.kt index 54be05b17b4..4f8906ac549 100644 --- a/buildSrc/src/main/kotlin/test/server/tests/Auth.kt +++ b/buildSrc/src/main/kotlin/test/server/tests/Auth.kt @@ -152,6 +152,41 @@ internal fun Application.authTestServer() { call.respond("OK") } } + + route("multiple") { + get("header") { + val token = call.request.headers[HttpHeaders.Authorization] + + if (token.isNullOrEmpty() || token.contains("Invalid")) { + call.response.header( + HttpHeaders.WWWAuthenticate, + "Basic realm=\"TestServer\", charset=UTF-8, Digest, Bearer realm=\"my-server\"" + ) + call.respond(HttpStatusCode.Unauthorized) + return@get + } + + call.respond("OK") + } + get("headers") { + val token = call.request.headers[HttpHeaders.Authorization] + + if (token.isNullOrEmpty() || token.contains("Invalid")) { + call.response.header( + HttpHeaders.WWWAuthenticate, + "Basic realm=\"TestServer\", charset=UTF-8, Digest" + ) + call.response.header( + HttpHeaders.WWWAuthenticate, + "Bearer realm=\"my-server\"" + ) + call.respond(HttpStatusCode.Unauthorized) + return@get + } + + call.respond("OK") + } + } } } } diff --git a/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt b/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt index 7186a28e0b9..c280b431f7e 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt +++ b/ktor-client/ktor-client-plugins/ktor-client-auth/common/src/io/ktor/client/plugins/auth/Auth.kt @@ -60,7 +60,19 @@ public class Auth private constructor( val (provider, authHeader) = when { authHeaders.isEmpty() && candidateProviders.size == 1 -> candidateProviders.first() to null authHeaders.isEmpty() -> return@intercept call - else -> findProviderAndHeader(candidateProviders, authHeaders) ?: return@intercept call + else -> { + var provider: AuthProvider? = null + val header = authHeaders.find { header -> + provider = candidateProviders.find { + it.isApplicable(header) + } + provider != null + } + if (provider == null || header == null) { + return@intercept call + } + provider to header + } } if (!provider.refreshToken(call.response)) return@intercept call @@ -76,21 +88,6 @@ public class Auth private constructor( return@intercept call } } - - private fun findProviderAndHeader( - providers: Collection, - authHeaders: List - ): Pair? { - authHeaders.forEach { header -> - providers.forEach { provider -> - if (provider.isApplicable(header)) { - return provider to header - } - } - } - - return null - } } } diff --git a/ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt b/ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt index d44fc1a74b1..7d2abc2a1e2 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt +++ b/ktor-client/ktor-client-plugins/ktor-client-auth/common/test/io/ktor/client/plugins/auth/AuthTest.kt @@ -535,4 +535,57 @@ class AuthTest : ClientLoader() { assertEquals(2, loadCount) } } + + @Test + fun testMultipleChallenges() = clientTests { + config { + install(Auth) { + basic { + credentials { BasicAuthCredentials("Invalid", "Invalid") } + } + bearer { + loadTokens { BearerTokens("test", "test") } + } + } + } + test { client -> + val responseOneHeader = client.get("$TEST_SERVER/auth/multiple/header").bodyAsText() + val responseMultipleHeaders = client.get("$TEST_SERVER/auth/multiple/headers").bodyAsText() + assertEquals("OK", responseOneHeader) + assertEquals("OK", responseMultipleHeaders) + } + } + + @Test + fun testMultipleChallengesInHeader() = clientTests { + test { client -> + val response = client.get("$TEST_SERVER/auth/multiple/header") + assertEquals(HttpStatusCode.Unauthorized, response.status) + response.headers[HttpHeaders.WWWAuthenticate]?.also { + assertTrue { it.contains("Bearer") } + assertTrue { it.contains("Basic") } + assertTrue { it.contains("Digest") } + } ?: run { + fail("Expected WWWAuthenticate header") + } + } + } + + @Test + fun testMultipleChallengesInMultipleHeaders() = clientTests { + test { client -> + val response = client.get("$TEST_SERVER/auth/multiple/headers") + assertEquals(HttpStatusCode.Unauthorized, response.status) + response.headers.getAll(HttpHeaders.WWWAuthenticate)?.let { + assertEquals(it.size, 2) + it.joinToString().let { header -> + assertTrue { header.contains("Basic") } + assertTrue { header.contains("Digest") } + assertTrue { header.contains("Bearer") } + } + } ?: run { + fail("Expected WWWAuthenticate header") + } + } + } } diff --git a/ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt b/ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt index 1a3a3866767..47cbf25b625 100644 --- a/ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt +++ b/ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt @@ -42,17 +42,17 @@ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? { return HttpAuthHeader.Parameterized(authScheme, emptyList()) } - val (indexAfterToken68, token68) = matchToken68(headerValue, index) - if (token68 != null) { - return checkSingleHeader(indexAfterToken68, HttpAuthHeader.Single(authScheme, token68)) + val token68EndIndex = matchToken68(headerValue, index) + val token68 = headerValue.substring(index until token68EndIndex).trim() + if (token68.isNotEmpty()) { + if (token68EndIndex == headerValue.length) { + return HttpAuthHeader.Single(authScheme, token68) + } } - val (endIndex, parameters) = matchParameters(headerValue, index) - return checkSingleHeader(endIndex, HttpAuthHeader.Parameterized(authScheme, parameters)) -} - -private fun checkSingleHeader(endIndex: Int, header: HttpAuthHeader): HttpAuthHeader { - return if (endIndex == -1) header else + val parameters = mutableMapOf() + val endIndex = matchParameters(headerValue, index, parameters) + return if (endIndex == -1) HttpAuthHeader.Parameterized(authScheme, parameters) else throw ParseException("Function parseAuthorizationHeader can parse only one header") } @@ -65,9 +65,7 @@ public fun parseAuthorizationHeaders(headerValue: String): List var index = 0 val headers = mutableListOf() while (index != -1) { - val (nextIndex, header) = parseAuthorizationHeader(headerValue, index) - headers.add(header) - index = nextIndex + index = parseAuthorizationHeader(headerValue, index, headers) } return headers } @@ -75,7 +73,8 @@ public fun parseAuthorizationHeaders(headerValue: String): List private fun parseAuthorizationHeader( headerValue: String, startIndex: Int, -): Pair { + headers: MutableList +): Int { var index = headerValue.skipSpaces(startIndex) // Auth scheme @@ -88,42 +87,62 @@ private fun parseAuthorizationHeader( if (authScheme.isBlank()) { throw ParseException("Invalid authScheme value: it should be token, can't be blank") } + index = headerValue.skipSpaces(index) - val (endChallengeIndex, isEndOfChallenge) = headerValue.isEndOfChallenge(index) - if (isEndOfChallenge) { - return endChallengeIndex to HttpAuthHeader.Parameterized(authScheme, emptyList()) + nextChallengeIndex(headers, HttpAuthHeader.Parameterized(authScheme, emptyList()), index, headerValue)?.let { + return it } - val (nextIndex, token68) = matchToken68(headerValue, endChallengeIndex) - if (token68 != null) { - return nextIndex to HttpAuthHeader.Single(authScheme, token68) + val token68EndIndex = matchToken68(headerValue, index) + val token68 = headerValue.substring(index until token68EndIndex).trim() + if (token68.isNotEmpty()) { + nextChallengeIndex(headers, HttpAuthHeader.Single(authScheme, token68), token68EndIndex, headerValue)?.let { + return it + } } - val (nextIndexChallenge, parameters) = matchParameters(headerValue, index) - return nextIndexChallenge to HttpAuthHeader.Parameterized(authScheme, parameters) + val parameters = mutableMapOf() + val nextIndexChallenge = matchParameters(headerValue, index, parameters) + headers.add(HttpAuthHeader.Parameterized(authScheme, parameters)) + return nextIndexChallenge } -private fun matchParameters(headerValue: String, startIndex: Int): Pair> { - val result = mutableMapOf() +private fun nextChallengeIndex( + headers: MutableList, + header: HttpAuthHeader, + index: Int, + headerValue: String +): Int? { + if (index == headerValue.length || headerValue[index] == ',') { + headers.add(header) + return when { + index == headerValue.length -> -1 + headerValue[index] == ',' -> index + 1 + else -> error("") // unreachable code + } + } + return null +} +private fun matchParameters(headerValue: String, startIndex: Int, parameters: MutableMap): Int { var index = startIndex while (index > 0 && index < headerValue.length) { - val (nextIndex, wasParameter) = matchParameter(headerValue, index, result) - if (wasParameter) { - index = headerValue.skipDelimiter(nextIndex, ',') + val nextIndex = matchParameter(headerValue, index, parameters) + if (nextIndex == index) { + return index } else { - return nextIndex to result + index = headerValue.skipDelimiter(nextIndex, ',') } } - return index to result + return index } private fun matchParameter( headerValue: String, startIndex: Int, parameters: MutableMap -): Pair { +): Int { val keyStart = headerValue.skipSpaces(startIndex) var index = keyStart @@ -136,7 +155,7 @@ private fun matchParameter( // Check if new challenge index = headerValue.skipSpaces(index) if (index == headerValue.length || headerValue[index] != '=') { - return keyStart to false + return startIndex } // Take '=' @@ -173,11 +192,11 @@ private fun matchParameter( parameters[key] = if (quoted) value.unescaped() else value if (quoted) index++ - return index to true + return index } -private fun matchToken68(headerValue: String, startIndex: Int): Pair { - var index = startIndex +private fun matchToken68(headerValue: String, startIndex: Int): Int { + var index = headerValue.skipSpaces(startIndex) while (index < headerValue.length && headerValue[index].isToken68()) { index++ @@ -187,14 +206,7 @@ private fun matchToken68(headerValue: String, startIndex: Int): Pair { - val index = skipSpaces(startIndex) - if (index == length) return -1 to true - if (this[index] == ',') return index + 1 to true - - return index to false -} - private fun Char.isToken68(): Boolean = (this in 'a'..'z') || (this in 'A'..'Z') || isDigit() || this in TOKEN68_EXTRA private fun Char.isToken(): Boolean = (this in 'a'..'z') || (this in 'A'..'Z') || isDigit() || this in TOKEN_EXTRA