Skip to content

Commit

Permalink
Add tests and remove Pair type
Browse files Browse the repository at this point in the history
  • Loading branch information
marychatte committed Dec 4, 2022
1 parent f6906de commit cfd97b1
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 67 deletions.
35 changes: 35 additions & 0 deletions buildSrc/src/main/kotlin/test/server/tests/Auth.kt
Expand Up @@ -152,6 +152,41 @@ internal fun Application.authTestServer() {
call.respond("OK")
}
}

route("multiple") {
get("header") {
val token = call.request.headers[HttpHeaders.Authorization]

if (token.isNullOrEmpty() || token.contains("Invalid")) {
call.response.header(
HttpHeaders.WWWAuthenticate,
"Basic realm=\"TestServer\", charset=UTF-8, Digest, Bearer realm=\"my-server\""
)
call.respond(HttpStatusCode.Unauthorized)
return@get
}

call.respond("OK")
}
get("headers") {
val token = call.request.headers[HttpHeaders.Authorization]

if (token.isNullOrEmpty() || token.contains("Invalid")) {
call.response.header(
HttpHeaders.WWWAuthenticate,
"Basic realm=\"TestServer\", charset=UTF-8, Digest"
)
call.response.header(
HttpHeaders.WWWAuthenticate,
"Bearer realm=\"my-server\""
)
call.respond(HttpStatusCode.Unauthorized)
return@get
}

call.respond("OK")
}
}
}
}
}
Expand Down
Expand Up @@ -57,11 +57,23 @@ public class Auth private constructor(
val headerValues = call.response.headers.getAll(HttpHeaders.WWWAuthenticate)
val authHeaders = headerValues?.map { parseAuthorizationHeaders(it) }?.flatten() ?: emptyList()

val (provider, authHeader) = when {
authHeaders.isEmpty() && candidateProviders.size == 1 -> candidateProviders.first() to null
var providerOrNull: AuthProvider? = null
var authHeader: HttpAuthHeader? = null

when {
authHeaders.isEmpty() && candidateProviders.size == 1 -> {
providerOrNull = candidateProviders.first()
}

authHeaders.isEmpty() -> return@intercept call
else -> findProviderAndHeader(candidateProviders, authHeaders) ?: return@intercept call

else -> authHeader = authHeaders.find { header ->
providerOrNull = candidateProviders.find { it.isApplicable(header) }
providerOrNull != null
}
}
val provider = providerOrNull ?: return@intercept call

if (!provider.refreshToken(call.response)) return@intercept call

candidateProviders.remove(provider)
Expand All @@ -76,21 +88,6 @@ public class Auth private constructor(
return@intercept call
}
}

private fun findProviderAndHeader(
providers: Collection<AuthProvider>,
authHeaders: List<HttpAuthHeader>
): Pair<AuthProvider, HttpAuthHeader>? {
authHeaders.forEach { header ->
providers.forEach { provider ->
if (provider.isApplicable(header)) {
return provider to header
}
}
}

return null
}
}
}

Expand Down
Expand Up @@ -535,4 +535,73 @@ class AuthTest : ClientLoader() {
assertEquals(2, loadCount)
}
}

@Test
fun testMultipleChallengesInHeader() = clientTests {
config {
install(Auth) {
basic {
credentials { BasicAuthCredentials("Invalid", "Invalid") }
}
bearer {
loadTokens { BearerTokens("test", "test") }
}
}
}
test { client ->
val responseOneHeader = client.get("$TEST_SERVER/auth/multiple/header").bodyAsText()
assertEquals("OK", responseOneHeader)
}
}

@Test
fun testMultipleChallengesInHeaders() = clientTests {
config {
install(Auth) {
basic {
credentials { BasicAuthCredentials("Invalid", "Invalid") }
}
bearer {
loadTokens { BearerTokens("test", "test") }
}
}
}
test { client ->
val responseMultipleHeaders = client.get("$TEST_SERVER/auth/multiple/headers").bodyAsText()
assertEquals("OK", responseMultipleHeaders)
}
}

@Test
fun testMultipleChallengesInHeaderUnauthorized() = clientTests {
test { client ->
val response = client.get("$TEST_SERVER/auth/multiple/header")
assertEquals(HttpStatusCode.Unauthorized, response.status)
response.headers[HttpHeaders.WWWAuthenticate]?.also {
assertTrue { it.contains("Bearer") }
assertTrue { it.contains("Basic") }
assertTrue { it.contains("Digest") }
} ?: run {
fail("Expected WWWAuthenticate header")
}
}
}

@Test
fun testMultipleChallengesInMultipleHeadersUnauthorized() = clientTests(listOf("Js")) {
test { client ->
val response = client.get("$TEST_SERVER/auth/multiple/headers")
assertEquals(HttpStatusCode.Unauthorized, response.status)
response.headers.getAll(HttpHeaders.WWWAuthenticate)?.let {
assertEquals(2, it.size)
it.joinToString().let { header ->
assertTrue { header.contains("Basic") }
assertTrue { header.contains("Digest") }
assertTrue { header.contains("Bearer") }
}
} ?: run {
fail("Expected WWWAuthenticate header")
}
}
}
}
102 changes: 53 additions & 49 deletions ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt
Expand Up @@ -42,17 +42,17 @@ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? {
return HttpAuthHeader.Parameterized(authScheme, emptyList())
}

val (indexAfterToken68, token68) = matchToken68(headerValue, index)
if (token68 != null) {
return checkSingleHeader(indexAfterToken68, HttpAuthHeader.Single(authScheme, token68))
val token68EndIndex = matchToken68(headerValue, index)
val token68 = headerValue.substring(index until token68EndIndex).trim()
if (token68.isNotEmpty()) {
if (token68EndIndex == headerValue.length) {
return HttpAuthHeader.Single(authScheme, token68)
}
}

val (endIndex, parameters) = matchParameters(headerValue, index)
return checkSingleHeader(endIndex, HttpAuthHeader.Parameterized(authScheme, parameters))
}

private fun checkSingleHeader(endIndex: Int, header: HttpAuthHeader): HttpAuthHeader {
return if (endIndex == -1) header else
val parameters = mutableMapOf<String, String>()
val endIndex = matchParameters(headerValue, index, parameters)
return if (endIndex == -1) HttpAuthHeader.Parameterized(authScheme, parameters) else
throw ParseException("Function parseAuthorizationHeader can parse only one header")
}

Expand All @@ -65,17 +65,16 @@ public fun parseAuthorizationHeaders(headerValue: String): List<HttpAuthHeader>
var index = 0
val headers = mutableListOf<HttpAuthHeader>()
while (index != -1) {
val (nextIndex, header) = parseAuthorizationHeader(headerValue, index)
headers.add(header)
index = nextIndex
index = parseAuthorizationHeader(headerValue, index, headers)
}
return headers
}

private fun parseAuthorizationHeader(
headerValue: String,
startIndex: Int,
): Pair<Int, HttpAuthHeader> {
headers: MutableList<HttpAuthHeader>
): Int {
var index = headerValue.skipSpaces(startIndex)

// Auth scheme
Expand All @@ -88,42 +87,62 @@ private fun parseAuthorizationHeader(
if (authScheme.isBlank()) {
throw ParseException("Invalid authScheme value: it should be token, can't be blank")
}
index = headerValue.skipSpaces(index)

val (endChallengeIndex, isEndOfChallenge) = headerValue.isEndOfChallenge(index)
if (isEndOfChallenge) {
return endChallengeIndex to HttpAuthHeader.Parameterized(authScheme, emptyList())
nextChallengeIndex(headers, HttpAuthHeader.Parameterized(authScheme, emptyList()), index, headerValue)?.let {
return it
}

val (nextIndex, token68) = matchToken68(headerValue, endChallengeIndex)
if (token68 != null) {
return nextIndex to HttpAuthHeader.Single(authScheme, token68)
val token68EndIndex = matchToken68(headerValue, index)
val token68 = headerValue.substring(index until token68EndIndex).trim()
if (token68.isNotEmpty()) {
nextChallengeIndex(headers, HttpAuthHeader.Single(authScheme, token68), token68EndIndex, headerValue)?.let {
return it
}
}

val (nextIndexChallenge, parameters) = matchParameters(headerValue, index)
return nextIndexChallenge to HttpAuthHeader.Parameterized(authScheme, parameters)
val parameters = mutableMapOf<String, String>()
val nextIndexChallenge = matchParameters(headerValue, index, parameters)
headers.add(HttpAuthHeader.Parameterized(authScheme, parameters))
return nextIndexChallenge
}

private fun matchParameters(headerValue: String, startIndex: Int): Pair<Int, Map<String, String>> {
val result = mutableMapOf<String, String>()
private fun nextChallengeIndex(
headers: MutableList<HttpAuthHeader>,
header: HttpAuthHeader,
index: Int,
headerValue: String
): Int? {
if (index == headerValue.length || headerValue[index] == ',') {
headers.add(header)
return when {
index == headerValue.length -> -1
headerValue[index] == ',' -> index + 1
else -> error("") // unreachable code
}
}
return null
}

private fun matchParameters(headerValue: String, startIndex: Int, parameters: MutableMap<String, String>): Int {
var index = startIndex
while (index > 0 && index < headerValue.length) {
val (nextIndex, wasParameter) = matchParameter(headerValue, index, result)
if (wasParameter) {
index = headerValue.skipDelimiter(nextIndex, ',')
val nextIndex = matchParameter(headerValue, index, parameters)
if (nextIndex == index) {
return index
} else {
return nextIndex to result
index = headerValue.skipDelimiter(nextIndex, ',')
}
}

return index to result
return index
}

private fun matchParameter(
headerValue: String,
startIndex: Int,
parameters: MutableMap<String, String>
): Pair<Int, Boolean> {
): Int {
val keyStart = headerValue.skipSpaces(startIndex)
var index = keyStart

Expand All @@ -136,7 +155,7 @@ private fun matchParameter(
// Check if new challenge
index = headerValue.skipSpaces(index)
if (index == headerValue.length || headerValue[index] != '=') {
return keyStart to false
return startIndex
}

// Take '='
Expand Down Expand Up @@ -173,11 +192,11 @@ private fun matchParameter(
parameters[key] = if (quoted) value.unescaped() else value

if (quoted) index++
return index to true
return index
}

private fun matchToken68(headerValue: String, startIndex: Int): Pair<Int, String?> {
var index = startIndex
private fun matchToken68(headerValue: String, startIndex: Int): Int {
var index = headerValue.skipSpaces(startIndex)

while (index < headerValue.length && headerValue[index].isToken68()) {
index++
Expand All @@ -187,14 +206,7 @@ private fun matchToken68(headerValue: String, startIndex: Int): Pair<Int, String
index++
}

val token68 = headerValue.substring(startIndex until index)

val (endChallengeIndex, isEndOfChallenge) = headerValue.isEndOfChallenge(index)
return if (isEndOfChallenge) {
endChallengeIndex to token68
} else {
startIndex to null
}
return headerValue.skipSpaces(index)
}

/**
Expand Down Expand Up @@ -434,14 +446,6 @@ private fun String.skipSpaces(startIndex: Int): Int {
return index
}

private fun String.isEndOfChallenge(startIndex: Int): Pair<Int, Boolean> {
val index = skipSpaces(startIndex)
if (index == length) return -1 to true
if (this[index] == ',') return index + 1 to true

return index to false
}

private fun Char.isToken68(): Boolean = (this in 'a'..'z') || (this in 'A'..'Z') || isDigit() || this in TOKEN68_EXTRA

private fun Char.isToken(): Boolean = (this in 'a'..'z') || (this in 'A'..'Z') || isDigit() || this in TOKEN_EXTRA

0 comments on commit cfd97b1

Please sign in to comment.