Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

credentials/tls: reject connections with ALPN disabled #7184

Merged
merged 14 commits into from
May 21, 2024
20 changes: 20 additions & 0 deletions credentials/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ import (
"net/url"
"os"

"google.golang.org/grpc/grpclog"
credinternal "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/envconfig"
)

var logger = grpclog.Component("credentials")

// TLSInfo contains the auth information for a TLS authenticated connection.
// It implements the AuthInfo interface.
type TLSInfo struct {
Expand Down Expand Up @@ -112,6 +116,22 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
conn.Close()
return nil, nil, ctx.Err()
}

// The negotiated protocol can be either of the following:
// 1. h2: When the server supports ALPN. Only HTTP/2 can be negotiated since
// it is the only protocol advertised by the client during the handshake.
// The tls library ensures that the server chooses a protocol advertised
// by the client.
// 2. "" (empty string): If the server doesn't support ALPN. ALPN is a requirement
// for using HTTP/2 over TLS. We can terminate the connection immediately.
np := conn.ConnectionState().NegotiatedProtocol
if np == "" {
easwars marked this conversation as resolved.
Show resolved Hide resolved
if envconfig.EnforceALPNEnabled {
_ = conn.Close()
easwars marked this conversation as resolved.
Show resolved Hide resolved
return nil, nil, fmt.Errorf("cannot check peer: missing selected ALPN property")
easwars marked this conversation as resolved.
Show resolved Hide resolved
}
logger.Warning("Allowing TLS connection to server %q with ALPN disabled")
easwars marked this conversation as resolved.
Show resolved Hide resolved
}
tlsInfo := TLSInfo{
State: conn.ConnectionState(),
CommonAuthInfo: CommonAuthInfo{
Expand Down
98 changes: 98 additions & 0 deletions credentials/tls_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ import (
"crypto/x509"
"fmt"
"os"
"regexp"
"strings"
"testing"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -236,3 +238,99 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
t.Fatalf("EmptyCall err = %v; want <nil>", err)
}
}

// TestTLS_DisabledALPN tests the behaviour of a gRPC client when connecting to
// a server that doesn't support ALPN.
easwars marked this conversation as resolved.
Show resolved Hide resolved
func (s) TestTLS_DisabledALPN(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

initialVal := envconfig.EnforceALPNEnabled
defer func() {
envconfig.EnforceALPNEnabled = initialVal
}()

// Start a non gRPC TLS server.
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
}
listner, err := tls.Listen("tcp", ":0", config)
easwars marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
t.Fatalf("Error starting TLS server: %v", err)
}
defer listner.Close()

tests := []struct {
name string
alpnEnforced bool
wantErrMatchPattern string
wantErrNonMatchPattern string
}{
{
name: "enforced",
alpnEnforced: true,
wantErrMatchPattern: "transport: .*missing selected ALPN property",
easwars marked this conversation as resolved.
Show resolved Hide resolved
},
{
name: "not_enforced",
wantErrNonMatchPattern: "transport:",
easwars marked this conversation as resolved.
Show resolved Hide resolved
},
{
name: "default_value",
wantErrNonMatchPattern: "transport:",
alpnEnforced: initialVal,
},
easwars marked this conversation as resolved.
Show resolved Hide resolved
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
envconfig.EnforceALPNEnabled = tc.alpnEnforced
// Listen to one TCP connection request.
go func() {
easwars marked this conversation as resolved.
Show resolved Hide resolved
conn, err := listner.Accept()
if err != nil {
t.Errorf("tls.Accept failed err = %v", err)
} else {
_, _ = conn.Write([]byte("Hello, World!"))
easwars marked this conversation as resolved.
Show resolved Hide resolved
_ = conn.Close()
}
}()

clientCreds := credentials.NewTLS(&tls.Config{
ServerName: serverName,
RootCAs: certPool,
})

cc, err := grpc.NewClient("dns:"+listner.Addr().String(), grpc.WithTransportCredentials(clientCreds))
if err != nil {
t.Fatalf("grpc.NewClient error: %v", err)
}
defer cc.Close()
client := testgrpc.NewTestServiceClient(cc)
_, rpcErr := client.EmptyCall(ctx, &testpb.Empty{})

if gotCode := status.Code(rpcErr); gotCode != codes.Unavailable {
t.Errorf("EmptyCall returned unexpected code: got=%v, want=%v", gotCode, codes.Unavailable)
}

matchPat, err := regexp.Compile(tc.wantErrMatchPattern)
if err != nil {
t.Fatalf("Error message match pattern %q is invalid due to error: %v", tc.wantErrMatchPattern, err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be changed to got before want? Similar for line 328

https://google.github.io/styleguide/go/decisions#got-before-want

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regex pattern is expected to always be valid, the want is implicitly nil. There is not want in this message. If this line executes, it means that the test case is invalid.

}

if tc.wantErrMatchPattern != "" && !matchPat.MatchString(status.Convert(rpcErr).Message()) {
t.Errorf("EmptyCall err = %v; want pattern match %q", rpcErr, matchPat)
}

nonMatchPat, err := regexp.Compile(tc.wantErrNonMatchPattern)
if err != nil {
t.Fatalf("Error message non-match pattern %q is invalid due to error: %v", tc.wantErrNonMatchPattern, err)
}

if tc.wantErrNonMatchPattern != "" && nonMatchPat.MatchString(status.Convert(rpcErr).Message()) {
t.Errorf("EmptyCall err = %v; want pattern missing %q", rpcErr, nonMatchPat)
}
easwars marked this conversation as resolved.
Show resolved Hide resolved
})
}
}
6 changes: 6 additions & 0 deletions internal/envconfig/envconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ var (
// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
// handshakes that can be performed.
ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
// EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
// should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
// option is present for backward compatibility. This option may be overridden
// by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
// or "false".
EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", false)
)

func boolFromEnv(envVar string, def bool) bool {
Expand Down