Skip to content

Commit

Permalink
okhttp: add okhttpServerBuilder permitKeepAliveTime() and permitKeepA…
Browse files Browse the repository at this point in the history
…liveWithoutCalls() for server keepAlive enforcement (grpc#9544)
  • Loading branch information
YifeiZhuang authored and larry-safran committed Oct 6, 2022
1 parent 0c2cb1c commit b3e10e0
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 13 deletions.
Expand Up @@ -14,19 +14,19 @@
* limitations under the License.
*/

package io.grpc.netty;
package io.grpc.internal;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.util.concurrent.TimeUnit;
import javax.annotation.CheckReturnValue;

/** Monitors the client's PING usage to make sure the rate is permitted. */
class KeepAliveEnforcer {
public final class KeepAliveEnforcer {
@VisibleForTesting
static final int MAX_PING_STRIKES = 2;
public static final int MAX_PING_STRIKES = 2;
@VisibleForTesting
static final long IMPLICIT_PERMIT_TIME_NANOS = TimeUnit.HOURS.toNanos(2);
public static final long IMPLICIT_PERMIT_TIME_NANOS = TimeUnit.HOURS.toNanos(2);

private final boolean permitWithoutCalls;
private final long minTimeNanos;
Expand Down
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package io.grpc.netty;
package io.grpc.internal;

import static com.google.common.truth.Truth.assertThat;

Expand Down
1 change: 1 addition & 0 deletions netty/src/main/java/io/grpc/netty/NettyServerHandler.java
Expand Up @@ -44,6 +44,7 @@
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.LogExceptionRunnable;
import io.grpc.internal.MaxConnectionIdleManager;
Expand Down
Expand Up @@ -61,6 +61,7 @@
import io.grpc.Status.Code;
import io.grpc.StreamTracer;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener;
Expand Down
38 changes: 38 additions & 0 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java
Expand Up @@ -20,6 +20,7 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.errorprone.annotations.DoNotCall;
import io.grpc.ChoiceServerCredentials;
import io.grpc.ExperimentalApi;
Expand Down Expand Up @@ -117,6 +118,8 @@ public static OkHttpServerBuilder forPort(SocketAddress address, ServerCredentia
int maxInboundMetadataSize = GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE;
int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
long maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED;
boolean permitKeepAliveWithoutCalls;
long permitKeepAliveTimeInNanos = TimeUnit.MINUTES.toNanos(5);

@VisibleForTesting
OkHttpServerBuilder(
Expand Down Expand Up @@ -223,6 +226,41 @@ public OkHttpServerBuilder keepAliveTimeout(long keepAliveTimeout, TimeUnit time
return this;
}

/**
* Specify the most aggressive keep-alive time clients are permitted to configure. The server will
* try to detect clients exceeding this rate and when detected will forcefully close the
* connection. The default is 5 minutes.
*
* <p>Even though a default is defined that allows some keep-alives, clients must not use
* keep-alive without approval from the service owner. Otherwise, they may experience failures in
* the future if the service becomes more restrictive. When unthrottled, keep-alives can cause a
* significant amount of traffic and CPU usage, so clients and servers should be conservative in
* what they use and accept.
*
* @see #permitKeepAliveWithoutCalls(boolean)
*/
@CanIgnoreReturnValue
@Override
public OkHttpServerBuilder permitKeepAliveTime(long keepAliveTime, TimeUnit timeUnit) {
checkArgument(keepAliveTime >= 0, "permit keepalive time must be non-negative: %s",
keepAliveTime);
permitKeepAliveTimeInNanos = timeUnit.toNanos(keepAliveTime);
return this;
}

/**
* Sets whether to allow clients to send keep-alive HTTP/2 PINGs even if there are no outstanding
* RPCs on the connection. Defaults to {@code false}.
*
* @see #permitKeepAliveTime(long, TimeUnit)
*/
@CanIgnoreReturnValue
@Override
public OkHttpServerBuilder permitKeepAliveWithoutCalls(boolean permit) {
permitKeepAliveWithoutCalls = permit;
return this;
}

/**
* Sets the flow control window in bytes. If not called, the default value is 64 KiB.
*/
Expand Down
57 changes: 50 additions & 7 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java
Expand Up @@ -29,6 +29,7 @@
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.MaxConnectionIdleManager;
import io.grpc.internal.ObjectPool;
Expand Down Expand Up @@ -95,6 +96,7 @@ final class OkHttpServerTransport implements ServerTransport,
private Attributes attributes;
private KeepAliveManager keepAliveManager;
private MaxConnectionIdleManager maxConnectionIdleManager;
private final KeepAliveEnforcer keepAliveEnforcer;

private final Object lock = new Object();
@GuardedBy("lock")
Expand Down Expand Up @@ -137,6 +139,8 @@ public OkHttpServerTransport(Config config, Socket bareSocket) {
logId = InternalLogId.allocate(getClass(), bareSocket.getRemoteSocketAddress().toString());
transportExecutor = config.transportExecutorPool.getObject();
scheduledExecutorService = config.scheduledExecutorServicePool.getObject();
keepAliveEnforcer = new KeepAliveEnforcer(config.permitKeepAliveWithoutCalls,
config.permitKeepAliveTimeInNanos, TimeUnit.NANOSECONDS);
}

public void start(ServerTransportListener listener) {
Expand All @@ -159,6 +163,27 @@ private void startIo(SerializingExecutor serializingExecutor) {
asyncSink.becomeConnected(Okio.sink(socket), socket);
FrameWriter rawFrameWriter = asyncSink.limitControlFramesWriter(
variant.newWriter(Okio.buffer(asyncSink), false));
FrameWriter writeMonitoringFrameWriter = new ForwardingFrameWriter(rawFrameWriter) {
@Override
public void synReply(boolean outFinished, int streamId, List<Header> headerBlock)
throws IOException {
keepAliveEnforcer.resetCounters();
super.synReply(outFinished, streamId, headerBlock);
}

@Override
public void headers(int streamId, List<Header> headerBlock) throws IOException {
keepAliveEnforcer.resetCounters();
super.headers(streamId, headerBlock);
}

@Override
public void data(boolean outFinished, int streamId, Buffer source, int byteCount)
throws IOException {
keepAliveEnforcer.resetCounters();
super.data(outFinished, streamId, source, byteCount);
}
};
synchronized (lock) {
this.securityInfo = result.securityInfo;

Expand All @@ -167,7 +192,7 @@ private void startIo(SerializingExecutor serializingExecutor) {
// does not propagate syscall errors through the FrameWriter. But we handle the
// AsyncSink failures with the same TransportExceptionHandler instance so it is all
// mixed back together.
frameWriter = new ExceptionHandlingFrameWriter(this, rawFrameWriter);
frameWriter = new ExceptionHandlingFrameWriter(this, writeMonitoringFrameWriter);
outboundFlow = new OutboundFlowController(this, frameWriter);

// These writes will be queued in the serializingExecutor waiting for this function to
Expand Down Expand Up @@ -381,8 +406,11 @@ public OutboundFlowController.StreamState[] getActiveStreams() {
void streamClosed(int streamId, boolean flush) {
synchronized (lock) {
streams.remove(streamId);
if (maxConnectionIdleManager != null && streams.isEmpty()) {
maxConnectionIdleManager.onTransportIdle();
if (streams.isEmpty()) {
keepAliveEnforcer.onTransportIdle();
if (maxConnectionIdleManager != null) {
maxConnectionIdleManager.onTransportIdle();
}
}
if (gracefulShutdown && streams.isEmpty()) {
frameWriter.close();
Expand Down Expand Up @@ -449,6 +477,8 @@ static final class Config {
final int maxInboundMessageSize;
final int maxInboundMetadataSize;
final long maxConnectionIdleNanos;
final boolean permitKeepAliveWithoutCalls;
final long permitKeepAliveTimeInNanos;

public Config(
OkHttpServerBuilder builder,
Expand All @@ -469,6 +499,8 @@ public Config(
maxInboundMessageSize = builder.maxInboundMessageSize;
maxInboundMetadataSize = builder.maxInboundMetadataSize;
maxConnectionIdleNanos = builder.maxConnectionIdleInNanos;
permitKeepAliveWithoutCalls = builder.permitKeepAliveWithoutCalls;
permitKeepAliveTimeInNanos = builder.permitKeepAliveTimeInNanos;
}
}

Expand Down Expand Up @@ -714,8 +746,11 @@ public void headers(boolean outFinished,
authority == null ? null : asciiString(authority),
statsTraceCtx,
tracer);
if (maxConnectionIdleManager != null && streams.isEmpty()) {
maxConnectionIdleManager.onTransportActive();
if (streams.isEmpty()) {
keepAliveEnforcer.onTransportActive();
if (maxConnectionIdleManager != null) {
maxConnectionIdleManager.onTransportActive();
}
}
streams.put(streamId, stream);
listener.streamCreated(streamForApp, method, metadata);
Expand Down Expand Up @@ -849,6 +884,11 @@ public void settings(boolean clearPrevious, Settings settings) {

@Override
public void ping(boolean ack, int payload1, int payload2) {
if (!keepAliveEnforcer.pingAcceptable()) {
abruptShutdown(ErrorCode.ENHANCE_YOUR_CALM, "too_many_pings",
Status.RESOURCE_EXHAUSTED.withDescription("Too many pings from client"), false);
return;
}
long payload = (((long) payload1) << 32) | (payload2 & 0xffffffffL);
if (!ack) {
frameLogger.logPing(OkHttpFrameLogger.Direction.INBOUND, payload);
Expand Down Expand Up @@ -973,8 +1013,11 @@ private void respondWithHttpError(
synchronized (lock) {
Http2ErrorStreamState stream =
new Http2ErrorStreamState(streamId, lock, outboundFlow, config.flowControlWindow);
if (maxConnectionIdleManager != null && streams.isEmpty()) {
maxConnectionIdleManager.onTransportActive();
if (streams.isEmpty()) {
keepAliveEnforcer.onTransportActive();
if (maxConnectionIdleManager != null) {
maxConnectionIdleManager.onTransportActive();
}
}
streams.put(streamId, stream);
if (inFinished) {
Expand Down
100 changes: 99 additions & 1 deletion okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java
Expand Up @@ -41,6 +41,7 @@
import io.grpc.Status;
import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransportListener;
Expand Down Expand Up @@ -122,7 +123,9 @@ public class OkHttpServerTransportTest {
}
})
.flowControlWindow(INITIAL_WINDOW_SIZE)
.maxConnectionIdle(MAX_CONNECTION_IDLE, TimeUnit.NANOSECONDS);
.maxConnectionIdle(MAX_CONNECTION_IDLE, TimeUnit.NANOSECONDS)
.permitKeepAliveWithoutCalls(true)
.permitKeepAliveTime(0, TimeUnit.SECONDS);

@Rule public final Timeout globalTimeout = Timeout.seconds(10);

Expand Down Expand Up @@ -1054,6 +1057,101 @@ public void channelzStats() throws Exception {
assertThat(stats.remote).isEqualTo(new InetSocketAddress("127.0.0.2", 5000));
}

@Test
public void keepAliveEnforcer_enforcesPings() throws Exception {
serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS)
.permitKeepAliveWithoutCalls(false);
initTransport();
handshake();

for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) {
pingPong();
}
pingPongId++;
clientFrameWriter.ping(false, pingPongId, 0);
clientFrameWriter.flush();
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM,
ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII));
}

@Test
public void keepAliveEnforcer_sendingDataResetsCounters() throws Exception {
serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS)
.permitKeepAliveWithoutCalls(false);
initTransport();
handshake();

clientFrameWriter.headers(1, Arrays.asList(
HTTP_SCHEME_HEADER,
METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "example.com:80"),
new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"),
CONTENT_TYPE_HEADER,
TE_HEADER,
new Header("some-metadata", "this could be anything")));
Buffer requestMessageFrame = createMessageFrame("Hello server");
clientFrameWriter.data(false, 1, requestMessageFrame, (int) requestMessageFrame.size());
pingPong();
MockStreamListener streamListener = mockTransportListener.newStreams.pop();

streamListener.stream.request(1);
pingPong();
assertThat(streamListener.messages.pop()).isEqualTo("Hello server");

streamListener.stream.writeHeaders(metadata("User-Data", "best data"));
streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8)));
streamListener.stream.flush();
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();

for (int i = 0; i < 10; i++) {
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
pingPong();
streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8)));
streamListener.stream.flush();
}
}

@Test
public void keepAliveEnforcer_initialIdle() throws Exception {
serverBuilder.permitKeepAliveTime(0, TimeUnit.SECONDS)
.permitKeepAliveWithoutCalls(false);
initTransport();
handshake();

for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) {
pingPong();
}
pingPongId++;
clientFrameWriter.ping(false, pingPongId, 0);
clientFrameWriter.flush();
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM,
ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII));
}

@Test
public void keepAliveEnforcer_noticesActive() throws Exception {
serverBuilder.permitKeepAliveTime(0, TimeUnit.SECONDS)
.permitKeepAliveWithoutCalls(false);
initTransport();
handshake();

clientFrameWriter.headers(1, Arrays.asList(
HTTP_SCHEME_HEADER,
METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "example.com:80"),
new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"),
CONTENT_TYPE_HEADER,
TE_HEADER,
new Header("some-metadata", "this could be anything")));
for (int i = 0; i < 10; i++) {
pingPong();
}
verify(clientFramesRead, never()).goAway(anyInt(), eq(ErrorCode.ENHANCE_YOUR_CALM),
eq(ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII)));
}

private void initTransport() throws Exception {
serverTransport = new OkHttpServerTransport(
new OkHttpServerTransport.Config(serverBuilder, Arrays.asList()),
Expand Down

0 comments on commit b3e10e0

Please sign in to comment.