Skip to content

Commit

Permalink
Support @RestControllerAdvice in Standalone MockMvc again
Browse files Browse the repository at this point in the history
Since Spring Framework 5.2, @RestControllerAdvice registered with
MockMvc when using MockMvcBuilders.standaloneSetup() has no longer been
properly supported if annotation attributes were declared in the
@RestControllerAdvice annotation. Prior to 5.2, this was not an issue.

The cause for this regression is two-fold.

1. Commit 50c2577 refactored
   DefaultListableBeanFactory so that findAnnotationOnBean() supports
   merged annotations; however, that commit did not refactor
   StaticListableBeanFactory#findAnnotationOnBean() to support merged
   annotations.

2. Commit 978adbd refactored
   ControllerAdviceBean so that a merged @ControllerAdvice annotation
   is only looked up via ApplicationContext#findAnnotationOnBean().

The latter relies on the fact that findAnnotationOnBean() supports
merged annotations (e.g., @RestControllerAdvice as a merged instance of
@ControllerAdvice). Behind the scenes, MockMvcBuilders.standaloneSetup()
creates a StubWebApplicationContext which internally uses a
StubBeanFactory which extends StaticListableBeanFactory. Consequently,
since the implementation of findAnnotationOnBean() in
StaticListableBeanFactory was not updated to support merged annotations
like it was in DefaultListableBeanFactory, we only see this regression
with the standalone MockMvc support and not with MockMvc support for an
existing WebApplicationContext or with standard Spring applications
using an ApplicationContext that uses DefaultListableBeanFactory.

This commit fixes this regression by supporting merged annotations in
StaticListableBeanFactory#findAnnotationOnBean() as well.

Closes spring-projectsgh-25520
  • Loading branch information
sbrannen authored and xcl(徐程林) committed Aug 16, 2020
1 parent 397b936 commit 07a69b4
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 15 deletions.
Expand Up @@ -37,7 +37,7 @@
import org.springframework.beans.factory.SmartFactoryBean;
import org.springframework.core.OrderComparator;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
Expand Down Expand Up @@ -450,7 +450,7 @@ public <A extends Annotation> A findAnnotationOnBean(String beanName, Class<A> a
throws NoSuchBeanDefinitionException {

Class<?> beanType = getType(beanName);
return (beanType != null ? AnnotationUtils.findAnnotation(beanType, annotationType) : null);
return (beanType != null ? AnnotatedElementUtils.findMergedAnnotation(beanType, annotationType) : null);
}

}
Expand Up @@ -16,6 +16,8 @@

package org.springframework.beans.factory;

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -33,6 +35,7 @@
import org.springframework.beans.testfixture.beans.TestBean;
import org.springframework.beans.testfixture.beans.factory.DummyFactory;
import org.springframework.cglib.proxy.NoOp;
import org.springframework.core.annotation.AliasFor;
import org.springframework.core.io.Resource;
import org.springframework.util.ObjectUtils;

Expand Down Expand Up @@ -324,6 +327,33 @@ public void testIntDependencies() {
assertThat(Arrays.equals(new String[] { "buffer" }, deps)).isTrue();
}

@Test
public void findAnnotationOnBean() {
this.listableBeanFactory.registerSingleton("controllerAdvice", new ControllerAdviceClass());
this.listableBeanFactory.registerSingleton("restControllerAdvice", new RestControllerAdviceClass());
testFindAnnotationOnBean(this.listableBeanFactory);
}

@Test // gh-25520
public void findAnnotationOnBeanWithStaticFactory() {
StaticListableBeanFactory lbf = new StaticListableBeanFactory();
lbf.addBean("controllerAdvice", new ControllerAdviceClass());
lbf.addBean("restControllerAdvice", new RestControllerAdviceClass());
testFindAnnotationOnBean(lbf);
}

private void testFindAnnotationOnBean(ListableBeanFactory lbf) {
assertControllerAdvice(lbf, "controllerAdvice");
assertControllerAdvice(lbf, "restControllerAdvice");
}

private void assertControllerAdvice(ListableBeanFactory lbf, String beanName) {
ControllerAdvice controllerAdvice = lbf.findAnnotationOnBean(beanName, ControllerAdvice.class);
assertThat(controllerAdvice).isNotNull();
assertThat(controllerAdvice.value()).isEqualTo("com.example");
assertThat(controllerAdvice.basePackage()).isEqualTo("com.example");
}

@Test
public void isSingletonAndIsPrototypeWithStaticFactory() {
StaticListableBeanFactory lbf = new StaticListableBeanFactory();
Expand Down Expand Up @@ -393,6 +423,35 @@ public void isSingletonAndIsPrototypeWithStaticFactory() {
}


@Retention(RetentionPolicy.RUNTIME)
@interface ControllerAdvice {

@AliasFor("basePackage")
String value() default "";

@AliasFor("value")
String basePackage() default "";
}

@Retention(RetentionPolicy.RUNTIME)
@ControllerAdvice
@interface RestControllerAdvice {

@AliasFor(annotation = ControllerAdvice.class)
String value() default "";

@AliasFor(annotation = ControllerAdvice.class)
String basePackage() default "";
}

@ControllerAdvice("com.example")
static class ControllerAdviceClass {
}

@RestControllerAdvice("com.example")
static class RestControllerAdviceClass {
}

static class TestBeanSmartFactoryBean implements SmartFactoryBean<TestBean> {

private final TestBean testBean = new TestBean("enigma", 42);
Expand Down
Expand Up @@ -16,16 +16,23 @@

package org.springframework.test.web.servlet.samples.standalone;

import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;

import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.RestControllerAdvice;

import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.forwardedUrl;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standaloneSetup;

Expand All @@ -37,28 +44,32 @@
*/
class ExceptionHandlerTests {

@Test
void localExceptionHandlerMethod() throws Exception {
standaloneSetup(new PersonController()).build()
.perform(get("/person/Clyde"))
@Nested
class MvcTests {

@Test
void localExceptionHandlerMethod() throws Exception {
standaloneSetup(new PersonController()).build()
.perform(get("/person/Clyde"))
.andExpect(status().isOk())
.andExpect(forwardedUrl("errorView"));
}
}

@Test
void globalExceptionHandlerMethod() throws Exception {
standaloneSetup(new PersonController()).setControllerAdvice(new GlobalExceptionHandler()).build()
@Test
void globalExceptionHandlerMethod() throws Exception {
standaloneSetup(new PersonController()).setControllerAdvice(new GlobalExceptionHandler()).build()
.perform(get("/person/Bonnie"))
.andExpect(status().isOk())
.andExpect(forwardedUrl("globalErrorView"));
}
}

@Test
void globalExceptionHandlerMethodUsingClassArgument() throws Exception {
standaloneSetup(PersonController.class).setControllerAdvice(GlobalExceptionHandler.class).build()
@Test
void globalExceptionHandlerMethodUsingClassArgument() throws Exception {
standaloneSetup(PersonController.class).setControllerAdvice(GlobalExceptionHandler.class).build()
.perform(get("/person/Bonnie"))
.andExpect(status().isOk())
.andExpect(forwardedUrl("globalErrorView"));
}
}


Expand All @@ -82,7 +93,6 @@ String handleException(IllegalArgumentException exception) {
}
}


@ControllerAdvice
private static class GlobalExceptionHandler {

Expand All @@ -92,4 +102,124 @@ String handleException(IllegalStateException exception) {
}
}


@Nested
class RestTests {

@Test
void noException() throws Exception {
standaloneSetup(RestPersonController.class)
.setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build()
.perform(get("/person/Yoda").accept(MediaType.APPLICATION_JSON))
.andExpect(status().isOk())
.andExpect(jsonPath("$.name").value("Yoda"));
}

@Test
void localExceptionHandlerMethod() throws Exception {
standaloneSetup(RestPersonController.class)
.setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build()
.perform(get("/person/Luke").accept(MediaType.APPLICATION_JSON))
.andExpect(status().isOk())
.andExpect(jsonPath("$.error").value("local - IllegalArgumentException"));
}

@Test
void globalExceptionHandlerMethod() throws Exception {
standaloneSetup(RestPersonController.class)
.setControllerAdvice(RestGlobalExceptionHandler.class).build()
.perform(get("/person/Leia").accept(MediaType.APPLICATION_JSON))
.andExpect(status().isOk())
.andExpect(jsonPath("$.error").value("global - IllegalStateException"));
}

@Test
void globalRestPersonControllerExceptionHandlerTakesPrecedenceOverGlobalExceptionHandler() throws Exception {
standaloneSetup(RestPersonController.class)
.setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class).build()
.perform(get("/person/Leia").accept(MediaType.APPLICATION_JSON))
.andExpect(status().isOk())
.andExpect(jsonPath("$.error").value("globalPersonController - IllegalStateException"));
}

@Test // gh-25520
void noHandlerFound() throws Exception {
standaloneSetup(RestPersonController.class)
.setControllerAdvice(RestGlobalExceptionHandler.class, RestPersonControllerExceptionHandler.class)
.addDispatcherServletCustomizer(dispatcherServlet -> dispatcherServlet.setThrowExceptionIfNoHandlerFound(true))
.build()
.perform(get("/bogus").accept(MediaType.APPLICATION_JSON))
.andExpect(status().isOk())
.andExpect(jsonPath("$.error").value("global - NoHandlerFoundException"));
}
}


@RestController
private static class RestPersonController {

@GetMapping("/person/{name}")
Person get(@PathVariable String name) {
switch (name) {
case "Luke":
throw new IllegalArgumentException();
case "Leia":
throw new IllegalStateException();
default:
return new Person("Yoda");
}
}

@ExceptionHandler
Error handleException(IllegalArgumentException exception) {
return new Error("local - " + exception.getClass().getSimpleName());
}
}

@RestControllerAdvice(assignableTypes = RestPersonController.class)
@Order(Ordered.HIGHEST_PRECEDENCE)
private static class RestPersonControllerExceptionHandler {

@ExceptionHandler
Error handleException(Throwable exception) {
return new Error("globalPersonController - " + exception.getClass().getSimpleName());
}
}

@RestControllerAdvice
@Order(Ordered.LOWEST_PRECEDENCE)
private static class RestGlobalExceptionHandler {

@ExceptionHandler
Error handleException(Throwable exception) {
return new Error( "global - " + exception.getClass().getSimpleName());
}
}

static class Person {

private final String name;

Person(String name) {
this.name = name;
}

public String getName() {
return name;
}
}

static class Error {

private final String error;

Error(String error) {
this.error = error;
}

public String getError() {
return error;
}
}

}

0 comments on commit 07a69b4

Please sign in to comment.