Skip to content

Commit

Permalink
KTOR-4578 Allow overriding HSTS settings per host (#3029)
Browse files Browse the repository at this point in the history
  • Loading branch information
rescribet authored and rsinukov committed Jun 30, 2022
1 parent d514083 commit 380762b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 13 deletions.
@@ -1,4 +1,9 @@
public final class io/ktor/server/plugins/hsts/HSTSConfig {
public final class io/ktor/server/plugins/hsts/HSTSConfig : io/ktor/server/plugins/hsts/HSTSHostConfig {
public fun <init> ()V
public final fun withHost (Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V
}

public class io/ktor/server/plugins/hsts/HSTSHostConfig {
public fun <init> ()V
public final fun getCustomDirectives ()Ljava/util/Map;
public final fun getIncludeSubDomains ()Z
Expand Down
Expand Up @@ -7,14 +7,15 @@ package io.ktor.server.plugins.hsts
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.plugins.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.util.*

/**
* A configuration for the [HSTS] plugin.
* A configuration for the [HSTS] settings for a host.
*/
@KtorDsl
public class HSTSConfig {
public open class HSTSHostConfig {
/**
* Specifies the `preload` HSTS directive, which allows you to include your domain name
* in the HSTS preload list.
Expand All @@ -41,6 +42,24 @@ public class HSTSConfig {
public val customDirectives: MutableMap<String, String?> = HashMap()
}

/**
* A configuration for the [HSTS] plugin.
*/
@KtorDsl
public class HSTSConfig : HSTSHostConfig() {
/**
* @see [withHost]
*/
internal val hostSpecific: MutableMap<String, HSTSHostConfig> = HashMap()

/**
* Set specific configuration for a [host].
*/
public fun withHost(host: String, configure: HSTSHostConfig.() -> Unit) {
this.hostSpecific[host] = HSTSHostConfig().apply(configure)
}
}

internal const val DEFAULT_HSTS_MAX_AGE: Long = 365L * 24 * 3600 // 365 days

/**
Expand All @@ -56,22 +75,19 @@ internal const val DEFAULT_HSTS_MAX_AGE: Long = 365L * 24 * 3600 // 365 days
* You can learn more from [HSTS](https://ktor.io/docs/hsts.html).
*/
public val HSTS: RouteScopedPlugin<HSTSConfig> = createRouteScopedPlugin("HSTS", ::HSTSConfig) {
/**
* A constructed `Strict-Transport-Security` header value.
*/
val headerValue: String = buildString {
fun constructHeaderValue(config: HSTSHostConfig) = buildString {
append("max-age=")
append(pluginConfig.maxAgeInSeconds)
append(config.maxAgeInSeconds)

if (pluginConfig.includeSubDomains) {
if (config.includeSubDomains) {
append("; includeSubDomains")
}
if (pluginConfig.preload) {
if (config.preload) {
append("; preload")
}

if (pluginConfig.customDirectives.isNotEmpty()) {
pluginConfig.customDirectives.entries.joinTo(this, separator = "; ", prefix = "; ") {
if (config.customDirectives.isNotEmpty()) {
config.customDirectives.entries.joinTo(this, separator = "; ", prefix = "; ") {
if (it.value != null) {
"${it.key.escapeIfNeeded()}=${it.value?.escapeIfNeeded()}"
} else {
Expand All @@ -81,9 +97,19 @@ public val HSTS: RouteScopedPlugin<HSTSConfig> = createRouteScopedPlugin("HSTS",
}
}

/**
* A constructed default `Strict-Transport-Security` header value.
*/
val headerValue: String = constructHeaderValue(pluginConfig)

val hostHeaderValues: Map<String, String> = pluginConfig.hostSpecific.mapValues { constructHeaderValue(it.value) }

onCall { call ->
if (call.request.origin.run { scheme == "https" && port == 443 }) {
call.response.header(HttpHeaders.StrictTransportSecurity, headerValue)
call.response.header(
HttpHeaders.StrictTransportSecurity,
hostHeaderValues[call.request.host()] ?: headerValue
)
}
}
}
Expand Up @@ -160,6 +160,40 @@ class HSTSTest {
}
}

@Test
fun testHttpsHostOverride() = testApplication {
application {
testApp {
customDirectives.clear()
includeSubDomains = true

withHost("differing") {
maxAgeInSeconds = 10
preload = true
includeSubDomains = false
}
}
}

client.get("/") {
header(HttpHeaders.XForwardedProto, "https")
}.let {
assertEquals(
"max-age=10; includeSubDomains; preload",
it.headers[HttpHeaders.StrictTransportSecurity]
)
}
client.get("/") {
header(HttpHeaders.XForwardedProto, "https")
header(HttpHeaders.XForwardedHost, "differing")
}.let {
assertEquals(
"max-age=10; preload",
it.headers[HttpHeaders.StrictTransportSecurity]
)
}
}

private fun Application.testApp(block: HSTSConfig.() -> Unit = {}) {
install(XForwardedHeaders)
install(HSTS) {
Expand Down

0 comments on commit 380762b

Please sign in to comment.