Skip to content

Commit

Permalink
xds/google_default_creds: handshake based on cluster name in address …
Browse files Browse the repository at this point in the history
…attributes (#4310)
  • Loading branch information
menghanl committed Apr 12, 2021
1 parent fab5982 commit 950ddd3
Show file tree
Hide file tree
Showing 10 changed files with 584 additions and 7 deletions.
13 changes: 11 additions & 2 deletions credentials/google/google.go
Expand Up @@ -99,6 +99,15 @@ func (c *creds) PerRPCCredentials() credentials.PerRPCCredentials {
return c.perRPCCreds
}

var (
newTLS = func() credentials.TransportCredentials {
return credentials.NewTLS(nil)
}
newALTS = func() credentials.TransportCredentials {
return alts.NewClientCreds(alts.DefaultClientOptions())
}
)

// NewWithMode should make a copy of Bundle, and switch mode. Modifying the
// existing Bundle may cause races.
func (c *creds) NewWithMode(mode string) (credentials.Bundle, error) {
Expand All @@ -110,11 +119,11 @@ func (c *creds) NewWithMode(mode string) (credentials.Bundle, error) {
// Create transport credentials.
switch mode {
case internal.CredsBundleModeFallback:
newCreds.transportCreds = credentials.NewTLS(nil)
newCreds.transportCreds = newClusterTransportCreds(newTLS(), newALTS())
case internal.CredsBundleModeBackendFromBalancer, internal.CredsBundleModeBalancer:
// Only the clients can use google default credentials, so we only need
// to create new ALTS client creds here.
newCreds.transportCreds = alts.NewClientCreds(alts.DefaultClientOptions())
newCreds.transportCreds = newALTS()
default:
return nil, fmt.Errorf("unsupported mode: %v", mode)
}
Expand Down
132 changes: 132 additions & 0 deletions credentials/google/google_test.go
@@ -0,0 +1,132 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package google

import (
"context"
"net"
"testing"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
"google.golang.org/grpc/resolver"
)

type testCreds struct {
credentials.TransportCredentials
typ string
}

func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return nil, &testAuthInfo{typ: c.typ}, nil
}

func (c *testCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return nil, &testAuthInfo{typ: c.typ}, nil
}

type testAuthInfo struct {
typ string
}

func (t *testAuthInfo) AuthType() string {
return t.typ
}

var (
testTLS = &testCreds{typ: "tls"}
testALTS = &testCreds{typ: "alts"}

contextWithHandshakeInfo = internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
)

func overrideNewCredsFuncs() func() {
oldNewTLS := newTLS
newTLS = func() credentials.TransportCredentials {
return testTLS
}
oldNewALTS := newALTS
newALTS = func() credentials.TransportCredentials {
return testALTS
}
return func() {
newTLS = oldNewTLS
newALTS = oldNewALTS
}
}

// TestClientHandshakeBasedOnClusterName that by default (without switching
// modes), ClientHandshake does either tls or alts base on the cluster name in
// attributes.
func TestClientHandshakeBasedOnClusterName(t *testing.T) {
defer overrideNewCredsFuncs()()
for bundleTyp, tc := range map[string]credentials.Bundle{
"defaultCreds": NewDefaultCredentials(),
"computeCreds": NewComputeEngineCredentials(),
} {
tests := []struct {
name string
ctx context.Context
wantTyp string
}{
{
name: "no cluster name",
ctx: context.Background(),
wantTyp: "tls",
},
{
name: "with non-CFE cluster name",
ctx: contextWithHandshakeInfo(context.Background(), credentials.ClientHandshakeInfo{
Attributes: xdsinternal.SetHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
}),
// non-CFE backends should use alts.
wantTyp: "alts",
},
{
name: "with CFE cluster name",
ctx: contextWithHandshakeInfo(context.Background(), credentials.ClientHandshakeInfo{
Attributes: xdsinternal.SetHandshakeClusterName(resolver.Address{}, cfeClusterName).Attributes,
}),
// CFE should use tls.
wantTyp: "tls",
},
}
for _, tt := range tests {
t.Run(bundleTyp+" "+tt.name, func(t *testing.T) {
_, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil)
if err != nil {
t.Fatalf("ClientHandshake failed: %v", err)
}
if gotType := info.AuthType(); gotType != tt.wantTyp {
t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp)
}

_, infoServer, err := tc.TransportCredentials().ServerHandshake(nil)
if err != nil {
t.Fatalf("ClientHandshake failed: %v", err)
}
// ServerHandshake should always do TLS.
if gotType := infoServer.AuthType(); gotType != "tls" {
t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls")
}
})
}
}
}
90 changes: 90 additions & 0 deletions credentials/google/xds.go
@@ -0,0 +1,90 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package google

import (
"context"
"net"

"google.golang.org/grpc/credentials"
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
)

const cfeClusterName = "google-cfe"

// clusterTransportCreds is a combo of TLS + ALTS.
//
// On the client, ClientHandshake picks TLS or ALTS based on address attributes.
// - if attributes has cluster name
// - if cluster name is "google_cfe", use TLS
// - otherwise, use ALTS
// - else, do TLS
//
// On the server, ServerHandshake always does TLS.
type clusterTransportCreds struct {
tls credentials.TransportCredentials
alts credentials.TransportCredentials
}

func newClusterTransportCreds(tls, alts credentials.TransportCredentials) *clusterTransportCreds {
return &clusterTransportCreds{
tls: tls,
alts: alts,
}
}

func (c *clusterTransportCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
chi := credentials.ClientHandshakeInfoFromContext(ctx)
if chi.Attributes == nil {
return c.tls.ClientHandshake(ctx, authority, rawConn)
}
cn, ok := xdsinternal.GetHandshakeClusterName(chi.Attributes)
if !ok || cn == cfeClusterName {
return c.tls.ClientHandshake(ctx, authority, rawConn)
}
// If attributes have cluster name, and cluster name is not cfe, it's a
// backend address, use ALTS.
return c.alts.ClientHandshake(ctx, authority, rawConn)
}

func (c *clusterTransportCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return c.tls.ServerHandshake(conn)
}

func (c *clusterTransportCreds) Info() credentials.ProtocolInfo {
// TODO: this always returns tls.Info now, because we don't have a cluster
// name to check when this method is called. This method doesn't affect
// anything important now. We may want to revisit this if it becomes more
// important later.
return c.tls.Info()
}

func (c *clusterTransportCreds) Clone() credentials.TransportCredentials {
return &clusterTransportCreds{
tls: c.tls.Clone(),
alts: c.alts.Clone(),
}
}

func (c *clusterTransportCreds) OverrideServerName(s string) error {
if err := c.tls.OverrideServerName(s); err != nil {
return err
}
return c.alts.OverrideServerName(s)
}
42 changes: 42 additions & 0 deletions internal/credentials/xds/handshake_cluster.go
@@ -0,0 +1,42 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package xds

import (
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/resolver"
)

// handshakeClusterNameKey is the type used as the key to store cluster name in
// the Attributes field of resolver.Address.
type handshakeClusterNameKey struct{}

// SetHandshakeClusterName returns a copy of addr in which the Attributes field
// is updated with the cluster name.
func SetHandshakeClusterName(addr resolver.Address, clusterName string) resolver.Address {
addr.Attributes = addr.Attributes.WithValues(handshakeClusterNameKey{}, clusterName)
return addr
}

// GetHandshakeClusterName returns cluster name stored in attr.
func GetHandshakeClusterName(attr *attributes.Attributes) (string, bool) {
v := attr.Value(handshakeClusterNameKey{})
name, ok := v.(string)
return name, ok
}
91 changes: 91 additions & 0 deletions xds/internal/balancer/clusterimpl/balancer_test.go
Expand Up @@ -29,6 +29,7 @@ import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/connectivity"
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/xds/internal/client/load"
Expand Down Expand Up @@ -369,3 +370,93 @@ func TestPickerUpdateAfterClose(t *testing.T) {
case <-time.After(time.Millisecond * 10):
}
}

// TestClusterNameInAddressAttributes covers the case that cluster name is
// attached to the subconn address attributes.
func TestClusterNameInAddressAttributes(t *testing.T) {
xdsC := fakeclient.NewClient()
oldNewXDSClient := newXDSClient
newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil }
defer func() { newXDSClient = oldNewXDSClient }()

builder := balancer.Get(clusterImplName)
cc := testutils.NewTestClientConn(t)
b := builder.Build(cc, balancer.BuildOptions{})
defer b.Close()

if err := b.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{
Addresses: testBackendAddrs,
},
BalancerConfig: &lbConfig{
Cluster: testClusterName,
EDSServiceName: testServiceName,
ChildPolicy: &internalserviceconfig.BalancerConfig{
Name: roundrobin.Name,
},
},
}); err != nil {
t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
}

sc1 := <-cc.NewSubConnCh
b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting})
// This should get the connecting picker.
p0 := <-cc.NewPickerCh
for i := 0; i < 10; i++ {
_, err := p0.Pick(balancer.PickInfo{})
if err != balancer.ErrNoSubConnAvailable {
t.Fatalf("picker.Pick, got _,%v, want Err=%v", err, balancer.ErrNoSubConnAvailable)
}
}

addrs1 := <-cc.NewSubConnAddrsCh
if got, want := addrs1[0].Addr, testBackendAddrs[0].Addr; got != want {
t.Fatalf("sc is created with addr %v, want %v", got, want)
}
cn, ok := xdsinternal.GetHandshakeClusterName(addrs1[0].Attributes)
if !ok || cn != testClusterName {
t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn, ok, testClusterName)
}

b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready})
// Test pick with one backend.
p1 := <-cc.NewPickerCh
const rpcCount = 20
for i := 0; i < rpcCount; i++ {
gotSCSt, err := p1.Pick(balancer.PickInfo{})
if err != nil || !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) {
t.Fatalf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1)
}
if gotSCSt.Done != nil {
gotSCSt.Done(balancer.DoneInfo{})
}
}

const testClusterName2 = "test-cluster-2"
var addr2 = resolver.Address{Addr: "2.2.2.2"}
if err := b.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{
Addresses: []resolver.Address{addr2},
},
BalancerConfig: &lbConfig{
Cluster: testClusterName2,
EDSServiceName: testServiceName,
ChildPolicy: &internalserviceconfig.BalancerConfig{
Name: roundrobin.Name,
},
},
}); err != nil {
t.Fatalf("unexpected error from UpdateClientConnState: %v", err)
}

addrs2 := <-cc.NewSubConnAddrsCh
if got, want := addrs2[0].Addr, addr2.Addr; got != want {
t.Fatalf("sc is created with addr %v, want %v", got, want)
}
// New addresses should have the new cluster name.
cn2, ok := xdsinternal.GetHandshakeClusterName(addrs2[0].Attributes)
if !ok || cn2 != testClusterName2 {
t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn2, ok, testClusterName2)
}
}

0 comments on commit 950ddd3

Please sign in to comment.