Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean Content-Encoding response header in WebFlux error handler #19372

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -34,6 +34,7 @@
import org.springframework.context.ApplicationContext;
import org.springframework.core.NestedExceptionUtils;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpLogging;
import org.springframework.http.HttpStatus;
import org.springframework.http.codec.HttpMessageReader;
Expand Down Expand Up @@ -71,6 +72,20 @@ public abstract class AbstractErrorWebExceptionHandler implements ErrorWebExcept
DISCONNECTED_CLIENT_EXCEPTIONS = Collections.unmodifiableSet(exceptions);
}

private static final Set<String> RESPONSE_CONTENT_HEADERS;

static {
Set<String> headers = new HashSet<>();
headers.add(HttpHeaders.CONTENT_ENCODING);
headers.add(HttpHeaders.CONTENT_DISPOSITION);
headers.add(HttpHeaders.CONTENT_LANGUAGE);
headers.add(HttpHeaders.CONTENT_LENGTH);
headers.add(HttpHeaders.CONTENT_LOCATION);
headers.add(HttpHeaders.CONTENT_RANGE);
headers.add(HttpHeaders.CONTENT_TYPE);
RESPONSE_CONTENT_HEADERS = Collections.unmodifiableSet(headers);
}

private static final Log logger = HttpLogging.forLogName(AbstractErrorWebExceptionHandler.class);

private final ApplicationContext applicationContext;
Expand Down Expand Up @@ -293,11 +308,16 @@ private String formatRequest(ServerRequest request) {
}

private Mono<? extends Void> write(ServerWebExchange exchange, ServerResponse response) {
// force content-type since writeTo won't overwrite response header values
exchange.getResponse().getHeaders().setContentType(response.headers().getContentType());
clearResponseContentHeaders(exchange.getResponse().getHeaders());
return response.writeTo(exchange, new ResponseContext());
}

private void clearResponseContentHeaders(HttpHeaders headers) {
for (String contentHeader : RESPONSE_CONTENT_HEADERS) {
headers.remove(contentHeader);
}
}

private class ResponseContext implements ServerResponse.Context {

@Override
Expand Down
Expand Up @@ -16,6 +16,10 @@

package org.springframework.boot.autoconfigure.web.reactive.error;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;

import javax.validation.Valid;

import org.hamcrest.Matchers;
Expand All @@ -32,8 +36,10 @@
import org.springframework.boot.test.context.runner.ReactiveWebApplicationContextRunner;
import org.springframework.boot.testsupport.rule.OutputCapture;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.test.web.reactive.server.HeaderAssertions;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
Expand All @@ -54,6 +60,20 @@
*/
public class DefaultErrorWebExceptionHandlerIntegrationTests {

private static final Set<String> RESPONSE_CONTENT_HEADERS;

static {
Set<String> headers = new HashSet<>();
headers.add(HttpHeaders.CONTENT_ENCODING);
headers.add(HttpHeaders.CONTENT_DISPOSITION);
headers.add(HttpHeaders.CONTENT_LANGUAGE);
headers.add(HttpHeaders.CONTENT_LENGTH);
headers.add(HttpHeaders.CONTENT_LOCATION);
headers.add(HttpHeaders.CONTENT_RANGE);
headers.add(HttpHeaders.CONTENT_TYPE);
RESPONSE_CONTENT_HEADERS = Collections.unmodifiableSet(headers);
}

@Rule
public OutputCapture outputCapture = new OutputCapture();

Expand Down Expand Up @@ -256,6 +276,18 @@ public void invalidAcceptMediaType() {
});
}

@Test
public void contentHeadersWasCleared() {
this.contextRunner.run((context) -> {
WebTestClient client = WebTestClient.bindToApplicationContext(context).build();
HeaderAssertions headerAssertions = client.get().uri("/contentHeader").exchange().expectStatus()
.isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR).expectHeader();
RESPONSE_CONTENT_HEADERS.stream()
.filter((h) -> !h.equals(HttpHeaders.CONTENT_TYPE) && !h.equals(HttpHeaders.CONTENT_LENGTH))
.forEach(headerAssertions::doesNotExist);
});
}

@Configuration
public static class Application {

Expand All @@ -278,6 +310,15 @@ public Mono<Void> commit(ServerWebExchange exchange) {
.then(Mono.error(new IllegalStateException("already committed!")));
}

@GetMapping("/contentHeader")
public Mono<Void> contentHeader(ServerWebExchange exchange) {
HttpHeaders headers = exchange.getResponse().getHeaders();
for (String contentHeader : RESPONSE_CONTENT_HEADERS) {
headers.set(contentHeader, "value");
}
throw new IllegalStateException("Expected!");
}

@GetMapping("/html")
public String htmlEscape() {
throw new IllegalStateException("<script>");
Expand Down