diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java index 0ead22c22f5b..204d0cdc70d3 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.io.Serializable; import java.net.URI; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; @@ -154,7 +155,7 @@ protected final MockHttpServletRequest createServletRequest(ServletContext servl String filename = part.getSubmittedFileName(); InputStream is = part.getInputStream(); if (filename != null) { - request.addFile(new MockMultipartFile(name, filename, part.getContentType(), is)); + request.addFile(new MockStandardMultipartFile(part, filename)); } else { InputStreamReader reader = new InputStreamReader(is, getCharsetOrDefault(part, defaultCharset)); @@ -179,4 +180,51 @@ private Charset getCharsetOrDefault(Part part, Charset defaultCharset) { } return defaultCharset; } + + /** + * Spring MultipartFile adapter, wrapping a Servlet Part object. + */ + @SuppressWarnings("serial") + private static class MockStandardMultipartFile extends MockMultipartFile implements Part, Serializable { + + private final Part part; + + private final String filename; + + public MockStandardMultipartFile(Part part, String filename) throws IOException { + super(part.getName(), part.getInputStream()); + this.part = part; + this.filename = filename; + } + + @Override + public String getSubmittedFileName() { + return this.part.getSubmittedFileName(); + } + + @Override + public void write(String fileName) throws IOException { + this.part.write(fileName); + } + + @Override + public void delete() throws IOException { + this.part.delete(); + } + + @Override + public String getHeader(String name) { + return this.part.getHeader(name); + } + + @Override + public Collection getHeaders(String name) { + return this.part.getHeaders(name); + } + + @Override + public Collection getHeaderNames() { + return this.part.getHeaderNames(); + } + } } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/MultipartControllerTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/MultipartControllerTests.java index e61d50f0d6e4..68c7e7f1d844 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/MultipartControllerTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/MultipartControllerTests.java @@ -38,6 +38,7 @@ import org.springframework.stereotype.Controller; import org.springframework.test.web.servlet.MockMvc; import org.springframework.ui.Model; +import org.springframework.util.StreamUtils; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.RequestParam; @@ -239,6 +240,38 @@ public void multipartRequestWithServletParts() throws Exception { .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); } + @Test + public void multipartRequestWithServletPartsForPartAttribute() throws Exception { + byte[] fileContent = "bar".getBytes(StandardCharsets.UTF_8); + MockPart filePart = new MockPart("file", "orig", fileContent); + + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockPart jsonPart = new MockPart("json", json); + jsonPart.getHeaders().setContentType(MediaType.APPLICATION_JSON); + + standaloneSetup(new MultipartController()).build() + .perform(multipart("/partattr").part(filePart).part(jsonPart)) + .andExpect(status().isFound()) + .andExpect(model().attribute("fileContent", fileContent)) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + + @Test + public void multipartRequestWithServletPartsForMultipartFileAttribute() throws Exception { + byte[] fileContent = "foo".getBytes(StandardCharsets.UTF_8); + MockPart filePart = new MockPart("file", "orig", fileContent); + + byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); + MockPart jsonPart = new MockPart("json", json); + jsonPart.getHeaders().setContentType(MediaType.APPLICATION_JSON); + + standaloneSetup(new MultipartController()).build() + .perform(multipart("/multipartfileattr").part(filePart).part(jsonPart)) + .andExpect(status().isFound()) + .andExpect(model().attribute("fileContent", fileContent)) + .andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah"))); + } + @Test // SPR-13317 public void multipartRequestWrapped() throws Exception { byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8); @@ -341,29 +374,79 @@ public String processOptionalFileList(@RequestParam Optional return "redirect:/index"; } - @RequestMapping(value = "/part", method = RequestMethod.POST) - public String processPart(@RequestParam Part part, - @RequestPart Map json, Model model) throws IOException { + @RequestMapping(value = "/json", method = RequestMethod.POST) + public String processMultipart(@RequestPart Map json, Model model) { + model.addAttribute("json", json); + return "redirect:/index"; + } - model.addAttribute("fileContent", part.getInputStream()); - model.addAttribute("jsonContent", json); + @RequestMapping(value = "/partattr") + public String processPartAttribute(PartForm form, + @RequestPart(required = false) Map json, Model model) throws IOException { + + if (form != null) { + Part part = form.getFile(); + if (0 != part.getSize()) { + byte[] fileContent = StreamUtils.copyToByteArray(part.getInputStream()); + model.addAttribute("fileContent", fileContent); + } + } + if (json != null) { + model.addAttribute("jsonContent", json); + } return "redirect:/index"; } - @RequestMapping(value = "/json", method = RequestMethod.POST) - public String processMultipart(@RequestPart Map json, Model model) { - model.addAttribute("json", json); + @RequestMapping(value = "/multipartfileattr") + public String processMultipartFileAttribute(MultipartFileForm form, + @RequestPart(required = false) Map json, Model model) throws IOException { + + if (form != null) { + MultipartFile file = form.getFile(); + if (!file.isEmpty()) { + model.addAttribute("fileContent", file.getBytes()); + } + } + if (json != null) { + model.addAttribute("jsonContent", json); + } + return "redirect:/index"; } } + private static class PartForm { + + private Part file; + + public PartForm(Part file) { + this.file = file; + } + + public Part getFile() { + return file; + } + } + + private static class MultipartFileForm { + + private MultipartFile file; + + public MultipartFileForm(MultipartFile file) { + this.file = file; + } + + public MultipartFile getFile() { + return file; + } + } private static class RequestWrappingFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, - FilterChain filterChain) throws IOException, ServletException { + FilterChain filterChain) throws IOException, ServletException { request = new HttpServletRequestWrapper(request); filterChain.doFilter(request, response);