diff --git a/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java b/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java index 6262215b37a0..25c4cd59c550 100644 --- a/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java +++ b/spring-beans/src/main/java/org/springframework/beans/factory/support/StaticListableBeanFactory.java @@ -37,7 +37,7 @@ import org.springframework.beans.factory.SmartFactoryBean; import org.springframework.core.OrderComparator; import org.springframework.core.ResolvableType; -import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; @@ -450,7 +450,7 @@ public A findAnnotationOnBean(String beanName, Class a throws NoSuchBeanDefinitionException { Class beanType = getType(beanName); - return (beanType != null ? AnnotationUtils.findAnnotation(beanType, annotationType) : null); + return (beanType != null ? AnnotatedElementUtils.findMergedAnnotation(beanType, annotationType) : null); } } diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java index a84a028eae55..d70a7bbd1bd7 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/BeanFactoryUtilsTests.java @@ -16,6 +16,8 @@ package org.springframework.beans.factory; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -33,6 +35,7 @@ import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.beans.testfixture.beans.factory.DummyFactory; import org.springframework.cglib.proxy.NoOp; +import org.springframework.core.annotation.AliasFor; import org.springframework.core.io.Resource; import org.springframework.util.ObjectUtils; @@ -324,6 +327,33 @@ public void testIntDependencies() { assertThat(Arrays.equals(new String[] { "buffer" }, deps)).isTrue(); } + @Test + public void findAnnotationOnBean() { + this.listableBeanFactory.registerSingleton("controllerAdvice", new ControllerAdviceClass()); + this.listableBeanFactory.registerSingleton("restControllerAdvice", new RestControllerAdviceClass()); + testFindAnnotationOnBean(this.listableBeanFactory); + } + + @Test // gh-25520 + public void findAnnotationOnBeanWithStaticFactory() { + StaticListableBeanFactory lbf = new StaticListableBeanFactory(); + lbf.addBean("controllerAdvice", new ControllerAdviceClass()); + lbf.addBean("restControllerAdvice", new RestControllerAdviceClass()); + testFindAnnotationOnBean(lbf); + } + + private void testFindAnnotationOnBean(ListableBeanFactory lbf) { + assertControllerAdvice(lbf, "controllerAdvice"); + assertControllerAdvice(lbf, "restControllerAdvice"); + } + + private void assertControllerAdvice(ListableBeanFactory lbf, String beanName) { + ControllerAdvice controllerAdvice = lbf.findAnnotationOnBean(beanName, ControllerAdvice.class); + assertThat(controllerAdvice).isNotNull(); + assertThat(controllerAdvice.value()).isEqualTo("com.example"); + assertThat(controllerAdvice.basePackage()).isEqualTo("com.example"); + } + @Test public void isSingletonAndIsPrototypeWithStaticFactory() { StaticListableBeanFactory lbf = new StaticListableBeanFactory(); @@ -393,6 +423,35 @@ public void isSingletonAndIsPrototypeWithStaticFactory() { } + @Retention(RetentionPolicy.RUNTIME) + @interface ControllerAdvice { + + @AliasFor("basePackage") + String value() default ""; + + @AliasFor("value") + String basePackage() default ""; + } + + @Retention(RetentionPolicy.RUNTIME) + @ControllerAdvice + @interface RestControllerAdvice { + + @AliasFor(annotation = ControllerAdvice.class) + String value() default ""; + + @AliasFor(annotation = ControllerAdvice.class) + String basePackage() default ""; + } + + @ControllerAdvice("com.example") + static class ControllerAdviceClass { + } + + @RestControllerAdvice("com.example") + static class RestControllerAdviceClass { + } + static class TestBeanSmartFactoryBean implements SmartFactoryBean { private final TestBean testBean = new TestBean("enigma", 42); diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java index 7d7b335b2243..1e21a37c7089 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/ExceptionHandlerTests.java @@ -16,16 +16,23 @@ package org.springframework.test.web.servlet.samples.standalone; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.http.MediaType; import org.springframework.stereotype.Controller; import org.springframework.web.bind.annotation.ControllerAdvice; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.bind.annotation.RestControllerAdvice; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.forwardedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standaloneSetup; @@ -37,28 +44,32 @@ */ class ExceptionHandlerTests { - @Test - void localExceptionHandlerMethod() throws Exception { - standaloneSetup(new PersonController()).build() - .perform(get("/person/Clyde")) + @Nested + class MvcTests { + + @Test + void localExceptionHandlerMethod() throws Exception { + standaloneSetup(new PersonController()).build() + .perform(get("/person/Clyde")) .andExpect(status().isOk()) .andExpect(forwardedUrl("errorView")); - } + } - @Test - void globalExceptionHandlerMethod() throws Exception { - standaloneSetup(new PersonController()).setControllerAdvice(new GlobalExceptionHandler()).build() + @Test + void globalExceptionHandlerMethod() throws Exception { + standaloneSetup(new PersonController()).setControllerAdvice(new GlobalExceptionHandler()).build() .perform(get("/person/Bonnie")) .andExpect(status().isOk()) .andExpect(forwardedUrl("globalErrorView")); - } + } - @Test - void globalExceptionHandlerMethodUsingClassArgument() throws Exception { - standaloneSetup(PersonController.class).setControllerAdvice(GlobalExceptionHandler.class).build() + @Test + void globalExceptionHandlerMethodUsingClassArgument() throws Exception { + standaloneSetup(PersonController.class).setControllerAdvice(GlobalExceptionHandler.class).build() .perform(get("/person/Bonnie")) .andExpect(status().isOk()) .andExpect(forwardedUrl("globalErrorView")); + } } @@ -82,7 +93,6 @@ String handleException(IllegalArgumentException exception) { } } - @ControllerAdvice private static class GlobalExceptionHandler { @@ -92,4 +102,124 @@ String handleException(IllegalStateException exception) { } } + + @Nested + class RestTests { + + @Test + void noException() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build() + .perform(get("/person/Yoda").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.name").value("Yoda")); + } + + @Test + void localExceptionHandlerMethod() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build() + .perform(get("/person/Luke").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.error").value("local - IllegalArgumentException")); + } + + @Test + void globalExceptionHandlerMethod() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class).build() + .perform(get("/person/Leia").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.error").value("global - IllegalStateException")); + } + + @Test + void globalRestPersonControllerExceptionHandlerTakesPrecedenceOverGlobalExceptionHandler() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build() + .perform(get("/person/Leia").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.error").value("globalPersonController - IllegalStateException")); + } + + @Test // gh-25520 + void noHandlerFound() throws Exception { + standaloneSetup(RestPersonController.class) + .setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class) + .addDispatcherServletCustomizer(dispatcherServlet -> dispatcherServlet.setThrowExceptionIfNoHandlerFound(true)) + .build() + .perform(get("/bogus").accept(MediaType.APPLICATION_JSON)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.error").value("global - NoHandlerFoundException")); + } + } + + + @RestController + private static class RestPersonController { + + @GetMapping("/person/{name}") + Person get(@PathVariable String name) { + switch (name) { + case "Luke": + throw new IllegalArgumentException(); + case "Leia": + throw new IllegalStateException(); + default: + return new Person("Yoda"); + } + } + + @ExceptionHandler + Error handleException(IllegalArgumentException exception) { + return new Error("local - " + exception.getClass().getSimpleName()); + } + } + + @RestControllerAdvice(assignableTypes = RestPersonController.class) + @Order(Ordered.HIGHEST_PRECEDENCE) + private static class RestPersonControllerExceptionHandler { + + @ExceptionHandler + Error handleException(Throwable exception) { + return new Error("globalPersonController - " + exception.getClass().getSimpleName()); + } + } + + @RestControllerAdvice + @Order(Ordered.LOWEST_PRECEDENCE) + private static class RestGlobalExceptionHandler { + + @ExceptionHandler + Error handleException(Throwable exception) { + return new Error( "global - " + exception.getClass().getSimpleName()); + } + } + + static class Person { + + private final String name; + + Person(String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + + static class Error { + + private final String error; + + Error(String error) { + this.error = error; + } + + public String getError() { + return error; + } + } + }