Skip to content

Commit

Permalink
Improve parity between Java and Kotlin router DSL
Browse files Browse the repository at this point in the history
This commit adds following functions to the Kotlin DSL:
add, filter, before, after and onError.

Closes spring-projectsgh-23524
  • Loading branch information
sdeleuze committed Sep 17, 2019
1 parent 7a1a8e1 commit 007940f
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 27 deletions.
Expand Up @@ -17,6 +17,7 @@
package org.springframework.web.reactive.function.server

import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.reactive.awaitFirst
import kotlinx.coroutines.reactor.mono
import org.springframework.core.io.Resource
import org.springframework.http.HttpMethod
Expand Down Expand Up @@ -64,7 +65,8 @@ fun coRouter(routes: (CoRouterFunctionDsl.() -> Unit)) =
*/
class CoRouterFunctionDsl(private val init: (CoRouterFunctionDsl.() -> Unit)) {

private val builder = RouterFunctions.route()
@PublishedApi
internal val builder = RouterFunctions.route()

/**
* Return a composed request predicate that tests against both this predicate AND
Expand Down Expand Up @@ -510,6 +512,80 @@ class CoRouterFunctionDsl(private val init: (CoRouterFunctionDsl.() -> Unit)) {
}
}

/**
* Merge externally defined router functions into this one.
* @param routerFunction the router function to be added
* @since 5.2
*/
fun add(routerFunction: RouterFunction<ServerResponse>) {
builder.add(routerFunction)
}

/**
* Filters all routes created by this router with the given filter function. Filter
* functions are typically used to address cross-cutting concerns, such as logging,
* security, etc.
* @param filterFunction the function to filter all routes built by this router
* @since 5.2
*/
fun filter(filterFunction: suspend (ServerRequest, suspend (ServerRequest) -> ServerResponse) -> ServerResponse) {
builder.filter { serverRequest, handlerFunction ->
mono(Dispatchers.Unconfined) {
filterFunction(serverRequest) {
handlerFunction.handle(serverRequest).awaitFirst()
}
}
}
}

/**
* Filter the request object for all routes created by this builder with the given request
* processing function. Filters are typically used to address cross-cutting concerns, such
* as logging, security, etc.
* @param requestProcessor a function that transforms the request
* @since 5.2
*/
fun before(requestProcessor: (ServerRequest) -> ServerRequest) {
builder.before(requestProcessor)
}

/**
* Filter the response object for all routes created by this builder with the given response
* processing function. Filters are typically used to address cross-cutting concerns, such
* as logging, security, etc.
* @param responseProcessor a function that transforms the response
* @since 5.2
*/
fun after(responseProcessor: (ServerRequest, ServerResponse) -> ServerResponse) {
builder.after(responseProcessor)
}

/**
* Filters all exceptions that match the predicate by applying the given response provider
* function.
* @param predicate the type of exception to filter
* @param responseProvider a function that creates a response
* @since 5.2
*/
fun onError(predicate: (Throwable) -> Boolean, responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
builder.onError(predicate) { throwable, request ->
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
}
}

/**
* Filters all exceptions that match the predicate by applying the given response provider
* function.
* @param E the type of exception to filter
* @param responseProvider a function that creates a response
* @since 5.2
*/
inline fun <reified E : Throwable> onError(noinline responseProvider: suspend (Throwable, ServerRequest) -> ServerResponse) {
builder.onError({it is E}) { throwable, request ->
mono(Dispatchers.Unconfined) { responseProvider.invoke(throwable, request) }
}
}

/**
* Return a composed routing function created from all the registered routes.
*/
Expand Down
Expand Up @@ -62,7 +62,8 @@ fun router(routes: RouterFunctionDsl.() -> Unit) = RouterFunctionDsl(routes).bui
*/
class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) {

private val builder = RouterFunctions.route()
@PublishedApi
internal val builder = RouterFunctions.route()

/**
* Return a composed request predicate that tests against both this predicate AND
Expand Down Expand Up @@ -505,6 +506,83 @@ class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) {
builder.resources(lookupFunction)
}

/**
* Merge externally defined router functions into this one.
* @param routerFunction the router function to be added
* @since 5.2
*/
fun add(routerFunction: RouterFunction<ServerResponse>) {
builder.add(routerFunction)
}

/**
* Filters all routes created by this router with the given filter function. Filter
* functions are typically used to address cross-cutting concerns, such as logging,
* security, etc.
* @param filterFunction the function to filter all routes built by this router
* @since 5.2
*/
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> Mono<ServerResponse>) -> Mono<ServerResponse>) {
builder.filter { request, next ->
filterFunction(request) {
next.handle(request)
}
}
}

/**
* Filter the request object for all routes created by this builder with the given request
* processing function. Filters are typically used to address cross-cutting concerns, such
* as logging, security, etc.
* @param requestProcessor a function that transforms the request
* @since 5.2
*/
fun before(requestProcessor: (ServerRequest) -> ServerRequest) {
builder.before(requestProcessor)
}

/**
* Filter the response object for all routes created by this builder with the given response
* processing function. Filters are typically used to address cross-cutting concerns, such
* as logging, security, etc.
* @param responseProcessor a function that transforms the response
* @since 5.2
*/
fun after(responseProcessor: (ServerRequest, ServerResponse) -> ServerResponse) {
builder.after(responseProcessor)
}

/**
* Filters all exceptions that match the predicate by applying the given response provider
* function.
* @param predicate the type of exception to filter
* @param responseProvider a function that creates a response
* @since 5.2
*/
fun onError(predicate: (Throwable) -> Boolean, responseProvider: (Throwable, ServerRequest) -> Mono<ServerResponse>) {
builder.onError(predicate, responseProvider)
}

/**
* Filters all exceptions that match the predicate by applying the given response provider
* function.
* @param E the type of exception to filter
* @param responseProvider a function that creates a response
* @since 5.2
*/
inline fun <reified E : Throwable> onError(noinline responseProvider: (Throwable, ServerRequest) -> Mono<ServerResponse>) {
builder.onError({it is E}, responseProvider)
}

/**
* Return a composed routing function created from all the registered routes.
* @since 5.1
*/
internal fun build(): RouterFunction<ServerResponse> {
init()
return builder.build()
}

/**
* Create a builder with the status code and headers of the given response.
* @param other the response to copy the status and headers from
Expand Down Expand Up @@ -621,13 +699,4 @@ class RouterFunctionDsl(private val init: RouterFunctionDsl.() -> Unit) {
fun unprocessableEntity(): ServerResponse.BodyBuilder =
ServerResponse.unprocessableEntity()

/**
* Return a composed routing function created from all the registered routes.
* @since 5.1
*/
internal fun build(): RouterFunction<ServerResponse> {
init()
return builder.build()
}

}
Expand Up @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test
import org.springframework.core.io.ClassPathResource
import org.springframework.http.HttpHeaders.*
import org.springframework.http.HttpMethod.*
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType.*
import org.springframework.web.reactive.function.server.MockServerRequest.builder
import reactor.test.StepVerifier
Expand Down Expand Up @@ -172,6 +173,28 @@ class CoRouterFunctionDslTests {
}
path("/baz", ::handle)
GET("/rendering") { RenderingResponse.create("index").buildAndAwait() }
add(otherRouter)
}

private val otherRouter = router {
"/other" {
ok().build()
}
filter { request, next ->
next(request)
}
before {
it
}
after { _, response ->
response
}
onError({it is IllegalStateException}) { _, _ ->
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
}
onError<IllegalStateException> { _, _ ->
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
}
}

@Suppress("UNUSED_PARAMETER")
Expand Down
Expand Up @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test
import org.springframework.core.io.ClassPathResource
import org.springframework.http.HttpHeaders.*
import org.springframework.http.HttpMethod.*
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType.*
import org.springframework.web.reactive.function.server.MockServerRequest.builder
import reactor.core.publisher.Mono
Expand Down Expand Up @@ -173,6 +174,28 @@ class RouterFunctionDslTests {
}
path("/baz", ::handle)
GET("/rendering") { RenderingResponse.create("index").build() }
add(otherRouter)
}

private val otherRouter = router {
"/other" {
ok().build()
}
filter { request, next ->
next(request)
}
before {
it
}
after { _, response ->
response
}
onError({it is IllegalStateException}) { _, _ ->
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
}
onError<IllegalStateException> { _, _ ->
ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build()
}
}

@Suppress("UNUSED_PARAMETER")
Expand Down
Expand Up @@ -60,7 +60,8 @@ fun router(routes: (RouterFunctionDsl.() -> Unit)) = RouterFunctionDsl(routes).b
*/
class RouterFunctionDsl(private val init: (RouterFunctionDsl.() -> Unit)) {

private val builder = RouterFunctions.route()
@PublishedApi
internal val builder = RouterFunctions.route()

/**
* Return a composed request predicate that tests against both this predicate AND
Expand Down Expand Up @@ -504,6 +505,74 @@ class RouterFunctionDsl(private val init: (RouterFunctionDsl.() -> Unit)) {
}
}

/**
* Merge externally defined router functions into this one.
* @param routerFunction the router function to be added
* @since 5.2
*/
fun add(routerFunction: RouterFunction<ServerResponse>) {
builder.add(routerFunction)
}

/**
* Filters all routes created by this router with the given filter function. Filter
* functions are typically used to address cross-cutting concerns, such as logging,
* security, etc.
* @param filterFunction the function to filter all routes built by this router
* @since 5.2
*/
fun filter(filterFunction: (ServerRequest, (ServerRequest) -> ServerResponse) -> ServerResponse) {
builder.filter { request, next ->
filterFunction(request) {
next.handle(request)
}
}
}

/**
* Filter the request object for all routes created by this builder with the given request
* processing function. Filters are typically used to address cross-cutting concerns, such
* as logging, security, etc.
* @param requestProcessor a function that transforms the request
* @since 5.2
*/
fun before(requestProcessor: (ServerRequest) -> ServerRequest) {
builder.before(requestProcessor)
}

/**
* Filter the response object for all routes created by this builder with the given response
* processing function. Filters are typically used to address cross-cutting concerns, such
* as logging, security, etc.
* @param responseProcessor a function that transforms the response
* @since 5.2
*/
fun after(responseProcessor: (ServerRequest, ServerResponse) -> ServerResponse) {
builder.after(responseProcessor)
}

/**
* Filters all exceptions that match the predicate by applying the given response provider
* function.
* @param predicate the type of exception to filter
* @param responseProvider a function that creates a response
* @since 5.2
*/
fun onError(predicate: (Throwable) -> Boolean, responseProvider: (Throwable, ServerRequest) -> ServerResponse) {
builder.onError(predicate, responseProvider)
}

/**
* Filters all exceptions that match the predicate by applying the given response provider
* function.
* @param E the type of exception to filter
* @param responseProvider a function that creates a response
* @since 5.2
*/
inline fun <reified E : Throwable> onError(noinline responseProvider: (Throwable, ServerRequest) -> ServerResponse) {
builder.onError({it is E}, responseProvider)
}

/**
* Return a composed routing function created from all the registered routes.
*/
Expand Down

0 comments on commit 007940f

Please sign in to comment.