diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/RouterFunctionMapping.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/RouterFunctionMapping.java index e9f78b295611..c2a14390de8d 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/RouterFunctionMapping.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/server/support/RouterFunctionMapping.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -119,12 +119,11 @@ protected void initRouterFunctions() { } private List> routerFunctions() { - List> functions = obtainApplicationContext() + return obtainApplicationContext() .getBeanProvider(RouterFunction.class) .orderedStream() - .map(router -> (RouterFunction)router) + .map(router -> (RouterFunction) router) .collect(Collectors.toList()); - return (!CollectionUtils.isEmpty(functions) ? functions : Collections.emptyList()); } private void logRouterFunctions(List> routerFunctions) { diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/support/RouterFunctionMappingTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/support/RouterFunctionMappingTests.java index a7d5fce5a17d..90b9e38d5272 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/support/RouterFunctionMappingTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/support/RouterFunctionMappingTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.reactive.function.server.HandlerFunction; @@ -72,6 +73,23 @@ void noMatch() { .verify(); } + @Test + void empty() throws Exception { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.refresh(); + + RouterFunctionMapping mapping = new RouterFunctionMapping(); + mapping.setMessageReaders(this.codecConfigurer.getReaders()); + mapping.setApplicationContext(context); + mapping.afterPropertiesSet(); + + Mono result = mapping.getHandler(createExchange("https://example.com/match")); + + StepVerifier.create(result) + .expectComplete() + .verify(); + } + @Test void changeParser() throws Exception { HandlerFunction handlerFunction = request -> ServerResponse.ok().build(); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/support/RouterFunctionMapping.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/support/RouterFunctionMapping.java index 7597f0af686d..c6e23129306d 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/function/support/RouterFunctionMapping.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/function/support/RouterFunctionMapping.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,11 +19,10 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Map; +import java.util.stream.Collectors; import javax.servlet.http.HttpServletRequest; -import org.springframework.beans.factory.BeanFactoryUtils; import org.springframework.beans.factory.InitializingBean; import org.springframework.context.ApplicationContext; import org.springframework.core.SpringProperties; @@ -135,7 +134,7 @@ public void setDetectHandlerFunctionsInAncestorContexts(boolean detectHandlerFun @Override public void afterPropertiesSet() throws Exception { if (this.routerFunction == null) { - initRouterFunction(); + initRouterFunctions(); } if (CollectionUtils.isEmpty(this.messageConverters)) { initMessageConverters(); @@ -154,20 +153,39 @@ public void afterPropertiesSet() throws Exception { * Detect a all {@linkplain RouterFunction router functions} in the * current application context. */ - @SuppressWarnings({"rawtypes", "unchecked"}) - private void initRouterFunction() { - ApplicationContext applicationContext = obtainApplicationContext(); - Map beans = - (this.detectHandlerFunctionsInAncestorContexts ? - BeanFactoryUtils.beansOfTypeIncludingAncestors(applicationContext, RouterFunction.class) : - applicationContext.getBeansOfType(RouterFunction.class)); - List routerFunctions = new ArrayList<>(beans.values()); + private void initRouterFunctions() { + List> routerFunctions = routerFunctions(); this.routerFunction = routerFunctions.stream().reduce(RouterFunction::andOther).orElse(null); logRouterFunctions(routerFunctions); } - @SuppressWarnings("rawtypes") - private void logRouterFunctions(List routerFunctions) { + private List> routerFunctions() { + List> routerFunctions = new ArrayList<>(); + if (this.detectHandlerFunctionsInAncestorContexts) { + detectRouterFunctionsInAncestorContexts(obtainApplicationContext(), routerFunctions); + } + obtainApplicationContext() + .getBeanProvider(RouterFunction.class) + .orderedStream() + .map(router -> (RouterFunction) router) + .collect(Collectors.toCollection(() -> routerFunctions)); + return routerFunctions; + } + + private void detectRouterFunctionsInAncestorContexts( + ApplicationContext applicationContext, List> routerFunctions) { + + ApplicationContext parentContext = applicationContext.getParent(); + if (parentContext != null) { + detectRouterFunctionsInAncestorContexts(parentContext, routerFunctions); + parentContext.getBeanProvider(RouterFunction.class) + .orderedStream() + .map(router -> (RouterFunction) router) + .collect(Collectors.toCollection(() -> routerFunctions)); + } + } + + private void logRouterFunctions(List> routerFunctions) { if (mappingsLogger.isDebugEnabled()) { routerFunctions.forEach(function -> mappingsLogger.debug("Mapped " + function)); } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/support/RouterFunctionMappingTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/support/RouterFunctionMappingTests.java index 165e346b799e..8b8ff09eec41 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/function/support/RouterFunctionMappingTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/function/support/RouterFunctionMappingTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,10 @@ import java.util.Optional; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.HandlerMapping; @@ -41,7 +44,7 @@ */ class RouterFunctionMappingTests { - private List> messageConverters = Collections.emptyList(); + private final List> messageConverters = Collections.emptyList(); @Test void normal() throws Exception { @@ -71,6 +74,65 @@ void noMatch() throws Exception { assertThat(result).isNull(); } + @Test + void empty() throws Exception { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.refresh(); + + RouterFunctionMapping mapping = new RouterFunctionMapping(); + mapping.setMessageConverters(this.messageConverters); + mapping.setApplicationContext(context); + mapping.afterPropertiesSet(); + + MockHttpServletRequest request = createTestRequest("/match"); + HandlerExecutionChain result = mapping.getHandler(request); + + assertThat(result).isNull(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void detectHandlerFunctionsInAncestorContexts(boolean detect) throws Exception { + HandlerFunction function1 = request -> ServerResponse.ok().build(); + HandlerFunction function2 = request -> ServerResponse.ok().build(); + HandlerFunction function3 = request -> ServerResponse.ok().build(); + + AnnotationConfigApplicationContext context1 = new AnnotationConfigApplicationContext(); + context1.registerBean(RouterFunction.class, () -> RouterFunctions.route().GET("/fn1", function1).build()); + context1.refresh(); + + AnnotationConfigApplicationContext context2 = new AnnotationConfigApplicationContext(); + context2.registerBean(RouterFunction.class, () -> RouterFunctions.route().GET("/fn2", function2).build()); + context2.setParent(context1); + context2.refresh(); + + AnnotationConfigApplicationContext context3 = new AnnotationConfigApplicationContext(); + context3.registerBean(RouterFunction.class, () -> RouterFunctions.route().GET("/fn3", function3).build()); + context3.setParent(context2); + context3.refresh(); + + RouterFunctionMapping mapping = new RouterFunctionMapping(); + mapping.setDetectHandlerFunctionsInAncestorContexts(detect); + mapping.setMessageConverters(this.messageConverters); + mapping.setApplicationContext(context3); + mapping.afterPropertiesSet(); + + HandlerExecutionChain chain1 = mapping.getHandler(createTestRequest("/fn1")); + HandlerExecutionChain chain2 = mapping.getHandler(createTestRequest("/fn2")); + if (detect) { + assertThat(chain1).isNotNull().extracting(HandlerExecutionChain::getHandler).isSameAs(function1); + assertThat(chain2).isNotNull().extracting(HandlerExecutionChain::getHandler).isSameAs(function2); + } + else { + assertThat(chain1).isNull(); + assertThat(chain2).isNull(); + } + + HandlerExecutionChain chain3 = mapping.getHandler(createTestRequest("/fn3")); + assertThat(chain3).isNotNull().extracting(HandlerExecutionChain::getHandler).isSameAs(function3); + + } + @Test void changeParser() throws Exception { HandlerFunction handlerFunction = request -> ServerResponse.ok().build();