Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

okhttp: add okhttpServerBuilder permitKeepAliveTime() and permitKeepAliveWithoutCalls() for server keepAlive enforcement #9544

Merged
merged 3 commits into from Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -14,19 +14,19 @@
* limitations under the License.
*/

package io.grpc.netty;
package io.grpc.internal;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.util.concurrent.TimeUnit;
import javax.annotation.CheckReturnValue;

/** Monitors the client's PING usage to make sure the rate is permitted. */
class KeepAliveEnforcer {
public final class KeepAliveEnforcer {
@VisibleForTesting
static final int MAX_PING_STRIKES = 2;
public static final int MAX_PING_STRIKES = 2;
@VisibleForTesting
static final long IMPLICIT_PERMIT_TIME_NANOS = TimeUnit.HOURS.toNanos(2);
public static final long IMPLICIT_PERMIT_TIME_NANOS = TimeUnit.HOURS.toNanos(2);

private final boolean permitWithoutCalls;
private final long minTimeNanos;
Expand Down
1 change: 1 addition & 0 deletions netty/src/main/java/io/grpc/netty/NettyServerHandler.java
Expand Up @@ -44,6 +44,7 @@
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.LogExceptionRunnable;
import io.grpc.internal.MaxConnectionIdleManager;
Expand Down
Expand Up @@ -14,10 +14,11 @@
* limitations under the License.
*/

package io.grpc.netty;
package io.grpc.internal;

import static com.google.common.truth.Truth.assertThat;

import io.grpc.internal.KeepAliveEnforcer;
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down
Expand Up @@ -61,6 +61,7 @@
import io.grpc.Status.Code;
import io.grpc.StreamTracer;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener;
Expand Down
38 changes: 38 additions & 0 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java
Expand Up @@ -20,6 +20,7 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.errorprone.annotations.DoNotCall;
import io.grpc.ChoiceServerCredentials;
import io.grpc.ExperimentalApi;
Expand Down Expand Up @@ -117,6 +118,8 @@ public static OkHttpServerBuilder forPort(SocketAddress address, ServerCredentia
int maxInboundMetadataSize = GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE;
int maxInboundMessageSize = GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
long maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED;
boolean permitKeepAliveWithoutCalls;
long permitKeepAliveTimeInNanos = TimeUnit.MINUTES.toNanos(5);

@VisibleForTesting
OkHttpServerBuilder(
Expand Down Expand Up @@ -223,6 +226,41 @@ public OkHttpServerBuilder keepAliveTimeout(long keepAliveTimeout, TimeUnit time
return this;
}

/**
* Specify the most aggressive keep-alive time clients are permitted to configure. The server will
* try to detect clients exceeding this rate and when detected will forcefully close the
* connection. The default is 5 minutes.
*
* <p>Even though a default is defined that allows some keep-alives, clients must not use
* keep-alive without approval from the service owner. Otherwise, they may experience failures in
* the future if the service becomes more restrictive. When unthrottled, keep-alives can cause a
* significant amount of traffic and CPU usage, so clients and servers should be conservative in
* what they use and accept.
*
* @see #permitKeepAliveWithoutCalls(boolean)
*/
@CanIgnoreReturnValue
@Override
public OkHttpServerBuilder permitKeepAliveTime(long keepAliveTime, TimeUnit timeUnit) {
checkArgument(keepAliveTime >= 0, "permit keepalive time must be non-negative: %s",
keepAliveTime);
permitKeepAliveTimeInNanos = timeUnit.toNanos(keepAliveTime);
return this;
}

/**
* Sets whether to allow clients to send keep-alive HTTP/2 PINGs even if there are no outstanding
* RPCs on the connection. Defaults to {@code false}.
*
* @see #permitKeepAliveTime(long, TimeUnit)
*/
@CanIgnoreReturnValue
@Override
public OkHttpServerBuilder permitKeepAliveWithoutCalls(boolean permit) {
permitKeepAliveWithoutCalls = permit;
return this;
}

/**
* Sets the flow control window in bytes. If not called, the default value is 64 KiB.
*/
Expand Down
57 changes: 50 additions & 7 deletions okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java
Expand Up @@ -29,6 +29,7 @@
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.KeepAliveManager;
import io.grpc.internal.MaxConnectionIdleManager;
import io.grpc.internal.ObjectPool;
Expand Down Expand Up @@ -95,6 +96,7 @@ final class OkHttpServerTransport implements ServerTransport,
private Attributes attributes;
private KeepAliveManager keepAliveManager;
private MaxConnectionIdleManager maxConnectionIdleManager;
private KeepAliveEnforcer keepAliveEnforcer;

private final Object lock = new Object();
@GuardedBy("lock")
Expand Down Expand Up @@ -157,8 +159,25 @@ private void startIo(SerializingExecutor serializingExecutor) {
int maxQueuedControlFrames = 10000;
AsyncSink asyncSink = AsyncSink.sink(serializingExecutor, this, maxQueuedControlFrames);
asyncSink.becomeConnected(Okio.sink(socket), socket);
this.keepAliveEnforcer = new KeepAliveEnforcer(
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
config.permitKeepAliveWithoutCalls, config.permitKeepAliveTimeInNanos,
TimeUnit.NANOSECONDS);
FrameWriter rawFrameWriter = asyncSink.limitControlFramesWriter(
variant.newWriter(Okio.buffer(asyncSink), false));
FrameWriter writeMonitoringFrameWriter = new ForwardingFrameWriter(rawFrameWriter) {
@Override
public void headers(int streamId, List<Header> headerBlock) throws IOException {
ejona86 marked this conversation as resolved.
Show resolved Hide resolved
keepAliveEnforcer.resetCounters();
super.headers(streamId, headerBlock);
}

@Override
public void data(boolean outFinished, int streamId, Buffer source, int byteCount)
throws IOException {
keepAliveEnforcer.resetCounters();
super.data(outFinished, streamId, source, byteCount);
}
};
synchronized (lock) {
this.securityInfo = result.securityInfo;

Expand All @@ -167,7 +186,7 @@ private void startIo(SerializingExecutor serializingExecutor) {
// does not propagate syscall errors through the FrameWriter. But we handle the
// AsyncSink failures with the same TransportExceptionHandler instance so it is all
// mixed back together.
frameWriter = new ExceptionHandlingFrameWriter(this, rawFrameWriter);
frameWriter = new ExceptionHandlingFrameWriter(this, writeMonitoringFrameWriter);
outboundFlow = new OutboundFlowController(this, frameWriter);

// These writes will be queued in the serializingExecutor waiting for this function to
Expand Down Expand Up @@ -381,8 +400,13 @@ public OutboundFlowController.StreamState[] getActiveStreams() {
void streamClosed(int streamId, boolean flush) {
synchronized (lock) {
streams.remove(streamId);
if (maxConnectionIdleManager != null && streams.isEmpty()) {
maxConnectionIdleManager.onTransportIdle();
if (streams.isEmpty()) {
if (maxConnectionIdleManager != null) {
maxConnectionIdleManager.onTransportIdle();
}
if (keepAliveEnforcer != null) {
YifeiZhuang marked this conversation as resolved.
Show resolved Hide resolved
keepAliveEnforcer.onTransportIdle();
}
}
if (gracefulShutdown && streams.isEmpty()) {
frameWriter.close();
Expand Down Expand Up @@ -449,6 +473,8 @@ static final class Config {
final int maxInboundMessageSize;
final int maxInboundMetadataSize;
final long maxConnectionIdleNanos;
final boolean permitKeepAliveWithoutCalls;
final long permitKeepAliveTimeInNanos;

public Config(
OkHttpServerBuilder builder,
Expand All @@ -469,6 +495,8 @@ public Config(
maxInboundMessageSize = builder.maxInboundMessageSize;
maxInboundMetadataSize = builder.maxInboundMetadataSize;
maxConnectionIdleNanos = builder.maxConnectionIdleInNanos;
permitKeepAliveWithoutCalls = builder.permitKeepAliveWithoutCalls;
permitKeepAliveTimeInNanos = builder.permitKeepAliveTimeInNanos;
}
}

Expand Down Expand Up @@ -714,8 +742,13 @@ public void headers(boolean outFinished,
authority == null ? null : asciiString(authority),
statsTraceCtx,
tracer);
if (maxConnectionIdleManager != null && streams.isEmpty()) {
maxConnectionIdleManager.onTransportActive();
if (streams.isEmpty()) {
if (maxConnectionIdleManager != null) {
maxConnectionIdleManager.onTransportActive();
}
if (keepAliveEnforcer != null) {
keepAliveEnforcer.onTransportActive();
}
}
streams.put(streamId, stream);
listener.streamCreated(streamForApp, method, metadata);
Expand Down Expand Up @@ -849,6 +882,11 @@ public void settings(boolean clearPrevious, Settings settings) {

@Override
public void ping(boolean ack, int payload1, int payload2) {
if (!keepAliveEnforcer.pingAcceptable()) {
abruptShutdown(ErrorCode.ENHANCE_YOUR_CALM, "too_many_pings",
Status.RESOURCE_EXHAUSTED.withDescription("Too many pings from client"), false);
return;
}
long payload = (((long) payload1) << 32) | (payload2 & 0xffffffffL);
if (!ack) {
frameLogger.logPing(OkHttpFrameLogger.Direction.INBOUND, payload);
Expand Down Expand Up @@ -973,8 +1011,13 @@ private void respondWithHttpError(
synchronized (lock) {
Http2ErrorStreamState stream =
new Http2ErrorStreamState(streamId, lock, outboundFlow, config.flowControlWindow);
if (maxConnectionIdleManager != null && streams.isEmpty()) {
maxConnectionIdleManager.onTransportActive();
if (streams.isEmpty()) {
if (maxConnectionIdleManager != null) {
maxConnectionIdleManager.onTransportActive();
}
if (keepAliveEnforcer != null) {
keepAliveEnforcer.onTransportActive();
}
}
streams.put(streamId, stream);
if (inFinished) {
Expand Down
100 changes: 99 additions & 1 deletion okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java
Expand Up @@ -41,6 +41,7 @@
import io.grpc.Status;
import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.KeepAliveEnforcer;
import io.grpc.internal.ServerStream;
import io.grpc.internal.ServerStreamListener;
import io.grpc.internal.ServerTransportListener;
Expand Down Expand Up @@ -122,7 +123,9 @@ public class OkHttpServerTransportTest {
}
})
.flowControlWindow(INITIAL_WINDOW_SIZE)
.maxConnectionIdle(MAX_CONNECTION_IDLE, TimeUnit.NANOSECONDS);
.maxConnectionIdle(MAX_CONNECTION_IDLE, TimeUnit.NANOSECONDS)
.permitKeepAliveWithoutCalls(true)
.permitKeepAliveTime(0, TimeUnit.SECONDS);

@Rule public final Timeout globalTimeout = Timeout.seconds(10);

Expand Down Expand Up @@ -1054,6 +1057,101 @@ public void channelzStats() throws Exception {
assertThat(stats.remote).isEqualTo(new InetSocketAddress("127.0.0.2", 5000));
}

@Test
public void keepAliveEnforcer_enforcesPings() throws Exception {
serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS)
.permitKeepAliveWithoutCalls(false);
initTransport();
handshake();

for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) {
pingPong();
}
pingPongId++;
clientFrameWriter.ping(false, pingPongId, 0);
clientFrameWriter.flush();
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM,
ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII));
}

@Test
public void keepAliveEnforcer_sendingDataResetsCounters() throws Exception {
serverBuilder.permitKeepAliveTime(1, TimeUnit.HOURS)
.permitKeepAliveWithoutCalls(false);
initTransport();
handshake();

clientFrameWriter.headers(1, Arrays.asList(
HTTP_SCHEME_HEADER,
METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "example.com:80"),
new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"),
CONTENT_TYPE_HEADER,
TE_HEADER,
new Header("some-metadata", "this could be anything")));
Buffer requestMessageFrame = createMessageFrame("Hello server");
clientFrameWriter.data(false, 1, requestMessageFrame, (int) requestMessageFrame.size());
pingPong();
MockStreamListener streamListener = mockTransportListener.newStreams.pop();

streamListener.stream.request(1);
pingPong();
assertThat(streamListener.messages.pop()).isEqualTo("Hello server");

streamListener.stream.writeHeaders(metadata("User-Data", "best data"));
streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8)));
streamListener.stream.flush();
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();

for (int i = 0; i < 10; i++) {
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
pingPong();
streamListener.stream.writeMessage(new ByteArrayInputStream("Howdy client".getBytes(UTF_8)));
streamListener.stream.flush();
}
}

@Test
public void keepAliveEnforcer_initialIdle() throws Exception {
serverBuilder.permitKeepAliveTime(0, TimeUnit.SECONDS)
.permitKeepAliveWithoutCalls(false);
initTransport();
handshake();

for (int i = 0; i < KeepAliveEnforcer.MAX_PING_STRIKES; i++) {
pingPong();
}
pingPongId++;
clientFrameWriter.ping(false, pingPongId, 0);
clientFrameWriter.flush();
assertThat(clientFrameReader.nextFrame(clientFramesRead)).isTrue();
verify(clientFramesRead).goAway(0, ErrorCode.ENHANCE_YOUR_CALM,
ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII));
}

@Test
public void keepAliveEnforcer_noticesActive() throws Exception {
serverBuilder.permitKeepAliveTime(0, TimeUnit.SECONDS)
.permitKeepAliveWithoutCalls(false);
initTransport();
handshake();

clientFrameWriter.headers(1, Arrays.asList(
HTTP_SCHEME_HEADER,
METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "example.com:80"),
new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"),
CONTENT_TYPE_HEADER,
TE_HEADER,
new Header("some-metadata", "this could be anything")));
for (int i = 0; i < 10; i++) {
pingPong();
}
verify(clientFramesRead, never()).goAway(anyInt(), eq(ErrorCode.ENHANCE_YOUR_CALM),
eq(ByteString.encodeString("too_many_pings", GrpcUtil.US_ASCII)));
}

private void initTransport() throws Exception {
serverTransport = new OkHttpServerTransport(
new OkHttpServerTransport.Config(serverBuilder, Arrays.asList()),
Expand Down