diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth/api/ktor-server-auth.api b/ktor-server/ktor-server-plugins/ktor-server-auth/api/ktor-server-auth.api index 398bce623b..ce44f63699 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth/api/ktor-server-auth.api +++ b/ktor-server/ktor-server-plugins/ktor-server-auth/api/ktor-server-auth.api @@ -358,6 +358,8 @@ public final class io/ktor/server/auth/OAuthAccessTokenResponse$OAuth1a : io/kto public final class io/ktor/server/auth/OAuthAccessTokenResponse$OAuth2 : io/ktor/server/auth/OAuthAccessTokenResponse { public fun (Ljava/lang/String;Ljava/lang/String;JLjava/lang/String;Lio/ktor/http/Parameters;)V public synthetic fun (Ljava/lang/String;Ljava/lang/String;JLjava/lang/String;Lio/ktor/http/Parameters;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Ljava/lang/String;Ljava/lang/String;JLjava/lang/String;Lio/ktor/http/Parameters;Ljava/lang/String;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;JLjava/lang/String;Lio/ktor/http/Parameters;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Ljava/lang/String; public final fun component2 ()Ljava/lang/String; public final fun component3 ()J @@ -370,6 +372,7 @@ public final class io/ktor/server/auth/OAuthAccessTokenResponse$OAuth2 : io/ktor public final fun getExpiresIn ()J public final fun getExtraParameters ()Lio/ktor/http/Parameters; public final fun getRefreshToken ()Ljava/lang/String; + public final fun getState ()Ljava/lang/String; public final fun getTokenType ()Ljava/lang/String; public fun hashCode ()I public fun toString ()Ljava/lang/String; @@ -458,8 +461,10 @@ public final class io/ktor/server/auth/OAuthServerSettings$OAuth1aServerSettings } public final class io/ktor/server/auth/OAuthServerSettings$OAuth2ServerSettings : io/ktor/server/auth/OAuthServerSettings { - public fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;)V public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function3;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function3;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLkotlin/jvm/functions/Function1;)V public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getAccessTokenInterceptor ()Lkotlin/jvm/functions/Function1; @@ -473,6 +478,7 @@ public final class io/ktor/server/auth/OAuthServerSettings$OAuth2ServerSettings public final fun getExtraAuthParameters ()Ljava/util/List; public final fun getExtraTokenParameters ()Ljava/util/List; public final fun getNonceManager ()Lio/ktor/util/NonceManager; + public final fun getOnStateCreated ()Lkotlin/jvm/functions/Function3; public final fun getPassParamsInURL ()Z public final fun getRequestMethod ()Lio/ktor/http/HttpMethod; } diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/src/io/ktor/server/auth/OAuth.kt b/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/src/io/ktor/server/auth/OAuth.kt index bdc3ab03db..7533fc6373 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/src/io/ktor/server/auth/OAuth.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/src/io/ktor/server/auth/OAuth.kt @@ -84,7 +84,8 @@ public sealed class OAuthServerSettings(public val name: String, public val vers public val passParamsInURL: Boolean = false, public val extraAuthParameters: List> = emptyList(), public val extraTokenParameters: List> = emptyList(), - public val accessTokenInterceptor: HttpRequestBuilder.() -> Unit = {} + public val accessTokenInterceptor: HttpRequestBuilder.() -> Unit = {}, + public val onStateCreated: suspend (call: ApplicationCall, state: String) -> Unit = { _, _ -> } ) : OAuthServerSettings(name, OAuthVersion.V20) { @Deprecated("This constructor will be removed", level = DeprecationLevel.HIDDEN) @@ -117,6 +118,40 @@ public sealed class OAuthServerSettings(public val name: String, public val vers emptyList(), accessTokenInterceptor ) + + @Deprecated("This constructor will be removed", level = DeprecationLevel.HIDDEN) + public constructor( + name: String, + authorizeUrl: String, + accessTokenUrl: String, + requestMethod: HttpMethod = HttpMethod.Get, + clientId: String, + clientSecret: String, + defaultScopes: List = emptyList(), + accessTokenRequiresBasicAuth: Boolean = false, + nonceManager: NonceManager = GenerateOnlyNonceManager, + authorizeUrlInterceptor: URLBuilder.() -> Unit = {}, + passParamsInURL: Boolean = false, + extraAuthParameters: List> = emptyList(), + extraTokenParameters: List> = emptyList(), + accessTokenInterceptor: HttpRequestBuilder.() -> Unit = {}, + ) : this( + name, + authorizeUrl, + accessTokenUrl, + requestMethod, + clientId, + clientSecret, + defaultScopes, + accessTokenRequiresBasicAuth, + nonceManager, + authorizeUrlInterceptor, + passParamsInURL, + extraAuthParameters, + extraTokenParameters, + accessTokenInterceptor, + { _, _ -> } + ) } } @@ -161,6 +196,7 @@ public sealed class OAuthAccessTokenResponse : Principal { * @property tokenType OAuth2 token type (usually Bearer) * @property expiresIn token expiration timestamp * @property refreshToken to be used to refresh access token after expiration + * @property state generated state used for the OAuth procedure * @property extraParameters contains additional parameters provided by the server */ public data class OAuth2( @@ -168,8 +204,23 @@ public sealed class OAuthAccessTokenResponse : Principal { val tokenType: String, val expiresIn: Long, val refreshToken: String?, - val extraParameters: Parameters = Parameters.Empty - ) : OAuthAccessTokenResponse() + val extraParameters: Parameters = Parameters.Empty, + ) : OAuthAccessTokenResponse() { + + public var state: String? = null + private set + + public constructor( + accessToken: String, + tokenType: String, + expiresIn: Long, + refreshToken: String?, + extraParameters: Parameters = Parameters.Empty, + state: String? = null + ) : this(accessToken, tokenType, expiresIn, refreshToken, extraParameters) { + this.state = state + } + } } /** @@ -214,6 +265,7 @@ public suspend fun PipelineContext.oauthRespondRedirect( call.redirectAuthenticateOAuth1a(provider, requestToken) } } + is OAuthServerSettings.OAuth2ServerSettings -> { call.redirectAuthenticateOAuth2( provider, @@ -275,6 +327,7 @@ public suspend fun PipelineContext.oauthHandleCallback( } } } + is OAuthServerSettings.OAuth2ServerSettings -> { val code = call.oauth2HandleCallback() if (code == null) { diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/src/io/ktor/server/auth/OAuth2.kt b/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/src/io/ktor/server/auth/OAuth2.kt index e38ac8d74c..b31ef11a81 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/src/io/ktor/server/auth/OAuth2.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/src/io/ktor/server/auth/OAuth2.kt @@ -265,6 +265,7 @@ private suspend fun oauth2RequestAccessToken( accessToken = contentDecoded[OAuth2ResponseParameters.AccessToken] ?: throw OAuth2Exception.MissingAccessToken(), tokenType = contentDecoded[OAuth2ResponseParameters.TokenType] ?: "", + state = state, expiresIn = contentDecoded[OAuth2ResponseParameters.ExpiresIn]?.toLong() ?: 0L, refreshToken = contentDecoded[OAuth2ResponseParameters.RefreshToken], extraParameters = contentDecoded diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/src/io/ktor/server/auth/OAuthProcedure.kt b/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/src/io/ktor/server/auth/OAuthProcedure.kt index e8e47d34d8..c1c5046988 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/src/io/ktor/server/auth/OAuthProcedure.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/src/io/ktor/server/auth/OAuthProcedure.kt @@ -85,10 +85,12 @@ internal suspend fun OAuthAuthenticationProvider.oauth2(authProviderName: String cause ?: return @Suppress("NAME_SHADOWING") context.challenge(OAuthKey, cause) { challenge, call -> + val state = provider.nonceManager.newNonce() + provider.onStateCreated(call, state) call.redirectAuthenticateOAuth2( provider, callbackRedirectUrl, - state = provider.nonceManager.newNonce(), + state = state, scopes = provider.defaultScopes, extraParameters = provider.extraAuthParameters, interceptor = provider.authorizeUrlInterceptor diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/test/io/ktor/tests/auth/OAuth2Test.kt b/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/test/io/ktor/tests/auth/OAuth2Test.kt index d9eb2682e3..b471df4d64 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/test/io/ktor/tests/auth/OAuth2Test.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth/jvm/test/io/ktor/tests/auth/OAuth2Test.kt @@ -5,8 +5,11 @@ package io.ktor.tests.auth import io.ktor.client.* +import io.ktor.client.call.* import io.ktor.client.engine.mock.* +import io.ktor.client.plugins.cookies.* import io.ktor.client.request.* +import io.ktor.client.statement.* import io.ktor.http.* import io.ktor.http.auth.* import io.ktor.http.content.* @@ -15,6 +18,7 @@ import io.ktor.server.auth.* import io.ktor.server.request.* import io.ktor.server.response.* import io.ktor.server.routing.* +import io.ktor.server.sessions.* import io.ktor.server.testing.* import io.ktor.server.testing.client.* import io.ktor.util.* @@ -148,7 +152,8 @@ class OAuth2Test { when (state) { null -> parametersOf("noState", "Had no state") else -> Parameters.Empty - } + }, + state ) } else if (grantType == OAuthGrantTypes.Password) { if (userName != "user1") { @@ -795,6 +800,78 @@ class OAuth2Test { assertEquals("Had no state", result) } + @Test + fun testApplicationState() = testApplication { + class UserSession(val token: String) + + val client = createClient { + install(HttpCookies) + } + val redirects = mutableMapOf() + externalServices { + hosts("http://oauth.com") { + routing { + post("/oauth/access_token") { + call.respondText("access_token=a_token", ContentType.Application.FormUrlEncoded) + } + get("/oauth/authorize") { + val state = call.parameters["state"]!! + call.respondText("code=code&state=$state", ContentType.Application.FormUrlEncoded) + } + } + } + } + install(Sessions) { + cookie("user_session") + } + install(Authentication) { + oauth("login") { + this@oauth.client = client + urlProvider = { "http://localhost/login" } + providerLookup = { + OAuthServerSettings.OAuth2ServerSettings( + name = "oauth2", + authorizeUrl = "http://oauth.com/oauth/authorize", + accessTokenUrl = "http://oauth.com/oauth/access_token", + clientId = "clientId1", + clientSecret = "clientSecret1", + requestMethod = HttpMethod.Post, + onStateCreated = { call, state -> + redirects[state] = call.request.queryParameters["redirectUrl"]!! + } + ) + } + } + } + routing { + authenticate("login") { + get("login") { + val state = call.principal()!!.state!! + call.sessions.set(UserSession(state)) + val redirect = redirects[state]!! + call.respondRedirect(redirect) + } + } + get("{path}") { + val session = call.sessions.get() + if (session == null) { + val redirectUrl = URLBuilder("http://localhost/login").run { + parameters.append("redirectUrl", call.request.uri) + build() + } + call.respondRedirect(redirectUrl) + return@get + } + call.respond(call.parameters["path"]!!) + } + } + val request1Auth = client.get("/some-url").body().let { parseQueryString(it) } + val code1 = request1Auth["code"]!! + val state1 = request1Auth["state"]!! + val response1 = client.get("/login?code=$code1&state=$state1") + assertEquals("some-url", response1.bodyAsText()) + } + private fun waitExecutor() { val latch = CountDownLatch(1) executor.submit {