Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Extract http headers utils #8471

Merged
merged 4 commits into from Dec 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -34,7 +34,6 @@
import io.micronaut.core.util.ArrayUtils;
import io.micronaut.core.util.CollectionUtils;
import io.micronaut.core.util.StringUtils;
import io.micronaut.core.util.SupplierUtil;
import io.micronaut.http.HttpAttributes;
import io.micronaut.http.HttpResponse;
import io.micronaut.http.HttpResponseWrapper;
Expand Down Expand Up @@ -90,6 +89,7 @@
import io.micronaut.http.sse.Event;
import io.micronaut.http.uri.UriBuilder;
import io.micronaut.http.uri.UriTemplate;
import io.micronaut.http.util.HttpHeadersUtil;
import io.micronaut.jackson.databind.JacksonDatabindMapper;
import io.micronaut.json.JsonMapper;
import io.micronaut.json.codec.JsonMediaTypeCodec;
Expand Down Expand Up @@ -183,9 +183,7 @@
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import static io.micronaut.scheduling.instrument.InvocationInstrumenter.NOOP;

/**
Expand All @@ -211,9 +209,6 @@ public class DefaultHttpClient implements
private static final int DEFAULT_HTTP_PORT = 80;
private static final int DEFAULT_HTTPS_PORT = 443;

private static final Supplier<Pattern> HEADER_MASK_PATTERNS = SupplierUtil.memoized(() ->
Pattern.compile(".*(password|cred|cert|key|secret|token|auth|signat).*", Pattern.CASE_INSENSITIVE)
);
/**
* Which headers <i>not</i> to copy from the first request when redirecting to a second request. There doesn't
* appear to be a spec for this. {@link java.net.HttpURLConnection} seems to drop all headers, but that would be a
Expand Down Expand Up @@ -1754,7 +1749,7 @@ private void debugRequest(URI requestURI, io.netty.handler.codec.http.HttpReques

private void traceRequest(io.micronaut.http.HttpRequest<?> request, io.netty.handler.codec.http.HttpRequest nettyRequest) {
HttpHeaders headers = nettyRequest.headers();
traceHeaders(headers);
HttpHeadersUtil.trace(log, headers.names(), headers::getAll);
if (io.micronaut.http.HttpMethod.permitsRequestBody(request.getMethod()) && request.getBody().isPresent() && nettyRequest instanceof FullHttpRequest) {
FullHttpRequest fullHttpRequest = (FullHttpRequest) nettyRequest;
ByteBuf content = fullHttpRequest.content();
Expand All @@ -1778,30 +1773,6 @@ private void traceChunk(ByteBuf content) {
log.trace("----");
}

private void traceHeaders(HttpHeaders headers) {
for (String name : headers.names()) {
boolean isMasked = HEADER_MASK_PATTERNS.get().matcher(name).matches();
List<String> all = headers.getAll(name);
if (all.size() > 1) {
for (String value : all) {
String maskedValue = isMasked ? mask(value) : value;
log.trace("{}: {}", name, maskedValue);
}
} else if (!all.isEmpty()) {
String maskedValue = isMasked ? mask(all.get(0)) : all.get(0);
log.trace("{}: {}", name, maskedValue);
}
}
}

@Nullable
private String mask(@Nullable String value) {
if (value == null) {
return null;
}
return "*MASKED*";
}

private static MediaTypeCodecRegistry createDefaultMediaTypeRegistry() {
JsonMapper mapper = new JacksonDatabindMapper();
ApplicationConfiguration configuration = new ApplicationConfiguration();
Expand Down Expand Up @@ -2116,7 +2087,7 @@ protected void channelReadInstrumented(ChannelHandlerContext ctx, R msg) throws
HttpHeaders headers = msg.headers();
if (log.isTraceEnabled()) {
log.trace("HTTP Client Response Received ({}) for Request: {} {}", msg.status(), finalRequest.getMethodName(), finalRequest.getUri());
traceHeaders(headers);
HttpHeadersUtil.trace(log, headers.names(), headers::getAll);
}
buildResponse(responsePromise, msg, httpStatus);
}
Expand Down
@@ -1,66 +1,64 @@
package io.micronaut.http.client.netty

sdelamo marked this conversation as resolved.
Show resolved Hide resolved
import ch.qos.logback.classic.Level
import ch.qos.logback.classic.Logger
import io.micronaut.context.ApplicationContext
import io.micronaut.context.annotation.Requires
import io.micronaut.http.HttpRequest
import io.micronaut.http.annotation.Controller
import io.micronaut.http.annotation.Get
import io.micronaut.http.client.HttpClient
import io.micronaut.runtime.server.EmbeddedServer
import jakarta.inject.Singleton
import org.slf4j.Logger
import ch.qos.logback.classic.spi.ILoggingEvent
import ch.qos.logback.core.AppenderBase
import io.micronaut.context.ApplicationContext
import io.netty.handler.codec.http.DefaultHttpHeaders
import org.slf4j.LoggerFactory
import spock.lang.Specification

import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit

class DefaultClientHeaderMaskTest extends Specification {

def "check masking works for #value"() {
def "check mask detects common security headers"() {
given:
def ctx = ApplicationContext.run()
def client = ctx.createBean(DefaultHttpClient, "http://localhost:8080")
EmbeddedServer server = ApplicationContext.run(EmbeddedServer, ["spec.name": "DefaultClientHeaderMaskTest"])
ApplicationContext ctx = server.applicationContext
HttpClient client = ctx.createBean(HttpClient, server.URL)

expect:
client.mask(value) == expected
client instanceof DefaultHttpClient

cleanup:
ctx.close()
when:
MemoryAppender appender = new MemoryAppender()
Logger log = LoggerFactory.getLogger(DefaultHttpClient.class)

where:
value | expected
null | null
"foo" | "*MASKED*"
"Tim Yates" | "*MASKED*"
}
then:
log instanceof ch.qos.logback.classic.Logger

def "check mask detects common security headers"() {
given:
MemoryAppender appender = new MemoryAppender()
Logger logger = (Logger) LoggerFactory.getLogger(DefaultHttpClient.class)
when:
ch.qos.logback.classic.Logger logger = (ch.qos.logback.classic.Logger) log
logger.addAppender(appender)
logger.setLevel(Level.TRACE)
appender.start()

DefaultHttpHeaders headers = new DefaultHttpHeaders()
headers.add("Authorization", "Bearer foo")
headers.add("Proxy-Authorization", "AWS4-HMAC-SHA256 bar")
headers.add("Cookie", "baz")
headers.add("Set-Cookie", "qux")
headers.add("X-Forwarded-For", "quux")
headers.add("X-Forwarded-Host", "quuz")
headers.add("X-Real-IP", "waldo")
headers.add("X-Forwarded-For", "fred")
headers.add("Credential", "foo")
headers.add("Signature", "bar probably secret")
def ctx = ApplicationContext.run()
def client = ctx.createBean(DefaultHttpClient, "http://localhost:8080")

when:
client.traceHeaders(headers)
def response = client.toBlocking().exchange(HttpRequest.GET("/masking").headers {headers ->
headers.add("Authorization", "Bearer foo")
headers.add("Proxy-Authorization", "AWS4-HMAC-SHA256 bar")
headers.add("Cookie", "baz")
headers.add("Set-Cookie", "qux")
headers.add("X-Forwarded-For", "quux")
headers.add("X-Forwarded-Host", "quuz")
headers.add("X-Real-IP", "waldo")
headers.add("X-Forwarded-For", "fred")
headers.add("Credential", "foo")
headers.add("Signature", "bar probably secret")
}, String)

then:
appender.events.size() == 10
appender.events.join("\n") == """Authorization: *MASKED*
response.body() == "ok"
appender.events.join("\n").contains("""Authorization: *MASKED*
|Proxy-Authorization: *MASKED*
|Cookie: baz
|Set-Cookie: qux
Expand All @@ -69,11 +67,21 @@ class DefaultClientHeaderMaskTest extends Specification {
|X-Forwarded-Host: quuz
|X-Real-IP: waldo
|Credential: *MASKED*
|Signature: *MASKED*""".stripMargin()
|Signature: *MASKED*""".stripMargin())

cleanup:
ctx.close()
appender.stop()
ctx.close()
}

@Requires(property = "spec.name", value = "DefaultClientHeaderMaskTest")
@Controller("/masking")
@Singleton
static class MaskedController {
@Get
String get() {
"ok"
}
}

static class MemoryAppender extends AppenderBase<ILoggingEvent> {
Expand Down
1 change: 1 addition & 0 deletions http/build.gradle
Expand Up @@ -18,6 +18,7 @@ dependencies {
testAnnotationProcessor project(":inject-java")
testImplementation project(":inject")
testImplementation project(":runtime")
testImplementation(libs.managed.logback)
}

tasks.named("compileKotlin") {
Expand Down
95 changes: 95 additions & 0 deletions http/src/main/java/io/micronaut/http/util/HttpHeadersUtil.java
@@ -0,0 +1,95 @@
/*
* Copyright 2017-2022 original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.micronaut.http.util;

import io.micronaut.core.annotation.NonNull;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.core.util.SupplierUtil;
import io.micronaut.http.HttpHeaders;
import org.slf4j.Logger;

import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.regex.Pattern;

/**
* Utility class to work with {@link io.micronaut.http.HttpHeaders} or HTTP Headers.
* @author Sergio del Amo
* @since 3.8.0
*/
public final class HttpHeadersUtil {
private static final Supplier<Pattern> HEADER_MASK_PATTERNS = SupplierUtil.memoized(() ->
Pattern.compile(".*(password|cred|cert|key|secret|token|auth|signat).*", Pattern.CASE_INSENSITIVE)
);

private HttpHeadersUtil() {

}

/**
* Trace HTTP Headers.
* @param log Logger
* @param httpHeaders HTTP Headers
*/
public static void trace(@NonNull Logger log,
@NonNull HttpHeaders httpHeaders) {
trace(log, httpHeaders.names(), httpHeaders::getAll);
}

/**
* Trace HTTP Headers.
* @param log Logger
* @param names HTTP Header names
* @param getAllHeaders Function to get all the header values for a particular header name
*/
public static void trace(@NonNull Logger log,
@NonNull Set<String> names,
@NonNull Function<String, List<String>> getAllHeaders) {
names.forEach(name -> trace(log, name, getAllHeaders));
}

/**
* Trace HTTP Headers.
* @param log Logger
* @param name HTTP Header name
* @param getAllHeaders Function to get all the header values for a particular header name
*/
public static void trace(@NonNull Logger log,
@NonNull String name,
@NonNull Function<String, List<String>> getAllHeaders) {
boolean isMasked = HEADER_MASK_PATTERNS.get().matcher(name).matches();
List<String> all = getAllHeaders.apply(name);
if (all.size() > 1) {
for (String value : all) {
String maskedValue = isMasked ? mask(value) : value;
log.trace("{}: {}", name, maskedValue);
}
} else if (!all.isEmpty()) {
String maskedValue = isMasked ? mask(all.get(0)) : all.get(0);
log.trace("{}: {}", name, maskedValue);
}
}

@Nullable
private static String mask(@Nullable String value) {
if (value == null) {
return null;
}
return "*MASKED*";
}
}
@@ -0,0 +1,78 @@
package io.micronaut.http.util

import ch.qos.logback.classic.Level
import ch.qos.logback.classic.spi.ILoggingEvent
import ch.qos.logback.core.AppenderBase
import io.micronaut.http.HttpHeaders
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import spock.lang.Specification

import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue

class HttpHeadersUtilSpec extends Specification {
def "check masking works for #value"() {
expect:
expected == HttpHeadersUtil.mask(value)

where:
value | expected
null | null
"foo" | "*MASKED*"
"Tim Yates" | "*MASKED*"
}

def "check mask detects common security headers"() {
given:
MemoryAppender appender = new MemoryAppender()
Logger log = LoggerFactory.getLogger(HttpHeadersUtilSpec.class)

expect:
log instanceof ch.qos.logback.classic.Logger

when:
ch.qos.logback.classic.Logger logger = (ch.qos.logback.classic.Logger) log
logger.addAppender(appender)
logger.setLevel(Level.TRACE)
appender.start()

HttpHeaders headers = new MockHttpHeaders([
"Authorization": ["Bearer foo"],
"Proxy-Authorization": ["AWS4-HMAC-SHA256 bar"],
"Cookie": ["baz"],
"Set-Cookie": ["qux"],
"X-Forwarded-For": ["quux", "fred"],
"X-Forwarded-Host": ["quuz"],
"X-Real-IP": ["waldo"],
"Credential": ["foo"],
"Signature": ["bar probably secret"]])

HttpHeadersUtil.trace(log, headers)

then:
appender.events.size() == headers.values().collect { it -> it.size() }.sum()
appender.events.contains("Authorization: *MASKED*")
appender.events.contains("Cookie: baz")
appender.events.contains("Credential: *MASKED*")
appender.events.contains("Set-Cookie: qux")
appender.events.contains("Proxy-Authorization: *MASKED*")
appender.events.contains("Signature: *MASKED*")
appender.events.contains("X-Forwarded-For: quux")
appender.events.contains("X-Forwarded-For: fred")
appender.events.contains("X-Forwarded-Host: quuz")
appender.events.contains("X-Real-IP: waldo")

cleanup:
appender.stop()
}

static class MemoryAppender extends AppenderBase<ILoggingEvent> {
final BlockingQueue<String> events = new LinkedBlockingQueue<>()

@Override
protected void append(ILoggingEvent e) {
events.add(e.formattedMessage)
}
}
}