Skip to content

Commit

Permalink
WebMvc respects RouterFunction beans ordering
Browse files Browse the repository at this point in the history
Closes gh-28595
  • Loading branch information
rstoyanchev committed Jun 14, 2022
1 parent 97854d9 commit 52d0681
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 21 deletions.
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -119,12 +119,11 @@ protected void initRouterFunctions() {
}

private List<RouterFunction<?>> routerFunctions() {
List<RouterFunction<?>> functions = obtainApplicationContext()
return obtainApplicationContext()
.getBeanProvider(RouterFunction.class)
.orderedStream()
.map(router -> (RouterFunction<?>)router)
.map(router -> (RouterFunction<?>) router)
.collect(Collectors.toList());
return (!CollectionUtils.isEmpty(functions) ? functions : Collections.emptyList());
}

private void logRouterFunctions(List<RouterFunction<?>> routerFunctions) {
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.function.server.HandlerFunction;
Expand Down Expand Up @@ -72,6 +73,23 @@ void noMatch() {
.verify();
}

@Test
void empty() throws Exception {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
context.refresh();

RouterFunctionMapping mapping = new RouterFunctionMapping();
mapping.setMessageReaders(this.codecConfigurer.getReaders());
mapping.setApplicationContext(context);
mapping.afterPropertiesSet();

Mono<Object> result = mapping.getHandler(createExchange("https://example.com/match"));

StepVerifier.create(result)
.expectComplete()
.verify();
}

@Test
void changeParser() throws Exception {
HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build();
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,11 +19,10 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import javax.servlet.http.HttpServletRequest;

import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.ApplicationContext;
import org.springframework.core.SpringProperties;
Expand Down Expand Up @@ -135,7 +134,7 @@ public void setDetectHandlerFunctionsInAncestorContexts(boolean detectHandlerFun
@Override
public void afterPropertiesSet() throws Exception {
if (this.routerFunction == null) {
initRouterFunction();
initRouterFunctions();
}
if (CollectionUtils.isEmpty(this.messageConverters)) {
initMessageConverters();
Expand All @@ -154,20 +153,39 @@ public void afterPropertiesSet() throws Exception {
* Detect a all {@linkplain RouterFunction router functions} in the
* current application context.
*/
@SuppressWarnings({"rawtypes", "unchecked"})
private void initRouterFunction() {
ApplicationContext applicationContext = obtainApplicationContext();
Map<String, RouterFunction> beans =
(this.detectHandlerFunctionsInAncestorContexts ?
BeanFactoryUtils.beansOfTypeIncludingAncestors(applicationContext, RouterFunction.class) :
applicationContext.getBeansOfType(RouterFunction.class));
List<RouterFunction> routerFunctions = new ArrayList<>(beans.values());
private void initRouterFunctions() {
List<RouterFunction<?>> routerFunctions = routerFunctions();
this.routerFunction = routerFunctions.stream().reduce(RouterFunction::andOther).orElse(null);
logRouterFunctions(routerFunctions);
}

@SuppressWarnings("rawtypes")
private void logRouterFunctions(List<RouterFunction> routerFunctions) {
private List<RouterFunction<?>> routerFunctions() {
List<RouterFunction<?>> routerFunctions = new ArrayList<>();
if (this.detectHandlerFunctionsInAncestorContexts) {
detectRouterFunctionsInAncestorContexts(obtainApplicationContext(), routerFunctions);
}
obtainApplicationContext()
.getBeanProvider(RouterFunction.class)
.orderedStream()
.map(router -> (RouterFunction<?>) router)
.collect(Collectors.toCollection(() -> routerFunctions));
return routerFunctions;
}

private void detectRouterFunctionsInAncestorContexts(
ApplicationContext applicationContext, List<RouterFunction<?>> routerFunctions) {

ApplicationContext parentContext = applicationContext.getParent();
if (parentContext != null) {
detectRouterFunctionsInAncestorContexts(parentContext, routerFunctions);
parentContext.getBeanProvider(RouterFunction.class)
.orderedStream()
.map(router -> (RouterFunction<?>) router)
.collect(Collectors.toCollection(() -> routerFunctions));
}
}

private void logRouterFunctions(List<RouterFunction<?>> routerFunctions) {
if (mappingsLogger.isDebugEnabled()) {
routerFunctions.forEach(function -> mappingsLogger.debug("Mapped " + function));
}
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,7 +21,10 @@
import java.util.Optional;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.HandlerMapping;
Expand All @@ -41,7 +44,7 @@
*/
class RouterFunctionMappingTests {

private List<HttpMessageConverter<?>> messageConverters = Collections.emptyList();
private final List<HttpMessageConverter<?>> messageConverters = Collections.emptyList();

@Test
void normal() throws Exception {
Expand Down Expand Up @@ -71,6 +74,65 @@ void noMatch() throws Exception {
assertThat(result).isNull();
}

@Test
void empty() throws Exception {
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();
context.refresh();

RouterFunctionMapping mapping = new RouterFunctionMapping();
mapping.setMessageConverters(this.messageConverters);
mapping.setApplicationContext(context);
mapping.afterPropertiesSet();

MockHttpServletRequest request = createTestRequest("/match");
HandlerExecutionChain result = mapping.getHandler(request);

assertThat(result).isNull();
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void detectHandlerFunctionsInAncestorContexts(boolean detect) throws Exception {
HandlerFunction<ServerResponse> function1 = request -> ServerResponse.ok().build();
HandlerFunction<ServerResponse> function2 = request -> ServerResponse.ok().build();
HandlerFunction<ServerResponse> function3 = request -> ServerResponse.ok().build();

AnnotationConfigApplicationContext context1 = new AnnotationConfigApplicationContext();
context1.registerBean(RouterFunction.class, () -> RouterFunctions.route().GET("/fn1", function1).build());
context1.refresh();

AnnotationConfigApplicationContext context2 = new AnnotationConfigApplicationContext();
context2.registerBean(RouterFunction.class, () -> RouterFunctions.route().GET("/fn2", function2).build());
context2.setParent(context1);
context2.refresh();

AnnotationConfigApplicationContext context3 = new AnnotationConfigApplicationContext();
context3.registerBean(RouterFunction.class, () -> RouterFunctions.route().GET("/fn3", function3).build());
context3.setParent(context2);
context3.refresh();

RouterFunctionMapping mapping = new RouterFunctionMapping();
mapping.setDetectHandlerFunctionsInAncestorContexts(detect);
mapping.setMessageConverters(this.messageConverters);
mapping.setApplicationContext(context3);
mapping.afterPropertiesSet();

HandlerExecutionChain chain1 = mapping.getHandler(createTestRequest("/fn1"));
HandlerExecutionChain chain2 = mapping.getHandler(createTestRequest("/fn2"));
if (detect) {
assertThat(chain1).isNotNull().extracting(HandlerExecutionChain::getHandler).isSameAs(function1);
assertThat(chain2).isNotNull().extracting(HandlerExecutionChain::getHandler).isSameAs(function2);
}
else {
assertThat(chain1).isNull();
assertThat(chain2).isNull();
}

HandlerExecutionChain chain3 = mapping.getHandler(createTestRequest("/fn3"));
assertThat(chain3).isNotNull().extracting(HandlerExecutionChain::getHandler).isSameAs(function3);

}

@Test
void changeParser() throws Exception {
HandlerFunction<ServerResponse> handlerFunction = request -> ServerResponse.ok().build();
Expand Down

0 comments on commit 52d0681

Please sign in to comment.