Skip to content

Commit

Permalink
Copy headers from part in MultipartBodyBuilder
Browse files Browse the repository at this point in the history
This commit makes sure that Part.headers() is copied over when adding a
part in the MultipartBodyBuilder.

Closes gh-26410
  • Loading branch information
poutsma committed Jan 21, 2021
1 parent daa5465 commit e537844
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.codec.multipart.FilePart;
import org.springframework.http.codec.multipart.Part;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -129,13 +128,14 @@ public PartBuilder part(String name, Object part, @Nullable MediaType contentTyp
Assert.notNull(part, "'part' must not be null");

if (part instanceof Part) {
PartBuilder builder = asyncPart(name, ((Part) part).content(), DataBuffer.class);
Part partObject = (Part) part;
PartBuilder builder = asyncPart(name, partObject.content(), DataBuffer.class);
if (!partObject.headers().isEmpty()) {
builder.headers(headers -> headers.putAll(partObject.headers()));
}
if (contentType != null) {
builder.contentType(contentType);
}
if (part instanceof FilePart) {
builder.filename(((FilePart) part).filename());
}
return builder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.core.testfixture.io.buffer.AbstractLeakCheckingTests;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.client.MultipartBodyBuilder;
import org.springframework.http.codec.ClientCodecConfigurer;
Expand Down Expand Up @@ -102,8 +103,12 @@ public String getFilename() {
this.bufferFactory.wrap("Cc".getBytes(StandardCharsets.UTF_8))
);
FilePart mockPart = mock(FilePart.class);
HttpHeaders partHeaders = new HttpHeaders();
partHeaders.setContentType(MediaType.TEXT_PLAIN);
partHeaders.setContentDispositionFormData("filePublisher", "file.txt");
partHeaders.add("foo", "bar");
given(mockPart.headers()).willReturn(partHeaders);
given(mockPart.content()).willReturn(bufferPublisher);
given(mockPart.filename()).willReturn("file.txt");

MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder();
bodyBuilder.part("name 1", "value 1");
Expand Down Expand Up @@ -166,6 +171,7 @@ public String getFilename() {

part = requestParts.getFirst("filePublisher");
assertThat(part.name()).isEqualTo("filePublisher");
assertThat(part.headers()).containsEntry("foo", Collections.singletonList("bar"));
assertThat(((FilePart) part).filename()).isEqualTo("file.txt");
value = decodeToString(part);
assertThat(value).isEqualTo("AaBbCc");
Expand Down

0 comments on commit e537844

Please sign in to comment.