Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
xds/google_default_creds: handshake based on cluster name in address …
…attributes (#4310)
- Loading branch information
Showing
10 changed files
with
584 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
/* | ||
* | ||
* Copyright 2021 gRPC authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* | ||
*/ | ||
|
||
package google | ||
|
||
import ( | ||
"context" | ||
"net" | ||
"testing" | ||
|
||
"google.golang.org/grpc/credentials" | ||
"google.golang.org/grpc/internal" | ||
xdsinternal "google.golang.org/grpc/internal/credentials/xds" | ||
"google.golang.org/grpc/resolver" | ||
) | ||
|
||
type testCreds struct { | ||
credentials.TransportCredentials | ||
typ string | ||
} | ||
|
||
func (c *testCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { | ||
return nil, &testAuthInfo{typ: c.typ}, nil | ||
} | ||
|
||
func (c *testCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { | ||
return nil, &testAuthInfo{typ: c.typ}, nil | ||
} | ||
|
||
type testAuthInfo struct { | ||
typ string | ||
} | ||
|
||
func (t *testAuthInfo) AuthType() string { | ||
return t.typ | ||
} | ||
|
||
var ( | ||
testTLS = &testCreds{typ: "tls"} | ||
testALTS = &testCreds{typ: "alts"} | ||
|
||
contextWithHandshakeInfo = internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) | ||
) | ||
|
||
func overrideNewCredsFuncs() func() { | ||
oldNewTLS := newTLS | ||
newTLS = func() credentials.TransportCredentials { | ||
return testTLS | ||
} | ||
oldNewALTS := newALTS | ||
newALTS = func() credentials.TransportCredentials { | ||
return testALTS | ||
} | ||
return func() { | ||
newTLS = oldNewTLS | ||
newALTS = oldNewALTS | ||
} | ||
} | ||
|
||
// TestClientHandshakeBasedOnClusterName that by default (without switching | ||
// modes), ClientHandshake does either tls or alts base on the cluster name in | ||
// attributes. | ||
func TestClientHandshakeBasedOnClusterName(t *testing.T) { | ||
defer overrideNewCredsFuncs()() | ||
for bundleTyp, tc := range map[string]credentials.Bundle{ | ||
"defaultCreds": NewDefaultCredentials(), | ||
"computeCreds": NewComputeEngineCredentials(), | ||
} { | ||
tests := []struct { | ||
name string | ||
ctx context.Context | ||
wantTyp string | ||
}{ | ||
{ | ||
name: "no cluster name", | ||
ctx: context.Background(), | ||
wantTyp: "tls", | ||
}, | ||
{ | ||
name: "with non-CFE cluster name", | ||
ctx: contextWithHandshakeInfo(context.Background(), credentials.ClientHandshakeInfo{ | ||
Attributes: xdsinternal.SetHandshakeClusterName(resolver.Address{}, "lalala").Attributes, | ||
}), | ||
// non-CFE backends should use alts. | ||
wantTyp: "alts", | ||
}, | ||
{ | ||
name: "with CFE cluster name", | ||
ctx: contextWithHandshakeInfo(context.Background(), credentials.ClientHandshakeInfo{ | ||
Attributes: xdsinternal.SetHandshakeClusterName(resolver.Address{}, cfeClusterName).Attributes, | ||
}), | ||
// CFE should use tls. | ||
wantTyp: "tls", | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(bundleTyp+" "+tt.name, func(t *testing.T) { | ||
_, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil) | ||
if err != nil { | ||
t.Fatalf("ClientHandshake failed: %v", err) | ||
} | ||
if gotType := info.AuthType(); gotType != tt.wantTyp { | ||
t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp) | ||
} | ||
|
||
_, infoServer, err := tc.TransportCredentials().ServerHandshake(nil) | ||
if err != nil { | ||
t.Fatalf("ClientHandshake failed: %v", err) | ||
} | ||
// ServerHandshake should always do TLS. | ||
if gotType := infoServer.AuthType(); gotType != "tls" { | ||
t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls") | ||
} | ||
}) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
/* | ||
* | ||
* Copyright 2021 gRPC authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* | ||
*/ | ||
|
||
package google | ||
|
||
import ( | ||
"context" | ||
"net" | ||
|
||
"google.golang.org/grpc/credentials" | ||
xdsinternal "google.golang.org/grpc/internal/credentials/xds" | ||
) | ||
|
||
const cfeClusterName = "google-cfe" | ||
|
||
// clusterTransportCreds is a combo of TLS + ALTS. | ||
// | ||
// On the client, ClientHandshake picks TLS or ALTS based on address attributes. | ||
// - if attributes has cluster name | ||
// - if cluster name is "google_cfe", use TLS | ||
// - otherwise, use ALTS | ||
// - else, do TLS | ||
// | ||
// On the server, ServerHandshake always does TLS. | ||
type clusterTransportCreds struct { | ||
tls credentials.TransportCredentials | ||
alts credentials.TransportCredentials | ||
} | ||
|
||
func newClusterTransportCreds(tls, alts credentials.TransportCredentials) *clusterTransportCreds { | ||
return &clusterTransportCreds{ | ||
tls: tls, | ||
alts: alts, | ||
} | ||
} | ||
|
||
func (c *clusterTransportCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { | ||
chi := credentials.ClientHandshakeInfoFromContext(ctx) | ||
if chi.Attributes == nil { | ||
return c.tls.ClientHandshake(ctx, authority, rawConn) | ||
} | ||
cn, ok := xdsinternal.GetHandshakeClusterName(chi.Attributes) | ||
if !ok || cn == cfeClusterName { | ||
return c.tls.ClientHandshake(ctx, authority, rawConn) | ||
} | ||
// If attributes have cluster name, and cluster name is not cfe, it's a | ||
// backend address, use ALTS. | ||
return c.alts.ClientHandshake(ctx, authority, rawConn) | ||
} | ||
|
||
func (c *clusterTransportCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { | ||
return c.tls.ServerHandshake(conn) | ||
} | ||
|
||
func (c *clusterTransportCreds) Info() credentials.ProtocolInfo { | ||
// TODO: this always returns tls.Info now, because we don't have a cluster | ||
// name to check when this method is called. This method doesn't affect | ||
// anything important now. We may want to revisit this if it becomes more | ||
// important later. | ||
return c.tls.Info() | ||
} | ||
|
||
func (c *clusterTransportCreds) Clone() credentials.TransportCredentials { | ||
return &clusterTransportCreds{ | ||
tls: c.tls.Clone(), | ||
alts: c.alts.Clone(), | ||
} | ||
} | ||
|
||
func (c *clusterTransportCreds) OverrideServerName(s string) error { | ||
if err := c.tls.OverrideServerName(s); err != nil { | ||
return err | ||
} | ||
return c.alts.OverrideServerName(s) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/* | ||
* | ||
* Copyright 2021 gRPC authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* | ||
*/ | ||
|
||
package xds | ||
|
||
import ( | ||
"google.golang.org/grpc/attributes" | ||
"google.golang.org/grpc/resolver" | ||
) | ||
|
||
// handshakeClusterNameKey is the type used as the key to store cluster name in | ||
// the Attributes field of resolver.Address. | ||
type handshakeClusterNameKey struct{} | ||
|
||
// SetHandshakeClusterName returns a copy of addr in which the Attributes field | ||
// is updated with the cluster name. | ||
func SetHandshakeClusterName(addr resolver.Address, clusterName string) resolver.Address { | ||
addr.Attributes = addr.Attributes.WithValues(handshakeClusterNameKey{}, clusterName) | ||
return addr | ||
} | ||
|
||
// GetHandshakeClusterName returns cluster name stored in attr. | ||
func GetHandshakeClusterName(attr *attributes.Attributes) (string, bool) { | ||
v := attr.Value(handshakeClusterNameKey{}) | ||
name, ok := v.(string) | ||
return name, ok | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.