Skip to content

Commit

Permalink
interop-testing: add fake altsHandshakerService for test (#7847)
Browse files Browse the repository at this point in the history
  • Loading branch information
YifeiZhuang committed Feb 10, 2021
1 parent 514101d commit 7f7821c
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 2 deletions.
Expand Up @@ -170,6 +170,7 @@ public abstract class AbstractInteropTest {

private ScheduledExecutorService testServiceExecutor;
private Server server;
private Server handshakerServer;

private final LinkedBlockingQueue<ServerStreamTracerInfo> serverStreamTracers =
new LinkedBlockingQueue<>();
Expand Down Expand Up @@ -223,6 +224,7 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata
protected static final Empty EMPTY = Empty.getDefaultInstance();

private void startServer() {
maybeStartHandshakerServer();
ServerBuilder<?> builder = getServerBuilder();
if (builder == null) {
server = null;
Expand Down Expand Up @@ -251,13 +253,27 @@ private void startServer() {
}
}

private void maybeStartHandshakerServer() {
ServerBuilder<?> handshakerServerBuilder = getHandshakerServerBuilder();
if (handshakerServerBuilder != null) {
try {
handshakerServer = handshakerServerBuilder.build().start();
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
}

private void stopServer() {
if (server != null) {
server.shutdownNow();
}
if (testServiceExecutor != null) {
testServiceExecutor.shutdown();
}
if (handshakerServer != null) {
handshakerServer.shutdownNow();
}
}

@VisibleForTesting
Expand Down Expand Up @@ -348,6 +364,11 @@ protected ServerBuilder<?> getServerBuilder() {
return null;
}

@Nullable
protected ServerBuilder<?> getHandshakerServerBuilder() {
return null;
}

protected final ClientInterceptor createCensusStatsClientInterceptor() {
return
InternalCensusStatsAccessor
Expand Down
@@ -0,0 +1,146 @@
/*
* Copyright 2021 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc.testing.integration;

import static com.google.common.base.Preconditions.checkState;
import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.CLIENT_START;
import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.NEXT;
import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.SERVER_START;

import com.google.protobuf.ByteString;
import io.grpc.alts.internal.HandshakerReq;
import io.grpc.alts.internal.HandshakerResp;
import io.grpc.alts.internal.HandshakerResult;
import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceImplBase;
import io.grpc.alts.internal.Identity;
import io.grpc.alts.internal.RpcProtocolVersions;
import io.grpc.alts.internal.RpcProtocolVersions.Version;
import io.grpc.stub.StreamObserver;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
* A fake HandshakeService for ALTS integration testing in non-gcp environments.
* */
public class AltsHandshakerTestService extends HandshakerServiceImplBase {
private static final Logger log = Logger.getLogger(AltsHandshakerTestService.class.getName());

private final Random random = new Random();
private static final int FIXED_LENGTH_OUTPUT = 16;
private final ByteString fakeOutput = data(FIXED_LENGTH_OUTPUT);
private final ByteString secret = data(128);
private State expectState = State.CLIENT_INIT;

@Override
public StreamObserver<HandshakerReq> doHandshake(
final StreamObserver<HandshakerResp> responseObserver) {
return new StreamObserver<HandshakerReq>() {
@Override
public void onNext(HandshakerReq value) {
log.log(Level.FINE, "request received: " + value);
synchronized (this) {
switch (expectState) {
case CLIENT_INIT:
checkState(CLIENT_START.equals(value.getReqOneofCase()));
HandshakerResp initClient = HandshakerResp.newBuilder()
.setOutFrames(fakeOutput)
.build();
log.log(Level.FINE, "init client response " + initClient);
responseObserver.onNext(initClient);
expectState = State.SERVER_INIT;
break;
case SERVER_INIT:
checkState(SERVER_START.equals(value.getReqOneofCase()));
HandshakerResp initServer = HandshakerResp.newBuilder()
.setBytesConsumed(FIXED_LENGTH_OUTPUT)
.setOutFrames(fakeOutput)
.build();
log.log(Level.FINE, "init server response" + initServer);
responseObserver.onNext(initServer);
expectState = State.CLIENT_FINISH;
break;
case CLIENT_FINISH:
checkState(NEXT.equals(value.getReqOneofCase()));
HandshakerResp resp = HandshakerResp.newBuilder()
.setResult(getResult())
.setBytesConsumed(FIXED_LENGTH_OUTPUT)
.setOutFrames(fakeOutput)
.build();
log.log(Level.FINE, "client finished response " + resp);
responseObserver.onNext(resp);
expectState = State.SERVER_FINISH;
break;
case SERVER_FINISH:
resp = HandshakerResp.newBuilder()
.setResult(getResult())
.setBytesConsumed(FIXED_LENGTH_OUTPUT)
.build();
log.log(Level.FINE, "server finished response " + resp);
responseObserver.onNext(resp);
expectState = State.CLIENT_INIT;
break;
default:
throw new RuntimeException("unknown state");
}
}
}

@Override
public void onError(Throwable t) {
log.log(Level.INFO, "onError " + t);
}

@Override
public void onCompleted() {
responseObserver.onCompleted();
}
};
}

private HandshakerResult getResult() {
return HandshakerResult.newBuilder().setApplicationProtocol("grpc")
.setRecordProtocol("ALTSRP_GCM_AES128_REKEY")
.setKeyData(secret)
.setMaxFrameSize(131072)
.setPeerIdentity(Identity.newBuilder()
.setServiceAccount("123456789-compute@developer.gserviceaccount.com")
.build())
.setPeerRpcVersions(RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder()
.setMajor(2).setMinor(1)
.build())
.setMinRpcVersion(Version.newBuilder()
.setMajor(2).setMinor(1)
.build())
.build())
.build();
}

private ByteString data(int len) {
byte[] k = new byte[len];
random.nextBytes(k);
return ByteString.copyFrom(k);
}

private enum State {
CLIENT_INIT,
SERVER_INIT,
CLIENT_FINISH,
SERVER_FINISH
}
}
Expand Up @@ -21,8 +21,10 @@
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.InsecureServerCredentials;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.ServerBuilder;
import io.grpc.TlsChannelCredentials;
import io.grpc.alts.AltsChannelCredentials;
import io.grpc.alts.ComputeEngineChannelCredentials;
Expand All @@ -42,6 +44,7 @@
import java.io.FileInputStream;
import java.nio.charset.Charset;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

/**
* Application that starts a client for the {@link TestServiceGrpc.TestServiceImplBase} and runs
Expand Down Expand Up @@ -83,6 +86,7 @@ public static void main(String[] args) throws Exception {
private String serviceAccountKeyFile;
private String oauthScope;
private boolean fullStreamDecompression;
private int localHandshakerPort = -1;

private Tester tester = new Tester();

Expand Down Expand Up @@ -141,6 +145,8 @@ void parseArgs(String[] args) {
oauthScope = value;
} else if ("full_stream_decompression".equals(key)) {
fullStreamDecompression = Boolean.parseBoolean(value);
} else if ("local_handshaker_port".equals(key)) {
localHandshakerPort = Integer.parseInt(value);
} else {
System.err.println("Unknown argument: " + key);
usage = true;
Expand All @@ -165,6 +171,9 @@ void parseArgs(String[] args) {
+ "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + c.useAlts
+ "\n --local_handshaker_port=PORT"
+ "\n Use local ALTS handshaker service on the specified "
+ "\n port for testing. Only effective when --use_alts=true."
+ "\n --use_upgrade=true|false Whether to use the h2c Upgrade mechanism."
+ "\n Enabling h2c Upgrade will disable TLS."
+ "\n Default " + c.useH2cUpgrade
Expand Down Expand Up @@ -398,7 +407,13 @@ protected ManagedChannelBuilder<?> createChannelBuilder() {

} else if (useAlts) {
useGeneric = true; // Retain old behavior; avoids erroring if incompatible
channelCredentials = AltsChannelCredentials.create();
if (localHandshakerPort > -1) {
channelCredentials = AltsChannelCredentials.newBuilder()
.enableUntrustedAltsForTesting()
.setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build();
} else {
channelCredentials = AltsChannelCredentials.create();
}

} else if (useTls) {
if (!useTestCa) {
Expand Down Expand Up @@ -475,6 +490,18 @@ protected boolean metricsExpected() {
// TODO(zhangkun83): remove this override once the said issue is fixed.
return false;
}

@Override
@Nullable
protected ServerBuilder<?> getHandshakerServerBuilder() {
if (localHandshakerPort > -1) {
return Grpc.newServerBuilderForPort(localHandshakerPort,
InsecureServerCredentials.create())
.addService(new AltsHandshakerTestService());
} else {
return null;
}
}
}

private static String validTestCasesHelpText() {
Expand Down
Expand Up @@ -70,6 +70,7 @@ public void run() {

private ScheduledExecutorService executor;
private Server server;
private int localHandshakerPort = -1;

@VisibleForTesting
void parseArgs(String[] args) {
Expand Down Expand Up @@ -98,6 +99,8 @@ void parseArgs(String[] args) {
useTls = Boolean.parseBoolean(value);
} else if ("use_alts".equals(key)) {
useAlts = Boolean.parseBoolean(value);
} else if ("local_handshaker_port".equals(key)) {
localHandshakerPort = Integer.parseInt(value);
} else if ("grpc_version".equals(key)) {
if (!"2".equals(value)) {
System.err.println("Only grpc version 2 is supported");
Expand All @@ -122,6 +125,9 @@ void parseArgs(String[] args) {
+ "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + s.useAlts
+ "\n --local_handshaker_port=PORT"
+ "\n Use local ALTS handshaker service on the specified port "
+ "\n for testing. Only effective when --use_alts=true."
);
System.exit(1);
}
Expand All @@ -132,7 +138,13 @@ void start() throws Exception {
executor = Executors.newSingleThreadScheduledExecutor();
ServerCredentials serverCreds;
if (useAlts) {
serverCreds = AltsServerCredentials.create();
if (localHandshakerPort > -1) {
serverCreds = AltsServerCredentials.newBuilder()
.enableUntrustedAltsForTesting()
.setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build();
} else {
serverCreds = AltsServerCredentials.create();
}
} else if (useTls) {
serverCreds = TlsServerCredentials.create(
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"));
Expand Down

0 comments on commit 7f7821c

Please sign in to comment.