Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interop-testing: add fake altsHandshakerService for test #7847

Merged
merged 7 commits into from Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -170,6 +170,7 @@ public abstract class AbstractInteropTest {

private ScheduledExecutorService testServiceExecutor;
private Server server;
private Server handshakerServer;

private final LinkedBlockingQueue<ServerStreamTracerInfo> serverStreamTracers =
new LinkedBlockingQueue<>();
Expand Down Expand Up @@ -223,6 +224,7 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata
protected static final Empty EMPTY = Empty.getDefaultInstance();

private void startServer() {
maybeStartHandshakerServer();
ServerBuilder<?> builder = getServerBuilder();
if (builder == null) {
server = null;
Expand Down Expand Up @@ -251,13 +253,27 @@ private void startServer() {
}
}

private void maybeStartHandshakerServer() {
ServerBuilder<?> handshakerServerBuilder = getHandshakerServerBuilder();
if (handshakerServerBuilder != null) {
try {
handshakerServer = handshakerServerBuilder.build().start();
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
}

private void stopServer() {
if (server != null) {
server.shutdownNow();
}
if (testServiceExecutor != null) {
testServiceExecutor.shutdown();
}
if (handshakerServer != null) {
handshakerServer.shutdownNow();
}
}

@VisibleForTesting
Expand Down Expand Up @@ -348,6 +364,11 @@ protected ServerBuilder<?> getServerBuilder() {
return null;
}

@Nullable
protected ServerBuilder<?> getHandshakerServerBuilder() {
return null;
}

protected final ClientInterceptor createCensusStatsClientInterceptor() {
return
InternalCensusStatsAccessor
Expand Down
@@ -0,0 +1,139 @@
/*
* Copyright 2021 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc.testing.integration;

import static com.google.common.base.Preconditions.checkState;
import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.CLIENT_START;
import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.NEXT;
import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.SERVER_START;

import com.google.protobuf.ByteString;
import io.grpc.alts.internal.HandshakerReq;
import io.grpc.alts.internal.HandshakerResp;
import io.grpc.alts.internal.HandshakerResult;
import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceImplBase;
import io.grpc.alts.internal.Identity;
import io.grpc.alts.internal.RpcProtocolVersions;
import io.grpc.alts.internal.RpcProtocolVersions.Version;
import io.grpc.stub.StreamObserver;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;

public class AltsHandshakerTestService extends HandshakerServiceImplBase {
private static final Logger log = Logger.getLogger(AltsHandshakerTestService.class.getName());

private final Random random = new Random();
private static final int FIXED_LENGTH_OUTPUT = 16;
private final ByteString fakeOutput = data(FIXED_LENGTH_OUTPUT);
private final ByteString secret = data(128);
private State expectState = State.CLIENT_INIT;

@Override
public StreamObserver<HandshakerReq> doHandshake(
final StreamObserver<HandshakerResp> responseObserver) {
return new StreamObserver<HandshakerReq>() {
@Override
public void onNext(HandshakerReq value) {
log.log(Level.FINE, "request received: " + value);
switch (expectState) {
case CLIENT_INIT:
checkState(CLIENT_START.equals(value.getReqOneofCase()));
HandshakerResp initClient = HandshakerResp.newBuilder()
.setOutFrames(fakeOutput)
.build();
log.log(Level.FINE, "init client response " + initClient);
responseObserver.onNext(initClient);
expectState = State.SERVER_INIT;
break;
case SERVER_INIT:
checkState(SERVER_START.equals(value.getReqOneofCase()));
HandshakerResp initServer = HandshakerResp.newBuilder()
.setBytesConsumed(FIXED_LENGTH_OUTPUT)
.setOutFrames(fakeOutput)
.build();
log.log(Level.FINE, "init server response" + initServer);
responseObserver.onNext(initServer);
expectState = State.CLIENT_FINISH;
break;
case CLIENT_FINISH:
checkState(NEXT.equals(value.getReqOneofCase()));
HandshakerResp resp = HandshakerResp.newBuilder()
.setResult(getResult())
.setBytesConsumed(FIXED_LENGTH_OUTPUT)
.setOutFrames(fakeOutput)
.build();
log.log(Level.FINE, "client finished response " + resp);
responseObserver.onNext(resp);
expectState = State.SERVER_FINISH;
break;
case SERVER_FINISH:
resp = HandshakerResp.newBuilder()
.setResult(getResult())
.setBytesConsumed(FIXED_LENGTH_OUTPUT)
.build();
log.log(Level.FINE, "server finished response " + resp);
responseObserver.onNext(resp);
break;
default:
throw new RuntimeException("unknown state");
}
}

@Override
public void onError(Throwable t) {
log.log(Level.INFO, "onError " + t);
}

@Override
public void onCompleted() {
}
dapengzhang0 marked this conversation as resolved.
Show resolved Hide resolved
};
}

private HandshakerResult getResult() {
return HandshakerResult.newBuilder().setApplicationProtocol("grpc")
.setRecordProtocol("ALTSRP_GCM_AES128_REKEY")
.setKeyData(secret)
.setMaxFrameSize(131072)
.setPeerIdentity(Identity.newBuilder()
.setServiceAccount("123456789-compute@developer.gserviceaccount.com")
.build())
.setPeerRpcVersions(RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(Version.newBuilder()
.setMajor(2).setMinor(1)
.build())
.setMinRpcVersion(Version.newBuilder()
.setMajor(2).setMinor(1)
.build())
.build())
.build();
}

private ByteString data(int len) {
byte[] k = new byte[len];
random.nextBytes(k);
return ByteString.copyFrom(k);
}

private enum State {
CLIENT_INIT,
SERVER_INIT,
CLIENT_FINISH,
SERVER_FINISH
}
}
Expand Up @@ -21,8 +21,11 @@
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.InsecureServerCredentials;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptors;
import io.grpc.TlsChannelCredentials;
import io.grpc.alts.AltsChannelCredentials;
import io.grpc.alts.ComputeEngineChannelCredentials;
Expand All @@ -42,6 +45,7 @@
import java.io.FileInputStream;
import java.nio.charset.Charset;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

/**
* Application that starts a client for the {@link TestServiceGrpc.TestServiceImplBase} and runs
Expand Down Expand Up @@ -83,6 +87,8 @@ public static void main(String[] args) throws Exception {
private String serviceAccountKeyFile;
private String oauthScope;
private boolean fullStreamDecompression;
private boolean localHandshakerForTesting;
private int localHandshakerPort = 8000;
dapengzhang0 marked this conversation as resolved.
Show resolved Hide resolved

private Tester tester = new Tester();

Expand Down Expand Up @@ -141,6 +147,8 @@ void parseArgs(String[] args) {
oauthScope = value;
} else if ("full_stream_decompression".equals(key)) {
fullStreamDecompression = Boolean.parseBoolean(value);
} else if ("use_test_handshaker".equals(key)) {
localHandshakerForTesting = Boolean.parseBoolean(value);
} else {
System.err.println("Unknown argument: " + key);
usage = true;
Expand All @@ -165,6 +173,9 @@ void parseArgs(String[] args) {
+ "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + c.useAlts
+ "\n --use_test_handshaker Whether to use local ALTS handshaker service for "
+ "\n testing. Only effective when --use_alts=true. Default "
+ c.localHandshakerForTesting
+ "\n --use_upgrade=true|false Whether to use the h2c Upgrade mechanism."
+ "\n Enabling h2c Upgrade will disable TLS."
+ "\n Default " + c.useH2cUpgrade
Expand Down Expand Up @@ -398,7 +409,13 @@ protected ManagedChannelBuilder<?> createChannelBuilder() {

} else if (useAlts) {
useGeneric = true; // Retain old behavior; avoids erroring if incompatible
channelCredentials = AltsChannelCredentials.create();
if (localHandshakerForTesting) {
channelCredentials = AltsChannelCredentials.newBuilder()
.enableUntrustedAltsForTesting()
.setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build();
} else {
channelCredentials = AltsChannelCredentials.create();
}

} else if (useTls) {
if (!useTestCa) {
Expand Down Expand Up @@ -475,6 +492,18 @@ protected boolean metricsExpected() {
// TODO(zhangkun83): remove this override once the said issue is fixed.
return false;
}

@Override
@Nullable
protected ServerBuilder<?> getHandshakerServerBuilder() {
if (localHandshakerForTesting) {
return Grpc.newServerBuilderForPort(localHandshakerPort,
InsecureServerCredentials.create())
.addService(ServerInterceptors.intercept(new AltsHandshakerTestService()));
dapengzhang0 marked this conversation as resolved.
Show resolved Hide resolved
} else {
return null;
}
}
}

private static String validTestCasesHelpText() {
Expand Down
Expand Up @@ -70,6 +70,8 @@ public void run() {

private ScheduledExecutorService executor;
private Server server;
private boolean localHandshakerForTesting;
private int localHandshakerPort = 8000;

@VisibleForTesting
void parseArgs(String[] args) {
Expand Down Expand Up @@ -98,6 +100,8 @@ void parseArgs(String[] args) {
useTls = Boolean.parseBoolean(value);
} else if ("use_alts".equals(key)) {
useAlts = Boolean.parseBoolean(value);
} else if ("use_test_handshaker".equals(key)) {
localHandshakerForTesting = Boolean.parseBoolean(value);
} else if ("grpc_version".equals(key)) {
if (!"2".equals(value)) {
System.err.println("Only grpc version 2 is supported");
Expand All @@ -122,6 +126,9 @@ void parseArgs(String[] args) {
+ "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + s.useAlts
+ "\n --use_test_handshaker Whether to use local ALTS handshaker service for "
+ "\n testing. Only effective when --use_alts=true. Default "
+ s.localHandshakerForTesting
);
System.exit(1);
}
Expand All @@ -132,7 +139,13 @@ void start() throws Exception {
executor = Executors.newSingleThreadScheduledExecutor();
ServerCredentials serverCreds;
if (useAlts) {
serverCreds = AltsServerCredentials.create();
if (localHandshakerForTesting) {
serverCreds = AltsServerCredentials.newBuilder()
.enableUntrustedAltsForTesting()
.setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build();
} else {
serverCreds = AltsServerCredentials.create();
}
} else if (useTls) {
serverCreds = TlsServerCredentials.create(
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"));
Expand Down