diff --git a/balancer/rls/internal/balancer.go b/balancer/rls/internal/balancer.go index b23783bf9da..e5985eeee35 100644 --- a/balancer/rls/internal/balancer.go +++ b/balancer/rls/internal/balancer.go @@ -19,183 +19,36 @@ package rls import ( - "sync" - - "google.golang.org/grpc" "google.golang.org/grpc/balancer" "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/internal/grpcsync" ) var ( _ balancer.Balancer = (*rlsBalancer)(nil) - // For overriding in tests. - newRLSClientFunc = newRLSClient - logger = grpclog.Component("rls") + logger = grpclog.Component("rls") ) // rlsBalancer implements the RLS LB policy. -type rlsBalancer struct { - done *grpcsync.Event - cc balancer.ClientConn - opts balancer.BuildOptions - - // Mutex protects all the state maintained by the LB policy. - // TODO(easwars): Once we add the cache, we will also have another lock for - // the cache alone. - mu sync.Mutex - lbCfg *lbConfig // Most recently received service config. - rlsCC *grpc.ClientConn // ClientConn to the RLS server. - rlsC *rlsClient // RLS client wrapper. - - ccUpdateCh chan *balancer.ClientConnState -} - -// run is a long running goroutine which handles all the updates that the -// balancer wishes to handle. The appropriate updateHandler will push the update -// on to a channel that this goroutine will select on, thereby the handling of -// the update will happen asynchronously. -func (lb *rlsBalancer) run() { - for { - // TODO(easwars): Handle other updates like subConn state changes, RLS - // responses from the server etc. - select { - case u := <-lb.ccUpdateCh: - lb.handleClientConnUpdate(u) - case <-lb.done.Done(): - return - } - } -} - -// handleClientConnUpdate handles updates to the service config. -// If the RLS server name or the RLS RPC timeout changes, it updates the control -// channel accordingly. -// TODO(easwars): Handle updates to other fields in the service config. -func (lb *rlsBalancer) handleClientConnUpdate(ccs *balancer.ClientConnState) { - logger.Infof("rls: service config: %+v", ccs.BalancerConfig) - lb.mu.Lock() - defer lb.mu.Unlock() - - if lb.done.HasFired() { - logger.Warning("rls: received service config after balancer close") - return - } - - newCfg := ccs.BalancerConfig.(*lbConfig) - if lb.lbCfg.Equal(newCfg) { - logger.Info("rls: new service config matches existing config") - return - } - - lb.updateControlChannel(newCfg) - lb.lbCfg = newCfg -} +type rlsBalancer struct{} -// UpdateClientConnState pushes the received ClientConnState update on the -// update channel which will be processed asynchronously by the run goroutine. -// Implements balancer.Balancer interface. func (lb *rlsBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { - select { - case lb.ccUpdateCh <- &ccs: - case <-lb.done.Done(): - } + logger.Fatal("rls: UpdateClientConnState is not yet unimplemented") return nil } -// ResolverErr implements balancer.Balancer interface. func (lb *rlsBalancer) ResolverError(error) { - // ResolverError is called by gRPC when the name resolver reports an error. - // TODO(easwars): How do we handle this? logger.Fatal("rls: ResolverError is not yet unimplemented") } -// UpdateSubConnState implements balancer.Balancer interface. func (lb *rlsBalancer) UpdateSubConnState(_ balancer.SubConn, _ balancer.SubConnState) { logger.Fatal("rls: UpdateSubConnState is not yet implemented") } -// Cleans up the resources allocated by the LB policy including the clientConn -// to the RLS server. -// Implements balancer.Balancer. func (lb *rlsBalancer) Close() { - lb.mu.Lock() - defer lb.mu.Unlock() - - lb.done.Fire() - if lb.rlsCC != nil { - lb.rlsCC.Close() - } + logger.Fatal("rls: Close is not yet implemented") } func (lb *rlsBalancer) ExitIdle() { - // TODO: are we 100% sure this should be a nop? -} - -// updateControlChannel updates the RLS client if required. -// Caller must hold lb.mu. -func (lb *rlsBalancer) updateControlChannel(newCfg *lbConfig) { - oldCfg := lb.lbCfg - if newCfg.lookupService == oldCfg.lookupService && newCfg.lookupServiceTimeout == oldCfg.lookupServiceTimeout { - return - } - - // Use RPC timeout from new config, if different from existing one. - timeout := oldCfg.lookupServiceTimeout - if timeout != newCfg.lookupServiceTimeout { - timeout = newCfg.lookupServiceTimeout - } - - if newCfg.lookupService == oldCfg.lookupService { - // This is the case where only the timeout has changed. We will continue - // to use the existing clientConn. but will create a new rlsClient with - // the new timeout. - lb.rlsC = newRLSClientFunc(lb.rlsCC, lb.opts.Target.Endpoint, timeout) - return - } - - // This is the case where the RLS server name has changed. We need to create - // a new clientConn and close the old one. - var dopts []grpc.DialOption - if dialer := lb.opts.Dialer; dialer != nil { - dopts = append(dopts, grpc.WithContextDialer(dialer)) - } - dopts = append(dopts, dialCreds(lb.opts)) - - cc, err := grpc.Dial(newCfg.lookupService, dopts...) - if err != nil { - logger.Errorf("rls: dialRLS(%s, %v): %v", newCfg.lookupService, lb.opts, err) - // An error from a non-blocking dial indicates something serious. We - // should continue to use the old control channel if one exists, and - // return so that the rest of the config updates can be processes. - return - } - if lb.rlsCC != nil { - lb.rlsCC.Close() - } - lb.rlsCC = cc - lb.rlsC = newRLSClientFunc(cc, lb.opts.Target.Endpoint, timeout) -} - -func dialCreds(opts balancer.BuildOptions) grpc.DialOption { - // The control channel should use the same authority as that of the parent - // channel. This ensures that the identify of the RLS server and that of the - // backend is the same, so if the RLS config is injected by an attacker, it - // cannot cause leakage of private information contained in headers set by - // the application. - server := opts.Target.Authority - switch { - case opts.DialCreds != nil: - if err := opts.DialCreds.OverrideServerName(server); err != nil { - logger.Warningf("rls: OverrideServerName(%s) = (%v), using Insecure", server, err) - return grpc.WithInsecure() - } - return grpc.WithTransportCredentials(opts.DialCreds) - case opts.CredsBundle != nil: - return grpc.WithTransportCredentials(opts.CredsBundle.TransportCredentials()) - default: - logger.Warning("rls: no credentials available, using Insecure") - return grpc.WithInsecure() - } + logger.Fatal("rls: ExitIdle is not yet implemented") } diff --git a/balancer/rls/internal/balancer_test.go b/balancer/rls/internal/balancer_test.go deleted file mode 100644 index 2378a86fff1..00000000000 --- a/balancer/rls/internal/balancer_test.go +++ /dev/null @@ -1,238 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package rls - -import ( - "context" - "net" - "testing" - "time" - - "google.golang.org/grpc" - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/rls/internal/testutils/fakeserver" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/internal/grpctest" - "google.golang.org/grpc/internal/testutils" - "google.golang.org/grpc/testdata" -) - -const defaultTestTimeout = 1 * time.Second - -type s struct { - grpctest.Tester -} - -func Test(t *testing.T) { - grpctest.RunSubTests(t, s{}) -} - -type listenerWrapper struct { - net.Listener - connCh *testutils.Channel -} - -// Accept waits for and returns the next connection to the listener. -func (l *listenerWrapper) Accept() (net.Conn, error) { - c, err := l.Listener.Accept() - if err != nil { - return nil, err - } - l.connCh.Send(c) - return c, nil -} - -func setupwithListener(t *testing.T, opts ...grpc.ServerOption) (*fakeserver.Server, *listenerWrapper, func()) { - t.Helper() - - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("net.Listen(tcp, localhost:0): %v", err) - } - lw := &listenerWrapper{ - Listener: l, - connCh: testutils.NewChannel(), - } - - server, cleanup, err := fakeserver.Start(lw, opts...) - if err != nil { - t.Fatalf("fakeserver.Start(): %v", err) - } - t.Logf("Fake RLS server started at %s ...", server.Address) - - return server, lw, cleanup -} - -type testBalancerCC struct { - balancer.ClientConn -} - -// TestUpdateControlChannelFirstConfig tests the scenario where the LB policy -// receives its first service config and verifies that a control channel to the -// RLS server specified in the serviceConfig is established. -func (s) TestUpdateControlChannelFirstConfig(t *testing.T) { - server, lis, cleanup := setupwithListener(t) - defer cleanup() - - bb := balancer.Get(rlsBalancerName) - if bb == nil { - t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName) - } - rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{}) - defer rlsB.Close() - t.Log("Built RLS LB policy ...") - - lbCfg := &lbConfig{lookupService: server.Address} - t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) - rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := lis.connCh.Receive(ctx); err != nil { - t.Fatal("Timeout expired when waiting for LB policy to create control channel") - } - - // TODO: Verify channel connectivity state once control channel connectivity - // state monitoring is in place. - - // TODO: Verify RLS RPC can be made once we integrate with the picker. -} - -// TestUpdateControlChannelSwitch tests the scenario where a control channel -// exists and the LB policy receives a new serviceConfig with a different RLS -// server name. Verifies that the new control channel is created and the old one -// is closed (the leakchecker takes care of this). -func (s) TestUpdateControlChannelSwitch(t *testing.T) { - server1, lis1, cleanup1 := setupwithListener(t) - defer cleanup1() - - server2, lis2, cleanup2 := setupwithListener(t) - defer cleanup2() - - bb := balancer.Get(rlsBalancerName) - if bb == nil { - t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName) - } - rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{}) - defer rlsB.Close() - t.Log("Built RLS LB policy ...") - - lbCfg := &lbConfig{lookupService: server1.Address} - t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) - rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := lis1.connCh.Receive(ctx); err != nil { - t.Fatal("Timeout expired when waiting for LB policy to create control channel") - } - - lbCfg = &lbConfig{lookupService: server2.Address} - t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) - rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) - - if _, err := lis2.connCh.Receive(ctx); err != nil { - t.Fatal("Timeout expired when waiting for LB policy to create control channel") - } - - // TODO: Verify channel connectivity state once control channel connectivity - // state monitoring is in place. - - // TODO: Verify RLS RPC can be made once we integrate with the picker. -} - -// TestUpdateControlChannelTimeout tests the scenario where the LB policy -// receives a service config update with a different lookupServiceTimeout, but -// the lookupService itself remains unchanged. It verifies that the LB policy -// does not create a new control channel in this case. -func (s) TestUpdateControlChannelTimeout(t *testing.T) { - server, lis, cleanup := setupwithListener(t) - defer cleanup() - - bb := balancer.Get(rlsBalancerName) - if bb == nil { - t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName) - } - rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{}) - defer rlsB.Close() - t.Log("Built RLS LB policy ...") - - lbCfg := &lbConfig{lookupService: server.Address, lookupServiceTimeout: 1 * time.Second} - t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) - rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := lis.connCh.Receive(ctx); err != nil { - t.Fatal("Timeout expired when waiting for LB policy to create control channel") - } - - lbCfg = &lbConfig{lookupService: server.Address, lookupServiceTimeout: 2 * time.Second} - t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) - rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) - if _, err := lis.connCh.Receive(ctx); err != context.DeadlineExceeded { - t.Fatal("LB policy created new control channel when only lookupServiceTimeout changed") - } - - // TODO: Verify channel connectivity state once control channel connectivity - // state monitoring is in place. - - // TODO: Verify RLS RPC can be made once we integrate with the picker. -} - -// TestUpdateControlChannelWithCreds tests the scenario where the control -// channel is to established with credentials from the parent channel. -func (s) TestUpdateControlChannelWithCreds(t *testing.T) { - sCreds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) - if err != nil { - t.Fatalf("credentials.NewServerTLSFromFile(server1.pem, server1.key) = %v", err) - } - cCreds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "") - if err != nil { - t.Fatalf("credentials.NewClientTLSFromFile(ca.pem) = %v", err) - } - - server, lis, cleanup := setupwithListener(t, grpc.Creds(sCreds)) - defer cleanup() - - bb := balancer.Get(rlsBalancerName) - if bb == nil { - t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName) - } - rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{ - DialCreds: cCreds, - }) - defer rlsB.Close() - t.Log("Built RLS LB policy ...") - - lbCfg := &lbConfig{lookupService: server.Address} - t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) - rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := lis.connCh.Receive(ctx); err != nil { - t.Fatal("Timeout expired when waiting for LB policy to create control channel") - } - - // TODO: Verify channel connectivity state once control channel connectivity - // state monitoring is in place. - - // TODO: Verify RLS RPC can be made once we integrate with the picker. -} diff --git a/balancer/rls/internal/builder.go b/balancer/rls/internal/builder.go index 7c29caef404..9707b08420d 100644 --- a/balancer/rls/internal/builder.go +++ b/balancer/rls/internal/builder.go @@ -21,10 +21,9 @@ package rls import ( "google.golang.org/grpc/balancer" - "google.golang.org/grpc/internal/grpcsync" ) -const rlsBalancerName = "rls" +const rlsBalancerName = "rls_experimental" func init() { balancer.Register(&rlsBB{}) @@ -41,13 +40,6 @@ func (*rlsBB) Name() string { } func (*rlsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - lb := &rlsBalancer{ - done: grpcsync.NewEvent(), - cc: cc, - opts: opts, - lbCfg: &lbConfig{}, - ccUpdateCh: make(chan *balancer.ClientConnState), - } - go lb.run() - return lb + // TODO(easwars): Fix this once the LB policy implementation is pulled in. + return &rlsBalancer{} } diff --git a/balancer/rls/internal/client.go b/balancer/rls/internal/client.go deleted file mode 100644 index b0c858e032e..00000000000 --- a/balancer/rls/internal/client.go +++ /dev/null @@ -1,80 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package rls - -import ( - "context" - "time" - - "google.golang.org/grpc" - rlsgrpc "google.golang.org/grpc/internal/proto/grpc_lookup_v1" - rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" -) - -// For gRPC services using RLS, the value of target_type in the -// RouteLookupServiceRequest will be set to this. -const grpcTargetType = "grpc" - -// rlsClient is a simple wrapper around a RouteLookupService client which -// provides non-blocking semantics on top of a blocking unary RPC call. -// -// The RLS LB policy creates a new rlsClient object with the following values: -// * a grpc.ClientConn to the RLS server using appropriate credentials from the -// parent channel -// * dialTarget corresponding to the original user dial target, e.g. -// "firestore.googleapis.com". -// -// The RLS LB policy uses an adaptive throttler to perform client side -// throttling and asks this client to make an RPC call only after checking with -// the throttler. -type rlsClient struct { - stub rlsgrpc.RouteLookupServiceClient - // origDialTarget is the original dial target of the user and sent in each - // RouteLookup RPC made to the RLS server. - origDialTarget string - // rpcTimeout specifies the timeout for the RouteLookup RPC call. The LB - // policy receives this value in its service config. - rpcTimeout time.Duration -} - -func newRLSClient(cc *grpc.ClientConn, dialTarget string, rpcTimeout time.Duration) *rlsClient { - return &rlsClient{ - stub: rlsgrpc.NewRouteLookupServiceClient(cc), - origDialTarget: dialTarget, - rpcTimeout: rpcTimeout, - } -} - -type lookupCallback func(targets []string, headerData string, err error) - -// lookup starts a RouteLookup RPC in a separate goroutine and returns the -// results (and error, if any) in the provided callback. -func (c *rlsClient) lookup(keyMap map[string]string, cb lookupCallback) { - go func() { - ctx, cancel := context.WithTimeout(context.Background(), c.rpcTimeout) - resp, err := c.stub.RouteLookup(ctx, &rlspb.RouteLookupRequest{ - // TODO(easwars): Use extra_keys field to populate host, service and - // method keys. - TargetType: grpcTargetType, - KeyMap: keyMap, - }) - cb(resp.GetTargets(), resp.GetHeaderData(), err) - cancel() - }() -} diff --git a/balancer/rls/internal/client_test.go b/balancer/rls/internal/client_test.go deleted file mode 100644 index 9a805c77ca3..00000000000 --- a/balancer/rls/internal/client_test.go +++ /dev/null @@ -1,178 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package rls - -import ( - "context" - "errors" - "fmt" - "testing" - "time" - - "github.com/golang/protobuf/proto" - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc" - "google.golang.org/grpc/balancer/rls/internal/testutils/fakeserver" - "google.golang.org/grpc/codes" - rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" - "google.golang.org/grpc/internal/testutils" - "google.golang.org/grpc/status" -) - -const ( - defaultDialTarget = "dummy" - defaultRPCTimeout = 5 * time.Second -) - -func setup(t *testing.T) (*fakeserver.Server, *grpc.ClientConn, func()) { - t.Helper() - - server, sCleanup, err := fakeserver.Start(nil) - if err != nil { - t.Fatalf("Failed to start fake RLS server: %v", err) - } - - cc, cCleanup, err := server.ClientConn() - if err != nil { - sCleanup() - t.Fatalf("Failed to get a ClientConn to the RLS server: %v", err) - } - - return server, cc, func() { - sCleanup() - cCleanup() - } -} - -// TestLookupFailure verifies the case where the RLS server returns an error. -func (s) TestLookupFailure(t *testing.T) { - server, cc, cleanup := setup(t) - defer cleanup() - - // We setup the fake server to return an error. - server.ResponseChan <- fakeserver.Response{Err: errors.New("rls failure")} - - rlsClient := newRLSClient(cc, defaultDialTarget, defaultRPCTimeout) - - errCh := testutils.NewChannel() - rlsClient.lookup(nil, func(targets []string, headerData string, err error) { - if err == nil { - errCh.Send(errors.New("rlsClient.lookup() succeeded, should have failed")) - return - } - if len(targets) != 0 || headerData != "" { - errCh.Send(fmt.Errorf("rlsClient.lookup() = (%v, %s), want (nil, \"\")", targets, headerData)) - return - } - errCh.Send(nil) - }) - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if e, err := errCh.Receive(ctx); err != nil || e != nil { - t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err) - } -} - -// TestLookupDeadlineExceeded tests the case where the RPC deadline associated -// with the lookup expires. -func (s) TestLookupDeadlineExceeded(t *testing.T) { - _, cc, cleanup := setup(t) - defer cleanup() - - // Give the Lookup RPC a small deadline, but don't setup the fake server to - // return anything. So the Lookup call will block and eventually expire. - rlsClient := newRLSClient(cc, defaultDialTarget, 100*time.Millisecond) - - errCh := testutils.NewChannel() - rlsClient.lookup(nil, func(_ []string, _ string, err error) { - if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded { - errCh.Send(fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded)) - return - } - errCh.Send(nil) - }) - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if e, err := errCh.Receive(ctx); err != nil || e != nil { - t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err) - } -} - -// TestLookupSuccess verifies the successful Lookup API case. -func (s) TestLookupSuccess(t *testing.T) { - server, cc, cleanup := setup(t) - defer cleanup() - - const wantHeaderData = "headerData" - - rlsReqKeyMap := map[string]string{ - "k1": "v1", - "k2": "v2", - } - wantLookupRequest := &rlspb.RouteLookupRequest{ - // TODO(easwars): Use extra_keys field to populate host, service and - // method keys. - TargetType: "grpc", - KeyMap: rlsReqKeyMap, - } - wantRespTargets := []string{"us_east_1.firestore.googleapis.com"} - - rlsClient := newRLSClient(cc, defaultDialTarget, defaultRPCTimeout) - - errCh := testutils.NewChannel() - rlsClient.lookup(rlsReqKeyMap, func(targets []string, hd string, err error) { - if err != nil { - errCh.Send(fmt.Errorf("rlsClient.Lookup() failed: %v", err)) - return - } - if !cmp.Equal(targets, wantRespTargets) || hd != wantHeaderData { - errCh.Send(fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, hd, wantRespTargets, wantHeaderData)) - return - } - errCh.Send(nil) - }) - - // Make sure that the fake server received the expected RouteLookupRequest - // proto. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - req, err := server.RequestChan.Receive(ctx) - if err != nil { - t.Fatalf("Timed out wile waiting for a RouteLookupRequest") - } - gotLookupRequest := req.(*rlspb.RouteLookupRequest) - if diff := cmp.Diff(wantLookupRequest, gotLookupRequest, cmp.Comparer(proto.Equal)); diff != "" { - t.Fatalf("RouteLookupRequest diff (-want, +got):\n%s", diff) - } - - // We setup the fake server to return this response when it receives a - // request. - server.ResponseChan <- fakeserver.Response{ - Resp: &rlspb.RouteLookupResponse{ - Targets: wantRespTargets, - HeaderData: wantHeaderData, - }, - } - - if e, err := errCh.Receive(ctx); err != nil || e != nil { - t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err) - } -} diff --git a/balancer/rls/internal/config_test.go b/balancer/rls/internal/config_test.go index 41d330c604e..84da9fe426f 100644 --- a/balancer/rls/internal/config_test.go +++ b/balancer/rls/internal/config_test.go @@ -61,7 +61,7 @@ func testEqual(a, b *lbConfig) bool { childPolicyConfigEqual(a.childPolicyConfig, b.childPolicyConfig) } -func TestParseConfig(t *testing.T) { +func (s) TestParseConfig(t *testing.T) { childPolicyTargetFieldVal, _ := json.Marshal(dummyChildPolicyTarget) tests := []struct { desc string @@ -158,7 +158,7 @@ func TestParseConfig(t *testing.T) { } } -func TestParseConfigErrors(t *testing.T) { +func (s) TestParseConfigErrors(t *testing.T) { tests := []struct { desc string input []byte diff --git a/balancer/rls/internal/control_channel.go b/balancer/rls/internal/control_channel.go new file mode 100644 index 00000000000..dc8446313e7 --- /dev/null +++ b/balancer/rls/internal/control_channel.go @@ -0,0 +1,206 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package rls + +import ( + "context" + "fmt" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/rls/internal/adaptive" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal" + internalgrpclog "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/pretty" + rlsgrpc "google.golang.org/grpc/internal/proto/grpc_lookup_v1" + rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" +) + +var newAdaptiveThrottler = func() adaptiveThrottler { return adaptive.New() } + +type adaptiveThrottler interface { + ShouldThrottle() bool + RegisterBackendResponse(throttled bool) +} + +// controlChannel is a wrapper around the gRPC channel to the RLS server +// specified in the service config. +type controlChannel struct { + // rpcTimeout specifies the timeout for the RouteLookup RPC call. The LB + // policy receives this value in its service config. + rpcTimeout time.Duration + // backToReadyCh is the channel on which an update is pushed when the + // connectivity state changes from READY --> TRANSIENT_FAILURE --> READY. + backToReadyCh chan struct{} + // throttler in an adaptive throttling implementation used to avoid + // hammering the RLS service while it is overloaded or down. + throttler adaptiveThrottler + + cc *grpc.ClientConn + client rlsgrpc.RouteLookupServiceClient + logger *internalgrpclog.PrefixLogger +} + +func newControlChannel(rlsServerName string, rpcTimeout time.Duration, bOpts balancer.BuildOptions, backToReadyCh chan struct{}) (*controlChannel, error) { + ctrlCh := &controlChannel{ + rpcTimeout: rpcTimeout, + backToReadyCh: backToReadyCh, + throttler: newAdaptiveThrottler(), + } + ctrlCh.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[rls-control-channel %p] ", ctrlCh)) + + dopts, err := ctrlCh.dialOpts(bOpts) + if err != nil { + return nil, err + } + ctrlCh.cc, err = grpc.Dial(rlsServerName, dopts...) + if err != nil { + return nil, err + } + ctrlCh.client = rlsgrpc.NewRouteLookupServiceClient(ctrlCh.cc) + ctrlCh.logger.Infof("Control channel created to RLS server at: %v", rlsServerName) + + go ctrlCh.monitorConnectivityState() + return ctrlCh, nil +} + +// dialOpts constructs the dial options for the control plane channel. +func (cc *controlChannel) dialOpts(bOpts balancer.BuildOptions) ([]grpc.DialOption, error) { + // The control plane channel will use the same authority as the parent + // channel for server authorization. This ensures that the identity of the + // RLS server and the identity of the backends is the same, so if the RLS + // config is injected by an attacker, it cannot cause leakage of private + // information contained in headers set by the application. + dopts := []grpc.DialOption{grpc.WithAuthority(bOpts.Authority)} + if bOpts.Dialer != nil { + dopts = append(dopts, grpc.WithContextDialer(bOpts.Dialer)) + } + + // The control channel will use the channel credentials from the parent + // channel, including any call creds associated with the channel creds. + var credsOpt grpc.DialOption + switch { + case bOpts.DialCreds != nil: + credsOpt = grpc.WithTransportCredentials(bOpts.DialCreds.Clone()) + case bOpts.CredsBundle != nil: + // The "fallback" mode in google default credentials (which is the only + // type of credentials we expect to be used with RLS) uses TLS/ALTS + // creds for transport and uses the same call creds as that on the + // parent bundle. + bundle, err := bOpts.CredsBundle.NewWithMode(internal.CredsBundleModeFallback) + if err != nil { + return nil, err + } + credsOpt = grpc.WithCredentialsBundle(bundle) + default: + cc.logger.Warningf("no credentials available, using Insecure") + credsOpt = grpc.WithInsecure() + } + return append(dopts, credsOpt), nil +} + +func (cc *controlChannel) monitorConnectivityState() { + cc.logger.Infof("Starting connectivity state monitoring goroutine") + // Since we use two mechanisms to deal with RLS server being down: + // - adaptive throttling for the channel as a whole + // - exponential backoff on a per-request basis + // we need a way to avoid double-penalizing requests by counting failures + // toward both mechanisms when the RLS server is unreachable. + // + // To accomplish this, we monitor the state of the control plane channel. If + // the state has been TRANSIENT_FAILURE since the last time it was in state + // READY, and it then transitions into state READY, we push on a channel + // which is being read by the LB policy. + // + // The LB the policy will iterate through the cache to reset the backoff + // timeouts in all cache entries. Specifically, this means that it will + // reset the backoff state and cancel the pending backoff timer. Note that + // when cancelling the backoff timer, just like when the backoff timer fires + // normally, a new picker is returned to the channel, to force it to + // re-process any wait-for-ready RPCs that may still be queued if we failed + // them while we were in backoff. However, we should optimize this case by + // returning only one new picker, regardless of how many backoff timers are + // cancelled. + + // Using the background context is fine here since we check for the ClientConn + // entering SHUTDOWN and return early in that case. + ctx := context.Background() + + first := true + for { + // Wait for the control channel to become READY. + for s := cc.cc.GetState(); s != connectivity.Ready; s = cc.cc.GetState() { + if s == connectivity.Shutdown { + return + } + cc.cc.WaitForStateChange(ctx, s) + } + cc.logger.Infof("Connectivity state is READY") + + if !first { + cc.logger.Infof("Control channel back to READY") + cc.backToReadyCh <- struct{}{} + } + first = false + + // Wait for the control channel to move out of READY. + cc.cc.WaitForStateChange(ctx, connectivity.Ready) + if cc.cc.GetState() == connectivity.Shutdown { + return + } + cc.logger.Infof("Connectivity state is %s", cc.cc.GetState()) + } +} + +func (cc *controlChannel) close() { + cc.logger.Infof("Closing control channel") + cc.cc.Close() +} + +type lookupCallback func(targets []string, headerData string, err error) + +// lookup starts a RouteLookup RPC in a separate goroutine and returns the +// results (and error, if any) in the provided callback. +// +// The returned boolean indicates whether the request was throttled by the +// client-side adaptive throttling algorithm in which case the provided callback +// will not be invoked. +func (cc *controlChannel) lookup(reqKeys map[string]string, reason rlspb.RouteLookupRequest_Reason, staleHeaders string, cb lookupCallback) (throttled bool) { + if cc.throttler.ShouldThrottle() { + cc.logger.Infof("RLS request throttled by client-side adaptive throttling") + return true + } + go func() { + req := &rlspb.RouteLookupRequest{ + TargetType: "grpc", + KeyMap: reqKeys, + Reason: reason, + StaleHeaderData: staleHeaders, + } + cc.logger.Infof("Sending RLS request %+v", pretty.ToJSON(req)) + + ctx, cancel := context.WithTimeout(context.Background(), cc.rpcTimeout) + defer cancel() + resp, err := cc.client.RouteLookup(ctx, req) + cb(resp.GetTargets(), resp.GetHeaderData(), err) + }() + return false +} diff --git a/balancer/rls/internal/control_channel_test.go b/balancer/rls/internal/control_channel_test.go new file mode 100644 index 00000000000..953f6531428 --- /dev/null +++ b/balancer/rls/internal/control_channel_test.go @@ -0,0 +1,469 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package rls + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io/ioutil" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/rls/internal/test/e2e" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" + rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/grpc/testdata" + "google.golang.org/protobuf/proto" +) + +// TestControlChannelThrottled tests the case where the adaptive throttler +// indicates that the control channel needs to be throttled. +func (s) TestControlChannelThrottled(t *testing.T) { + // Start an RLS server and set the throttler to always throttle requests. + rlsServer, rlsReqCh := setupFakeRLSServer(t, nil) + overrideAdaptiveThrottler(t, alwaysThrottlingThrottler()) + + // Create a control channel to the fake RLS server. + ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, balancer.BuildOptions{}, nil) + if err != nil { + t.Fatalf("Failed to create control channel to RLS server: %v", err) + } + defer ctrlCh.close() + + // Perform the lookup and expect the attempt to be throttled. + ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, nil) + + select { + case <-rlsReqCh: + t.Fatal("RouteLookup RPC invoked when control channel is throtlled") + case <-time.After(defaultTestShortTimeout): + } +} + +// TestLookupFailure tests the case where the RLS server responds with an error. +func (s) TestLookupFailure(t *testing.T) { + // Start an RLS server and set the throttler to never throttle requests. + rlsServer, _ := setupFakeRLSServer(t, nil) + overrideAdaptiveThrottler(t, neverThrottlingThrottler()) + + // Setup the RLS server to respond with errors. + rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *e2e.RouteLookupResponse { + return &e2e.RouteLookupResponse{Err: errors.New("rls failure")} + }) + + // Create a control channel to the fake RLS server. + ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, balancer.BuildOptions{}, nil) + if err != nil { + t.Fatalf("Failed to create control channel to RLS server: %v", err) + } + defer ctrlCh.close() + + // Perform the lookup and expect the callback to be invoked with an error. + errCh := make(chan error, 1) + ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) { + if err == nil { + errCh <- errors.New("rlsClient.lookup() succeeded, should have failed") + return + } + errCh <- nil + }) + + select { + case <-time.After(defaultTestTimeout): + t.Fatal("timeout when waiting for lookup callback to be invoked") + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } +} + +// TestLookupDeadlineExceeded tests the case where the RLS server does not +// respond within the configured rpc timeout. +func (s) TestLookupDeadlineExceeded(t *testing.T) { + // A unary interceptor which sleeps for long enough to cause lookup RPCs to + // exceed their deadline. + rlsReqCh := make(chan struct{}, 1) + interceptor := func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + rlsReqCh <- struct{}{} + time.Sleep(2 * defaultTestShortTimeout) + return handler(ctx, req) + } + + // Start an RLS server and set the throttler to never throttle. + rlsServer, _ := setupFakeRLSServer(t, nil, grpc.UnaryInterceptor(interceptor)) + overrideAdaptiveThrottler(t, neverThrottlingThrottler()) + + // Create a control channel with a small deadline. + ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestShortTimeout, balancer.BuildOptions{}, nil) + if err != nil { + t.Fatalf("Failed to create control channel to RLS server: %v", err) + } + defer ctrlCh.close() + + // Perform the lookup and expect the callback to be invoked with an error. + errCh := make(chan error) + ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) { + if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded { + errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded) + return + } + errCh <- nil + }) + + select { + case <-time.After(defaultTestTimeout): + t.Fatal("timeout when waiting for lookup callback to be invoked") + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } +} + +// testCredsBundle wraps a test call creds and real transport creds. +type testCredsBundle struct { + transportCreds credentials.TransportCredentials + callCreds credentials.PerRPCCredentials +} + +func (f *testCredsBundle) TransportCredentials() credentials.TransportCredentials { + return f.transportCreds +} + +func (f *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials { + return f.callCreds +} + +func (f *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) { + if mode != internal.CredsBundleModeFallback { + return nil, fmt.Errorf("unsupported mode: %v", mode) + } + return &testCredsBundle{ + transportCreds: f.transportCreds, + callCreds: f.callCreds, + }, nil +} + +var ( + // Call creds sent by the testPerRPCCredentials on the client, and verified + // by an interceptor on the server. + perRPCCredsData = map[string]string{ + "test-key": "test-value", + "test-key-bin": string([]byte{1, 2, 3}), + } +) + +type testPerRPCCredentials struct { + callCreds map[string]string +} + +func (f *testPerRPCCredentials) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { + return f.callCreds, nil +} + +func (f *testPerRPCCredentials) RequireTransportSecurity() bool { + return true +} + +// Unary server interceptor which validates if the RPC contains call credentials +// which match `perRPCCredsData +func callCredsValidatingServerInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.PermissionDenied, "didn't find metadata in context") + } + for k, want := range perRPCCredsData { + got, ok := md[k] + if !ok { + return ctx, status.Errorf(codes.PermissionDenied, "didn't find call creds key %v in context", k) + } + if got[0] != want { + return ctx, status.Errorf(codes.PermissionDenied, "for key %v, got value %v, want %v", k, got, want) + } + } + return handler(ctx, req) +} + +// makeTLSCreds is a test helper which creates a TLS based transport credentials +// from files specified in the arguments. +func makeTLSCreds(t *testing.T, certPath, keyPath, rootsPath string) credentials.TransportCredentials { + cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath)) + if err != nil { + t.Fatalf("tls.LoadX509KeyPair(%q, %q) failed: %v", certPath, keyPath, err) + } + b, err := ioutil.ReadFile(testdata.Path(rootsPath)) + if err != nil { + t.Fatalf("ioutil.ReadFile(%q) failed: %v", rootsPath, err) + } + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM(b) { + t.Fatal("failed to append certificates") + } + return credentials.NewTLS(&tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: roots, + }) +} + +const ( + wantHeaderData = "headerData" + staleHeaderData = "staleHeaderData" +) + +var ( + keyMap = map[string]string{ + "k1": "v1", + "k2": "v2", + } + wantTargets = []string{"us_east_1.firestore.googleapis.com"} + lookupRequest = &rlspb.RouteLookupRequest{ + TargetType: "grpc", + KeyMap: keyMap, + Reason: rlspb.RouteLookupRequest_REASON_MISS, + StaleHeaderData: staleHeaderData, + } + lookupResponse = &e2e.RouteLookupResponse{ + Resp: &rlspb.RouteLookupResponse{ + Targets: wantTargets, + HeaderData: wantHeaderData, + }, + } +) + +func testControlChannelCredsSuccess(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions) { + // Start an RLS server and set the throttler to never throttle requests. + rlsServer, _ := setupFakeRLSServer(t, nil, sopts...) + overrideAdaptiveThrottler(t, neverThrottlingThrottler()) + + // Setup the RLS server to respond with a valid response. + rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *e2e.RouteLookupResponse { + return lookupResponse + }) + + // Verify that the request received by the RLS matches the expected one. + rlsServer.SetRequestCallback(func(got *rlspb.RouteLookupRequest) { + if diff := cmp.Diff(lookupRequest, got, cmp.Comparer(proto.Equal)); diff != "" { + t.Errorf("RouteLookupRequest diff (-want, +got):\n%s", diff) + } + }) + + // Create a control channel to the fake server. + ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, bopts, nil) + if err != nil { + t.Fatalf("Failed to create control channel to RLS server: %v", err) + } + defer ctrlCh.close() + + // Perform the lookup and expect a successful callback invocation. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + errCh := make(chan error, 1) + ctrlCh.lookup(keyMap, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(targets []string, headerData string, err error) { + if err != nil { + errCh <- fmt.Errorf("rlsClient.lookup() failed with err: %v", err) + return + } + if !cmp.Equal(targets, wantTargets) || headerData != wantHeaderData { + errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, headerData, wantTargets, wantHeaderData) + return + } + errCh <- nil + }) + + select { + case <-ctx.Done(): + t.Fatal("timeout when waiting for lookup callback to be invoked") + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } +} + +// TestControlChannelCredsSuccess tests creation of the control channel with +// different credentials, which are expected to succeed. +func (s) TestControlChannelCredsSuccess(t *testing.T) { + serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") + clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem") + + tests := []struct { + name string + sopts []grpc.ServerOption + bopts balancer.BuildOptions + }{ + { + name: "insecure", + sopts: nil, + bopts: balancer.BuildOptions{}, + }, + { + name: "transport creds only", + sopts: []grpc.ServerOption{grpc.Creds(serverCreds)}, + bopts: balancer.BuildOptions{ + DialCreds: clientCreds, + Authority: "x.test.example.com", + }, + }, + { + name: "creds bundle", + sopts: []grpc.ServerOption{ + grpc.Creds(serverCreds), + grpc.UnaryInterceptor(callCredsValidatingServerInterceptor), + }, + bopts: balancer.BuildOptions{ + CredsBundle: &testCredsBundle{ + transportCreds: clientCreds, + callCreds: &testPerRPCCredentials{callCreds: perRPCCredsData}, + }, + Authority: "x.test.example.com", + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testControlChannelCredsSuccess(t, test.sopts, test.bopts) + }) + } +} + +func testControlChannelCredsFailure(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions, wantCode codes.Code, wantErr string) { + // StartFakeRouteLookupServer a fake server. + // + // Start an RLS server and set the throttler to never throttle requests. The + // creds failures happen before the RPC handler on the server is invoked. + // So, there is need to setup the request and responses on the fake server. + rlsServer, _ := setupFakeRLSServer(t, nil, sopts...) + overrideAdaptiveThrottler(t, neverThrottlingThrottler()) + + // Create the control channel to the fake server. + ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, bopts, nil) + if err != nil { + t.Fatalf("Failed to create control channel to RLS server: %v", err) + } + defer ctrlCh.close() + + // Perform the lookup and expect the callback to be invoked with an error. + errCh := make(chan error) + ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) { + if st, ok := status.FromError(err); !ok || st.Code() != wantCode || !strings.Contains(st.String(), wantErr) { + errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, wantCode: %v, wantErr: %s", err, wantCode, wantErr) + return + } + errCh <- nil + }) + + select { + case <-time.After(defaultTestTimeout): + t.Fatal("timeout when waiting for lookup callback to be invoked") + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } +} + +// TestControlChannelCredsFailure tests creation of the control channel with +// different credentials, which are expected to fail. +func (s) TestControlChannelCredsFailure(t *testing.T) { + serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") + clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem") + + tests := []struct { + name string + sopts []grpc.ServerOption + bopts balancer.BuildOptions + wantCode codes.Code + wantErr string + }{ + { + name: "transport creds authority mismatch", + sopts: []grpc.ServerOption{grpc.Creds(serverCreds)}, + bopts: balancer.BuildOptions{ + DialCreds: clientCreds, + Authority: "authority-mismatch", + }, + wantCode: codes.Unavailable, + wantErr: "transport: authentication handshake failed: x509: certificate is valid for *.test.example.com, not authority-mismatch", + }, + { + name: "transport creds handshake failure", + sopts: nil, // server expects insecure connection + bopts: balancer.BuildOptions{ + DialCreds: clientCreds, + Authority: "x.test.example.com", + }, + wantCode: codes.Unavailable, + wantErr: "transport: authentication handshake failed: tls: first record does not look like a TLS handshake", + }, + { + name: "call creds mismatch", + sopts: []grpc.ServerOption{ + grpc.Creds(serverCreds), + grpc.UnaryInterceptor(callCredsValidatingServerInterceptor), // server expects call creds + }, + bopts: balancer.BuildOptions{ + CredsBundle: &testCredsBundle{ + transportCreds: clientCreds, + callCreds: &testPerRPCCredentials{}, // sends no call creds + }, + Authority: "x.test.example.com", + }, + wantCode: codes.PermissionDenied, + wantErr: "didn't find call creds", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testControlChannelCredsFailure(t, test.sopts, test.bopts, test.wantCode, test.wantErr) + }) + } +} + +type unsupportedCredsBundle struct { + credentials.Bundle +} + +func (*unsupportedCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) { + return nil, fmt.Errorf("unsupported mode: %v", mode) +} + +// TestNewControlChannelUnsupportedCredsBundle tests the case where the control +// channel is configured with a bundle which does not support the mode we use. +func (s) TestNewControlChannelUnsupportedCredsBundle(t *testing.T) { + rlsServer, _ := setupFakeRLSServer(t, nil) + + // Create the control channel to the fake server. + ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, balancer.BuildOptions{CredsBundle: &unsupportedCredsBundle{}}, nil) + if err == nil { + ctrlCh.close() + t.Fatal("newControlChannel succeeded when expected to fail") + } +} diff --git a/balancer/rls/internal/helpers_test.go b/balancer/rls/internal/helpers_test.go new file mode 100644 index 00000000000..bb5478a3fa5 --- /dev/null +++ b/balancer/rls/internal/helpers_test.go @@ -0,0 +1,327 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package rls + +import ( + "context" + "net" + "strings" + "sync" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer/rls/internal/test/e2e" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/balancergroup" + "google.golang.org/grpc/internal/grpctest" + rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/serviceconfig" + "google.golang.org/grpc/status" + testgrpc "google.golang.org/grpc/test/grpc_testing" + testpb "google.golang.org/grpc/test/grpc_testing" + "google.golang.org/protobuf/types/known/durationpb" +) + +// TODO(easwars): Remove this once all RLS code is merged. +//lint:file-ignore U1000 Ignore all unused code, not all code is merged yet. + +const ( + defaultTestTimeout = 5 * time.Second + defaultTestShortTimeout = 100 * time.Millisecond +) + +func init() { + balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond +} + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// connWrapper wraps a net.Conn and pushes on a channel when closed. +type connWrapper struct { + net.Conn + closeCh *testutils.Channel +} + +func (cw *connWrapper) Close() error { + err := cw.Conn.Close() + cw.closeCh.Replace(nil) + return err +} + +// listenerWrapper wraps a net.Listener and the returned net.Conn. +// +// It pushes on a channel whenever it accepts a new connection. +type listenerWrapper struct { + net.Listener + newConnCh *testutils.Channel +} + +func (l *listenerWrapper) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + closeCh := testutils.NewChannel() + conn := &connWrapper{Conn: c, closeCh: closeCh} + l.newConnCh.Send(conn) + return conn, nil +} + +func newListenerWrapper(t *testing.T, lis net.Listener) *listenerWrapper { + if lis == nil { + var err error + lis, err = testutils.LocalTCPListener() + if err != nil { + t.Fatal(err) + } + } + + return &listenerWrapper{ + Listener: lis, + newConnCh: testutils.NewChannel(), + } +} + +// fakeBackoffStrategy is a fake implementation of the backoff.Strategy +// interface, for tests to inject the backoff duration. +type fakeBackoffStrategy struct { + backoff time.Duration +} + +func (f *fakeBackoffStrategy) Backoff(retries int) time.Duration { + return f.backoff +} + +// fakeThrottler is a fake implementation of the adaptiveThrottler interface. +type fakeThrottler struct { + throttleFunc func() bool +} + +func (f *fakeThrottler) ShouldThrottle() bool { return f.throttleFunc() } +func (f *fakeThrottler) RegisterBackendResponse(bool) {} + +// alwaysThrottlingThrottler returns a fake throttler which always throttles. +func alwaysThrottlingThrottler() *fakeThrottler { + return &fakeThrottler{throttleFunc: func() bool { return true }} +} + +// neverThrottlingThrottler returns a fake throttler which never throttles. +func neverThrottlingThrottler() *fakeThrottler { + return &fakeThrottler{throttleFunc: func() bool { return false }} +} + +// oneTimeAllowingThrottler returns a fake throttler which does not throttle the +// first request, but throttles everything that comes after. This is useful for +// tests which need to set up a valid cache entry before testing other cases. +func oneTimeAllowingThrottler() *fakeThrottler { + var once sync.Once + return &fakeThrottler{ + throttleFunc: func() bool { + throttle := true + once.Do(func() { throttle = false }) + return throttle + }, + } +} + +func overrideAdaptiveThrottler(t *testing.T, f *fakeThrottler) { + origAdaptiveThrottler := newAdaptiveThrottler + newAdaptiveThrottler = func() adaptiveThrottler { return f } + t.Cleanup(func() { newAdaptiveThrottler = origAdaptiveThrottler }) +} + +// setupFakeRLSServer starts and returns a fake RouteLookupService server +// listening on the given listener or on a random local port. Also returns a +// channel for tests to get notified whenever the RouteLookup RPC is invoked on +// the fake server. +// +// This function sets up the fake server to respond with an empty response for +// the RouteLookup RPCs. Tests can override this by calling the +// SetResponseCallback() method on the returned fake server. +func setupFakeRLSServer(t *testing.T, lis net.Listener, opts ...grpc.ServerOption) (*e2e.FakeRouteLookupServer, chan struct{}) { + s, cancel := e2e.StartFakeRouteLookupServer(t, lis, opts...) + t.Logf("Started fake RLS server at %q", s.Address) + + ch := make(chan struct{}, 1) + s.SetRequestCallback(func(request *rlspb.RouteLookupRequest) { + select { + case ch <- struct{}{}: + default: + } + }) + t.Cleanup(cancel) + return s, ch +} + +// buildBasicRLSConfig constructs a basic service config for the RLS LB policy +// which header matching rules. This expects the passed child policy name to +// have been registered by the caller. +func buildBasicRLSConfig(childPolicyName, rlsServerAddress string) *e2e.RLSConfig { + return &e2e.RLSConfig{ + RouteLookupConfig: &rlspb.RouteLookupConfig{ + GrpcKeybuilders: []*rlspb.GrpcKeyBuilder{ + { + Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "grpc.testing.TestService"}}, + Headers: []*rlspb.NameMatcher{ + {Key: "k1", Names: []string{"n1"}}, + {Key: "k2", Names: []string{"n2"}}, + }, + }, + }, + LookupService: rlsServerAddress, + LookupServiceTimeout: durationpb.New(defaultTestTimeout), + CacheSizeBytes: 1024, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: childPolicyName}, + ChildPolicyConfigTargetFieldName: e2e.RLSChildPolicyTargetNameField, + } +} + +// buildBasicRLSConfigWithChildPolicy constructs a very basic service config for +// the RLS LB policy. It also registers a test LB policy which is capable of +// being a child of the RLS LB policy. +func buildBasicRLSConfigWithChildPolicy(t *testing.T, childPolicyName, rlsServerAddress string) *e2e.RLSConfig { + childPolicyName = "test-child-policy" + childPolicyName + e2e.RegisterRLSChildPolicy(childPolicyName, nil) + t.Logf("Registered child policy with name %q", childPolicyName) + + return &e2e.RLSConfig{ + RouteLookupConfig: &rlspb.RouteLookupConfig{ + GrpcKeybuilders: []*rlspb.GrpcKeyBuilder{{Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "grpc.testing.TestService"}}}}, + LookupService: rlsServerAddress, + LookupServiceTimeout: durationpb.New(defaultTestTimeout), + CacheSizeBytes: 1024, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: childPolicyName}, + ChildPolicyConfigTargetFieldName: e2e.RLSChildPolicyTargetNameField, + } +} + +// startBackend starts a backend implementing the TestService on a local port. +// It returns a channel for tests to get notified whenever an RPC is invoked on +// the backend. This allows tests to ensure that RPCs reach expected backends. +// Also returns the address of the backend. +func startBackend(t *testing.T, sopts ...grpc.ServerOption) (rpcCh chan struct{}, address string) { + t.Helper() + + rpcCh = make(chan struct{}, 1) + backend := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + select { + case rpcCh <- struct{}{}: + default: + } + return &testpb.Empty{}, nil + }, + } + if err := backend.StartServer(sopts...); err != nil { + t.Fatalf("Failed to start backend: %v", err) + } + t.Logf("Started TestService backend at: %q", backend.Address) + t.Cleanup(func() { backend.Stop() }) + return rpcCh, backend.Address +} + +// startManualResolverWithConfig registers and returns a manual resolver which +// pushes the RLS LB policy's service config on the channel. +func startManualResolverWithConfig(t *testing.T, rlsConfig *e2e.RLSConfig) *manual.Resolver { + t.Helper() + + scJSON, err := rlsConfig.ServiceConfigJSON() + if err != nil { + t.Fatal(err) + } + + sc := internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)(scJSON) + r := manual.NewBuilderWithScheme("rls-e2e") + r.InitialState(resolver.State{ServiceConfig: sc}) + t.Cleanup(r.Close) + return r +} + +// makeTestRPCAndExpectItToReachBackend is a test helper function which makes +// the EmptyCall RPC on the given ClientConn and verifies that it reaches a +// backend. The latter is accomplished by listening on the provided channel +// which gets pushed to whenever the backend in question gets an RPC. +func makeTestRPCAndExpectItToReachBackend(ctx context.Context, t *testing.T, cc *grpc.ClientConn, ch chan struct{}) { + t.Helper() + + client := testgrpc.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("TestService/EmptyCall() failed with error: %v", err) + } + select { + case <-ctx.Done(): + t.Fatalf("Timeout when waiting for backend to receive RPC") + case <-ch: + } +} + +// makeTestRPCAndVerifyError is a test helper function which makes the EmptyCall +// RPC on the given ClientConn and verifies that the RPC fails with the given +// status code and error. +func makeTestRPCAndVerifyError(ctx context.Context, t *testing.T, cc *grpc.ClientConn, wantCode codes.Code, wantErr error) { + t.Helper() + + client := testgrpc.NewTestServiceClient(cc) + _, err := client.EmptyCall(ctx, &testpb.Empty{}) + if err == nil { + t.Fatal("TestService/EmptyCall() succeeded when expected to fail") + } + if code := status.Code(err); code != wantCode { + t.Fatalf("TestService/EmptyCall() returned code: %v, want: %v", code, wantCode) + } + if wantErr != nil && !strings.Contains(err.Error(), wantErr.Error()) { + t.Fatalf("TestService/EmptyCall() returned err: %v, want: %v", err, wantErr) + } +} + +// verifyRLSRequest is a test helper which listens on a channel to see if an RLS +// request was received by the fake RLS server. Based on whether the test +// expects a request to be sent out or not, it uses a different timeout. +func verifyRLSRequest(t *testing.T, ch chan struct{}, wantRequest bool) { + t.Helper() + + if wantRequest { + select { + case <-time.After(defaultTestTimeout): + t.Fatalf("Timeout when waiting for an RLS request to be sent out") + case <-ch: + } + } else { + select { + case <-time.After(defaultTestShortTimeout): + case <-ch: + t.Fatalf("RLS request sent out when not expecting one") + } + } +} diff --git a/balancer/rls/internal/picker.go b/balancer/rls/internal/picker.go deleted file mode 100644 index 37e58759e25..00000000000 --- a/balancer/rls/internal/picker.go +++ /dev/null @@ -1,147 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package rls - -import ( - "errors" - "time" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/rls/internal/cache" - "google.golang.org/grpc/balancer/rls/internal/keys" - "google.golang.org/grpc/metadata" -) - -var errRLSThrottled = errors.New("RLS call throttled at client side") - -// RLS rlsPicker selects the subConn to be used for a particular RPC. It does -// not manage subConns directly and usually deletegates to pickers provided by -// child policies. -// -// The RLS LB policy creates a new rlsPicker object whenever its ServiceConfig -// is updated and provides a bunch of hooks for the rlsPicker to get the latest -// state that it can used to make its decision. -type rlsPicker struct { - // The keyBuilder map used to generate RLS keys for the RPC. This is built - // by the LB policy based on the received ServiceConfig. - kbm keys.BuilderMap - // Endpoint from the user's original dial target. Used to set the `host_key` - // field in `extra_keys`. - origEndpoint string - - // The following hooks are setup by the LB policy to enable the rlsPicker to - // access state stored in the policy. This approach has the following - // advantages: - // 1. The rlsPicker is loosely coupled with the LB policy in the sense that - // updates happening on the LB policy like the receipt of an RLS - // response, or an update to the default rlsPicker etc are not explicitly - // pushed to the rlsPicker, but are readily available to the rlsPicker - // when it invokes these hooks. And the LB policy takes care of - // synchronizing access to these shared state. - // 2. It makes unit testing the rlsPicker easy since any number of these - // hooks could be overridden. - - // readCache is used to read from the data cache and the pending request - // map in an atomic fashion. The first return parameter is the entry in the - // data cache, and the second indicates whether an entry for the same key - // is present in the pending cache. - readCache func(cache.Key) (*cache.Entry, bool) - // shouldThrottle decides if the current RPC should be throttled at the - // client side. It uses an adaptive throttling algorithm. - shouldThrottle func() bool - // startRLS kicks off an RLS request in the background for the provided RPC - // path and keyMap. An entry in the pending request map is created before - // sending out the request and an entry in the data cache is created or - // updated upon receipt of a response. See implementation in the LB policy - // for details. - startRLS func(string, keys.KeyMap) - // defaultPick enables the rlsPicker to delegate the pick decision to the - // rlsPicker returned by the child LB policy pointing to the default target - // specified in the service config. - defaultPick func(balancer.PickInfo) (balancer.PickResult, error) -} - -// Pick makes the routing decision for every outbound RPC. -func (p *rlsPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - // Build the request's keys using the key builders from LB config. - md, _ := metadata.FromOutgoingContext(info.Ctx) - km := p.kbm.RLSKey(md, p.origEndpoint, info.FullMethodName) - - // We use the LB policy hook to read the data cache and the pending request - // map (whether or not an entry exists) for the RPC path and the generated - // RLS keys. We will end up kicking off an RLS request only if there is no - // pending request for the current RPC path and keys, and either we didn't - // find an entry in the data cache or the entry was stale and it wasn't in - // backoff. - startRequest := false - now := time.Now() - entry, pending := p.readCache(cache.Key{Path: info.FullMethodName, KeyMap: km.Str}) - if entry == nil { - startRequest = true - } else { - entry.Mu.Lock() - defer entry.Mu.Unlock() - if entry.StaleTime.Before(now) && entry.BackoffTime.Before(now) { - // This is the proactive cache refresh. - startRequest = true - } - } - - if startRequest && !pending { - if p.shouldThrottle() { - // The entry doesn't exist or has expired and the new RLS request - // has been throttled. Treat it as an error and delegate to default - // pick, if one exists, or fail the pick. - if entry == nil || entry.ExpiryTime.Before(now) { - if p.defaultPick != nil { - return p.defaultPick(info) - } - return balancer.PickResult{}, errRLSThrottled - } - // The proactive refresh has been throttled. Nothing to worry, just - // keep using the existing entry. - } else { - p.startRLS(info.FullMethodName, km) - } - } - - if entry != nil { - if entry.ExpiryTime.After(now) { - // This is the jolly good case where we have found a valid entry in - // the data cache. We delegate to the LB policy associated with - // this cache entry. - return entry.ChildPicker.Pick(info) - } else if entry.BackoffTime.After(now) { - // The entry has expired, but is in backoff. We delegate to the - // default pick, if one exists, or return the error from the last - // failed RLS request for this entry. - if p.defaultPick != nil { - return p.defaultPick(info) - } - return balancer.PickResult{}, entry.CallStatus - } - } - - // We get here only in the following cases: - // * No data cache entry or expired entry, RLS request sent out - // * No valid data cache entry and Pending cache entry exists - // We need to queue to pick which will be handled once the RLS response is - // received. - return balancer.PickResult{}, balancer.ErrNoSubConnAvailable -} diff --git a/balancer/rls/internal/picker_test.go b/balancer/rls/internal/picker_test.go deleted file mode 100644 index f115be98a39..00000000000 --- a/balancer/rls/internal/picker_test.go +++ /dev/null @@ -1,615 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package rls - -import ( - "context" - "errors" - "fmt" - "math" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/rls/internal/cache" - "google.golang.org/grpc/balancer/rls/internal/keys" - "google.golang.org/grpc/internal/grpcrand" - rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" - "google.golang.org/grpc/internal/testutils" - "google.golang.org/grpc/metadata" -) - -const defaultTestMaxAge = 5 * time.Second - -// initKeyBuilderMap initializes a keyBuilderMap of the form: -// { -// "gFoo": "k1=n1", -// "gBar/method1": "k2=n21,n22" -// "gFoobar": "k3=n3", -// } -func initKeyBuilderMap() (keys.BuilderMap, error) { - kb1 := &rlspb.GrpcKeyBuilder{ - Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "gFoo"}}, - Headers: []*rlspb.NameMatcher{{Key: "k1", Names: []string{"n1"}}}, - } - kb2 := &rlspb.GrpcKeyBuilder{ - Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "gBar", Method: "method1"}}, - Headers: []*rlspb.NameMatcher{{Key: "k2", Names: []string{"n21", "n22"}}}, - } - kb3 := &rlspb.GrpcKeyBuilder{ - Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "gFoobar"}}, - Headers: []*rlspb.NameMatcher{{Key: "k3", Names: []string{"n3"}}}, - } - return keys.MakeBuilderMap(&rlspb.RouteLookupConfig{ - GrpcKeybuilders: []*rlspb.GrpcKeyBuilder{kb1, kb2, kb3}, - }) -} - -// fakeSubConn embeds the balancer.SubConn interface and contains an id which -// helps verify that the expected subConn was returned by the rlsPicker. -type fakeSubConn struct { - balancer.SubConn - id int -} - -// fakePicker sends a PickResult with a fakeSubConn with the configured id. -type fakePicker struct { - id int -} - -func (p *fakePicker) Pick(_ balancer.PickInfo) (balancer.PickResult, error) { - return balancer.PickResult{SubConn: &fakeSubConn{id: p.id}}, nil -} - -// newFakePicker returns a fakePicker configured with a random ID. The subConns -// returned by this picker are of type fakefakeSubConn, and contain the same -// random ID, which tests can use to verify. -func newFakePicker() *fakePicker { - return &fakePicker{id: grpcrand.Intn(math.MaxInt32)} -} - -func verifySubConn(sc balancer.SubConn, wantID int) error { - fsc, ok := sc.(*fakeSubConn) - if !ok { - return fmt.Errorf("Pick() returned a SubConn of type %T, want %T", sc, &fakeSubConn{}) - } - if fsc.id != wantID { - return fmt.Errorf("Pick() returned SubConn %d, want %d", fsc.id, wantID) - } - return nil -} - -// TestPickKeyBuilder verifies the different possible scenarios for forming an -// RLS key for an incoming RPC. -func TestPickKeyBuilder(t *testing.T) { - kbm, err := initKeyBuilderMap() - if err != nil { - t.Fatalf("Failed to create keyBuilderMap: %v", err) - } - - tests := []struct { - desc string - rpcPath string - md metadata.MD - wantKey cache.Key - }{ - { - desc: "non existent service in keyBuilder map", - rpcPath: "/gNonExistentService/method", - md: metadata.New(map[string]string{"n1": "v1", "n3": "v3"}), - wantKey: cache.Key{Path: "/gNonExistentService/method", KeyMap: ""}, - }, - { - desc: "no metadata in incoming context", - rpcPath: "/gFoo/method", - md: metadata.MD{}, - wantKey: cache.Key{Path: "/gFoo/method", KeyMap: ""}, - }, - { - desc: "keyBuilderMatch", - rpcPath: "/gFoo/method", - md: metadata.New(map[string]string{"n1": "v1", "n3": "v3"}), - wantKey: cache.Key{Path: "/gFoo/method", KeyMap: "k1=v1"}, - }, - } - - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - randID := grpcrand.Intn(math.MaxInt32) - p := rlsPicker{ - kbm: kbm, - readCache: func(key cache.Key) (*cache.Entry, bool) { - if !cmp.Equal(key, test.wantKey) { - t.Fatalf("rlsPicker using cacheKey %v, want %v", key, test.wantKey) - } - - now := time.Now() - return &cache.Entry{ - ExpiryTime: now.Add(defaultTestMaxAge), - StaleTime: now.Add(defaultTestMaxAge), - // Cache entry is configured with a child policy whose - // rlsPicker always returns an empty PickResult and nil - // error. - ChildPicker: &fakePicker{id: randID}, - }, false - }, - // The other hooks are not set here because they are not expected to be - // invoked for these cases and if they get invoked, they will panic. - } - - gotResult, err := p.Pick(balancer.PickInfo{ - FullMethodName: test.rpcPath, - Ctx: metadata.NewOutgoingContext(context.Background(), test.md), - }) - if err != nil { - t.Fatalf("Pick() failed with error: %v", err) - } - sc, ok := gotResult.SubConn.(*fakeSubConn) - if !ok { - t.Fatalf("Pick() returned a SubConn of type %T, want %T", gotResult.SubConn, &fakeSubConn{}) - } - if sc.id != randID { - t.Fatalf("Pick() returned SubConn %d, want %d", sc.id, randID) - } - }) - } -} - -// TestPick_DataCacheMiss_PendingCacheMiss verifies different Pick scenarios -// where the entry is neither found in the data cache nor in the pending cache. -func TestPick_DataCacheMiss_PendingCacheMiss(t *testing.T) { - const ( - rpcPath = "/gFoo/method" - wantKeyMapStr = "k1=v1" - ) - kbm, err := initKeyBuilderMap() - if err != nil { - t.Fatalf("Failed to create keyBuilderMap: %v", err) - } - md := metadata.New(map[string]string{"n1": "v1", "n3": "v3"}) - wantKey := cache.Key{Path: rpcPath, KeyMap: wantKeyMapStr} - - tests := []struct { - desc string - // Whether or not a default target is configured. - defaultPickExists bool - // Whether or not the RLS request should be throttled. - throttle bool - // Whether or not the test is expected to make a new RLS request. - wantRLSRequest bool - // Expected error returned by the rlsPicker under test. - wantErr error - }{ - { - desc: "rls request throttled with default pick", - defaultPickExists: true, - throttle: true, - }, - { - desc: "rls request throttled without default pick", - throttle: true, - wantErr: errRLSThrottled, - }, - { - desc: "rls request not throttled", - wantRLSRequest: true, - wantErr: balancer.ErrNoSubConnAvailable, - }, - } - - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - rlsCh := testutils.NewChannel() - defaultPicker := newFakePicker() - - p := rlsPicker{ - kbm: kbm, - // Cache lookup fails, no pending entry. - readCache: func(key cache.Key) (*cache.Entry, bool) { - if !cmp.Equal(key, wantKey) { - t.Fatalf("cache lookup using cacheKey %v, want %v", key, wantKey) - } - return nil, false - }, - shouldThrottle: func() bool { return test.throttle }, - startRLS: func(path string, km keys.KeyMap) { - if !test.wantRLSRequest { - rlsCh.Send(errors.New("RLS request attempted when none was expected")) - return - } - if path != rpcPath { - rlsCh.Send(fmt.Errorf("RLS request initiated for rpcPath %s, want %s", path, rpcPath)) - return - } - if km.Str != wantKeyMapStr { - rlsCh.Send(fmt.Errorf("RLS request initiated with keys %v, want %v", km.Str, wantKeyMapStr)) - return - } - rlsCh.Send(nil) - }, - } - if test.defaultPickExists { - p.defaultPick = defaultPicker.Pick - } - - gotResult, err := p.Pick(balancer.PickInfo{ - FullMethodName: rpcPath, - Ctx: metadata.NewOutgoingContext(context.Background(), md), - }) - if err != test.wantErr { - t.Fatalf("Pick() returned error {%v}, want {%v}", err, test.wantErr) - } - // If the test specified that a new RLS request should be made, - // verify it. - if test.wantRLSRequest { - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if rlsErr, err := rlsCh.Receive(ctx); err != nil || rlsErr != nil { - t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err) - } - } - if test.wantErr != nil { - return - } - - // We get here only for cases where we expect the pick to be - // delegated to the default picker. - if err := verifySubConn(gotResult.SubConn, defaultPicker.id); err != nil { - t.Fatal(err) - } - }) - } -} - -// TestPick_DataCacheMiss_PendingCacheMiss verifies different Pick scenarios -// where the entry is not found in the data cache, but there is a entry in the -// pending cache. For all of these scenarios, no new RLS request will be sent. -func TestPick_DataCacheMiss_PendingCacheHit(t *testing.T) { - const ( - rpcPath = "/gFoo/method" - wantKeyMapStr = "k1=v1" - ) - kbm, err := initKeyBuilderMap() - if err != nil { - t.Fatalf("Failed to create keyBuilderMap: %v", err) - } - md := metadata.New(map[string]string{"n1": "v1", "n3": "v3"}) - wantKey := cache.Key{Path: rpcPath, KeyMap: wantKeyMapStr} - - tests := []struct { - desc string - defaultPickExists bool - }{ - { - desc: "default pick exists", - defaultPickExists: true, - }, - { - desc: "default pick does not exists", - }, - } - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - rlsCh := testutils.NewChannel() - p := rlsPicker{ - kbm: kbm, - // Cache lookup fails, pending entry exists. - readCache: func(key cache.Key) (*cache.Entry, bool) { - if !cmp.Equal(key, wantKey) { - t.Fatalf("cache lookup using cacheKey %v, want %v", key, wantKey) - } - return nil, true - }, - // Never throttle. We do not expect an RLS request to be sent out anyways. - shouldThrottle: func() bool { return false }, - startRLS: func(_ string, _ keys.KeyMap) { - rlsCh.Send(nil) - }, - } - if test.defaultPickExists { - p.defaultPick = func(info balancer.PickInfo) (balancer.PickResult, error) { - // We do not expect the default picker to be invoked at all. - // So, if we get here, the test will fail, because it - // expects the pick to be queued. - return balancer.PickResult{}, nil - } - } - - if _, err := p.Pick(balancer.PickInfo{ - FullMethodName: rpcPath, - Ctx: metadata.NewOutgoingContext(context.Background(), md), - }); err != balancer.ErrNoSubConnAvailable { - t.Fatalf("Pick() returned error {%v}, want {%v}", err, balancer.ErrNoSubConnAvailable) - } - - // Make sure that no RLS request was sent out. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := rlsCh.Receive(ctx); err != context.DeadlineExceeded { - t.Fatalf("RLS request sent out when pending entry exists") - } - }) - } -} - -// TestPick_DataCacheHit_PendingCacheMiss verifies different Pick scenarios -// where the entry is found in the data cache, and there is no entry in the -// pending cache. This includes cases where the entry in the data cache is -// stale, expired or in backoff. -func TestPick_DataCacheHit_PendingCacheMiss(t *testing.T) { - const ( - rpcPath = "/gFoo/method" - wantKeyMapStr = "k1=v1" - ) - kbm, err := initKeyBuilderMap() - if err != nil { - t.Fatalf("Failed to create keyBuilderMap: %v", err) - } - md := metadata.New(map[string]string{"n1": "v1", "n3": "v3"}) - wantKey := cache.Key{Path: rpcPath, KeyMap: wantKeyMapStr} - rlsLastErr := errors.New("last RLS request failed") - - tests := []struct { - desc string - // The cache entry, as returned by the overridden readCache hook. - cacheEntry *cache.Entry - // Whether or not a default target is configured. - defaultPickExists bool - // Whether or not the RLS request should be throttled. - throttle bool - // Whether or not the test is expected to make a new RLS request. - wantRLSRequest bool - // Whether or not the rlsPicker should delegate to the child picker. - wantChildPick bool - // Whether or not the rlsPicker should delegate to the default picker. - wantDefaultPick bool - // Expected error returned by the rlsPicker under test. - wantErr error - }{ - { - desc: "valid entry", - cacheEntry: &cache.Entry{ - ExpiryTime: time.Now().Add(defaultTestMaxAge), - StaleTime: time.Now().Add(defaultTestMaxAge), - }, - wantChildPick: true, - }, - { - desc: "entryStale_requestThrottled", - cacheEntry: &cache.Entry{ExpiryTime: time.Now().Add(defaultTestMaxAge)}, - throttle: true, - wantChildPick: true, - }, - { - desc: "entryStale_requestNotThrottled", - cacheEntry: &cache.Entry{ExpiryTime: time.Now().Add(defaultTestMaxAge)}, - wantRLSRequest: true, - wantChildPick: true, - }, - { - desc: "entryExpired_requestThrottled_defaultPickExists", - cacheEntry: &cache.Entry{}, - throttle: true, - defaultPickExists: true, - wantDefaultPick: true, - }, - { - desc: "entryExpired_requestThrottled_defaultPickNotExists", - cacheEntry: &cache.Entry{}, - throttle: true, - wantErr: errRLSThrottled, - }, - { - desc: "entryExpired_requestNotThrottled", - cacheEntry: &cache.Entry{}, - wantRLSRequest: true, - wantErr: balancer.ErrNoSubConnAvailable, - }, - { - desc: "entryExpired_backoffNotExpired_defaultPickExists", - cacheEntry: &cache.Entry{ - BackoffTime: time.Now().Add(defaultTestMaxAge), - CallStatus: rlsLastErr, - }, - defaultPickExists: true, - }, - { - desc: "entryExpired_backoffNotExpired_defaultPickNotExists", - cacheEntry: &cache.Entry{ - BackoffTime: time.Now().Add(defaultTestMaxAge), - CallStatus: rlsLastErr, - }, - wantErr: rlsLastErr, - }, - } - - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - rlsCh := testutils.NewChannel() - childPicker := newFakePicker() - defaultPicker := newFakePicker() - - p := rlsPicker{ - kbm: kbm, - readCache: func(key cache.Key) (*cache.Entry, bool) { - if !cmp.Equal(key, wantKey) { - t.Fatalf("cache lookup using cacheKey %v, want %v", key, wantKey) - } - test.cacheEntry.ChildPicker = childPicker - return test.cacheEntry, false - }, - shouldThrottle: func() bool { return test.throttle }, - startRLS: func(path string, km keys.KeyMap) { - if !test.wantRLSRequest { - rlsCh.Send(errors.New("RLS request attempted when none was expected")) - return - } - if path != rpcPath { - rlsCh.Send(fmt.Errorf("RLS request initiated for rpcPath %s, want %s", path, rpcPath)) - return - } - if km.Str != wantKeyMapStr { - rlsCh.Send(fmt.Errorf("RLS request initiated with keys %v, want %v", km.Str, wantKeyMapStr)) - return - } - rlsCh.Send(nil) - }, - } - if test.defaultPickExists { - p.defaultPick = defaultPicker.Pick - } - - gotResult, err := p.Pick(balancer.PickInfo{ - FullMethodName: rpcPath, - Ctx: metadata.NewOutgoingContext(context.Background(), md), - }) - if err != test.wantErr { - t.Fatalf("Pick() returned error {%v}, want {%v}", err, test.wantErr) - } - // If the test specified that a new RLS request should be made, - // verify it. - if test.wantRLSRequest { - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if rlsErr, err := rlsCh.Receive(ctx); err != nil || rlsErr != nil { - t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err) - } - } - if test.wantErr != nil { - return - } - - // We get here only for cases where we expect the pick to be - // delegated to the child picker or the default picker. - if test.wantChildPick { - if err := verifySubConn(gotResult.SubConn, childPicker.id); err != nil { - t.Fatal(err) - } - - } - if test.wantDefaultPick { - if err := verifySubConn(gotResult.SubConn, defaultPicker.id); err != nil { - t.Fatal(err) - } - } - }) - } -} - -// TestPick_DataCacheHit_PendingCacheHit verifies different Pick scenarios where -// the entry is found both in the data cache and in the pending cache. This -// mostly verifies cases where the entry is stale, but there is already a -// pending RLS request, so no new request should be sent out. -func TestPick_DataCacheHit_PendingCacheHit(t *testing.T) { - const ( - rpcPath = "/gFoo/method" - wantKeyMapStr = "k1=v1" - ) - kbm, err := initKeyBuilderMap() - if err != nil { - t.Fatalf("Failed to create keyBuilderMap: %v", err) - } - md := metadata.New(map[string]string{"n1": "v1", "n3": "v3"}) - wantKey := cache.Key{Path: rpcPath, KeyMap: wantKeyMapStr} - - tests := []struct { - desc string - // The cache entry, as returned by the overridden readCache hook. - cacheEntry *cache.Entry - // Whether or not a default target is configured. - defaultPickExists bool - // Expected error returned by the rlsPicker under test. - wantErr error - }{ - { - desc: "stale entry", - cacheEntry: &cache.Entry{ExpiryTime: time.Now().Add(defaultTestMaxAge)}, - }, - { - desc: "stale entry with default picker", - cacheEntry: &cache.Entry{ExpiryTime: time.Now().Add(defaultTestMaxAge)}, - defaultPickExists: true, - }, - { - desc: "entryExpired_defaultPickExists", - cacheEntry: &cache.Entry{}, - defaultPickExists: true, - wantErr: balancer.ErrNoSubConnAvailable, - }, - { - desc: "entryExpired_defaultPickNotExists", - cacheEntry: &cache.Entry{}, - wantErr: balancer.ErrNoSubConnAvailable, - }, - } - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - rlsCh := testutils.NewChannel() - childPicker := newFakePicker() - - p := rlsPicker{ - kbm: kbm, - readCache: func(key cache.Key) (*cache.Entry, bool) { - if !cmp.Equal(key, wantKey) { - t.Fatalf("cache lookup using cacheKey %v, want %v", key, wantKey) - } - test.cacheEntry.ChildPicker = childPicker - return test.cacheEntry, true - }, - // Never throttle. We do not expect an RLS request to be sent out anyways. - shouldThrottle: func() bool { return false }, - startRLS: func(path string, km keys.KeyMap) { - rlsCh.Send(nil) - }, - } - if test.defaultPickExists { - p.defaultPick = func(info balancer.PickInfo) (balancer.PickResult, error) { - // We do not expect the default picker to be invoked at all. - // So, if we get here, we return an error. - return balancer.PickResult{}, errors.New("default picker invoked when expecting a child pick") - } - } - - gotResult, err := p.Pick(balancer.PickInfo{ - FullMethodName: rpcPath, - Ctx: metadata.NewOutgoingContext(context.Background(), md), - }) - if err != test.wantErr { - t.Fatalf("Pick() returned error {%v}, want {%v}", err, test.wantErr) - } - // Make sure that no RLS request was sent out. - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := rlsCh.Receive(ctx); err != context.DeadlineExceeded { - t.Fatalf("RLS request sent out when pending entry exists") - } - if test.wantErr != nil { - return - } - - // We get here only for cases where we expect the pick to be - // delegated to the child picker. - if err := verifySubConn(gotResult.SubConn, childPicker.id); err != nil { - t.Fatal(err) - } - }) - } -} diff --git a/balancer/rls/internal/test/e2e/e2e.go b/balancer/rls/internal/test/e2e/e2e.go new file mode 100644 index 00000000000..7b8a8bbde13 --- /dev/null +++ b/balancer/rls/internal/test/e2e/e2e.go @@ -0,0 +1,20 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package e2e contains utilities for end-to-end RouteLookupService tests. +package e2e diff --git a/balancer/rls/internal/test/e2e/rls_child_policy.go b/balancer/rls/internal/test/e2e/rls_child_policy.go new file mode 100644 index 00000000000..5a6e3e69175 --- /dev/null +++ b/balancer/rls/internal/test/e2e/rls_child_policy.go @@ -0,0 +1,131 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package e2e + +import ( + "encoding/json" + "errors" + "fmt" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" +) + +const ( + // RLSChildPolicyTargetNameField is a top-level field name to add to the child + // policy's config, whose value is set to the target for the child policy. + RLSChildPolicyTargetNameField = "Backend" + // RLSChildPolicyBadTarget is a value which is considered a bad target by the + // child policy. This is useful to test bad child policy configuration. + RLSChildPolicyBadTarget = "bad-target" +) + +// ErrParseConfigBadTarget is the error returned from ParseConfig when the +// backend field is set to RLSChildPolicyBadTarget. +var ErrParseConfigBadTarget = errors.New("backend field set to RLSChildPolicyBadTarget") + +// BalancerFuncs is a set of callbacks which get invoked when the corresponding +// method on the child policy is invoked. +type BalancerFuncs struct { + UpdateClientConnState func(cfg *RLSChildPolicyConfig) error + Close func() +} + +// RegisterRLSChildPolicy registers a balancer builder with the given name, to +// be used as a child policy for the RLS LB policy. +// +// The child policy uses a pickfirst balancer under the hood to send all traffic +// to the single backend specified by the `RLSChildPolicyTargetNameField` field +// in its configuration which looks like: {"Backend": "Backend-address"}. +func RegisterRLSChildPolicy(name string, bf *BalancerFuncs) { + balancer.Register(bb{name: name, bf: bf}) +} + +type bb struct { + name string + bf *BalancerFuncs +} + +func (bb bb) Name() string { return bb.name } + +func (bb bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + pf := balancer.Get(grpc.PickFirstBalancerName) + b := &bal{ + Balancer: pf.Build(cc, opts), + bf: bb.bf, + done: grpcsync.NewEvent(), + } + go b.run() + return b +} + +func (bb bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + cfg := &RLSChildPolicyConfig{} + if err := json.Unmarshal(c, cfg); err != nil { + return nil, err + } + if cfg.Backend == RLSChildPolicyBadTarget { + return nil, ErrParseConfigBadTarget + } + return cfg, nil +} + +type bal struct { + balancer.Balancer + bf *BalancerFuncs + done *grpcsync.Event +} + +// RLSChildPolicyConfig is the LB config for the test child policy. +type RLSChildPolicyConfig struct { + serviceconfig.LoadBalancingConfig + Backend string // The target for which this child policy was created. + Random string // A random field to test child policy config changes. +} + +func (b *bal) UpdateClientConnState(c balancer.ClientConnState) error { + cfg, ok := c.BalancerConfig.(*RLSChildPolicyConfig) + if !ok { + return fmt.Errorf("received balancer config of type %T, want %T", c.BalancerConfig, &RLSChildPolicyConfig{}) + } + if b.bf != nil && b.bf.UpdateClientConnState != nil { + b.bf.UpdateClientConnState(cfg) + } + return b.Balancer.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{Addresses: []resolver.Address{{Addr: cfg.Backend}}}, + }) +} + +func (b *bal) Close() { + b.Balancer.Close() + if b.bf != nil && b.bf.Close != nil { + b.bf.Close() + } + b.done.Fire() +} + +// run is a dummy goroutine to make sure that child policies are closed at the +// end of tests. If they are not closed, these goroutines will be picked up by +// the leakcheker and tests will fail. +func (b *bal) run() { + <-b.done.Done() +} diff --git a/balancer/rls/internal/test/e2e/rls_fakeserver.go b/balancer/rls/internal/test/e2e/rls_fakeserver.go new file mode 100644 index 00000000000..52198541282 --- /dev/null +++ b/balancer/rls/internal/test/e2e/rls_fakeserver.go @@ -0,0 +1,110 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package e2e + +import ( + "context" + "net" + "sync" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + rlsgrpc "google.golang.org/grpc/internal/proto/grpc_lookup_v1" + rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/status" +) + +// RouteLookupResponse wraps an RLS response and the associated error to be sent +// to a client when the RouteLookup RPC is invoked. +type RouteLookupResponse struct { + Resp *rlspb.RouteLookupResponse + Err error +} + +// FakeRouteLookupServer is a fake implementation of the RouteLookupService. +// +// It is safe for concurrent use. +type FakeRouteLookupServer struct { + rlsgrpc.UnimplementedRouteLookupServiceServer + Address string + + mu sync.Mutex + respCb func(context.Context, *rlspb.RouteLookupRequest) *RouteLookupResponse + reqCb func(*rlspb.RouteLookupRequest) +} + +// StartFakeRouteLookupServer starts a fake RLS server listening for requests on +// lis. If lis is nil, it creates a new listener on a random local port. The +// returned cancel function should be invoked by the caller upon completion of +// the test. +func StartFakeRouteLookupServer(t *testing.T, lis net.Listener, opts ...grpc.ServerOption) (*FakeRouteLookupServer, func()) { + t.Helper() + + if lis == nil { + var err error + lis, err = testutils.LocalTCPListener() + if err != nil { + t.Fatalf("net.Listen() failed: %v", err) + } + } + + s := &FakeRouteLookupServer{Address: lis.Addr().String()} + server := grpc.NewServer(opts...) + rlsgrpc.RegisterRouteLookupServiceServer(server, s) + go server.Serve(lis) + return s, func() { server.Stop() } +} + +// RouteLookup implements the RouteLookupService. +func (s *FakeRouteLookupServer) RouteLookup(ctx context.Context, req *rlspb.RouteLookupRequest) (*rlspb.RouteLookupResponse, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.reqCb != nil { + s.reqCb(req) + } + if err := ctx.Err(); err != nil { + return nil, status.Error(codes.DeadlineExceeded, err.Error()) + } + if s.respCb == nil { + return &rlspb.RouteLookupResponse{}, nil + } + resp := s.respCb(ctx, req) + return resp.Resp, resp.Err +} + +// SetResponseCallback sets a callback to be invoked on every RLS request. If +// this callback is set, the response returned by the fake server depends on the +// value returned by the callback. If this callback is not set, the fake server +// responds with an empty response. +func (s *FakeRouteLookupServer) SetResponseCallback(f func(context.Context, *rlspb.RouteLookupRequest) *RouteLookupResponse) { + s.mu.Lock() + s.respCb = f + s.mu.Unlock() +} + +// SetRequestCallback sets a callback to be invoked on every RLS request. The +// callback is given the incoming request, and tests can use this to verify that +// the request matches its expectations. +func (s *FakeRouteLookupServer) SetRequestCallback(f func(*rlspb.RouteLookupRequest)) { + s.mu.Lock() + s.reqCb = f + s.mu.Unlock() +} diff --git a/balancer/rls/internal/test/e2e/rls_lb_config.go b/balancer/rls/internal/test/e2e/rls_lb_config.go new file mode 100644 index 00000000000..2aec642c77e --- /dev/null +++ b/balancer/rls/internal/test/e2e/rls_lb_config.go @@ -0,0 +1,100 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package e2e + +import ( + "errors" + "fmt" + + "google.golang.org/grpc/balancer" + rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/serviceconfig" + + "google.golang.org/protobuf/encoding/protojson" +) + +// RLSConfig is a utility type to build service config for the RLS LB policy. +type RLSConfig struct { + RouteLookupConfig *rlspb.RouteLookupConfig + ChildPolicy *internalserviceconfig.BalancerConfig + ChildPolicyConfigTargetFieldName string +} + +// ServiceConfigJSON generates service config with a load balancing config +// corresponding to the RLS LB policy. +func (c *RLSConfig) ServiceConfigJSON() (string, error) { + m := protojson.MarshalOptions{ + Multiline: true, + Indent: " ", + UseProtoNames: true, + } + routeLookupCfg, err := m.Marshal(c.RouteLookupConfig) + if err != nil { + return "", err + } + childPolicy, err := c.ChildPolicy.MarshalJSON() + if err != nil { + return "", err + } + + return fmt.Sprintf(` +{ + "loadBalancingConfig": [ + { + "rls_experimental": { + "routeLookupConfig": %s, + "childPolicy": %s, + "childPolicyConfigTargetFieldName": %q + } + } + ] +}`, string(routeLookupCfg), string(childPolicy), c.ChildPolicyConfigTargetFieldName), nil +} + +// LoadBalancingConfig generates load balancing config which can used as part of +// a ClientConnState update to the RLS LB policy. +func (c *RLSConfig) LoadBalancingConfig() (serviceconfig.LoadBalancingConfig, error) { + m := protojson.MarshalOptions{ + Multiline: true, + Indent: " ", + UseProtoNames: true, + } + routeLookupCfg, err := m.Marshal(c.RouteLookupConfig) + if err != nil { + return nil, err + } + childPolicy, err := c.ChildPolicy.MarshalJSON() + if err != nil { + return nil, err + } + lbConfigJSON := fmt.Sprintf(` +{ + "routeLookupConfig": %s, + "childPolicy": %s, + "childPolicyConfigTargetFieldName": %q +}`, string(routeLookupCfg), string(childPolicy), c.ChildPolicyConfigTargetFieldName) + + builder := balancer.Get("rls_experimental") + if builder == nil { + return nil, errors.New("balancer builder not found for RLS LB policy") + } + parser := builder.(balancer.ConfigParser) + return parser.ParseConfig([]byte(lbConfigJSON)) +} diff --git a/internal/stubserver/stubserver.go b/internal/stubserver/stubserver.go index c97010dfe9a..f3ed23aa32a 100644 --- a/internal/stubserver/stubserver.go +++ b/internal/stubserver/stubserver.go @@ -80,6 +80,14 @@ func (ss *StubServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallSer // Start starts the server and creates a client connected to it. func (ss *StubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption) error { + if err := ss.StartServer(sopts...); err != nil { + return err + } + return ss.StartClient(dopts...) +} + +// StartServer only starts the server. It does not create a client to it. +func (ss *StubServer) StartServer(sopts ...grpc.ServerOption) error { if ss.Network == "" { ss.Network = "tcp" } @@ -102,7 +110,12 @@ func (ss *StubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption) go s.Serve(lis) ss.cleanups = append(ss.cleanups, s.Stop) ss.S = s + return nil +} +// StartClient creates a client connected to this service that the test may use. +// The newly created client will be available in the Client field of StubServer. +func (ss *StubServer) StartClient(dopts ...grpc.DialOption) error { opts := append([]grpc.DialOption{grpc.WithInsecure()}, dopts...) if ss.R != nil { ss.Target = ss.R.Scheme() + ":///" + ss.Address diff --git a/internal/testutils/restartable_listener.go b/internal/testutils/restartable_listener.go index 1f501939191..efe4019a08c 100644 --- a/internal/testutils/restartable_listener.go +++ b/internal/testutils/restartable_listener.go @@ -83,12 +83,11 @@ func (l *RestartableListener) Addr() net.Addr { func (l *RestartableListener) Stop() { l.mu.Lock() l.stopped = true - tmp := l.conns - l.conns = nil - l.mu.Unlock() - for _, conn := range tmp { + for _, conn := range l.conns { conn.Close() } + l.conns = nil + l.mu.Unlock() } // Restart gets a previously stopped listener to start accepting connections.