diff --git a/http-client/src/main/java/io/micronaut/http/client/netty/DefaultHttpClient.java b/http-client/src/main/java/io/micronaut/http/client/netty/DefaultHttpClient.java index 2982b185ee1..509107c861c 100644 --- a/http-client/src/main/java/io/micronaut/http/client/netty/DefaultHttpClient.java +++ b/http-client/src/main/java/io/micronaut/http/client/netty/DefaultHttpClient.java @@ -34,7 +34,6 @@ import io.micronaut.core.util.ArrayUtils; import io.micronaut.core.util.CollectionUtils; import io.micronaut.core.util.StringUtils; -import io.micronaut.core.util.SupplierUtil; import io.micronaut.http.HttpAttributes; import io.micronaut.http.HttpResponse; import io.micronaut.http.HttpResponseWrapper; @@ -90,6 +89,7 @@ import io.micronaut.http.sse.Event; import io.micronaut.http.uri.UriBuilder; import io.micronaut.http.uri.UriTemplate; +import io.micronaut.http.util.HttpHeadersUtil; import io.micronaut.jackson.databind.JacksonDatabindMapper; import io.micronaut.json.JsonMapper; import io.micronaut.json.codec.JsonMediaTypeCodec; @@ -183,9 +183,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; -import java.util.regex.Pattern; import java.util.stream.Collectors; - import static io.micronaut.scheduling.instrument.InvocationInstrumenter.NOOP; /** @@ -211,9 +209,6 @@ public class DefaultHttpClient implements private static final int DEFAULT_HTTP_PORT = 80; private static final int DEFAULT_HTTPS_PORT = 443; - private static final Supplier HEADER_MASK_PATTERNS = SupplierUtil.memoized(() -> - Pattern.compile(".*(password|cred|cert|key|secret|token|auth|signat).*", Pattern.CASE_INSENSITIVE) - ); /** * Which headers not to copy from the first request when redirecting to a second request. There doesn't * appear to be a spec for this. {@link java.net.HttpURLConnection} seems to drop all headers, but that would be a @@ -1754,7 +1749,7 @@ private void debugRequest(URI requestURI, io.netty.handler.codec.http.HttpReques private void traceRequest(io.micronaut.http.HttpRequest request, io.netty.handler.codec.http.HttpRequest nettyRequest) { HttpHeaders headers = nettyRequest.headers(); - traceHeaders(headers); + HttpHeadersUtil.trace(log, headers.names(), headers::getAll); if (io.micronaut.http.HttpMethod.permitsRequestBody(request.getMethod()) && request.getBody().isPresent() && nettyRequest instanceof FullHttpRequest) { FullHttpRequest fullHttpRequest = (FullHttpRequest) nettyRequest; ByteBuf content = fullHttpRequest.content(); @@ -1778,30 +1773,6 @@ private void traceChunk(ByteBuf content) { log.trace("----"); } - private void traceHeaders(HttpHeaders headers) { - for (String name : headers.names()) { - boolean isMasked = HEADER_MASK_PATTERNS.get().matcher(name).matches(); - List all = headers.getAll(name); - if (all.size() > 1) { - for (String value : all) { - String maskedValue = isMasked ? mask(value) : value; - log.trace("{}: {}", name, maskedValue); - } - } else if (!all.isEmpty()) { - String maskedValue = isMasked ? mask(all.get(0)) : all.get(0); - log.trace("{}: {}", name, maskedValue); - } - } - } - - @Nullable - private String mask(@Nullable String value) { - if (value == null) { - return null; - } - return "*MASKED*"; - } - private static MediaTypeCodecRegistry createDefaultMediaTypeRegistry() { JsonMapper mapper = new JacksonDatabindMapper(); ApplicationConfiguration configuration = new ApplicationConfiguration(); @@ -2116,7 +2087,7 @@ protected void channelReadInstrumented(ChannelHandlerContext ctx, R msg) throws HttpHeaders headers = msg.headers(); if (log.isTraceEnabled()) { log.trace("HTTP Client Response Received ({}) for Request: {} {}", msg.status(), finalRequest.getMethodName(), finalRequest.getUri()); - traceHeaders(headers); + HttpHeadersUtil.trace(log, headers.names(), headers::getAll); } buildResponse(responsePromise, msg, httpStatus); } diff --git a/http-client/src/test/groovy/io/micronaut/http/client/netty/DefaultClientHeaderMaskTest.groovy b/http-client/src/test/groovy/io/micronaut/http/client/netty/DefaultClientHeaderMaskTest.groovy index 8f3a693ed91..64134aa22b8 100644 --- a/http-client/src/test/groovy/io/micronaut/http/client/netty/DefaultClientHeaderMaskTest.groovy +++ b/http-client/src/test/groovy/io/micronaut/http/client/netty/DefaultClientHeaderMaskTest.groovy @@ -1,66 +1,64 @@ package io.micronaut.http.client.netty import ch.qos.logback.classic.Level -import ch.qos.logback.classic.Logger +import io.micronaut.context.ApplicationContext +import io.micronaut.context.annotation.Requires +import io.micronaut.http.HttpRequest +import io.micronaut.http.annotation.Controller +import io.micronaut.http.annotation.Get +import io.micronaut.http.client.HttpClient +import io.micronaut.runtime.server.EmbeddedServer +import jakarta.inject.Singleton +import org.slf4j.Logger import ch.qos.logback.classic.spi.ILoggingEvent import ch.qos.logback.core.AppenderBase -import io.micronaut.context.ApplicationContext import io.netty.handler.codec.http.DefaultHttpHeaders import org.slf4j.LoggerFactory import spock.lang.Specification import java.util.concurrent.BlockingQueue import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.TimeUnit class DefaultClientHeaderMaskTest extends Specification { - def "check masking works for #value"() { + def "check mask detects common security headers"() { given: - def ctx = ApplicationContext.run() - def client = ctx.createBean(DefaultHttpClient, "http://localhost:8080") + EmbeddedServer server = ApplicationContext.run(EmbeddedServer, ["spec.name": "DefaultClientHeaderMaskTest"]) + ApplicationContext ctx = server.applicationContext + HttpClient client = ctx.createBean(HttpClient, server.URL) expect: - client.mask(value) == expected + client instanceof DefaultHttpClient - cleanup: - ctx.close() + when: + MemoryAppender appender = new MemoryAppender() + Logger log = LoggerFactory.getLogger(DefaultHttpClient.class) - where: - value | expected - null | null - "foo" | "*MASKED*" - "Tim Yates" | "*MASKED*" - } + then: + log instanceof ch.qos.logback.classic.Logger - def "check mask detects common security headers"() { - given: - MemoryAppender appender = new MemoryAppender() - Logger logger = (Logger) LoggerFactory.getLogger(DefaultHttpClient.class) + when: + ch.qos.logback.classic.Logger logger = (ch.qos.logback.classic.Logger) log logger.addAppender(appender) logger.setLevel(Level.TRACE) appender.start() - DefaultHttpHeaders headers = new DefaultHttpHeaders() - headers.add("Authorization", "Bearer foo") - headers.add("Proxy-Authorization", "AWS4-HMAC-SHA256 bar") - headers.add("Cookie", "baz") - headers.add("Set-Cookie", "qux") - headers.add("X-Forwarded-For", "quux") - headers.add("X-Forwarded-Host", "quuz") - headers.add("X-Real-IP", "waldo") - headers.add("X-Forwarded-For", "fred") - headers.add("Credential", "foo") - headers.add("Signature", "bar probably secret") - def ctx = ApplicationContext.run() - def client = ctx.createBean(DefaultHttpClient, "http://localhost:8080") - - when: - client.traceHeaders(headers) + def response = client.toBlocking().exchange(HttpRequest.GET("/masking").headers {headers -> + headers.add("Authorization", "Bearer foo") + headers.add("Proxy-Authorization", "AWS4-HMAC-SHA256 bar") + headers.add("Cookie", "baz") + headers.add("Set-Cookie", "qux") + headers.add("X-Forwarded-For", "quux") + headers.add("X-Forwarded-Host", "quuz") + headers.add("X-Real-IP", "waldo") + headers.add("X-Forwarded-For", "fred") + headers.add("Credential", "foo") + headers.add("Signature", "bar probably secret") + }, String) then: - appender.events.size() == 10 - appender.events.join("\n") == """Authorization: *MASKED* + response.body() == "ok" + appender.events.join("\n").contains("""Authorization: *MASKED* |Proxy-Authorization: *MASKED* |Cookie: baz |Set-Cookie: qux @@ -69,11 +67,21 @@ class DefaultClientHeaderMaskTest extends Specification { |X-Forwarded-Host: quuz |X-Real-IP: waldo |Credential: *MASKED* - |Signature: *MASKED*""".stripMargin() + |Signature: *MASKED*""".stripMargin()) cleanup: - ctx.close() appender.stop() + ctx.close() + } + + @Requires(property = "spec.name", value = "DefaultClientHeaderMaskTest") + @Controller("/masking") + @Singleton + static class MaskedController { + @Get + String get() { + "ok" + } } static class MemoryAppender extends AppenderBase { diff --git a/http/build.gradle b/http/build.gradle index 6e41cd391d2..0d33ef4f112 100644 --- a/http/build.gradle +++ b/http/build.gradle @@ -18,6 +18,7 @@ dependencies { testAnnotationProcessor project(":inject-java") testImplementation project(":inject") testImplementation project(":runtime") + testImplementation(libs.managed.logback) } tasks.named("compileKotlin") { diff --git a/http/src/main/java/io/micronaut/http/util/HttpHeadersUtil.java b/http/src/main/java/io/micronaut/http/util/HttpHeadersUtil.java new file mode 100644 index 00000000000..efcbc4213b2 --- /dev/null +++ b/http/src/main/java/io/micronaut/http/util/HttpHeadersUtil.java @@ -0,0 +1,95 @@ +/* + * Copyright 2017-2022 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.http.util; + +import io.micronaut.core.annotation.NonNull; +import io.micronaut.core.annotation.Nullable; +import io.micronaut.core.util.SupplierUtil; +import io.micronaut.http.HttpHeaders; +import org.slf4j.Logger; + +import java.util.List; +import java.util.Set; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.regex.Pattern; + +/** + * Utility class to work with {@link io.micronaut.http.HttpHeaders} or HTTP Headers. + * @author Sergio del Amo + * @since 3.8.0 + */ +public final class HttpHeadersUtil { + private static final Supplier HEADER_MASK_PATTERNS = SupplierUtil.memoized(() -> + Pattern.compile(".*(password|cred|cert|key|secret|token|auth|signat).*", Pattern.CASE_INSENSITIVE) + ); + + private HttpHeadersUtil() { + + } + + /** + * Trace HTTP Headers. + * @param log Logger + * @param httpHeaders HTTP Headers + */ + public static void trace(@NonNull Logger log, + @NonNull HttpHeaders httpHeaders) { + trace(log, httpHeaders.names(), httpHeaders::getAll); + } + + /** + * Trace HTTP Headers. + * @param log Logger + * @param names HTTP Header names + * @param getAllHeaders Function to get all the header values for a particular header name + */ + public static void trace(@NonNull Logger log, + @NonNull Set names, + @NonNull Function> getAllHeaders) { + names.forEach(name -> trace(log, name, getAllHeaders)); + } + + /** + * Trace HTTP Headers. + * @param log Logger + * @param name HTTP Header name + * @param getAllHeaders Function to get all the header values for a particular header name + */ + public static void trace(@NonNull Logger log, + @NonNull String name, + @NonNull Function> getAllHeaders) { + boolean isMasked = HEADER_MASK_PATTERNS.get().matcher(name).matches(); + List all = getAllHeaders.apply(name); + if (all.size() > 1) { + for (String value : all) { + String maskedValue = isMasked ? mask(value) : value; + log.trace("{}: {}", name, maskedValue); + } + } else if (!all.isEmpty()) { + String maskedValue = isMasked ? mask(all.get(0)) : all.get(0); + log.trace("{}: {}", name, maskedValue); + } + } + + @Nullable + private static String mask(@Nullable String value) { + if (value == null) { + return null; + } + return "*MASKED*"; + } +} diff --git a/http/src/test/groovy/io/micronaut/http/util/HttpHeadersUtilSpec.groovy b/http/src/test/groovy/io/micronaut/http/util/HttpHeadersUtilSpec.groovy new file mode 100644 index 00000000000..54098639c29 --- /dev/null +++ b/http/src/test/groovy/io/micronaut/http/util/HttpHeadersUtilSpec.groovy @@ -0,0 +1,78 @@ +package io.micronaut.http.util + +import ch.qos.logback.classic.Level +import ch.qos.logback.classic.spi.ILoggingEvent +import ch.qos.logback.core.AppenderBase +import io.micronaut.http.HttpHeaders +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import spock.lang.Specification + +import java.util.concurrent.BlockingQueue +import java.util.concurrent.LinkedBlockingQueue + +class HttpHeadersUtilSpec extends Specification { + def "check masking works for #value"() { + expect: + expected == HttpHeadersUtil.mask(value) + + where: + value | expected + null | null + "foo" | "*MASKED*" + "Tim Yates" | "*MASKED*" + } + + def "check mask detects common security headers"() { + given: + MemoryAppender appender = new MemoryAppender() + Logger log = LoggerFactory.getLogger(HttpHeadersUtilSpec.class) + + expect: + log instanceof ch.qos.logback.classic.Logger + + when: + ch.qos.logback.classic.Logger logger = (ch.qos.logback.classic.Logger) log + logger.addAppender(appender) + logger.setLevel(Level.TRACE) + appender.start() + + HttpHeaders headers = new MockHttpHeaders([ + "Authorization": ["Bearer foo"], + "Proxy-Authorization": ["AWS4-HMAC-SHA256 bar"], + "Cookie": ["baz"], + "Set-Cookie": ["qux"], + "X-Forwarded-For": ["quux", "fred"], + "X-Forwarded-Host": ["quuz"], + "X-Real-IP": ["waldo"], + "Credential": ["foo"], + "Signature": ["bar probably secret"]]) + + HttpHeadersUtil.trace(log, headers) + + then: + appender.events.size() == headers.values().collect { it -> it.size() }.sum() + appender.events.contains("Authorization: *MASKED*") + appender.events.contains("Cookie: baz") + appender.events.contains("Credential: *MASKED*") + appender.events.contains("Set-Cookie: qux") + appender.events.contains("Proxy-Authorization: *MASKED*") + appender.events.contains("Signature: *MASKED*") + appender.events.contains("X-Forwarded-For: quux") + appender.events.contains("X-Forwarded-For: fred") + appender.events.contains("X-Forwarded-Host: quuz") + appender.events.contains("X-Real-IP: waldo") + + cleanup: + appender.stop() + } + + static class MemoryAppender extends AppenderBase { + final BlockingQueue events = new LinkedBlockingQueue<>() + + @Override + protected void append(ILoggingEvent e) { + events.add(e.formattedMessage) + } + } +} diff --git a/http/src/test/groovy/io/micronaut/http/util/MockHttpHeaders.java b/http/src/test/groovy/io/micronaut/http/util/MockHttpHeaders.java new file mode 100644 index 00000000000..29177e632f0 --- /dev/null +++ b/http/src/test/groovy/io/micronaut/http/util/MockHttpHeaders.java @@ -0,0 +1,93 @@ +/* + * Copyright 2017-2020 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.http.util; + +import io.micronaut.core.annotation.Nullable; +import io.micronaut.core.convert.ArgumentConversionContext; +import io.micronaut.core.convert.ConversionService; +import io.micronaut.http.MutableHttpHeaders; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +public class MockHttpHeaders implements MutableHttpHeaders { + + private final Map> headers; + + public MockHttpHeaders(Map> headers) { + this.headers = headers; + } + + @Override + public MutableHttpHeaders add(CharSequence header, CharSequence value) { + headers.compute(header, (key, val) -> { + if (val == null) { + val = new ArrayList<>(); + } + val.add(value.toString()); + return val; + }); + return this; + } + + @Override + public MutableHttpHeaders remove(CharSequence header) { + headers.remove(header); + return this; + } + + @Override + public List getAll(CharSequence name) { + List values = headers.get(name); + if (values == null) { + return Collections.emptyList(); + } else { + return values; + } + } + + @Nullable + @Override + public String get(CharSequence name) { + List values = headers.get(name); + if (values == null || values.isEmpty()) { + return null; + } else { + return values.get(0); + } + } + + @Override + public Set names() { + return headers.keySet().stream().map(CharSequence::toString).collect(Collectors.toSet()); + } + + @Override + public Collection> values() { + return headers.values(); + } + + @Override + public Optional get(CharSequence name, ArgumentConversionContext conversionContext) { + return ConversionService.SHARED.convert(get(name), conversionContext); + } +}