Skip to content

Commit

Permalink
Replace MockMultipartFile with StandardMockMultipartFile
Browse files Browse the repository at this point in the history
  • Loading branch information
binchoo committed Jan 13, 2022
1 parent ec9ce5b commit a1b2262
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 10 deletions.
Expand Up @@ -19,6 +19,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Serializable;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -154,7 +155,7 @@ protected final MockHttpServletRequest createServletRequest(ServletContext servl
String filename = part.getSubmittedFileName();
InputStream is = part.getInputStream();
if (filename != null) {
request.addFile(new MockMultipartFile(name, filename, part.getContentType(), is));
request.addFile(new MockStandardMultipartFile(part, filename));
}
else {
InputStreamReader reader = new InputStreamReader(is, getCharsetOrDefault(part, defaultCharset));
Expand All @@ -179,4 +180,51 @@ private Charset getCharsetOrDefault(Part part, Charset defaultCharset) {
}
return defaultCharset;
}

/**
* Spring MultipartFile adapter, wrapping a Servlet Part object.
*/
@SuppressWarnings("serial")
private static class MockStandardMultipartFile extends MockMultipartFile implements Part, Serializable {

private final Part part;

private final String filename;

public MockStandardMultipartFile(Part part, String filename) throws IOException {
super(part.getName(), part.getInputStream());
this.part = part;
this.filename = filename;
}

@Override
public String getSubmittedFileName() {
return this.part.getSubmittedFileName();
}

@Override
public void write(String fileName) throws IOException {
this.part.write(fileName);
}

@Override
public void delete() throws IOException {
this.part.delete();
}

@Override
public String getHeader(String name) {
return this.part.getHeader(name);
}

@Override
public Collection<String> getHeaders(String name) {
return this.part.getHeaders(name);
}

@Override
public Collection<String> getHeaderNames() {
return this.part.getHeaderNames();
}
}
}
Expand Up @@ -38,6 +38,7 @@
import org.springframework.stereotype.Controller;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.ui.Model;
import org.springframework.util.StreamUtils;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
Expand Down Expand Up @@ -239,6 +240,38 @@ public void multipartRequestWithServletParts() throws Exception {
.andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah")));
}

@Test
public void multipartRequestWithServletPartsForPartAttribute() throws Exception {
byte[] fileContent = "bar".getBytes(StandardCharsets.UTF_8);
MockPart filePart = new MockPart("file", "orig", fileContent);

byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8);
MockPart jsonPart = new MockPart("json", json);
jsonPart.getHeaders().setContentType(MediaType.APPLICATION_JSON);

standaloneSetup(new MultipartController()).build()
.perform(multipart("/partattr").part(filePart).part(jsonPart))
.andExpect(status().isFound())
.andExpect(model().attribute("fileContent", fileContent))
.andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah")));
}

@Test
public void multipartRequestWithServletPartsForMultipartFileAttribute() throws Exception {
byte[] fileContent = "foo".getBytes(StandardCharsets.UTF_8);
MockPart filePart = new MockPart("file", "orig", fileContent);

byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8);
MockPart jsonPart = new MockPart("json", json);
jsonPart.getHeaders().setContentType(MediaType.APPLICATION_JSON);

standaloneSetup(new MultipartController()).build()
.perform(multipart("/multipartfileattr").part(filePart).part(jsonPart))
.andExpect(status().isFound())
.andExpect(model().attribute("fileContent", fileContent))
.andExpect(model().attribute("jsonContent", Collections.singletonMap("name", "yeeeah")));
}

@Test // SPR-13317
public void multipartRequestWrapped() throws Exception {
byte[] json = "{\"name\":\"yeeeah\"}".getBytes(StandardCharsets.UTF_8);
Expand Down Expand Up @@ -341,29 +374,79 @@ public String processOptionalFileList(@RequestParam Optional<List<MultipartFile>
return "redirect:/index";
}

@RequestMapping(value = "/part", method = RequestMethod.POST)
public String processPart(@RequestParam Part part,
@RequestPart Map<String, String> json, Model model) throws IOException {
@RequestMapping(value = "/json", method = RequestMethod.POST)
public String processMultipart(@RequestPart Map<String, String> json, Model model) {
model.addAttribute("json", json);
return "redirect:/index";
}

model.addAttribute("fileContent", part.getInputStream());
model.addAttribute("jsonContent", json);
@RequestMapping(value = "/partattr")
public String processPartAttribute(PartForm form,
@RequestPart(required = false) Map<String, String> json, Model model) throws IOException {

if (form != null) {
Part part = form.getFile();
if (0 != part.getSize()) {
byte[] fileContent = StreamUtils.copyToByteArray(part.getInputStream());
model.addAttribute("fileContent", fileContent);
}
}
if (json != null) {
model.addAttribute("jsonContent", json);
}

return "redirect:/index";
}

@RequestMapping(value = "/json", method = RequestMethod.POST)
public String processMultipart(@RequestPart Map<String, String> json, Model model) {
model.addAttribute("json", json);
@RequestMapping(value = "/multipartfileattr")
public String processMultipartFileAttribute(MultipartFileForm form,
@RequestPart(required = false) Map<String, String> json, Model model) throws IOException {

if (form != null) {
MultipartFile file = form.getFile();
if (!file.isEmpty()) {
model.addAttribute("fileContent", file.getBytes());
}
}
if (json != null) {
model.addAttribute("jsonContent", json);
}

return "redirect:/index";
}
}

private static class PartForm {

private Part file;

public PartForm(Part file) {
this.file = file;
}

public Part getFile() {
return file;
}
}

private static class MultipartFileForm {

private MultipartFile file;

public MultipartFileForm(MultipartFile file) {
this.file = file;
}

public MultipartFile getFile() {
return file;
}
}

private static class RequestWrappingFilter extends OncePerRequestFilter {

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws IOException, ServletException {
FilterChain filterChain) throws IOException, ServletException {

request = new HttpServletRequestWrapper(request);
filterChain.doFilter(request, response);
Expand Down

0 comments on commit a1b2262

Please sign in to comment.