Skip to content

Commit

Permalink
refactor: Extract http headers utils (#8471)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdelamo committed Dec 9, 2022
1 parent 4fddaa9 commit 91be1e4
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 71 deletions.
Expand Up @@ -34,7 +34,6 @@
import io.micronaut.core.util.ArrayUtils;
import io.micronaut.core.util.CollectionUtils;
import io.micronaut.core.util.StringUtils;
import io.micronaut.core.util.SupplierUtil;
import io.micronaut.http.HttpAttributes;
import io.micronaut.http.HttpResponse;
import io.micronaut.http.HttpResponseWrapper;
Expand Down Expand Up @@ -90,6 +89,7 @@
import io.micronaut.http.sse.Event;
import io.micronaut.http.uri.UriBuilder;
import io.micronaut.http.uri.UriTemplate;
import io.micronaut.http.util.HttpHeadersUtil;
import io.micronaut.jackson.databind.JacksonDatabindMapper;
import io.micronaut.json.JsonMapper;
import io.micronaut.json.codec.JsonMediaTypeCodec;
Expand Down Expand Up @@ -183,9 +183,7 @@
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static io.micronaut.scheduling.instrument.InvocationInstrumenter.NOOP;

/**
Expand All @@ -211,9 +209,6 @@ public class DefaultHttpClient implements
private static final int DEFAULT_HTTP_PORT = 80;
private static final int DEFAULT_HTTPS_PORT = 443;

private static final Supplier<Pattern> HEADER_MASK_PATTERNS = SupplierUtil.memoized(() ->
Pattern.compile(".*(password|cred|cert|key|secret|token|auth|signat).*", Pattern.CASE_INSENSITIVE)
);
/**
* Which headers <i>not</i> to copy from the first request when redirecting to a second request. There doesn't
* appear to be a spec for this. {@link java.net.HttpURLConnection} seems to drop all headers, but that would be a
Expand Down Expand Up @@ -1754,7 +1749,7 @@ private void debugRequest(URI requestURI, io.netty.handler.codec.http.HttpReques

private void traceRequest(io.micronaut.http.HttpRequest<?> request, io.netty.handler.codec.http.HttpRequest nettyRequest) {
HttpHeaders headers = nettyRequest.headers();
traceHeaders(headers);
HttpHeadersUtil.trace(log, headers.names(), headers::getAll);
if (io.micronaut.http.HttpMethod.permitsRequestBody(request.getMethod()) && request.getBody().isPresent() && nettyRequest instanceof FullHttpRequest) {
FullHttpRequest fullHttpRequest = (FullHttpRequest) nettyRequest;
ByteBuf content = fullHttpRequest.content();
Expand All @@ -1778,30 +1773,6 @@ private void traceChunk(ByteBuf content) {
log.trace("----");
}

private void traceHeaders(HttpHeaders headers) {
for (String name : headers.names()) {
boolean isMasked = HEADER_MASK_PATTERNS.get().matcher(name).matches();
List<String> all = headers.getAll(name);
if (all.size() > 1) {
for (String value : all) {
String maskedValue = isMasked ? mask(value) : value;
log.trace("{}: {}", name, maskedValue);
}
} else if (!all.isEmpty()) {
String maskedValue = isMasked ? mask(all.get(0)) : all.get(0);
log.trace("{}: {}", name, maskedValue);
}
}
}

@Nullable
private String mask(@Nullable String value) {
if (value == null) {
return null;
}
return "*MASKED*";
}

private static MediaTypeCodecRegistry createDefaultMediaTypeRegistry() {
JsonMapper mapper = new JacksonDatabindMapper();
ApplicationConfiguration configuration = new ApplicationConfiguration();
Expand Down Expand Up @@ -2116,7 +2087,7 @@ protected void channelReadInstrumented(ChannelHandlerContext ctx, R msg) throws
HttpHeaders headers = msg.headers();
if (log.isTraceEnabled()) {
log.trace("HTTP Client Response Received ({}) for Request: {} {}", msg.status(), finalRequest.getMethodName(), finalRequest.getUri());
traceHeaders(headers);
HttpHeadersUtil.trace(log, headers.names(), headers::getAll);
}
buildResponse(responsePromise, msg, httpStatus);
}
Expand Down
@@ -1,66 +1,64 @@
package io.micronaut.http.client.netty

import ch.qos.logback.classic.Level
import ch.qos.logback.classic.Logger
import io.micronaut.context.ApplicationContext
import io.micronaut.context.annotation.Requires
import io.micronaut.http.HttpRequest
import io.micronaut.http.annotation.Controller
import io.micronaut.http.annotation.Get
import io.micronaut.http.client.HttpClient
import io.micronaut.runtime.server.EmbeddedServer
import jakarta.inject.Singleton
import org.slf4j.Logger
import ch.qos.logback.classic.spi.ILoggingEvent
import ch.qos.logback.core.AppenderBase
import io.micronaut.context.ApplicationContext
import io.netty.handler.codec.http.DefaultHttpHeaders
import org.slf4j.LoggerFactory
import spock.lang.Specification

import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit

class DefaultClientHeaderMaskTest extends Specification {

def "check masking works for #value"() {
def "check mask detects common security headers"() {
given:
def ctx = ApplicationContext.run()
def client = ctx.createBean(DefaultHttpClient, "http://localhost:8080")
EmbeddedServer server = ApplicationContext.run(EmbeddedServer, ["spec.name": "DefaultClientHeaderMaskTest"])
ApplicationContext ctx = server.applicationContext
HttpClient client = ctx.createBean(HttpClient, server.URL)

expect:
client.mask(value) == expected
client instanceof DefaultHttpClient

cleanup:
ctx.close()
when:
MemoryAppender appender = new MemoryAppender()
Logger log = LoggerFactory.getLogger(DefaultHttpClient.class)

where:
value | expected
null | null
"foo" | "*MASKED*"
"Tim Yates" | "*MASKED*"
}
then:
log instanceof ch.qos.logback.classic.Logger

def "check mask detects common security headers"() {
given:
MemoryAppender appender = new MemoryAppender()
Logger logger = (Logger) LoggerFactory.getLogger(DefaultHttpClient.class)
when:
ch.qos.logback.classic.Logger logger = (ch.qos.logback.classic.Logger) log
logger.addAppender(appender)
logger.setLevel(Level.TRACE)
appender.start()

DefaultHttpHeaders headers = new DefaultHttpHeaders()
headers.add("Authorization", "Bearer foo")
headers.add("Proxy-Authorization", "AWS4-HMAC-SHA256 bar")
headers.add("Cookie", "baz")
headers.add("Set-Cookie", "qux")
headers.add("X-Forwarded-For", "quux")
headers.add("X-Forwarded-Host", "quuz")
headers.add("X-Real-IP", "waldo")
headers.add("X-Forwarded-For", "fred")
headers.add("Credential", "foo")
headers.add("Signature", "bar probably secret")
def ctx = ApplicationContext.run()
def client = ctx.createBean(DefaultHttpClient, "http://localhost:8080")

when:
client.traceHeaders(headers)
def response = client.toBlocking().exchange(HttpRequest.GET("/masking").headers {headers ->
headers.add("Authorization", "Bearer foo")
headers.add("Proxy-Authorization", "AWS4-HMAC-SHA256 bar")
headers.add("Cookie", "baz")
headers.add("Set-Cookie", "qux")
headers.add("X-Forwarded-For", "quux")
headers.add("X-Forwarded-Host", "quuz")
headers.add("X-Real-IP", "waldo")
headers.add("X-Forwarded-For", "fred")
headers.add("Credential", "foo")
headers.add("Signature", "bar probably secret")
}, String)

then:
appender.events.size() == 10
appender.events.join("\n") == """Authorization: *MASKED*
response.body() == "ok"
appender.events.join("\n").contains("""Authorization: *MASKED*
|Proxy-Authorization: *MASKED*
|Cookie: baz
|Set-Cookie: qux
Expand All @@ -69,11 +67,21 @@ class DefaultClientHeaderMaskTest extends Specification {
|X-Forwarded-Host: quuz
|X-Real-IP: waldo
|Credential: *MASKED*
|Signature: *MASKED*""".stripMargin()
|Signature: *MASKED*""".stripMargin())

cleanup:
ctx.close()
appender.stop()
ctx.close()
}

@Requires(property = "spec.name", value = "DefaultClientHeaderMaskTest")
@Controller("/masking")
@Singleton
static class MaskedController {
@Get
String get() {
"ok"
}
}

static class MemoryAppender extends AppenderBase<ILoggingEvent> {
Expand Down
1 change: 1 addition & 0 deletions http/build.gradle
Expand Up @@ -18,6 +18,7 @@ dependencies {
testAnnotationProcessor project(":inject-java")
testImplementation project(":inject")
testImplementation project(":runtime")
testImplementation(libs.managed.logback)
}

tasks.named("compileKotlin") {
Expand Down
95 changes: 95 additions & 0 deletions http/src/main/java/io/micronaut/http/util/HttpHeadersUtil.java
@@ -0,0 +1,95 @@
/*
* Copyright 2017-2022 original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.micronaut.http.util;

import io.micronaut.core.annotation.NonNull;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.util.SupplierUtil;
import io.micronaut.http.HttpHeaders;
import org.slf4j.Logger;

import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.regex.Pattern;

/**
* Utility class to work with {@link io.micronaut.http.HttpHeaders} or HTTP Headers.
* @author Sergio del Amo
* @since 3.8.0
*/
public final class HttpHeadersUtil {
private static final Supplier<Pattern> HEADER_MASK_PATTERNS = SupplierUtil.memoized(() ->
Pattern.compile(".*(password|cred|cert|key|secret|token|auth|signat).*", Pattern.CASE_INSENSITIVE)
);

private HttpHeadersUtil() {

}

/**
* Trace HTTP Headers.
* @param log Logger
* @param httpHeaders HTTP Headers
*/
public static void trace(@NonNull Logger log,
@NonNull HttpHeaders httpHeaders) {
trace(log, httpHeaders.names(), httpHeaders::getAll);
}

/**
* Trace HTTP Headers.
* @param log Logger
* @param names HTTP Header names
* @param getAllHeaders Function to get all the header values for a particular header name
*/
public static void trace(@NonNull Logger log,
@NonNull Set<String> names,
@NonNull Function<String, List<String>> getAllHeaders) {
names.forEach(name -> trace(log, name, getAllHeaders));
}

/**
* Trace HTTP Headers.
* @param log Logger
* @param name HTTP Header name
* @param getAllHeaders Function to get all the header values for a particular header name
*/
public static void trace(@NonNull Logger log,
@NonNull String name,
@NonNull Function<String, List<String>> getAllHeaders) {
boolean isMasked = HEADER_MASK_PATTERNS.get().matcher(name).matches();
List<String> all = getAllHeaders.apply(name);
if (all.size() > 1) {
for (String value : all) {
String maskedValue = isMasked ? mask(value) : value;
log.trace("{}: {}", name, maskedValue);
}
} else if (!all.isEmpty()) {
String maskedValue = isMasked ? mask(all.get(0)) : all.get(0);
log.trace("{}: {}", name, maskedValue);
}
}

@Nullable
private static String mask(@Nullable String value) {
if (value == null) {
return null;
}
return "*MASKED*";
}
}
@@ -0,0 +1,78 @@
package io.micronaut.http.util

import ch.qos.logback.classic.Level
import ch.qos.logback.classic.spi.ILoggingEvent
import ch.qos.logback.core.AppenderBase
import io.micronaut.http.HttpHeaders
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import spock.lang.Specification

import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue

class HttpHeadersUtilSpec extends Specification {
def "check masking works for #value"() {
expect:
expected == HttpHeadersUtil.mask(value)

where:
value | expected
null | null
"foo" | "*MASKED*"
"Tim Yates" | "*MASKED*"
}

def "check mask detects common security headers"() {
given:
MemoryAppender appender = new MemoryAppender()
Logger log = LoggerFactory.getLogger(HttpHeadersUtilSpec.class)

expect:
log instanceof ch.qos.logback.classic.Logger

when:
ch.qos.logback.classic.Logger logger = (ch.qos.logback.classic.Logger) log
logger.addAppender(appender)
logger.setLevel(Level.TRACE)
appender.start()

HttpHeaders headers = new MockHttpHeaders([
"Authorization": ["Bearer foo"],
"Proxy-Authorization": ["AWS4-HMAC-SHA256 bar"],
"Cookie": ["baz"],
"Set-Cookie": ["qux"],
"X-Forwarded-For": ["quux", "fred"],
"X-Forwarded-Host": ["quuz"],
"X-Real-IP": ["waldo"],
"Credential": ["foo"],
"Signature": ["bar probably secret"]])

HttpHeadersUtil.trace(log, headers)

then:
appender.events.size() == headers.values().collect { it -> it.size() }.sum()
appender.events.contains("Authorization: *MASKED*")
appender.events.contains("Cookie: baz")
appender.events.contains("Credential: *MASKED*")
appender.events.contains("Set-Cookie: qux")
appender.events.contains("Proxy-Authorization: *MASKED*")
appender.events.contains("Signature: *MASKED*")
appender.events.contains("X-Forwarded-For: quux")
appender.events.contains("X-Forwarded-For: fred")
appender.events.contains("X-Forwarded-Host: quuz")
appender.events.contains("X-Real-IP: waldo")

cleanup:
appender.stop()
}

static class MemoryAppender extends AppenderBase<ILoggingEvent> {
final BlockingQueue<String> events = new LinkedBlockingQueue<>()

@Override
protected void append(ILoggingEvent e) {
events.add(e.formattedMessage)
}
}
}

0 comments on commit 91be1e4

Please sign in to comment.