Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation for pseudo-headers on requests/responses #13647

Draft
wants to merge 4 commits into
base: 4.1
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -548,7 +548,7 @@ private T buildFromConnection(Http2Connection connection) {
writer = new Http2OutboundFrameLogger(writer, frameLogger);
}

Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, writer);
Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, writer, isValidateHeaders());
boolean encoderEnforceMaxConcurrentStreams = encoderEnforceMaxConcurrentStreams();

if (maxQueuedControlFrames != 0) {
Expand Down
Expand Up @@ -25,6 +25,8 @@
import io.netty.util.internal.UnstableApi;

import java.util.ArrayDeque;
import java.util.Iterator;
import java.util.Map.Entry;
import java.util.Queue;

import static io.netty.handler.codec.http.HttpStatusClass.INFORMATIONAL;
Expand All @@ -48,10 +50,18 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder, Ht
// This initial capacity is plenty for SETTINGS traffic.
private final Queue<Http2Settings> outstandingLocalSettingsQueue = new ArrayDeque<Http2Settings>(4);
private Queue<Http2Settings> outstandingRemoteSettingsQueue;
private final boolean validateHeaders;

@Deprecated
public DefaultHttp2ConnectionEncoder(Http2Connection connection, Http2FrameWriter frameWriter) {
this(connection, frameWriter, true);
}

public DefaultHttp2ConnectionEncoder(Http2Connection connection,
Http2FrameWriter frameWriter, boolean validateHeaders) {
this.connection = checkNotNull(connection, "connection");
this.frameWriter = checkNotNull(frameWriter, "frameWriter");
this.validateHeaders = validateHeaders;
if (connection.remote().flowController() == null) {
connection.remote().flowController(new DefaultHttp2RemoteFlowController(connection));
}
Expand Down Expand Up @@ -160,6 +170,41 @@ private static boolean validateHeadersSentState(Http2Stream stream, Http2Headers
return isInformational;
}

/**
* According to RFC 9113, "Pseudo-header fields defined for requests MUST NOT appear in responses;
* pseudo-header fields defined for responses MUST NOT appear in requests." In the same section it also state
* "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document"
* @return {@code true} if validation completed successfully, {@code false} if not
*/
private boolean validatePseudoHeaders(Http2Headers headers) {
// No pseudo headers to verify
if (headers.names().isEmpty()) {
return true;
}
boolean isResponseHeaders = connection.isServer();
Iterator<Entry<CharSequence, CharSequence>> iterator = headers.iterator();
while (iterator.hasNext()) {
CharSequence name = iterator.next().getKey();
if (Http2Headers.PseudoHeaderName.hasPseudoHeaderFormat(name)) {
Http2Headers.PseudoHeaderName pseudoHeader = Http2Headers.PseudoHeaderName.getPseudoHeader(name);
if (pseudoHeader == null) {
// Found pseudo header not documented by HTTP2
return false;
}
if ((isResponseHeaders && pseudoHeader.isRequestOnly()) ||
(!isResponseHeaders && !pseudoHeader.isRequestOnly())) {
// Found pseudo header where it should not have been
return false;
}
} else {
// Found first header without pseudo header prefix. No need to continue since
// all pseudo header should be in the front of the iterator
break;
}
}
return true;
}

@Override
public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int streamId,
final Http2Headers headers, final int streamDependency, final short weight,
Expand Down Expand Up @@ -221,6 +266,12 @@ private ChannelFuture writeHeaders0(final ChannelHandlerContext ctx, final int s
}
}

if (validateHeaders && !validatePseudoHeaders(headers)) {
promise.tryFailure(new IllegalArgumentException("Invalid Pseudo-Header found in response for: " +
streamId));
return promise;
}

// Trailing headers must go through flow control if there are other frames queued in flow control
// for this stream.
Http2RemoteFlowController flowController = flowController();
Expand Down
Expand Up @@ -213,7 +213,8 @@ public Http2FrameCodec build() {
frameWriter = new Http2OutboundFrameLogger(frameWriter, frameLogger());
frameReader = new Http2InboundFrameLogger(frameReader, frameLogger());
}
Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter);
Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection,
frameWriter, isValidateHeaders());
if (encoderEnforceMaxConcurrentStreams()) {
encoder = new StreamBufferingEncoder(encoder);
}
Expand Down
Expand Up @@ -227,7 +227,8 @@ public Http2MultiplexCodec build() {
frameWriter = new Http2OutboundFrameLogger(frameWriter, frameLogger());
frameReader = new Http2InboundFrameLogger(frameReader, frameLogger());
}
Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter);
Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection,
frameWriter, isValidateHeaders());
if (encoderEnforceMaxConcurrentStreams()) {
encoder = new StreamBufferingEncoder(encoder);
}
Expand Down
Expand Up @@ -79,6 +79,7 @@
*/
public class DefaultHttp2ConnectionEncoderTest {
private static final int STREAM_ID = 2;
private static final int CLIENT_STREAM_ID = 3;
private static final int PUSH_STREAM_ID = 4;

@Mock
Expand Down Expand Up @@ -333,6 +334,70 @@ public void dataFramesShouldMergeUseVoidPromise() throws Exception {
assertFalse(promise2.isSuccess());
}

@Test
public void writeEmptyHeadersWithValidation() throws Exception {
writeAllFlowControlledFrames();
createStream(STREAM_ID, false);
ChannelPromise promise = newPromise();
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise);
assertTrue(promise.isDone());
assertTrue(promise.isSuccess());
verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE),
eq(0), eq(true), eq(promise));
}

@Test
public void writeServerResponseHeadersWithValidation() throws Exception {
writeAllFlowControlledFrames();
createStream(STREAM_ID, false);
ChannelPromise promise = newPromise();
Http2Headers headers = dummyResponseHeaders();
encoder.writeHeaders(ctx, STREAM_ID, headers, 0, true, promise);
assertTrue(promise.isDone());
assertTrue(promise.isSuccess());
verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), eq(headers),
eq(0), eq(true), eq(promise));
}

@Test
public void writeServerRequestHeadersWithValidationShouldFail() throws Exception {
writeAllFlowControlledFrames();
createStream(STREAM_ID, false);
ChannelPromise promise = newPromise();
Http2Headers headers = dummyRequestHeaders();
encoder.writeHeaders(ctx, STREAM_ID, headers, 0, true, promise);
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());
assertThat(promise.cause(), instanceOf(IllegalArgumentException.class));
}

@Test
public void writeClientRequestHeadersWithValidation() throws Exception {
createClientConnection();
writeAllFlowControlledFrames();
createStream(CLIENT_STREAM_ID, false);
ChannelPromise promise = newPromise();
Http2Headers headers = dummyRequestHeaders();
encoder.writeHeaders(ctx, CLIENT_STREAM_ID, headers, 0, true, promise);
assertTrue(promise.isDone());
assertTrue(promise.isSuccess());
verify(writer).writeHeaders(eq(ctx), eq(CLIENT_STREAM_ID), eq(headers),
eq(0), eq(true), eq(promise));
}

@Test
public void writeClientResponseHeadersWithValidationShouldFail() throws Exception {
createClientConnection();
writeAllFlowControlledFrames();
createStream(CLIENT_STREAM_ID, false);
ChannelPromise promise = newPromise();
Http2Headers headers = dummyResponseHeaders();
encoder.writeHeaders(ctx, CLIENT_STREAM_ID, headers, 0, true, promise);
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());
assertThat(promise.cause(), instanceOf(IllegalArgumentException.class));
}

@Test
public void dataFramesDontMergeWithHeaders() throws Exception {
createStream(STREAM_ID, false);
Expand Down Expand Up @@ -950,8 +1015,26 @@ private ChannelFuture newSucceededFuture() {
return newPromise().setSuccess();
}

private void createClientConnection() {
// To mimic a client connection, recreate connection object for client
connection = new DefaultHttp2Connection(false);
connection.remote().flowController(remoteFlow);
// Re-create encoder with the client connection
encoder = new DefaultHttp2ConnectionEncoder(connection, writer);
encoder.lifecycleManager(lifecycleManager);
}

private static ByteBuf dummyData() {
// The buffer is purposely 8 bytes so it will even work for a ping frame.
return wrappedBuffer("abcdefgh".getBytes(UTF_8));
}

private static Http2Headers dummyResponseHeaders() {
return new DefaultHttp2Headers().status("200");
}

private static Http2Headers dummyRequestHeaders() {
return new DefaultHttp2Headers().scheme("https")
.method("GET").path("/foo.txt");
}
}
Expand Up @@ -148,13 +148,14 @@ public void teardown() throws Exception {
@Test
public void inflightFrameAfterStreamResetShouldNotMakeConnectionUnusable() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final Http2Headers responseHeaders = dummyOKResponseHeaders();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
ChannelHandlerContext ctx = invocationOnMock.getArgument(0);
http2Server.encoder().writeHeaders(ctx,
(Integer) invocationOnMock.getArgument(1),
(Http2Headers) invocationOnMock.getArgument(2),
responseHeaders,
0,
false,
ctx.newPromise());
Expand Down Expand Up @@ -1272,6 +1273,11 @@ private static Http2Headers dummyHeaders() {
.add(randomString(), randomString());
}

private static Http2Headers dummyOKResponseHeaders() {
return new DefaultHttp2Headers(false).status(new AsciiString("200"))
.add("response-" + randomString(), randomString());
}

private static void mockFlowControl(Http2FrameListener listener) throws Http2Exception {
doAnswer(new Answer<Integer>() {
@Override
Expand Down
Expand Up @@ -706,10 +706,11 @@ public void creatingWritingReadingAndClosingOutboundStreamShouldWork() {
assertTrue(inboundHandler.isChannelActive());

// Write to the child channel
Http2Headers headers = new DefaultHttp2Headers().scheme("https").method("GET").path("/foo.txt");
Http2Headers headers = new DefaultHttp2Headers().status("200");
// Write to initialize child channel
childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(headers));

// Read from the child channel
headers = new DefaultHttp2Headers().scheme("https").method("GET").path("/foo.txt");
frameInboundWriter.writeInboundHeaders(childChannel.stream().id(), headers, 0, false);

Http2HeadersFrame headersFrame = inboundHandler.readInbound();
Expand Down