Skip to content

Commit

Permalink
KTOR-5225 Add callbacks to save application state for OAuth2 (#3282)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsinukov committed Dec 2, 2022
1 parent d5d7872 commit 9e12c15
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 6 deletions.
Expand Up @@ -358,6 +358,8 @@ public final class io/ktor/server/auth/OAuthAccessTokenResponse$OAuth1a : io/kto
public final class io/ktor/server/auth/OAuthAccessTokenResponse$OAuth2 : io/ktor/server/auth/OAuthAccessTokenResponse {
public fun <init> (Ljava/lang/String;Ljava/lang/String;JLjava/lang/String;Lio/ktor/http/Parameters;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;JLjava/lang/String;Lio/ktor/http/Parameters;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (Ljava/lang/String;Ljava/lang/String;JLjava/lang/String;Lio/ktor/http/Parameters;Ljava/lang/String;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;JLjava/lang/String;Lio/ktor/http/Parameters;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun component1 ()Ljava/lang/String;
public final fun component2 ()Ljava/lang/String;
public final fun component3 ()J
Expand All @@ -370,6 +372,7 @@ public final class io/ktor/server/auth/OAuthAccessTokenResponse$OAuth2 : io/ktor
public final fun getExpiresIn ()J
public final fun getExtraParameters ()Lio/ktor/http/Parameters;
public final fun getRefreshToken ()Ljava/lang/String;
public final fun getState ()Ljava/lang/String;
public final fun getTokenType ()Ljava/lang/String;
public fun hashCode ()I
public fun toString ()Ljava/lang/String;
Expand Down Expand Up @@ -458,8 +461,10 @@ public final class io/ktor/server/auth/OAuthServerSettings$OAuth1aServerSettings
}

public final class io/ktor/server/auth/OAuthServerSettings$OAuth2ServerSettings : io/ktor/server/auth/OAuthServerSettings {
public fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function3;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function3;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLkotlin/jvm/functions/Function1;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun getAccessTokenInterceptor ()Lkotlin/jvm/functions/Function1;
Expand All @@ -473,6 +478,7 @@ public final class io/ktor/server/auth/OAuthServerSettings$OAuth2ServerSettings
public final fun getExtraAuthParameters ()Ljava/util/List;
public final fun getExtraTokenParameters ()Ljava/util/List;
public final fun getNonceManager ()Lio/ktor/util/NonceManager;
public final fun getOnStateCreated ()Lkotlin/jvm/functions/Function3;
public final fun getPassParamsInURL ()Z
public final fun getRequestMethod ()Lio/ktor/http/HttpMethod;
}
Expand Down
Expand Up @@ -84,7 +84,8 @@ public sealed class OAuthServerSettings(public val name: String, public val vers
public val passParamsInURL: Boolean = false,
public val extraAuthParameters: List<Pair<String, String>> = emptyList(),
public val extraTokenParameters: List<Pair<String, String>> = emptyList(),
public val accessTokenInterceptor: HttpRequestBuilder.() -> Unit = {}
public val accessTokenInterceptor: HttpRequestBuilder.() -> Unit = {},
public val onStateCreated: suspend (call: ApplicationCall, state: String) -> Unit = { _, _ -> }
) : OAuthServerSettings(name, OAuthVersion.V20) {

@Deprecated("This constructor will be removed", level = DeprecationLevel.HIDDEN)
Expand Down Expand Up @@ -117,6 +118,40 @@ public sealed class OAuthServerSettings(public val name: String, public val vers
emptyList(),
accessTokenInterceptor
)

@Deprecated("This constructor will be removed", level = DeprecationLevel.HIDDEN)
public constructor(
name: String,
authorizeUrl: String,
accessTokenUrl: String,
requestMethod: HttpMethod = HttpMethod.Get,
clientId: String,
clientSecret: String,
defaultScopes: List<String> = emptyList(),
accessTokenRequiresBasicAuth: Boolean = false,
nonceManager: NonceManager = GenerateOnlyNonceManager,
authorizeUrlInterceptor: URLBuilder.() -> Unit = {},
passParamsInURL: Boolean = false,
extraAuthParameters: List<Pair<String, String>> = emptyList(),
extraTokenParameters: List<Pair<String, String>> = emptyList(),
accessTokenInterceptor: HttpRequestBuilder.() -> Unit = {},
) : this(
name,
authorizeUrl,
accessTokenUrl,
requestMethod,
clientId,
clientSecret,
defaultScopes,
accessTokenRequiresBasicAuth,
nonceManager,
authorizeUrlInterceptor,
passParamsInURL,
extraAuthParameters,
extraTokenParameters,
accessTokenInterceptor,
{ _, _ -> }
)
}
}

Expand Down Expand Up @@ -161,15 +196,31 @@ public sealed class OAuthAccessTokenResponse : Principal {
* @property tokenType OAuth2 token type (usually Bearer)
* @property expiresIn token expiration timestamp
* @property refreshToken to be used to refresh access token after expiration
* @property state generated state used for the OAuth procedure
* @property extraParameters contains additional parameters provided by the server
*/
public data class OAuth2(
val accessToken: String,
val tokenType: String,
val expiresIn: Long,
val refreshToken: String?,
val extraParameters: Parameters = Parameters.Empty
) : OAuthAccessTokenResponse()
val extraParameters: Parameters = Parameters.Empty,
) : OAuthAccessTokenResponse() {

public var state: String? = null
private set

public constructor(
accessToken: String,
tokenType: String,
expiresIn: Long,
refreshToken: String?,
extraParameters: Parameters = Parameters.Empty,
state: String? = null
) : this(accessToken, tokenType, expiresIn, refreshToken, extraParameters) {
this.state = state
}
}
}

/**
Expand Down Expand Up @@ -214,6 +265,7 @@ public suspend fun PipelineContext<Unit, ApplicationCall>.oauthRespondRedirect(
call.redirectAuthenticateOAuth1a(provider, requestToken)
}
}

is OAuthServerSettings.OAuth2ServerSettings -> {
call.redirectAuthenticateOAuth2(
provider,
Expand Down Expand Up @@ -275,6 +327,7 @@ public suspend fun PipelineContext<Unit, ApplicationCall>.oauthHandleCallback(
}
}
}

is OAuthServerSettings.OAuth2ServerSettings -> {
val code = call.oauth2HandleCallback()
if (code == null) {
Expand Down
Expand Up @@ -265,6 +265,7 @@ private suspend fun oauth2RequestAccessToken(
accessToken = contentDecoded[OAuth2ResponseParameters.AccessToken]
?: throw OAuth2Exception.MissingAccessToken(),
tokenType = contentDecoded[OAuth2ResponseParameters.TokenType] ?: "",
state = state,
expiresIn = contentDecoded[OAuth2ResponseParameters.ExpiresIn]?.toLong() ?: 0L,
refreshToken = contentDecoded[OAuth2ResponseParameters.RefreshToken],
extraParameters = contentDecoded
Expand Down
Expand Up @@ -85,10 +85,12 @@ internal suspend fun OAuthAuthenticationProvider.oauth2(authProviderName: String
cause ?: return
@Suppress("NAME_SHADOWING")
context.challenge(OAuthKey, cause) { challenge, call ->
val state = provider.nonceManager.newNonce()
provider.onStateCreated(call, state)
call.redirectAuthenticateOAuth2(
provider,
callbackRedirectUrl,
state = provider.nonceManager.newNonce(),
state = state,
scopes = provider.defaultScopes,
extraParameters = provider.extraAuthParameters,
interceptor = provider.authorizeUrlInterceptor
Expand Down
Expand Up @@ -5,8 +5,11 @@
package io.ktor.tests.auth

import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.engine.mock.*
import io.ktor.client.plugins.cookies.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.auth.*
import io.ktor.http.content.*
Expand All @@ -15,6 +18,7 @@ import io.ktor.server.auth.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.sessions.*
import io.ktor.server.testing.*
import io.ktor.server.testing.client.*
import io.ktor.util.*
Expand Down Expand Up @@ -148,7 +152,8 @@ class OAuth2Test {
when (state) {
null -> parametersOf("noState", "Had no state")
else -> Parameters.Empty
}
},
state
)
} else if (grantType == OAuthGrantTypes.Password) {
if (userName != "user1") {
Expand Down Expand Up @@ -795,6 +800,78 @@ class OAuth2Test {
assertEquals("Had no state", result)
}

@Test
fun testApplicationState() = testApplication {
class UserSession(val token: String)

val client = createClient {
install(HttpCookies)
}
val redirects = mutableMapOf<String, String>()
externalServices {
hosts("http://oauth.com") {
routing {
post("/oauth/access_token") {
call.respondText("access_token=a_token", ContentType.Application.FormUrlEncoded)
}
get("/oauth/authorize") {
val state = call.parameters["state"]!!
call.respondText("code=code&state=$state", ContentType.Application.FormUrlEncoded)
}
}
}
}
install(Sessions) {
cookie<UserSession>("user_session")
}
install(Authentication) {
oauth("login") {
this@oauth.client = client
urlProvider = { "http://localhost/login" }
providerLookup = {
OAuthServerSettings.OAuth2ServerSettings(
name = "oauth2",
authorizeUrl = "http://oauth.com/oauth/authorize",
accessTokenUrl = "http://oauth.com/oauth/access_token",
clientId = "clientId1",
clientSecret = "clientSecret1",
requestMethod = HttpMethod.Post,
onStateCreated = { call, state ->
redirects[state] = call.request.queryParameters["redirectUrl"]!!
}
)
}
}
}
routing {
authenticate("login") {
get("login") {
val state = call.principal<OAuthAccessTokenResponse.OAuth2>()!!.state!!
call.sessions.set(UserSession(state))
val redirect = redirects[state]!!
call.respondRedirect(redirect)
}
}
get("{path}") {
val session = call.sessions.get<UserSession>()
if (session == null) {
val redirectUrl = URLBuilder("http://localhost/login").run {
parameters.append("redirectUrl", call.request.uri)
build()
}
call.respondRedirect(redirectUrl)
return@get
}
call.respond(call.parameters["path"]!!)
}
}
val request1Auth = client.get("/some-url").body<String>().let { parseQueryString(it) }
val code1 = request1Auth["code"]!!
val state1 = request1Auth["state"]!!
val response1 = client.get("/login?code=$code1&state=$state1")
assertEquals("some-url", response1.bodyAsText())
}

private fun waitExecutor() {
val latch = CountDownLatch(1)
executor.submit {
Expand Down

0 comments on commit 9e12c15

Please sign in to comment.