Skip to content

Commit

Permalink
Parse header with multiple challenges
Browse files Browse the repository at this point in the history
  • Loading branch information
marychatte committed Nov 29, 2022
1 parent b40bc53 commit f6906de
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 43 deletions.
Expand Up @@ -54,13 +54,13 @@ public class Auth private constructor(
val candidateProviders = HashSet(plugin.providers)

while (call.response.status == HttpStatusCode.Unauthorized) {
val headerValue = call.response.headers[HttpHeaders.WWWAuthenticate]
val headerValues = call.response.headers.getAll(HttpHeaders.WWWAuthenticate)
val authHeaders = headerValues?.map { parseAuthorizationHeaders(it) }?.flatten() ?: emptyList()

val authHeader = headerValue?.let { parseAuthorizationHeader(headerValue) }
val provider = when {
authHeader == null && candidateProviders.size == 1 -> candidateProviders.first()
authHeader == null -> return@intercept call
else -> candidateProviders.find { it.isApplicable(authHeader) } ?: return@intercept call
val (provider, authHeader) = when {
authHeaders.isEmpty() && candidateProviders.size == 1 -> candidateProviders.first() to null
authHeaders.isEmpty() -> return@intercept call
else -> findProviderAndHeader(candidateProviders, authHeaders) ?: return@intercept call
}
if (!provider.refreshToken(call.response)) return@intercept call

Expand All @@ -76,6 +76,21 @@ public class Auth private constructor(
return@intercept call
}
}

private fun findProviderAndHeader(
providers: Collection<AuthProvider>,
authHeaders: List<HttpAuthHeader>
): Pair<AuthProvider, HttpAuthHeader>? {
authHeaders.forEach { header ->
providers.forEach { provider ->
if (provider.isApplicable(header)) {
return provider to header
}
}
}

return null
}
}
}

Expand Down
1 change: 1 addition & 0 deletions ktor-http/api/ktor-http.api
Expand Up @@ -1114,6 +1114,7 @@ public final class io/ktor/http/auth/HttpAuthHeader$Single : io/ktor/http/auth/H

public final class io/ktor/http/auth/HttpAuthHeaderKt {
public static final fun parseAuthorizationHeader (Ljava/lang/String;)Lio/ktor/http/auth/HttpAuthHeader;
public static final fun parseAuthorizationHeaders (Ljava/lang/String;)Ljava/util/List;
}

public final class io/ktor/http/content/ByteArrayContent : io/ktor/http/content/OutgoingContent$ByteArrayContent {
Expand Down
124 changes: 96 additions & 28 deletions ktor-http/common/src/io/ktor/http/auth/HttpAuthHeader.kt
Expand Up @@ -8,7 +8,6 @@ import io.ktor.http.*
import io.ktor.http.parsing.*
import io.ktor.util.*
import io.ktor.utils.io.charsets.*
import kotlin.native.concurrent.*

private val TOKEN_EXTRA = setOf('!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~')
private val TOKEN68_EXTRA = setOf('-', '.', '_', '~', '+', '/')
Expand All @@ -19,20 +18,21 @@ private val escapeRegex: Regex = "\\\\.".toRegex()
* Parses an authorization header [headerValue] into a [HttpAuthHeader].
* @return [HttpAuthHeader] or `null` if argument string is blank.
* @throws [ParseException] on invalid header
*
* @see [parseAuthorizationHeaders]
*/
public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? {
var index = 0
index = headerValue.skipSpaces(index)

var tokenStartIndex = index
val tokenStartIndex = index
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}

// Auth scheme
val authScheme = headerValue.substring(tokenStartIndex until index)
index = headerValue.skipSpaces(index)
tokenStartIndex = index

if (authScheme.isBlank()) {
return null
Expand All @@ -42,44 +42,104 @@ public fun parseAuthorizationHeader(headerValue: String): HttpAuthHeader? {
return HttpAuthHeader.Parameterized(authScheme, emptyList())
}

val token68 = matchToken68(headerValue, index)
val (indexAfterToken68, token68) = matchToken68(headerValue, index)
if (token68 != null) {
return HttpAuthHeader.Single(authScheme, token68)
return checkSingleHeader(indexAfterToken68, HttpAuthHeader.Single(authScheme, token68))
}

val parameters = matchParameters(headerValue, tokenStartIndex)
return HttpAuthHeader.Parameterized(authScheme, parameters)
val (endIndex, parameters) = matchParameters(headerValue, index)
return checkSingleHeader(endIndex, HttpAuthHeader.Parameterized(authScheme, parameters))
}

private fun checkSingleHeader(endIndex: Int, header: HttpAuthHeader): HttpAuthHeader {
return if (endIndex == -1) header else
throw ParseException("Function parseAuthorizationHeader can parse only one header")
}

private fun matchParameters(headerValue: String, startIndex: Int): Map<String, String> {
/**
* Parses an authorization header [headerValue] into a list of [HttpAuthHeader].
* @return a list of [HttpAuthHeader]
* @throws [ParseException] on invalid header
*/
public fun parseAuthorizationHeaders(headerValue: String): List<HttpAuthHeader> {
var index = 0
val headers = mutableListOf<HttpAuthHeader>()
while (index != -1) {
val (nextIndex, header) = parseAuthorizationHeader(headerValue, index)
headers.add(header)
index = nextIndex
}
return headers
}

private fun parseAuthorizationHeader(
headerValue: String,
startIndex: Int,
): Pair<Int, HttpAuthHeader> {
var index = headerValue.skipSpaces(startIndex)

// Auth scheme
val schemeStartIndex = index
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}
val authScheme = headerValue.substring(schemeStartIndex until index)

if (authScheme.isBlank()) {
throw ParseException("Invalid authScheme value: it should be token, can't be blank")
}

val (endChallengeIndex, isEndOfChallenge) = headerValue.isEndOfChallenge(index)
if (isEndOfChallenge) {
return endChallengeIndex to HttpAuthHeader.Parameterized(authScheme, emptyList())
}

val (nextIndex, token68) = matchToken68(headerValue, endChallengeIndex)
if (token68 != null) {
return nextIndex to HttpAuthHeader.Single(authScheme, token68)
}

val (nextIndexChallenge, parameters) = matchParameters(headerValue, index)
return nextIndexChallenge to HttpAuthHeader.Parameterized(authScheme, parameters)
}

private fun matchParameters(headerValue: String, startIndex: Int): Pair<Int, Map<String, String>> {
val result = mutableMapOf<String, String>()

var index = startIndex
while (index > 0 && index < headerValue.length) {
index = matchParameter(headerValue, index, result)
index = headerValue.skipDelimiter(index, ',')
val (nextIndex, wasParameter) = matchParameter(headerValue, index, result)
if (wasParameter) {
index = headerValue.skipDelimiter(nextIndex, ',')
} else {
return nextIndex to result
}
}

return result
return index to result
}

private fun matchParameter(headerValue: String, startIndex: Int, parameters: MutableMap<String, String>): Int {
private fun matchParameter(
headerValue: String,
startIndex: Int,
parameters: MutableMap<String, String>
): Pair<Int, Boolean> {
val keyStart = headerValue.skipSpaces(startIndex)
var index = keyStart

// Take key
while (index < headerValue.length && headerValue[index].isToken()) {
index++
}

val key = headerValue.substring(keyStart until index)

// Take '='
// Check if new challenge
index = headerValue.skipSpaces(index)
if (index >= headerValue.length || headerValue[index] != '=') {
throw ParseException("Expected `=` after parameter key '$key': $headerValue")
if (index == headerValue.length || headerValue[index] != '=') {
return keyStart to false
}

// Take '='
index++
index = headerValue.skipSpaces(index)

Expand Down Expand Up @@ -113,10 +173,10 @@ private fun matchParameter(headerValue: String, startIndex: Int, parameters: Mut
parameters[key] = if (quoted) value.unescaped() else value

if (quoted) index++
return index
return index to true
}

private fun matchToken68(headerValue: String, startIndex: Int): String? {
private fun matchToken68(headerValue: String, startIndex: Int): Pair<Int, String?> {
var index = startIndex

while (index < headerValue.length && headerValue[index].isToken68()) {
Expand All @@ -127,12 +187,14 @@ private fun matchToken68(headerValue: String, startIndex: Int): String? {
index++
}

val onlySpaceRemaining = (index until headerValue.length).all { headerValue[it] == ' ' }
if (onlySpaceRemaining) {
return headerValue.substring(startIndex until index)
}
val token68 = headerValue.substring(startIndex until index)

return null
val (endChallengeIndex, isEndOfChallenge) = headerValue.isEndOfChallenge(index)
return if (isEndOfChallenge) {
endChallengeIndex to token68
} else {
startIndex to null
}
}

/**
Expand Down Expand Up @@ -355,13 +417,11 @@ private fun String.unescaped() = replace(escapeRegex) { it.value.takeLast(1) }
private fun String.skipDelimiter(startIndex: Int, delimiter: Char): Int {
var index = skipSpaces(startIndex)

while (index < length && this[index] != delimiter) {
index++
}

if (index == length) return -1
index++
if (this[index] != delimiter)
throw ParseException("Expected delimiter $delimiter at position $index, but found ${this[index]}")

index++
return skipSpaces(index)
}

Expand All @@ -374,6 +434,14 @@ private fun String.skipSpaces(startIndex: Int): Int {
return index
}

private fun String.isEndOfChallenge(startIndex: Int): Pair<Int, Boolean> {
val index = skipSpaces(startIndex)
if (index == length) return -1 to true
if (this[index] == ',') return index + 1 to true

return index to false
}

private fun Char.isToken68(): Boolean = (this in 'a'..'z') || (this in 'A'..'Z') || isDigit() || this in TOKEN68_EXTRA

private fun Char.isToken(): Boolean = (this in 'a'..'z') || (this in 'A'..'Z') || isDigit() || this in TOKEN_EXTRA
Expand Up @@ -9,19 +9,23 @@ import kotlin.random.*
import kotlin.test.*

class AuthorizeHeaderParserTest {
@Test fun empty() {
@Test
fun empty() {
testParserParameterized("Basic", emptyMap(), "Basic")
}

@Test fun emptyWithTrailingSpaces() {
@Test
fun emptyWithTrailingSpaces() {
testParserParameterized("Basic", emptyMap(), "Basic ")
}

@Test fun singleSimple() {
@Test
fun singleSimple() {
testParserSingle("Basic", "abc==", "Basic abc==")
}

@Test fun testParameterizedSimple() {
@Test
fun testParameterizedSimple() {
testParserParameterized("Basic", mapOf("a" to "1"), "Basic a=1")
testParserParameterized("Basic", mapOf("a" to "1"), "Basic a =1")
testParserParameterized("Basic", mapOf("a" to "1"), "Basic a = 1")
Expand All @@ -30,27 +34,62 @@ class AuthorizeHeaderParserTest {
testParserParameterized("Basic", mapOf("a" to "1"), "Basic a=1 ")
}

@Test fun testParameterizedSimpleTwoParams() {
@Test
fun testParameterizedSimpleTwoParams() {
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1, b=2")
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1,b=2")
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1 ,b=2")
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1 , b=2")
testParserParameterized("Basic", mapOf("a" to "1", "b" to "2"), "Basic a=1 , b=2 ")
}

@Test fun testParameterizedQuoted() {
@Test
fun testParameterizedQuoted() {
testParserParameterized("Basic", mapOf("a" to "1 2"), "Basic a=\"1 2\"")
}

@Test fun testParameterizedQuotedEscaped() {
@Test
fun testParameterizedQuotedEscaped() {
testParserParameterized("Basic", mapOf("a" to "1 \" 2"), "Basic a=\"1 \\\" 2\"")
testParserParameterized("Basic", mapOf("a" to "1 A 2"), "Basic a=\"1 \\A 2\"")
}

@Test fun testParameterizedQuotedEscapedInTheMiddle() {
@Test
fun testParameterizedQuotedEscapedInTheMiddle() {
testParserParameterized("Basic", mapOf("a" to "1 \" 2", "b" to "2"), "Basic a=\"1 \\\" 2\", b= 2")
}

@Test
fun testMultipleChallengesParameters() {
val expected = listOf(
HttpAuthHeader.Parameterized("Digest", emptyMap()),
HttpAuthHeader.Parameterized("Bearer", mapOf("1" to "2", "3" to "4")),
HttpAuthHeader.Parameterized("Basic", emptyMap()),
)
testParserMultipleChallenges(expected, "Digest, Bearer 1 = 2, 3=4, Basic ")
}

@Test
fun testMultipleChallengesSingle() {
val expected = listOf(
HttpAuthHeader.Single("Bearer", "abc=="),
HttpAuthHeader.Parameterized("Bearer", mapOf("abc" to "def")),
HttpAuthHeader.Single("Basic", "def==="),
HttpAuthHeader.Parameterized("Digest", emptyMap())
)
testParserMultipleChallenges(expected, "Bearer abc==, Bearer abc=def, Basic def===, Digest")
}

@Test
fun testMultipleChallengesAllHeaders() {
val expected = listOf(
HttpAuthHeader.Parameterized("Basic", emptyMap()),
HttpAuthHeader.Parameterized("Bearer", mapOf("abc" to "def")),
HttpAuthHeader.Single("Digest", "abc==")
)
testParserMultipleChallenges(expected, "Basic, Bearer abc=def,Digest abc==")
}

private fun testParserSingle(scheme: String, value: String, headerValue: String) {
val actual = parseAuthorizationHeader(headerValue)!!

Expand All @@ -75,11 +114,31 @@ class AuthorizeHeaderParserTest {
}
}

private fun testParserMultipleChallenges(expected: List<HttpAuthHeader>, headerValue: String) {
val actual = parseAuthorizationHeaders(headerValue)

assertEquals(expected.size, actual.size)
(expected zip actual).forEach { (expectedHeader, actualHeader) ->
if (expectedHeader is HttpAuthHeader.Single) {
assertIs<HttpAuthHeader.Single>(actualHeader)

assertEquals(expectedHeader.blob, actualHeader.blob)
}
if (expectedHeader is HttpAuthHeader.Parameterized) {
assertIs<HttpAuthHeader.Parameterized>(actualHeader)
assertEquals(
expectedHeader.parameters.associateBy({ it.name }, { it.value }),
actualHeader.parameters.associateBy({ it.name }, { it.value })
)
}
}
}

private fun Random.nextString(
length: Int,
possible: Iterable<Char> = ('a'..'z') + ('A'..'Z') + ('0'..'9')
) = possible.toList().let { possibleElements ->
(0..length - 1).map { nextFrom(possibleElements) }.joinToString("")
(0 until length).map { nextFrom(possibleElements) }.joinToString("")
}

private fun Random.nextString(length: Int, possible: String) = nextString(length, possible.toList())
Expand Down

0 comments on commit f6906de

Please sign in to comment.