Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix StreamBufferingEncoder GOAWAY bug #11144

Merged
merged 13 commits into from Apr 19, 2021
Merged
Expand Up @@ -106,6 +106,9 @@ public byte[] debugData() {
private final TreeMap<Integer, PendingStream> pendingStreams = new TreeMap<Integer, PendingStream>();
private int maxConcurrentStreams;
private boolean closed;
private Integer goAwayLastStreamId;
private long goAwayErrorCode;
private ByteBuf goAwayDebugData;
normanmaurer marked this conversation as resolved.
Show resolved Hide resolved

public StreamBufferingEncoder(Http2ConnectionEncoder delegate) {
this(delegate, SMALLEST_MAX_CONCURRENT_STREAMS);
Expand All @@ -118,7 +121,10 @@ public StreamBufferingEncoder(Http2ConnectionEncoder delegate, int initialMaxCon

@Override
public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) {
cancelGoAwayStreams(lastStreamId, errorCode, debugData);
goAwayLastStreamId = lastStreamId;
goAwayErrorCode = errorCode;
goAwayDebugData = debugData;
cancelGoAwayStreams();
}

@Override
Expand Down Expand Up @@ -149,13 +155,14 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2
if (closed) {
return promise.setFailure(new Http2ChannelClosedException());
}
if (isExistingStream(streamId) || connection().goAwayReceived()) {
if (isExistingStream(streamId) || canCreateStream()) {
return super.writeHeaders(ctx, streamId, headers, streamDependency, weight,
exclusive, padding, endOfStream, promise);
}
if (canCreateStream()) {
return super.writeHeaders(ctx, streamId, headers, streamDependency, weight,
exclusive, padding, endOfStream, promise);
if (goAwayLastStreamId != null) {
promise.setFailure(new Http2GoAwayException(
goAwayLastStreamId, goAwayErrorCode, ByteBufUtil.getBytes(goAwayDebugData)));
return promise;
normanmaurer marked this conversation as resolved.
Show resolved Hide resolved
}
PendingStream pendingStream = pendingStreams.get(streamId);
if (pendingStream == null) {
Expand Down Expand Up @@ -248,12 +255,13 @@ private void tryCreatePendingStreams() {
}
}

private void cancelGoAwayStreams(int lastStreamId, long errorCode, ByteBuf debugData) {
private void cancelGoAwayStreams() {
Iterator<PendingStream> iter = pendingStreams.values().iterator();
Exception e = new Http2GoAwayException(lastStreamId, errorCode, ByteBufUtil.getBytes(debugData));
Exception e = new Http2GoAwayException(
goAwayLastStreamId, goAwayErrorCode, ByteBufUtil.getBytes(goAwayDebugData));
while (iter.hasNext()) {
PendingStream stream = iter.next();
if (stream.streamId > lastStreamId) {
if (stream.streamId > goAwayLastStreamId) {
iter.remove();
stream.close(e);
}
Expand Down
Expand Up @@ -49,6 +49,7 @@
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.channel.DefaultMessageSizeEstimator;
import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2GoAwayException;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.ImmediateEventExecutor;
Expand Down Expand Up @@ -111,6 +112,11 @@ public void setup() throws Exception {
when(writer.writeGoAway(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class),
any(ChannelPromise.class)))
.thenAnswer(successAnswer());
when(writer.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class),
normanmaurer marked this conversation as resolved.
Show resolved Hide resolved
anyInt(), anyBoolean(), any(ChannelPromise.class))).thenAnswer(noopAnswer());
when(writer.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class),
anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean(), any(ChannelPromise.class)))
.thenAnswer(noopAnswer());

connection = new DefaultHttp2Connection(false);
connection.remote().flowController(new DefaultHttp2RemoteFlowController(connection));
Expand Down Expand Up @@ -167,7 +173,7 @@ public void multipleWritesToActiveStream() {
encoder.writeData(ctx, 3, data(), 0, false, newPromise());
encoderWriteHeaders(3, newPromise());

writeVerifyWriteHeaders(times(2), 3);
writeVerifyWriteHeaders(times(1), 3);
// Contiguous data writes are coalesced
ArgumentCaptor<ByteBuf> bufCaptor = ArgumentCaptor.forClass(ByteBuf.class);
verify(writer, times(1))
Expand Down Expand Up @@ -245,18 +251,32 @@ public void receivingGoAwayFailsBufferedStreams() throws Http2Exception {
futures.add(encoderWriteHeaders(streamId, newPromise()));
streamId += 2;
}
assertEquals(5, connection.numActiveStreams());
assertEquals(4, encoder.numBufferedStreams());

connection.goAwayReceived(11, 8, EMPTY_BUFFER);

assertEquals(5, connection.numActiveStreams());
assertEquals(0, encoder.numBufferedStreams());
int failCount = 0;
for (ChannelFuture f : futures) {
if (f.cause() != null) {
assertTrue(f.cause() instanceof Http2GoAwayException);
failCount++;
}
}
assertEquals(9, failCount);
assertEquals(4, failCount);
}

@Test
public void receivingGoAwayFailsNewStreamIfMaxConcurrentStreamsReached() throws Http2Exception {
encoder.writeSettingsAck(ctx, newPromise());
setMaxConcurrentStreams(1);
encoderWriteHeaders(3, newPromise());
connection.goAwayReceived(11, 8, EMPTY_BUFFER);
ChannelFuture f = encoderWriteHeaders(5, newPromise());

assertTrue(f.cause() instanceof Http2GoAwayException);
assertEquals(0, encoder.numBufferedStreams());
}

Expand Down Expand Up @@ -533,6 +553,20 @@ public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
};
}

private Answer<ChannelFuture> noopAnswer() {
return new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
for (Object a : invocation.getArguments()) {
if (a instanceof ChannelPromise) {
return (ChannelFuture) a;
}
}
return newPromise();
}
};
}

private ChannelPromise newPromise() {
return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE);
}
Expand Down