Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-5225 Add callbacks to save application state for OAuth2 #3282

Merged
merged 1 commit into from Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -358,6 +358,8 @@ public final class io/ktor/server/auth/OAuthAccessTokenResponse$OAuth1a : io/kto
public final class io/ktor/server/auth/OAuthAccessTokenResponse$OAuth2 : io/ktor/server/auth/OAuthAccessTokenResponse {
public fun <init> (Ljava/lang/String;Ljava/lang/String;JLjava/lang/String;Lio/ktor/http/Parameters;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;JLjava/lang/String;Lio/ktor/http/Parameters;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (Ljava/lang/String;Ljava/lang/String;JLjava/lang/String;Lio/ktor/http/Parameters;Ljava/lang/String;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;JLjava/lang/String;Lio/ktor/http/Parameters;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun component1 ()Ljava/lang/String;
public final fun component2 ()Ljava/lang/String;
public final fun component3 ()J
Expand All @@ -370,6 +372,7 @@ public final class io/ktor/server/auth/OAuthAccessTokenResponse$OAuth2 : io/ktor
public final fun getExpiresIn ()J
public final fun getExtraParameters ()Lio/ktor/http/Parameters;
public final fun getRefreshToken ()Ljava/lang/String;
public final fun getState ()Ljava/lang/String;
public final fun getTokenType ()Ljava/lang/String;
public fun hashCode ()I
public fun toString ()Ljava/lang/String;
Expand Down Expand Up @@ -458,8 +461,10 @@ public final class io/ktor/server/auth/OAuthServerSettings$OAuth1aServerSettings
}

public final class io/ktor/server/auth/OAuthServerSettings$OAuth2ServerSettings : io/ktor/server/auth/OAuthServerSettings {
public fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function3;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function3;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLkotlin/jvm/functions/Function1;)V
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Lio/ktor/http/HttpMethod;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;ZLio/ktor/util/NonceManager;Lkotlin/jvm/functions/Function1;ZLkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun getAccessTokenInterceptor ()Lkotlin/jvm/functions/Function1;
Expand All @@ -473,6 +478,7 @@ public final class io/ktor/server/auth/OAuthServerSettings$OAuth2ServerSettings
public final fun getExtraAuthParameters ()Ljava/util/List;
public final fun getExtraTokenParameters ()Ljava/util/List;
public final fun getNonceManager ()Lio/ktor/util/NonceManager;
public final fun getOnStateCreated ()Lkotlin/jvm/functions/Function3;
public final fun getPassParamsInURL ()Z
public final fun getRequestMethod ()Lio/ktor/http/HttpMethod;
}
Expand Down
Expand Up @@ -84,7 +84,8 @@ public sealed class OAuthServerSettings(public val name: String, public val vers
public val passParamsInURL: Boolean = false,
public val extraAuthParameters: List<Pair<String, String>> = emptyList(),
public val extraTokenParameters: List<Pair<String, String>> = emptyList(),
public val accessTokenInterceptor: HttpRequestBuilder.() -> Unit = {}
public val accessTokenInterceptor: HttpRequestBuilder.() -> Unit = {},
public val onStateCreated: suspend (call: ApplicationCall, state: String) -> Unit = { _, _ -> }
) : OAuthServerSettings(name, OAuthVersion.V20) {

@Deprecated("This constructor will be removed", level = DeprecationLevel.HIDDEN)
Expand Down Expand Up @@ -117,6 +118,40 @@ public sealed class OAuthServerSettings(public val name: String, public val vers
emptyList(),
accessTokenInterceptor
)

@Deprecated("This constructor will be removed", level = DeprecationLevel.HIDDEN)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: "This constructor will be removed" -> "Binary compatibility with 2.x"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, deprecation messages should be for users explaining why this is deprecated, not for us explaining why it is kept.

Copy link
Contributor

@whyoleg whyoleg Dec 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, declarations with level=HIDDEN will be not visible to user, until they open source code. They even should not feel this change at all, as new parameter has default value, but ok.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but for consistency, I prefer it this way.

public constructor(
name: String,
authorizeUrl: String,
accessTokenUrl: String,
requestMethod: HttpMethod = HttpMethod.Get,
clientId: String,
clientSecret: String,
defaultScopes: List<String> = emptyList(),
accessTokenRequiresBasicAuth: Boolean = false,
nonceManager: NonceManager = GenerateOnlyNonceManager,
authorizeUrlInterceptor: URLBuilder.() -> Unit = {},
passParamsInURL: Boolean = false,
extraAuthParameters: List<Pair<String, String>> = emptyList(),
extraTokenParameters: List<Pair<String, String>> = emptyList(),
accessTokenInterceptor: HttpRequestBuilder.() -> Unit = {},
) : this(
name,
authorizeUrl,
accessTokenUrl,
requestMethod,
clientId,
clientSecret,
defaultScopes,
accessTokenRequiresBasicAuth,
nonceManager,
authorizeUrlInterceptor,
passParamsInURL,
extraAuthParameters,
extraTokenParameters,
accessTokenInterceptor,
{ _, _ -> }
)
}
}

Expand Down Expand Up @@ -161,15 +196,31 @@ public sealed class OAuthAccessTokenResponse : Principal {
* @property tokenType OAuth2 token type (usually Bearer)
* @property expiresIn token expiration timestamp
* @property refreshToken to be used to refresh access token after expiration
* @property state generated state used for the OAuth procedure
* @property extraParameters contains additional parameters provided by the server
*/
public data class OAuth2(
val accessToken: String,
val tokenType: String,
val expiresIn: Long,
val refreshToken: String?,
val extraParameters: Parameters = Parameters.Empty
) : OAuthAccessTokenResponse()
val extraParameters: Parameters = Parameters.Empty,
) : OAuthAccessTokenResponse() {

public var state: String? = null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a question: is it expected, that while OAuth2 is data class this new property will not appear in copy/componentN/equals/hashCode?

Copy link
Contributor Author

@rsinukov rsinukov Dec 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's really difficult to add it without breaking binary compatibility. In practice, I don't expect users to do copy/componentN on principals. And for equals and hashcode, having accessToken there should be enough.

private set

public constructor(
accessToken: String,
tokenType: String,
expiresIn: Long,
refreshToken: String?,
extraParameters: Parameters = Parameters.Empty,
state: String? = null
) : this(accessToken, tokenType, expiresIn, refreshToken, extraParameters) {
this.state = state
}
}
}

/**
Expand Down Expand Up @@ -214,6 +265,7 @@ public suspend fun PipelineContext<Unit, ApplicationCall>.oauthRespondRedirect(
call.redirectAuthenticateOAuth1a(provider, requestToken)
}
}

is OAuthServerSettings.OAuth2ServerSettings -> {
call.redirectAuthenticateOAuth2(
provider,
Expand Down Expand Up @@ -275,6 +327,7 @@ public suspend fun PipelineContext<Unit, ApplicationCall>.oauthHandleCallback(
}
}
}

is OAuthServerSettings.OAuth2ServerSettings -> {
val code = call.oauth2HandleCallback()
if (code == null) {
Expand Down
Expand Up @@ -265,6 +265,7 @@ private suspend fun oauth2RequestAccessToken(
accessToken = contentDecoded[OAuth2ResponseParameters.AccessToken]
?: throw OAuth2Exception.MissingAccessToken(),
tokenType = contentDecoded[OAuth2ResponseParameters.TokenType] ?: "",
state = state,
expiresIn = contentDecoded[OAuth2ResponseParameters.ExpiresIn]?.toLong() ?: 0L,
refreshToken = contentDecoded[OAuth2ResponseParameters.RefreshToken],
extraParameters = contentDecoded
Expand Down
Expand Up @@ -85,10 +85,12 @@ internal suspend fun OAuthAuthenticationProvider.oauth2(authProviderName: String
cause ?: return
@Suppress("NAME_SHADOWING")
context.challenge(OAuthKey, cause) { challenge, call ->
val state = provider.nonceManager.newNonce()
provider.onStateCreated(call, state)
call.redirectAuthenticateOAuth2(
provider,
callbackRedirectUrl,
state = provider.nonceManager.newNonce(),
state = state,
scopes = provider.defaultScopes,
extraParameters = provider.extraAuthParameters,
interceptor = provider.authorizeUrlInterceptor
Expand Down
Expand Up @@ -5,8 +5,11 @@
package io.ktor.tests.auth

import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.engine.mock.*
import io.ktor.client.plugins.cookies.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.http.auth.*
import io.ktor.http.content.*
Expand All @@ -15,6 +18,7 @@ import io.ktor.server.auth.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.sessions.*
import io.ktor.server.testing.*
import io.ktor.server.testing.client.*
import io.ktor.util.*
Expand Down Expand Up @@ -148,7 +152,8 @@ class OAuth2Test {
when (state) {
null -> parametersOf("noState", "Had no state")
else -> Parameters.Empty
}
},
state
)
} else if (grantType == OAuthGrantTypes.Password) {
if (userName != "user1") {
Expand Down Expand Up @@ -795,6 +800,78 @@ class OAuth2Test {
assertEquals("Had no state", result)
}

@Test
fun testApplicationState() = testApplication {
class UserSession(val token: String)

val client = createClient {
install(HttpCookies)
}
val redirects = mutableMapOf<String, String>()
externalServices {
hosts("http://oauth.com") {
routing {
post("/oauth/access_token") {
call.respondText("access_token=a_token", ContentType.Application.FormUrlEncoded)
}
get("/oauth/authorize") {
val state = call.parameters["state"]!!
call.respondText("code=code&state=$state", ContentType.Application.FormUrlEncoded)
}
}
}
}
install(Sessions) {
cookie<UserSession>("user_session")
}
install(Authentication) {
oauth("login") {
this@oauth.client = client
urlProvider = { "http://localhost/login" }
providerLookup = {
OAuthServerSettings.OAuth2ServerSettings(
name = "oauth2",
authorizeUrl = "http://oauth.com/oauth/authorize",
accessTokenUrl = "http://oauth.com/oauth/access_token",
clientId = "clientId1",
clientSecret = "clientSecret1",
requestMethod = HttpMethod.Post,
onStateCreated = { call, state ->
redirects[state] = call.request.queryParameters["redirectUrl"]!!
}
)
}
}
}
routing {
authenticate("login") {
get("login") {
val state = call.principal<OAuthAccessTokenResponse.OAuth2>()!!.state!!
call.sessions.set(UserSession(state))
val redirect = redirects[state]!!
call.respondRedirect(redirect)
}
}
get("{path}") {
val session = call.sessions.get<UserSession>()
if (session == null) {
val redirectUrl = URLBuilder("http://localhost/login").run {
parameters.append("redirectUrl", call.request.uri)
build()
}
call.respondRedirect(redirectUrl)
return@get
}
call.respond(call.parameters["path"]!!)
}
}
val request1Auth = client.get("/some-url").body<String>().let { parseQueryString(it) }
val code1 = request1Auth["code"]!!
val state1 = request1Auth["state"]!!
val response1 = client.get("/login?code=$code1&state=$state1")
assertEquals("some-url", response1.bodyAsText())
}

private fun waitExecutor() {
val latch = CountDownLatch(1)
executor.submit {
Expand Down