From 07a69b422c388065ce77d7b17945e75a0ef729bf Mon Sep 17 00:00:00 2001 From: Sam Brannen Date: Wed, 5 Aug 2020 17:24:18 +0200 Subject: [PATCH] Support @RestControllerAdvice in Standalone MockMvc again Since Spring Framework 5.2, @RestControllerAdvice registered with MockMvc when using MockMvcBuilders.standaloneSetup() has no longer been properly supported if annotation attributes were declared in the @RestControllerAdvice annotation. Prior to 5.2, this was not an issue. The cause for this regression is two-fold. 1. Commit 50c257794f7845829ac9ce78a102ef94e7e28a2e refactored DefaultListableBeanFactory so that findAnnotationOnBean() supports merged annotations; however, that commit did not refactor StaticListableBeanFactory#findAnnotationOnBean() to support merged annotations. 2. Commit 978adbdae749566fbf458f4f72847dfc0b5aabf7 refactored ControllerAdviceBean so that a merged @ControllerAdvice annotation is only looked up via ApplicationContext#findAnnotationOnBean(). The latter relies on the fact that findAnnotationOnBean() supports merged annotations (e.g., @RestControllerAdvice as a merged instance of @ControllerAdvice). Behind the scenes, MockMvcBuilders.standaloneSetup() creates a StubWebApplicationContext which internally uses a StubBeanFactory which extends StaticListableBeanFactory. Consequently, since the implementation of findAnnotationOnBean() in StaticListableBeanFactory was not updated to support merged annotations like it was in DefaultListableBeanFactory, we only see this regression with the standalone MockMvc support and not with MockMvc support for an existing WebApplicationContext or with standard Spring applications using an ApplicationContext that uses DefaultListableBeanFactory. This commit fixes this regression by supporting merged annotations in StaticListableBeanFactory#findAnnotationOnBean() as well. Closes gh-25520 --- .../support/StaticListableBeanFactory.java | 4 +- .../beans/factory/BeanFactoryUtilsTests.java | 59 +++++++ .../standalone/ExceptionHandlerTests.java | 156 ++++++++++++++++-- 3 files changed, 204 insertions(+), 15 deletions(-) diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java index 6262215b37a0..25c4cd59c550 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java @@ -37,7 +37,7 @@ import org.springframework.beans.factory.SmartFactoryBean; import org.springframework.core.OrderComparator; import org.springframework.core.ResolvableType; -import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; @@ -450,7 +450,7 @@ public A findAnnotationOnBean(String beanName, Class a throws NoSuchBeanDefinitionException { Class beanType = getType(beanName); - return (beanType != null ? AnnotationUtils.findAnnotation(beanType, annotationType) : null); + return (beanType != null ? AnnotatedElementUtils.findMergedAnnotation(beanType, annotationType) : null); } } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java index a84a028eae55..d70a7bbd1bd7 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java @@ -16,6 +16,8 @@ package org.springframework.beans.factory; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -33,6 +35,7 @@ import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.factory.DummyFactory; import org.springframework.cglib.proxy.NoOp; +import org.springframework.core.annotation.AliasFor; import org.springframework.core.io.Resource; import org.springframework.util.ObjectUtils; @@ -324,6 +327,33 @@ public void testIntDependencies() { assertThat(Arrays.equals(new String[] { "buffer" }, deps)).isTrue(); } + @Test + public void findAnnotationOnBean() { + this.listableBeanFactory.registerSingleton("controllerAdvice", new ControllerAdviceClass()); + this.listableBeanFactory.registerSingleton("restControllerAdvice", new RestControllerAdviceClass()); + testFindAnnotationOnBean(this.listableBeanFactory); + } + + @Test // gh-25520 + public void findAnnotationOnBeanWithStaticFactory() { + StaticListableBeanFactory lbf = new StaticListableBeanFactory(); + lbf.addBean("controllerAdvice", new ControllerAdviceClass()); + lbf.addBean("restControllerAdvice", new RestControllerAdviceClass()); + testFindAnnotationOnBean(lbf); + } + + private void testFindAnnotationOnBean(ListableBeanFactory lbf) { + assertControllerAdvice(lbf, "controllerAdvice"); + assertControllerAdvice(lbf, "restControllerAdvice"); + } + + private void assertControllerAdvice(ListableBeanFactory lbf, String beanName) { + ControllerAdvice controllerAdvice = lbf.findAnnotationOnBean(beanName, ControllerAdvice.class); + assertThat(controllerAdvice).isNotNull(); + assertThat(controllerAdvice.value()).isEqualTo("com.example"); + assertThat(controllerAdvice.basePackage()).isEqualTo("com.example"); + } + @Test public void isSingletonAndIsPrototypeWithStaticFactory() { StaticListableBeanFactory lbf = new StaticListableBeanFactory(); @@ -393,6 +423,35 @@ public void isSingletonAndIsPrototypeWithStaticFactory() { } + @Retention(RetentionPolicy.RUNTIME) + @interface ControllerAdvice { + + @AliasFor("basePackage") + String value() default ""; + + @AliasFor("value") + String basePackage() default ""; + } + + @Retention(RetentionPolicy.RUNTIME) + @ControllerAdvice + @interface RestControllerAdvice { + + @AliasFor(annotation = ControllerAdvice.class) + String value() default ""; + + @AliasFor(annotation = ControllerAdvice.class) + String basePackage() default ""; + } + + @ControllerAdvice("com.example") + static class ControllerAdviceClass { + } + + @RestControllerAdvice("com.example") + static class RestControllerAdviceClass { + } + static class TestBeanSmartFactoryBean implements SmartFactoryBean { private final TestBean testBean = new TestBean("enigma", 42); diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java index 7d7b335b2243..1e21a37c7089 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java @@ -16,16 +16,23 @@ package org.springframework.test.web.servlet.samples.standalone; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.http.MediaType; import org.springframework.stereotype.Controller; import org.springframework.web.bind.annotation.ControllerAdvice; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.bind.annotation.RestControllerAdvice; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.forwardedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standaloneSetup; @@ -37,28 +44,32 @@ */ class ExceptionHandlerTests { - @Test - void localExceptionHandlerMethod() throws Exception { - standaloneSetup(new PersonController()).build() - .perform(get("/person/Clyde")) + @Nested + class MvcTests { + + @Test + void localExceptionHandlerMethod() throws Exception { + standaloneSetup(new PersonController()).build() + .perform(get("/person/Clyde")) .andExpect(status().isOk()) .andExpect(forwardedUrl("errorView")); - } + } - @Test - void globalExceptionHandlerMethod() throws Exception { - standaloneSetup(new PersonController()).setControllerAdvice(new GlobalExceptionHandler()).build() + @Test + void globalExceptionHandlerMethod() throws Exception { + standaloneSetup(new PersonController()).setControllerAdvice(new GlobalExceptionHandler()).build() .perform(get("/person/Bonnie")) .andExpect(status().isOk()) .andExpect(forwardedUrl("globalErrorView")); - } + } - @Test - void globalExceptionHandlerMethodUsingClassArgument() throws Exception { - standaloneSetup(PersonController.class).setControllerAdvice(GlobalExceptionHandler.class).build() + @Test + void globalExceptionHandlerMethodUsingClassArgument() throws Exception { + standaloneSetup(PersonController.class).setControllerAdvice(GlobalExceptionHandler.class).build() .perform(get("/person/Bonnie")) .andExpect(status().isOk()) .andExpect(forwardedUrl("globalErrorView")); + } } @@ -82,7 +93,6 @@ String handleException(IllegalArgumentException exception) { } } - @ControllerAdvice private static class GlobalExceptionHandler { @@ -92,4 +102,124 @@ String handleException(IllegalStateException exception) { } } + + @Nested + class RestTests { + + @Test + void noException() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build() + .perform(get("/person/Yoda").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.name").value("Yoda")); + } + + @Test + void localExceptionHandlerMethod() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build() + .perform(get("/person/Luke").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.error").value("local - IllegalArgumentException")); + } + + @Test + void globalExceptionHandlerMethod() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class).build() + .perform(get("/person/Leia").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.error").value("global - IllegalStateException")); + } + + @Test + void globalRestPersonControllerExceptionHandlerTakesPrecedenceOverGlobalExceptionHandler() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build() + .perform(get("/person/Leia").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.error").value("globalPersonController - IllegalStateException")); + } + + @Test // gh-25520 + void noHandlerFound() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class) + .addDispatcherServletCustomizer(dispatcherServlet -> dispatcherServlet.setThrowExceptionIfNoHandlerFound(true)) + .build() + .perform(get("/bogus").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.error").value("global - NoHandlerFoundException")); + } + } + + + @RestController + private static class RestPersonController { + + @GetMapping("/person/{name}") + Person get(@PathVariable String name) { + switch (name) { + case "Luke": + throw new IllegalArgumentException(); + case "Leia": + throw new IllegalStateException(); + default: + return new Person("Yoda"); + } + } + + @ExceptionHandler + Error handleException(IllegalArgumentException exception) { + return new Error("local - " + exception.getClass().getSimpleName()); + } + } + + @RestControllerAdvice(assignableTypes = RestPersonController.class) + @Order(Ordered.HIGHEST_PRECEDENCE) + private static class RestPersonControllerExceptionHandler { + + @ExceptionHandler + Error handleException(Throwable exception) { + return new Error("globalPersonController - " + exception.getClass().getSimpleName()); + } + } + + @RestControllerAdvice + @Order(Ordered.LOWEST_PRECEDENCE) + private static class RestGlobalExceptionHandler { + + @ExceptionHandler + Error handleException(Throwable exception) { + return new Error( "global - " + exception.getClass().getSimpleName()); + } + } + + static class Person { + + private final String name; + + Person(String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + + static class Error { + + private final String error; + + Error(String error) { + this.error = error; + } + + public String getError() { + return error; + } + } + }