From 02b53300faad0f3e65ad29ac7699787d2bd0ca81 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Fri, 8 May 2020 09:37:37 +0100 Subject: [PATCH] HttpHeaders#equals handles wrapping correctly Closes gh-25034 --- .../org/springframework/http/HttpHeaders.java | 32 +++++++++---------- .../http/ReadOnlyHttpHeaders.java | 6 ++-- .../http/HttpHeadersTests.java | 8 +++++ 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java index eaf831696a83..ff1e9b0cbd94 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java +++ b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java @@ -1747,8 +1747,14 @@ public boolean equals(@Nullable Object other) { if (!(other instanceof HttpHeaders)) { return false; } - HttpHeaders otherHeaders = (HttpHeaders) other; - return this.headers.equals(otherHeaders.headers); + return unwrap(this).equals(unwrap((HttpHeaders) other)); + } + + private static MultiValueMap unwrap(HttpHeaders headers) { + while (headers.headers instanceof HttpHeaders) { + headers = (HttpHeaders) headers.headers; + } + return headers.headers; } @Override @@ -1763,20 +1769,17 @@ public String toString() { /** - * Return an {@code HttpHeaders} object that can only be read, not written to. + * Apply a read-only {@code HttpHeaders} wrapper around the given headers. */ - public static HttpHeaders readOnlyHttpHeaders(HttpHeaders headers) { + public static HttpHeaders readOnlyHttpHeaders(MultiValueMap headers) { Assert.notNull(headers, "HttpHeaders must not be null"); - if (headers instanceof ReadOnlyHttpHeaders) { - return headers; - } - else { - return new ReadOnlyHttpHeaders(headers); - } + return (headers instanceof ReadOnlyHttpHeaders ? + (HttpHeaders) headers : new ReadOnlyHttpHeaders(headers)); } /** - * Return an {@code HttpHeaders} object that can be read and written to. + * Remove any read-only wrapper that may have been previously applied around + * the given headers via {@link #readOnlyHttpHeaders(MultiValueMap)}. * @since 5.1.1 */ public static HttpHeaders writableHttpHeaders(HttpHeaders headers) { @@ -1784,12 +1787,7 @@ public static HttpHeaders writableHttpHeaders(HttpHeaders headers) { if (headers == EMPTY) { return new HttpHeaders(); } - else if (headers instanceof ReadOnlyHttpHeaders) { - return new HttpHeaders(headers.headers); - } - else { - return headers; - } + return (headers instanceof ReadOnlyHttpHeaders ? new HttpHeaders(headers.headers) : headers); } /** diff --git a/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java b/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java index 3fb73408630c..299ff406f9b0 100644 --- a/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java +++ b/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,8 +46,8 @@ class ReadOnlyHttpHeaders extends HttpHeaders { private List cachedAccept; - ReadOnlyHttpHeaders(HttpHeaders headers) { - super(headers.headers); + ReadOnlyHttpHeaders(MultiValueMap headers) { + super(headers); } diff --git a/spring-web/src/test/java/org/springframework/http/HttpHeadersTests.java b/spring-web/src/test/java/org/springframework/http/HttpHeadersTests.java index 24da0fee600b..f9bebd095a09 100644 --- a/spring-web/src/test/java/org/springframework/http/HttpHeadersTests.java +++ b/spring-web/src/test/java/org/springframework/http/HttpHeadersTests.java @@ -703,4 +703,12 @@ public void readOnlyHttpHeadersRetainEntrySetOrder() { assertThat(readOnlyHttpHeaders.entrySet()).extracting(Entry::getKey).containsExactly(expectedKeys); } + @Test // gh-25034 + public void equalsUnwrapsHttpHeaders() { + HttpHeaders headers1 = new HttpHeaders(); + HttpHeaders headers2 = new HttpHeaders(new HttpHeaders(headers1)); + + assertThat(headers1).isEqualTo(headers2); + assertThat(headers2).isEqualTo(headers1); + } }