diff --git a/netty/src/main/java/io/grpc/netty/KeepAliveEnforcer.java b/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java
similarity index 94%
rename from netty/src/main/java/io/grpc/netty/KeepAliveEnforcer.java
rename to core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java
index 6470e440327..dd539e75a18 100644
--- a/netty/src/main/java/io/grpc/netty/KeepAliveEnforcer.java
+++ b/core/src/main/java/io/grpc/internal/KeepAliveEnforcer.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package io.grpc.netty;
+package io.grpc.internal;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
@@ -22,11 +22,11 @@
import javax.annotation.CheckReturnValue;
/** Monitors the client's PING usage to make sure the rate is permitted. */
-class KeepAliveEnforcer {
+public final class KeepAliveEnforcer {
@VisibleForTesting
- static final int MAX_PING_STRIKES = 2;
+ public static final int MAX_PING_STRIKES = 2;
@VisibleForTesting
- static final long IMPLICIT_PERMIT_TIME_NANOS = TimeUnit.HOURS.toNanos(2);
+ public static final long IMPLICIT_PERMIT_TIME_NANOS = TimeUnit.HOURS.toNanos(2);
private final boolean permitWithoutCalls;
private final long minTimeNanos;
diff --git a/netty/src/test/java/io/grpc/netty/KeepAliveEnforcerTest.java b/core/src/test/java/io/grpc/internal/KeepAliveEnforcerTest.java
similarity index 99%
rename from netty/src/test/java/io/grpc/netty/KeepAliveEnforcerTest.java
rename to core/src/test/java/io/grpc/internal/KeepAliveEnforcerTest.java
index 8dfeb990e2b..c58ed6ea160 100644
--- a/netty/src/test/java/io/grpc/netty/KeepAliveEnforcerTest.java
+++ b/core/src/test/java/io/grpc/internal/KeepAliveEnforcerTest.java
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package io.grpc.netty;
+package io.grpc.internal;
import static com.google.common.truth.Truth.assertThat;
diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java
index fa21221a0ae..62dd50ce65e 100644
--- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java
+++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java
@@ -44,6 +44,7 @@
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
+import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.LogExceptionRunnable;
import io.grpc.internal.MaxConnectionIdleManager;
diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java
index 170273e2c60..72c267a4825 100644
--- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java
+++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java
@@ -61,6 +61,7 @@
import io.grpc.Status.Code;
import io.grpc.StreamTracer;
import io.grpc.internal.GrpcUtil;
+import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener;
diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java
index 0e1273c7f23..d3ea82894b0 100644
--- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java
+++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java
@@ -20,6 +20,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
+import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.errorprone.annotations.DoNotCall;
import io.grpc.ChoiceServerCredentials;
import io.grpc.ExperimentalApi;
@@ -117,6 +118,8 @@ public static OkHttpServerBuilder forPort(SocketAddress address, ServerCredentia
int maxInboundMetadataSize = GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE;
int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
long maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED;
+ boolean permitKeepAliveWithoutCalls;
+ long permitKeepAliveTimeInNanos = TimeUnit.MINUTES.toNanos(5);
@VisibleForTesting
OkHttpServerBuilder(
@@ -223,6 +226,41 @@ public OkHttpServerBuilder keepAliveTimeout(long keepAliveTimeout, TimeUnit time
return this;
}
+ /**
+ * Specify the most aggressive keep-alive time clients are permitted to configure. The server will
+ * try to detect clients exceeding this rate and when detected will forcefully close the
+ * connection. The default is 5 minutes.
+ *
+ *
Even though a default is defined that allows some keep-alives, clients must not use
+ * keep-alive without approval from the service owner. Otherwise, they may experience failures in
+ * the future if the service becomes more restrictive. When unthrottled, keep-alives can cause a
+ * significant amount of traffic and CPU usage, so clients and servers should be conservative in
+ * what they use and accept.
+ *
+ * @see #permitKeepAliveWithoutCalls(boolean)
+ */
+ @CanIgnoreReturnValue
+ @Override
+ public OkHttpServerBuilder permitKeepAliveTime(long keepAliveTime, TimeUnit timeUnit) {
+ checkArgument(keepAliveTime >= 0, "permit keepalive time must be non-negative: %s",
+ keepAliveTime);
+ permitKeepAliveTimeInNanos = timeUnit.toNanos(keepAliveTime);
+ return this;
+ }
+
+ /**
+ * Sets whether to allow clients to send keep-alive HTTP/2 PINGs even if there are no outstanding
+ * RPCs on the connection. Defaults to {@code false}.
+ *
+ * @see #permitKeepAliveTime(long, TimeUnit)
+ */
+ @CanIgnoreReturnValue
+ @Override
+ public OkHttpServerBuilder permitKeepAliveWithoutCalls(boolean permit) {
+ permitKeepAliveWithoutCalls = permit;
+ return this;
+ }
+
/**
* Sets the flow control window in bytes. If not called, the default value is 64 KiB.
*/
diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java
index b2b581d7155..f6099bec17a 100644
--- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java
+++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java
@@ -29,6 +29,7 @@
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
+import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.MaxConnectionIdleManager;
import io.grpc.internal.ObjectPool;
@@ -95,6 +96,7 @@ final class OkHttpServerTransport implements ServerTransport,
private Attributes attributes;
private KeepAliveManager keepAliveManager;
private MaxConnectionIdleManager maxConnectionIdleManager;
+ private final KeepAliveEnforcer keepAliveEnforcer;
private final Object lock = new Object();
@GuardedBy("lock")
@@ -137,6 +139,8 @@ public OkHttpServerTransport(Config config, Socket bareSocket) {
logId = InternalLogId.allocate(getClass(), bareSocket.getRemoteSocketAddress().toString());
transportExecutor = config.transportExecutorPool.getObject();
scheduledExecutorService = config.scheduledExecutorServicePool.getObject();
+ keepAliveEnforcer = new KeepAliveEnforcer(config.permitKeepAliveWithoutCalls,
+ config.permitKeepAliveTimeInNanos, TimeUnit.NANOSECONDS);
}
public void start(ServerTransportListener listener) {
@@ -159,6 +163,27 @@ private void startIo(SerializingExecutor serializingExecutor) {
asyncSink.becomeConnected(Okio.sink(socket), socket);
FrameWriter rawFrameWriter = asyncSink.limitControlFramesWriter(
variant.newWriter(Okio.buffer(asyncSink), false));
+ FrameWriter writeMonitoringFrameWriter = new ForwardingFrameWriter(rawFrameWriter) {
+ @Override
+ public void synReply(boolean outFinished, int streamId, List headerBlock)
+ throws IOException {
+ keepAliveEnforcer.resetCounters();
+ super.synReply(outFinished, streamId, headerBlock);
+ }
+
+ @Override
+ public void headers(int streamId, List headerBlock) throws IOException {
+ keepAliveEnforcer.resetCounters();
+ super.headers(streamId, headerBlock);
+ }
+
+ @Override
+ public void data(boolean outFinished, int streamId, Buffer source, int byteCount)
+ throws IOException {
+ keepAliveEnforcer.resetCounters();
+ super.data(outFinished, streamId, source, byteCount);
+ }
+ };
synchronized (lock) {
this.securityInfo = result.securityInfo;
@@ -167,7 +192,7 @@ private void startIo(SerializingExecutor serializingExecutor) {
// does not propagate syscall errors through the FrameWriter. But we handle the
// AsyncSink failures with the same TransportExceptionHandler instance so it is all
// mixed back together.
- frameWriter = new ExceptionHandlingFrameWriter(this, rawFrameWriter);
+ frameWriter = new ExceptionHandlingFrameWriter(this, writeMonitoringFrameWriter);
outboundFlow = new OutboundFlowController(this, frameWriter);
// These writes will be queued in the serializingExecutor waiting for this function to
@@ -381,8 +406,11 @@ public OutboundFlowController.StreamState[] getActiveStreams() {
void streamClosed(int streamId, boolean flush) {
synchronized (lock) {
streams.remove(streamId);
- if (maxConnectionIdleManager != null && streams.isEmpty()) {
- maxConnectionIdleManager.onTransportIdle();
+ if (streams.isEmpty()) {
+ keepAliveEnforcer.onTransportIdle();
+ if (maxConnectionIdleManager != null) {
+ maxConnectionIdleManager.onTransportIdle();
+ }
}
if (gracefulShutdown && streams.isEmpty()) {
frameWriter.close();
@@ -449,6 +477,8 @@ static final class Config {
final int maxInboundMessageSize;
final int maxInboundMetadataSize;
final long maxConnectionIdleNanos;
+ final boolean permitKeepAliveWithoutCalls;
+ final long permitKeepAliveTimeInNanos;
public Config(
OkHttpServerBuilder builder,
@@ -469,6 +499,8 @@ public Config(
maxInboundMessageSize = builder.maxInboundMessageSize;
maxInboundMetadataSize = builder.maxInboundMetadataSize;
maxConnectionIdleNanos = builder.maxConnectionIdleInNanos;
+ permitKeepAliveWithoutCalls = builder.permitKeepAliveWithoutCalls;
+ permitKeepAliveTimeInNanos = builder.permitKeepAliveTimeInNanos;
}
}
@@ -714,8 +746,11 @@ public void headers(boolean outFinished,
authority == null ? null : asciiString(authority),
statsTraceCtx,
tracer);
- if (maxConnectionIdleManager != null && streams.isEmpty()) {
- maxConnectionIdleManager.onTransportActive();
+ if (streams.isEmpty()) {
+ keepAliveEnforcer.onTransportActive();
+ if (maxConnectionIdleManager != null) {
+ maxConnectionIdleManager.onTransportActive();
+ }
}
streams.put(streamId, stream);
listener.streamCreated(streamForApp, method, metadata);
@@ -849,6 +884,11 @@ public void settings(boolean clearPrevious, Settings settings) {
@Override
public void ping(boolean ack, int payload1, int payload2) {
+ if (!keepAliveEnforcer.pingAcceptable()) {
+ abruptShutdown(ErrorCode.ENHANCE_YOUR_CALM, "too_many_pings",
+ Status.RESOURCE_EXHAUSTED.withDescription("Too many pings from client"), false);
+ return;
+ }
long payload = (((long) payload1) << 32) | (payload2 & 0xffffffffL);
if (!ack) {
frameLogger.logPing(OkHttpFrameLogger.Direction.INBOUND, payload);
@@ -973,8 +1013,11 @@ private void respondWithHttpError(
synchronized (lock) {
Http2ErrorStreamState stream =
new Http2ErrorStreamState(streamId, lock, outboundFlow, config.flowControlWindow);
- if (maxConnectionIdleManager != null && streams.isEmpty()) {
- maxConnectionIdleManager.onTransportActive();
+ if (streams.isEmpty()) {
+ keepAliveEnforcer.onTransportActive();
+ if (maxConnectionIdleManager != null) {
+ maxConnectionIdleManager.onTransportActive();
+ }
}
streams.put(streamId, stream);
if (inFinished) {
diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java
index 2711978bb00..a52045011ae 100644
--- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java
+++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java
@@ -41,6 +41,7 @@
import io.grpc.Status;
import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil;
+import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransportListener;
@@ -122,7 +123,9 @@ public class OkHttpServerTransportTest {
}
})
.flowControlWindow(INITIAL_WINDOW_SIZE)
- .maxConnectionIdle(MAX_CONNECTION_IDLE, TimeUnit.NANOSECONDS);
+ .maxConnectionIdle(MAX_CONNECTION_IDLE, TimeUnit.NANOSECONDS)
+ .permitKeepAliveWithoutCalls(true)
+ .permitKeepAliveTime(0, TimeUnit.SECONDS);
@Rule public final Timeout globalTimeout = Timeout.seconds(10);
@@ -1054,6 +1057,101 @@ public void channelzStats() throws Exception {
assertThat(stats.remote).isEqualTo(new InetSocketAddress("127.0.0.2", 5000));
}
+ @Test
+ public void keepAliveEnforcer_enforcesPings() throws Exception {
+ serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS)
+ .permitKeepAliveWithoutCalls(false);
+ initTransport();
+ handshake();
+
+ for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) {
+ pingPong();
+ }
+ pingPongId++;
+ clientFrameWriter.ping(false, pingPongId, 0);
+ clientFrameWriter.flush();
+ assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
+ verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM,
+ ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII));
+ }
+
+ @Test
+ public void keepAliveEnforcer_sendingDataResetsCounters() throws Exception {
+ serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS)
+ .permitKeepAliveWithoutCalls(false);
+ initTransport();
+ handshake();
+
+ clientFrameWriter.headers(1, Arrays.asList(
+ HTTP_SCHEME_HEADER,
+ METHOD_HEADER,
+ new Header(Header.TARGET_AUTHORITY, "example.com:80"),
+ new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"),
+ CONTENT_TYPE_HEADER,
+ TE_HEADER,
+ new Header("some-metadata", "this could be anything")));
+ Buffer requestMessageFrame = createMessageFrame("Hello server");
+ clientFrameWriter.data(false, 1, requestMessageFrame, (int) requestMessageFrame.size());
+ pingPong();
+ MockStreamListener streamListener = mockTransportListener.newStreams.pop();
+
+ streamListener.stream.request(1);
+ pingPong();
+ assertThat(streamListener.messages.pop()).isEqualTo("Hello server");
+
+ streamListener.stream.writeHeaders(metadata("User-Data", "best data"));
+ streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8)));
+ streamListener.stream.flush();
+ assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
+
+ for (int i = 0; i < 10; i++) {
+ assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
+ pingPong();
+ streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8)));
+ streamListener.stream.flush();
+ }
+ }
+
+ @Test
+ public void keepAliveEnforcer_initialIdle() throws Exception {
+ serverBuilder.permitKeepAliveTime(0, TimeUnit.SECONDS)
+ .permitKeepAliveWithoutCalls(false);
+ initTransport();
+ handshake();
+
+ for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) {
+ pingPong();
+ }
+ pingPongId++;
+ clientFrameWriter.ping(false, pingPongId, 0);
+ clientFrameWriter.flush();
+ assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
+ verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM,
+ ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII));
+ }
+
+ @Test
+ public void keepAliveEnforcer_noticesActive() throws Exception {
+ serverBuilder.permitKeepAliveTime(0, TimeUnit.SECONDS)
+ .permitKeepAliveWithoutCalls(false);
+ initTransport();
+ handshake();
+
+ clientFrameWriter.headers(1, Arrays.asList(
+ HTTP_SCHEME_HEADER,
+ METHOD_HEADER,
+ new Header(Header.TARGET_AUTHORITY, "example.com:80"),
+ new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"),
+ CONTENT_TYPE_HEADER,
+ TE_HEADER,
+ new Header("some-metadata", "this could be anything")));
+ for (int i = 0; i < 10; i++) {
+ pingPong();
+ }
+ verify(clientFramesRead, never()).goAway(anyInt(), eq(ErrorCode.ENHANCE_YOUR_CALM),
+ eq(ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII)));
+ }
+
private void initTransport() throws Exception {
serverTransport = new OkHttpServerTransport(
new OkHttpServerTransport.Config(serverBuilder, Arrays.asList()),