From b3e10e09706a2d19e2f97ccbcbce1ccf446a0b07 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Thu, 15 Sep 2022 16:08:55 -0700 Subject: [PATCH] okhttp: add okhttpServerBuilder permitKeepAliveTime() and permitKeepAliveWithoutCalls() for server keepAlive enforcement (#9544) --- .../io/grpc/internal}/KeepAliveEnforcer.java | 8 +- .../grpc/internal}/KeepAliveEnforcerTest.java | 2 +- .../io/grpc/netty/NettyServerHandler.java | 1 + .../io/grpc/netty/NettyServerHandlerTest.java | 1 + .../io/grpc/okhttp/OkHttpServerBuilder.java | 38 +++++++ .../io/grpc/okhttp/OkHttpServerTransport.java | 57 ++++++++-- .../okhttp/OkHttpServerTransportTest.java | 100 +++++++++++++++++- 7 files changed, 194 insertions(+), 13 deletions(-) rename {netty/src/main/java/io/grpc/netty => core/src/main/java/io/grpc/internal}/KeepAliveEnforcer.java (94%) rename {netty/src/test/java/io/grpc/netty => core/src/test/java/io/grpc/internal}/KeepAliveEnforcerTest.java (99%) diff --git a/netty/src/main/java/io/grpc/netty/KeepAliveEnforcer.java b/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java similarity index 94% rename from netty/src/main/java/io/grpc/netty/KeepAliveEnforcer.java rename to core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java index 6470e440327..dd539e75a18 100644 --- a/netty/src/main/java/io/grpc/netty/KeepAliveEnforcer.java +++ b/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.netty; +package io.grpc.internal; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; @@ -22,11 +22,11 @@ import javax.annotation.CheckReturnValue; /** Monitors the client's PING usage to make sure the rate is permitted. */ -class KeepAliveEnforcer { +public final class KeepAliveEnforcer { @VisibleForTesting - static final int MAX_PING_STRIKES = 2; + public static final int MAX_PING_STRIKES = 2; @VisibleForTesting - static final long IMPLICIT_PERMIT_TIME_NANOS = TimeUnit.HOURS.toNanos(2); + public static final long IMPLICIT_PERMIT_TIME_NANOS = TimeUnit.HOURS.toNanos(2); private final boolean permitWithoutCalls; private final long minTimeNanos; diff --git a/netty/src/test/java/io/grpc/netty/KeepAliveEnforcerTest.java b/core/src/test/java/io/grpc/internal/KeepAliveEnforcerTest.java similarity index 99% rename from netty/src/test/java/io/grpc/netty/KeepAliveEnforcerTest.java rename to core/src/test/java/io/grpc/internal/KeepAliveEnforcerTest.java index 8dfeb990e2b..c58ed6ea160 100644 --- a/netty/src/test/java/io/grpc/netty/KeepAliveEnforcerTest.java +++ b/core/src/test/java/io/grpc/internal/KeepAliveEnforcerTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.netty; +package io.grpc.internal; import static com.google.common.truth.Truth.assertThat; diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index fa21221a0ae..62dd50ce65e 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -44,6 +44,7 @@ import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.KeepAliveEnforcer; import io.grpc.internal.KeepAliveManager; import io.grpc.internal.LogExceptionRunnable; import io.grpc.internal.MaxConnectionIdleManager; diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 170273e2c60..72c267a4825 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -61,6 +61,7 @@ import io.grpc.Status.Code; import io.grpc.StreamTracer; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.KeepAliveEnforcer; import io.grpc.internal.KeepAliveManager; import io.grpc.internal.ServerStream; import io.grpc.internal.ServerStreamListener; diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java index 0e1273c7f23..d3ea82894b0 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java @@ -20,6 +20,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; import io.grpc.ChoiceServerCredentials; import io.grpc.ExperimentalApi; @@ -117,6 +118,8 @@ public static OkHttpServerBuilder forPort(SocketAddress address, ServerCredentia int maxInboundMetadataSize = GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE; int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; long maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED; + boolean permitKeepAliveWithoutCalls; + long permitKeepAliveTimeInNanos = TimeUnit.MINUTES.toNanos(5); @VisibleForTesting OkHttpServerBuilder( @@ -223,6 +226,41 @@ public OkHttpServerBuilder keepAliveTimeout(long keepAliveTimeout, TimeUnit time return this; } + /** + * Specify the most aggressive keep-alive time clients are permitted to configure. The server will + * try to detect clients exceeding this rate and when detected will forcefully close the + * connection. The default is 5 minutes. + * + *

Even though a default is defined that allows some keep-alives, clients must not use + * keep-alive without approval from the service owner. Otherwise, they may experience failures in + * the future if the service becomes more restrictive. When unthrottled, keep-alives can cause a + * significant amount of traffic and CPU usage, so clients and servers should be conservative in + * what they use and accept. + * + * @see #permitKeepAliveWithoutCalls(boolean) + */ + @CanIgnoreReturnValue + @Override + public OkHttpServerBuilder permitKeepAliveTime(long keepAliveTime, TimeUnit timeUnit) { + checkArgument(keepAliveTime >= 0, "permit keepalive time must be non-negative: %s", + keepAliveTime); + permitKeepAliveTimeInNanos = timeUnit.toNanos(keepAliveTime); + return this; + } + + /** + * Sets whether to allow clients to send keep-alive HTTP/2 PINGs even if there are no outstanding + * RPCs on the connection. Defaults to {@code false}. + * + * @see #permitKeepAliveTime(long, TimeUnit) + */ + @CanIgnoreReturnValue + @Override + public OkHttpServerBuilder permitKeepAliveWithoutCalls(boolean permit) { + permitKeepAliveWithoutCalls = permit; + return this; + } + /** * Sets the flow control window in bytes. If not called, the default value is 64 KiB. */ diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java index b2b581d7155..f6099bec17a 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java @@ -29,6 +29,7 @@ import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.KeepAliveEnforcer; import io.grpc.internal.KeepAliveManager; import io.grpc.internal.MaxConnectionIdleManager; import io.grpc.internal.ObjectPool; @@ -95,6 +96,7 @@ final class OkHttpServerTransport implements ServerTransport, private Attributes attributes; private KeepAliveManager keepAliveManager; private MaxConnectionIdleManager maxConnectionIdleManager; + private final KeepAliveEnforcer keepAliveEnforcer; private final Object lock = new Object(); @GuardedBy("lock") @@ -137,6 +139,8 @@ public OkHttpServerTransport(Config config, Socket bareSocket) { logId = InternalLogId.allocate(getClass(), bareSocket.getRemoteSocketAddress().toString()); transportExecutor = config.transportExecutorPool.getObject(); scheduledExecutorService = config.scheduledExecutorServicePool.getObject(); + keepAliveEnforcer = new KeepAliveEnforcer(config.permitKeepAliveWithoutCalls, + config.permitKeepAliveTimeInNanos, TimeUnit.NANOSECONDS); } public void start(ServerTransportListener listener) { @@ -159,6 +163,27 @@ private void startIo(SerializingExecutor serializingExecutor) { asyncSink.becomeConnected(Okio.sink(socket), socket); FrameWriter rawFrameWriter = asyncSink.limitControlFramesWriter( variant.newWriter(Okio.buffer(asyncSink), false)); + FrameWriter writeMonitoringFrameWriter = new ForwardingFrameWriter(rawFrameWriter) { + @Override + public void synReply(boolean outFinished, int streamId, List

headerBlock) + throws IOException { + keepAliveEnforcer.resetCounters(); + super.synReply(outFinished, streamId, headerBlock); + } + + @Override + public void headers(int streamId, List
headerBlock) throws IOException { + keepAliveEnforcer.resetCounters(); + super.headers(streamId, headerBlock); + } + + @Override + public void data(boolean outFinished, int streamId, Buffer source, int byteCount) + throws IOException { + keepAliveEnforcer.resetCounters(); + super.data(outFinished, streamId, source, byteCount); + } + }; synchronized (lock) { this.securityInfo = result.securityInfo; @@ -167,7 +192,7 @@ private void startIo(SerializingExecutor serializingExecutor) { // does not propagate syscall errors through the FrameWriter. But we handle the // AsyncSink failures with the same TransportExceptionHandler instance so it is all // mixed back together. - frameWriter = new ExceptionHandlingFrameWriter(this, rawFrameWriter); + frameWriter = new ExceptionHandlingFrameWriter(this, writeMonitoringFrameWriter); outboundFlow = new OutboundFlowController(this, frameWriter); // These writes will be queued in the serializingExecutor waiting for this function to @@ -381,8 +406,11 @@ public OutboundFlowController.StreamState[] getActiveStreams() { void streamClosed(int streamId, boolean flush) { synchronized (lock) { streams.remove(streamId); - if (maxConnectionIdleManager != null && streams.isEmpty()) { - maxConnectionIdleManager.onTransportIdle(); + if (streams.isEmpty()) { + keepAliveEnforcer.onTransportIdle(); + if (maxConnectionIdleManager != null) { + maxConnectionIdleManager.onTransportIdle(); + } } if (gracefulShutdown && streams.isEmpty()) { frameWriter.close(); @@ -449,6 +477,8 @@ static final class Config { final int maxInboundMessageSize; final int maxInboundMetadataSize; final long maxConnectionIdleNanos; + final boolean permitKeepAliveWithoutCalls; + final long permitKeepAliveTimeInNanos; public Config( OkHttpServerBuilder builder, @@ -469,6 +499,8 @@ public Config( maxInboundMessageSize = builder.maxInboundMessageSize; maxInboundMetadataSize = builder.maxInboundMetadataSize; maxConnectionIdleNanos = builder.maxConnectionIdleInNanos; + permitKeepAliveWithoutCalls = builder.permitKeepAliveWithoutCalls; + permitKeepAliveTimeInNanos = builder.permitKeepAliveTimeInNanos; } } @@ -714,8 +746,11 @@ public void headers(boolean outFinished, authority == null ? null : asciiString(authority), statsTraceCtx, tracer); - if (maxConnectionIdleManager != null && streams.isEmpty()) { - maxConnectionIdleManager.onTransportActive(); + if (streams.isEmpty()) { + keepAliveEnforcer.onTransportActive(); + if (maxConnectionIdleManager != null) { + maxConnectionIdleManager.onTransportActive(); + } } streams.put(streamId, stream); listener.streamCreated(streamForApp, method, metadata); @@ -849,6 +884,11 @@ public void settings(boolean clearPrevious, Settings settings) { @Override public void ping(boolean ack, int payload1, int payload2) { + if (!keepAliveEnforcer.pingAcceptable()) { + abruptShutdown(ErrorCode.ENHANCE_YOUR_CALM, "too_many_pings", + Status.RESOURCE_EXHAUSTED.withDescription("Too many pings from client"), false); + return; + } long payload = (((long) payload1) << 32) | (payload2 & 0xffffffffL); if (!ack) { frameLogger.logPing(OkHttpFrameLogger.Direction.INBOUND, payload); @@ -973,8 +1013,11 @@ private void respondWithHttpError( synchronized (lock) { Http2ErrorStreamState stream = new Http2ErrorStreamState(streamId, lock, outboundFlow, config.flowControlWindow); - if (maxConnectionIdleManager != null && streams.isEmpty()) { - maxConnectionIdleManager.onTransportActive(); + if (streams.isEmpty()) { + keepAliveEnforcer.onTransportActive(); + if (maxConnectionIdleManager != null) { + maxConnectionIdleManager.onTransportActive(); + } } streams.put(streamId, stream); if (inFinished) { diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java index 2711978bb00..a52045011ae 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java @@ -41,6 +41,7 @@ import io.grpc.Status; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; +import io.grpc.internal.KeepAliveEnforcer; import io.grpc.internal.ServerStream; import io.grpc.internal.ServerStreamListener; import io.grpc.internal.ServerTransportListener; @@ -122,7 +123,9 @@ public class OkHttpServerTransportTest { } }) .flowControlWindow(INITIAL_WINDOW_SIZE) - .maxConnectionIdle(MAX_CONNECTION_IDLE, TimeUnit.NANOSECONDS); + .maxConnectionIdle(MAX_CONNECTION_IDLE, TimeUnit.NANOSECONDS) + .permitKeepAliveWithoutCalls(true) + .permitKeepAliveTime(0, TimeUnit.SECONDS); @Rule public final Timeout globalTimeout = Timeout.seconds(10); @@ -1054,6 +1057,101 @@ public void channelzStats() throws Exception { assertThat(stats.remote).isEqualTo(new InetSocketAddress("127.0.0.2", 5000)); } + @Test + public void keepAliveEnforcer_enforcesPings() throws Exception { + serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS) + .permitKeepAliveWithoutCalls(false); + initTransport(); + handshake(); + + for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) { + pingPong(); + } + pingPongId++; + clientFrameWriter.ping(false, pingPongId, 0); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM, + ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII)); + } + + @Test + public void keepAliveEnforcer_sendingDataResetsCounters() throws Exception { + serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS) + .permitKeepAliveWithoutCalls(false); + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("some-metadata", "this could be anything"))); + Buffer requestMessageFrame = createMessageFrame("Hello server"); + clientFrameWriter.data(false, 1, requestMessageFrame, (int) requestMessageFrame.size()); + pingPong(); + MockStreamListener streamListener = mockTransportListener.newStreams.pop(); + + streamListener.stream.request(1); + pingPong(); + assertThat(streamListener.messages.pop()).isEqualTo("Hello server"); + + streamListener.stream.writeHeaders(metadata("User-Data", "best data")); + streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8))); + streamListener.stream.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + + for (int i = 0; i < 10; i++) { + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + pingPong(); + streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8))); + streamListener.stream.flush(); + } + } + + @Test + public void keepAliveEnforcer_initialIdle() throws Exception { + serverBuilder.permitKeepAliveTime(0, TimeUnit.SECONDS) + .permitKeepAliveWithoutCalls(false); + initTransport(); + handshake(); + + for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) { + pingPong(); + } + pingPongId++; + clientFrameWriter.ping(false, pingPongId, 0); + clientFrameWriter.flush(); + assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue(); + verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM, + ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII)); + } + + @Test + public void keepAliveEnforcer_noticesActive() throws Exception { + serverBuilder.permitKeepAliveTime(0, TimeUnit.SECONDS) + .permitKeepAliveWithoutCalls(false); + initTransport(); + handshake(); + + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER, + new Header("some-metadata", "this could be anything"))); + for (int i = 0; i < 10; i++) { + pingPong(); + } + verify(clientFramesRead, never()).goAway(anyInt(), eq(ErrorCode.ENHANCE_YOUR_CALM), + eq(ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII))); + } + private void initTransport() throws Exception { serverTransport = new OkHttpServerTransport( new OkHttpServerTransport.Config(serverBuilder, Arrays.asList()),