Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interop-testing: add fake altsHandshakerService for test #7847

Merged
merged 7 commits into from Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -101,6 +101,7 @@ public void onError(Throwable t) {

@Override
public void onCompleted() {
responseObserver.onCompleted();
}
dapengzhang0 marked this conversation as resolved.
Show resolved Hide resolved
};
}
Expand Down
Expand Up @@ -25,7 +25,6 @@
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.ServerBuilder;
import io.grpc.ServerInterceptors;
import io.grpc.TlsChannelCredentials;
import io.grpc.alts.AltsChannelCredentials;
import io.grpc.alts.ComputeEngineChannelCredentials;
Expand Down Expand Up @@ -87,8 +86,7 @@ public static void main(String[] args) throws Exception {
private String serviceAccountKeyFile;
private String oauthScope;
private boolean fullStreamDecompression;
private boolean localHandshakerForTesting;
private int localHandshakerPort = 8000;
private int localHandshakerPort = -1;

private Tester tester = new Tester();

Expand Down Expand Up @@ -147,8 +145,8 @@ void parseArgs(String[] args) {
oauthScope = value;
} else if ("full_stream_decompression".equals(key)) {
fullStreamDecompression = Boolean.parseBoolean(value);
} else if ("use_test_handshaker".equals(key)) {
localHandshakerForTesting = Boolean.parseBoolean(value);
} else if ("local_handshaker_port".equals(key)) {
localHandshakerPort = Integer.parseInt(value);
} else {
System.err.println("Unknown argument: " + key);
usage = true;
Expand All @@ -173,9 +171,9 @@ void parseArgs(String[] args) {
+ "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + c.useAlts
+ "\n --use_test_handshaker Whether to use local ALTS handshaker service for "
+ "\n testing. Only effective when --use_alts=true. Default "
+ c.localHandshakerForTesting
+ "\n --local_handshaker_port=PORT"
+ "\n Use local ALTS handshaker service on the specified "
+ "\n port for testing. Only effective when --use_alts=true."
+ "\n --use_upgrade=true|false Whether to use the h2c Upgrade mechanism."
+ "\n Enabling h2c Upgrade will disable TLS."
+ "\n Default " + c.useH2cUpgrade
Expand Down Expand Up @@ -409,7 +407,7 @@ protected ManagedChannelBuilder<?> createChannelBuilder() {

} else if (useAlts) {
useGeneric = true; // Retain old behavior; avoids erroring if incompatible
if (localHandshakerForTesting) {
if (localHandshakerPort > -1) {
channelCredentials = AltsChannelCredentials.newBuilder()
.enableUntrustedAltsForTesting()
.setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build();
Expand Down Expand Up @@ -496,10 +494,10 @@ protected boolean metricsExpected() {
@Override
@Nullable
protected ServerBuilder<?> getHandshakerServerBuilder() {
if (localHandshakerForTesting) {
if (localHandshakerPort > -1) {
return Grpc.newServerBuilderForPort(localHandshakerPort,
InsecureServerCredentials.create())
.addService(ServerInterceptors.intercept(new AltsHandshakerTestService()));
.addService(new AltsHandshakerTestService());
} else {
return null;
}
Expand Down
Expand Up @@ -70,8 +70,7 @@ public void run() {

private ScheduledExecutorService executor;
private Server server;
private boolean localHandshakerForTesting;
private int localHandshakerPort = 8000;
private int localHandshakerPort = -1;

@VisibleForTesting
void parseArgs(String[] args) {
Expand Down Expand Up @@ -100,8 +99,8 @@ void parseArgs(String[] args) {
useTls = Boolean.parseBoolean(value);
} else if ("use_alts".equals(key)) {
useAlts = Boolean.parseBoolean(value);
} else if ("use_test_handshaker".equals(key)) {
localHandshakerForTesting = Boolean.parseBoolean(value);
} else if ("local_handshaker_port".equals(key)) {
localHandshakerPort = Integer.parseInt(value);
} else if ("grpc_version".equals(key)) {
if (!"2".equals(value)) {
System.err.println("Only grpc version 2 is supported");
Expand All @@ -126,9 +125,9 @@ void parseArgs(String[] args) {
+ "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + s.useAlts
+ "\n --use_test_handshaker Whether to use local ALTS handshaker service for "
+ "\n testing. Only effective when --use_alts=true. Default "
+ s.localHandshakerForTesting
+ "\n --local_handshaker_port=PORT"
+ "\n Use local ALTS handshaker service on the specified port "
+ "\n for testing. Only effective when --use_alts=true."
);
System.exit(1);
}
Expand All @@ -139,7 +138,7 @@ void start() throws Exception {
executor = Executors.newSingleThreadScheduledExecutor();
ServerCredentials serverCreds;
if (useAlts) {
if (localHandshakerForTesting) {
if (localHandshakerPort > -1) {
serverCreds = AltsServerCredentials.newBuilder()
.enableUntrustedAltsForTesting()
.setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build();
Expand Down
Expand Up @@ -16,91 +16,82 @@

package io.grpc.testing.integration;

import static io.grpc.testing.integration.AbstractInteropTest.EMPTY;
import static org.junit.Assert.assertEquals;

import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.ByteString;
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureServerCredentials;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.ServerCredentials;
import io.grpc.ServerInterceptors;
import io.grpc.alts.AltsChannelCredentials;
import io.grpc.alts.AltsServerCredentials;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.junit.After;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.integration.Messages.Payload;
import io.grpc.testing.integration.Messages.SimpleRequest;
import io.grpc.testing.integration.Messages.SimpleResponse;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class AltsHandshakerTest {
private ScheduledExecutorService executor;
private Server testServer;
@Rule
public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private Server handshakeServer;
private Server testServer;
private ManagedChannel channel;

private final int handshakerServerPort = 8000;
private final int testServerPort = 8080;
private final String serverHost = "localhost";

private void startHandshakerServer() throws Exception {
handshakeServer = Grpc.newServerBuilderForPort(handshakerServerPort,
InsecureServerCredentials.create())
.addService(ServerInterceptors.intercept(new AltsHandshakerTestService()))
.build()
.start();
private Server registerHandshakeServer() {
return grpcCleanup.register(ServerBuilder.forPort(0)
.addService(new AltsHandshakerTestService())
.build());
}

private void startAltsServer() throws Exception {
executor = Executors.newSingleThreadScheduledExecutor();
ServerCredentials serverCreds = AltsServerCredentials.newBuilder()
private Server registerTestServer() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually you can still use void startAltsServer(), only change in the old code needed is testServer = grpcCleanup.register(.....).start(); Registering is only a small step and does not deserve a method name.

ServerCredentials serverCredentials = AltsServerCredentials.newBuilder()
.enableUntrustedAltsForTesting()
.setHandshakerAddressForTesting(serverHost + ":" + handshakerServerPort)
.setHandshakerAddressForTesting("localhost:" + handshakeServer.getPort())
.build();
testServer = Grpc.newServerBuilderForPort(testServerPort, serverCreds)
.addService(ServerInterceptors.intercept(new TestServiceImpl(executor)))
.build()
.start();
return grpcCleanup.register(
Grpc.newServerBuilderForPort(0, serverCredentials)
.addService(new TestServiceGrpc.TestServiceImplBase() {
@Override
public void unaryCall(SimpleRequest request, StreamObserver<SimpleResponse> so) {
so.onNext(SimpleResponse.getDefaultInstance());
so.onCompleted();
}
}).build());
}

@Before
public void setup() throws Exception {
startHandshakerServer();
startAltsServer();

private ManagedChannel registerChannel() {
ChannelCredentials channelCredentials = AltsChannelCredentials.newBuilder()
.enableUntrustedAltsForTesting()
.setHandshakerAddressForTesting(serverHost + ":" + handshakerServerPort).build();
channel = Grpc.newChannelBuilderForAddress(serverHost, testServerPort, channelCredentials)
.build();
.setHandshakerAddressForTesting("localhost:" + handshakeServer.getPort()).build();
return grpcCleanup.register(
Grpc.newChannelBuilderForAddress("localhost", testServer.getPort(), channelCredentials)
.build());
}

@After
public void stop() throws Exception {
if (testServer != null) {
testServer.shutdown();
testServer.awaitTermination();
}
if (handshakeServer != null) {
handshakeServer.shutdown();
handshakeServer.awaitTermination();
}
if (channel != null) {
channel.shutdown();
channel.awaitTermination(1, TimeUnit.SECONDS);
}
MoreExecutors.shutdownAndAwaitTermination(executor, 10, TimeUnit.SECONDS);
@Before
public void setup() throws Exception {
handshakeServer = registerHandshakeServer();
handshakeServer.start();
testServer = registerTestServer();
testServer.start();
channel = registerChannel();
}

@Test
public void testAlts() {
TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(channel);
assertEquals(EMPTY, blockingStub.emptyCall(EMPTY));
final SimpleRequest request = SimpleRequest.newBuilder()
.setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[10])))
.build();
assertEquals(SimpleResponse.getDefaultInstance(), blockingStub.unaryCall(request));
}
}