From e2daae9ac8e4b5c0a04f47b163157445915b949a Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Mon, 19 Apr 2021 07:05:53 -0700 Subject: [PATCH] Fix StreamBufferingEncoder GOAWAY bug (#11144) Motivation: There is a bug in `StreamBufferingEncoder` such that when client receives GOWAY while there are pending streams due to MAX_CONCURRENT_STREAMS, we see the following error: ``` io.netty.handler.codec.http2.Http2Exception$StreamException: Maximum active streams violated for this endpoint. at io.netty.handler.codec.http2.Http2Exception.streamError(Http2Exception.java:147) at io.netty.handler.codec.http2.DefaultHttp2Connection$DefaultEndpoint.checkNewStreamAllowed(DefaultHttp2Connection.java:896) at io.netty.handler.codec.http2.DefaultHttp2Connection$DefaultEndpoint.createStream(DefaultHttp2Connection.java:748) at io.netty.handler.codec.http2.DefaultHttp2Connection$DefaultEndpoint.createStream(DefaultHttp2Connection.java:668) at io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder.writeHeaders0(DefaultHttp2ConnectionEncoder.java:201) at io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder.writeHeaders(DefaultHttp2ConnectionEncoder.java:167) at io.netty.handler.codec.http2.DecoratingHttp2FrameWriter.writeHeaders(DecoratingHttp2FrameWriter.java:53) at io.netty.handler.codec.http2.StreamBufferingEncoder.writeHeaders(StreamBufferingEncoder.java:153) at io.netty.handler.codec.http2.StreamBufferingEncoder.writeHeaders(StreamBufferingEncoder.java:141) at io.grpc.netty.NettyClientHandler.createStreamTraced(NettyClientHandler.java:584) at io.grpc.netty.NettyClientHandler.createStream(NettyClientHandler.java:567) at io.grpc.netty.NettyClientHandler.write(NettyClientHandler.java:328) at io.netty.channel.AbstractChannelHandlerContext.invokeWrite0(AbstractChannelHandlerContext.java:717) at io.netty.channel.AbstractChannelHandlerContext.invokeWrite(AbstractChannelHandlerContext.java:709) at io.netty.channel.AbstractChannelHandlerContext.write(AbstractChannelHandlerContext.java:792) at io.netty.channel.AbstractChannelHandlerContext.write(AbstractChannelHandlerContext.java:702) at io.netty.channel.DefaultChannelPipeline.write(DefaultChannelPipeline.java:1015) at io.netty.channel.AbstractChannel.write(AbstractChannel.java:289) at io.grpc.netty.WriteQueue$AbstractQueuedCommand.run(WriteQueue.java:213) at io.grpc.netty.WriteQueue.flush(WriteQueue.java:128) at io.grpc.netty.WriteQueue.drainNow(WriteQueue.java:114) at io.grpc.netty.NettyClientHandler.goingAway(NettyClientHandler.java:783) at io.grpc.netty.NettyClientHandler.access$300(NettyClientHandler.java:91) at io.grpc.netty.NettyClientHandler$3.onGoAwayReceived(NettyClientHandler.java:280) at io.netty.handler.codec.http2.DefaultHttp2Connection.goAwayReceived(DefaultHttp2Connection.java:236) at io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder.onGoAwayRead0(DefaultHttp2ConnectionDecoder.java:218) at io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder$FrameReadListener.onGoAwayRead(DefaultHttp2ConnectionDecoder.java:551) at io.netty.handler.codec.http2.Http2InboundFrameLogger$1.onGoAwayRead(Http2InboundFrameLogger.java:119) at io.netty.handler.codec.http2.DefaultHttp2FrameReader.readGoAwayFrame(DefaultHttp2FrameReader.java:591) at io.netty.handler.codec.http2.DefaultHttp2FrameReader.processPayloadState(DefaultHttp2FrameReader.java:272) at io.netty.handler.codec.http2.DefaultHttp2FrameReader.readFrame(DefaultHttp2FrameReader.java:160) at io.netty.handler.codec.http2.Http2InboundFrameLogger.readFrame(Http2InboundFrameLogger.java:41) at io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder.decodeFrame(DefaultHttp2ConnectionDecoder.java:174) at io.netty.handler.codec.http2.Http2ConnectionHandler$FrameDecoder.decode(Http2ConnectionHandler.java:378) at io.netty.handler.codec.http2.Http2ConnectionHandler.decode(Http2ConnectionHandler.java:438) at io.netty.handler.codec.ByteToMessageDecoder.decodeRemovalReentryProtection(ByteToMessageDecoder.java:498) at io.netty.handler.codec.ByteToMessageDecoder.callDecode(ByteToMessageDecoder.java:437) at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:276) at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:379) at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:365) at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:357) at io.netty.handler.ssl.SslHandler.unwrap(SslHandler.java:1486) at io.netty.handler.ssl.SslHandler.decodeJdkCompatible(SslHandler.java:1235) at io.netty.handler.ssl.SslHandler.decode(SslHandler.java:1282) at io.netty.handler.codec.ByteToMessageDecoder.decodeRemovalReentryProtection(ByteToMessageDecoder.java:498) at io.netty.handler.codec.ByteToMessageDecoder.callDecode(ByteToMessageDecoder.java:437) at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:276) at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:379) at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:365) at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:357) at io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1410) at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:379) at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:365) at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:919) at io.netty.channel.epoll.AbstractEpollStreamChannel$EpollStreamUnsafe.epollInReady(AbstractEpollStreamChannel.java:792) at io.netty.channel.epoll.EpollEventLoop.processReady(EpollEventLoop.java:475) at io.netty.channel.epoll.EpollEventLoop.run(EpollEventLoop.java:378) at io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:989) at io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74) at io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30) at java.base/java.lang.Thread.run(Unknown Source) ``` The bug should come from the way that `StreamBufferingEncoder.writeHeaders()` handles the condition `connection().goAwayReceived()`. The current behavior is to delegate to `super.writeHeaders()` and let the stream fail, but this will end up with `Http2Exception` with the message "Maximum active streams violated for this endpoint" which is horrible. See https://github.com/netty/netty/blob/e5951d46fc89db507ba7d2968d2ede26378f0b04/codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java#L152-L155 Modification: Abort new stream immediately if goaway received *and* MAX_CONCURRENT_STREAM reached in `StreamBufferingEncoder` rather than delegating to the `writeHeaders()` method of its super class. Result: In the situation when GOAWAY received as well as MAX_CONCURRENT_STREAM exceeded, the client will fail the buffered streams with `Http2Error.NO_ERROR` and message "GOAWAY received" instead of "Maximum active streams violated for this endpoint". Co-authored-by: Norman Maurer --- .../codec/http2/StreamBufferingEncoder.java | 50 ++++++++++++------- .../http2/StreamBufferingEncoderTest.java | 38 +++++++++++++- 2 files changed, 69 insertions(+), 19 deletions(-) diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java index ac754edd440..c3c340bc82b 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java @@ -69,33 +69,45 @@ public Http2ChannelClosedException() { } } + private static final class GoAwayDetail { + private final int lastStreamId; + private final long errorCode; + private final byte[] debugData; + + GoAwayDetail(int lastStreamId, long errorCode, byte[] debugData) { + this.lastStreamId = lastStreamId; + this.errorCode = errorCode; + this.debugData = debugData.clone(); + } + } + /** * Thrown by {@link StreamBufferingEncoder} if buffered streams are terminated due to * receipt of a {@code GOAWAY}. */ public static final class Http2GoAwayException extends Http2Exception { private static final long serialVersionUID = 1326785622777291198L; - private final int lastStreamId; - private final long errorCode; - private final byte[] debugData; + private final GoAwayDetail goAwayDetail; public Http2GoAwayException(int lastStreamId, long errorCode, byte[] debugData) { + this(new GoAwayDetail(lastStreamId, errorCode, debugData)); + } + + Http2GoAwayException(GoAwayDetail goAwayDetail) { super(Http2Error.STREAM_CLOSED); - this.lastStreamId = lastStreamId; - this.errorCode = errorCode; - this.debugData = debugData; + this.goAwayDetail = goAwayDetail; } public int lastStreamId() { - return lastStreamId; + return goAwayDetail.lastStreamId; } public long errorCode() { - return errorCode; + return goAwayDetail.errorCode; } public byte[] debugData() { - return debugData; + return goAwayDetail.debugData.clone(); } } @@ -106,6 +118,7 @@ public byte[] debugData() { private final TreeMap pendingStreams = new TreeMap(); private int maxConcurrentStreams; private boolean closed; + private GoAwayDetail goAwayDetail; public StreamBufferingEncoder(Http2ConnectionEncoder delegate) { this(delegate, SMALLEST_MAX_CONCURRENT_STREAMS); @@ -118,7 +131,11 @@ public StreamBufferingEncoder(Http2ConnectionEncoder delegate, int initialMaxCon @Override public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) { - cancelGoAwayStreams(lastStreamId, errorCode, debugData); + goAwayDetail = new GoAwayDetail( + // Using getBytes(..., false) is safe here as GoAwayDetail(...) will clone the byte[]. + lastStreamId, errorCode, + ByteBufUtil.getBytes(debugData, debugData.readerIndex(), debugData.readableBytes(), false)); + cancelGoAwayStreams(goAwayDetail); } @Override @@ -149,13 +166,12 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2 if (closed) { return promise.setFailure(new Http2ChannelClosedException()); } - if (isExistingStream(streamId) || connection().goAwayReceived()) { + if (isExistingStream(streamId) || canCreateStream()) { return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, promise); } - if (canCreateStream()) { - return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, - exclusive, padding, endOfStream, promise); + if (goAwayDetail != null) { + return promise.setFailure(new Http2GoAwayException(goAwayDetail)); } PendingStream pendingStream = pendingStreams.get(streamId); if (pendingStream == null) { @@ -248,12 +264,12 @@ private void tryCreatePendingStreams() { } } - private void cancelGoAwayStreams(int lastStreamId, long errorCode, ByteBuf debugData) { + private void cancelGoAwayStreams(GoAwayDetail goAwayDetail) { Iterator iter = pendingStreams.values().iterator(); - Exception e = new Http2GoAwayException(lastStreamId, errorCode, ByteBufUtil.getBytes(debugData)); + Exception e = new Http2GoAwayException(goAwayDetail); while (iter.hasNext()) { PendingStream stream = iter.next(); - if (stream.streamId > lastStreamId) { + if (stream.streamId > goAwayDetail.lastStreamId) { iter.remove(); stream.close(e); } diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java index 2ad6068ca61..26c1714315a 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java @@ -49,6 +49,7 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelPromise; import io.netty.channel.DefaultMessageSizeEstimator; +import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2GoAwayException; import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.ImmediateEventExecutor; @@ -111,6 +112,11 @@ public void setup() throws Exception { when(writer.writeGoAway(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class), any(ChannelPromise.class))) .thenAnswer(successAnswer()); + when(writer.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), + anyInt(), anyBoolean(), any(ChannelPromise.class))).thenAnswer(noopAnswer()); + when(writer.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), + anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean(), any(ChannelPromise.class))) + .thenAnswer(noopAnswer()); connection = new DefaultHttp2Connection(false); connection.remote().flowController(new DefaultHttp2RemoteFlowController(connection)); @@ -167,7 +173,7 @@ public void multipleWritesToActiveStream() { encoder.writeData(ctx, 3, data(), 0, false, newPromise()); encoderWriteHeaders(3, newPromise()); - writeVerifyWriteHeaders(times(2), 3); + writeVerifyWriteHeaders(times(1), 3); // Contiguous data writes are coalesced ArgumentCaptor bufCaptor = ArgumentCaptor.forClass(ByteBuf.class); verify(writer, times(1)) @@ -245,18 +251,32 @@ public void receivingGoAwayFailsBufferedStreams() throws Http2Exception { futures.add(encoderWriteHeaders(streamId, newPromise())); streamId += 2; } + assertEquals(5, connection.numActiveStreams()); assertEquals(4, encoder.numBufferedStreams()); connection.goAwayReceived(11, 8, EMPTY_BUFFER); assertEquals(5, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); int failCount = 0; for (ChannelFuture f : futures) { if (f.cause() != null) { + assertTrue(f.cause() instanceof Http2GoAwayException); failCount++; } } - assertEquals(9, failCount); + assertEquals(4, failCount); + } + + @Test + public void receivingGoAwayFailsNewStreamIfMaxConcurrentStreamsReached() throws Http2Exception { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(1); + encoderWriteHeaders(3, newPromise()); + connection.goAwayReceived(11, 8, EMPTY_BUFFER); + ChannelFuture f = encoderWriteHeaders(5, newPromise()); + + assertTrue(f.cause() instanceof Http2GoAwayException); assertEquals(0, encoder.numBufferedStreams()); } @@ -533,6 +553,20 @@ public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { }; } + private Answer noopAnswer() { + return new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { + for (Object a : invocation.getArguments()) { + if (a instanceof ChannelPromise) { + return (ChannelFuture) a; + } + } + return newPromise(); + } + }; + } + private ChannelPromise newPromise() { return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); }