Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

credentials/tls: reject connections with ALPN disabled #7184

Merged
merged 14 commits into from
May 21, 2024
34 changes: 33 additions & 1 deletion credentials/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ import (
"net/url"
"os"

"google.golang.org/grpc/grpclog"
credinternal "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/envconfig"
)

var logger = grpclog.Component("credentials")

// TLSInfo contains the auth information for a TLS authenticated connection.
// It implements the AuthInfo interface.
type TLSInfo struct {
Expand Down Expand Up @@ -112,6 +116,22 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
conn.Close()
return nil, nil, ctx.Err()
}

// The negotiated protocol can be either of the following:
// 1. h2: When the server supports ALPN. Only HTTP/2 can be negotiated since
// it is the only protocol advertised by the client during the handshake.
// The tls library ensures that the server chooses a protocol advertised
// by the client.
// 2. "" (empty string): If the server doesn't support ALPN. ALPN is a requirement
// for using HTTP/2 over TLS. We can terminate the connection immediately.
np := conn.ConnectionState().NegotiatedProtocol
if np == "" {
easwars marked this conversation as resolved.
Show resolved Hide resolved
if envconfig.EnforceALPNEnabled {
conn.Close()
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
}
logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName)
}
tlsInfo := TLSInfo{
State: conn.ConnectionState(),
CommonAuthInfo: CommonAuthInfo{
Expand All @@ -131,8 +151,20 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
conn.Close()
return nil, nil, err
}
cs := conn.ConnectionState()
// The negotiated application protocol can be empty only if the client doesn't
// support ALPN. In such cases, we can close the connection since ALPN is required
// for using HTTP/2 over TLS.
if cs.NegotiatedProtocol == "" {
if envconfig.EnforceALPNEnabled {
conn.Close()
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
} else if logger.V(2) {
logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases")
}
easwars marked this conversation as resolved.
Show resolved Hide resolved
}
tlsInfo := TLSInfo{
State: conn.ConnectionState(),
State: cs,
CommonAuthInfo: CommonAuthInfo{
SecurityLevel: PrivacyAndIntegrity,
},
Expand Down
160 changes: 160 additions & 0 deletions credentials/tls_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"strings"
"testing"
Expand All @@ -31,6 +32,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -236,3 +238,161 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
t.Fatalf("EmptyCall err = %v; want <nil>", err)
}
}

// TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
// connecting to a server that doesn't support ALPN.
func (s) TestTLS_DisabledALPNClient(t *testing.T) {
initialVal := envconfig.EnforceALPNEnabled
defer func() {
envconfig.EnforceALPNEnabled = initialVal
}()

tests := []struct {
name string
alpnEnforced bool
wantErr bool
}{
{
name: "enforced",
alpnEnforced: true,
wantErr: true,
},
{
name: "not_enforced",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
envconfig.EnforceALPNEnabled = tc.alpnEnforced

listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
Certificates: []tls.Certificate{serverCert},
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
})
if err != nil {
t.Fatalf("Error starting TLS server: %v", err)
}

errCh := make(chan error)
easwars marked this conversation as resolved.
Show resolved Hide resolved

easwars marked this conversation as resolved.
Show resolved Hide resolved
go func() {
easwars marked this conversation as resolved.
Show resolved Hide resolved
conn, err := listener.Accept()
if err != nil {
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
} else {
// The first write to the TLS listener initiates the TLS handshake.
conn.Write([]byte("Hello, World!"))
easwars marked this conversation as resolved.
Show resolved Hide resolved
conn.Close()
}
close(errCh)
}()

serverAddr := listener.Addr().String()
conn, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
}
defer conn.Close()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

clientCfg := tls.Config{
ServerName: serverName,
RootCAs: certPool,
NextProtos: []string{"h2"},
}
_, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn)

if gotErr := (err != nil); gotErr != tc.wantErr {
t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
}

err = nil
select {
case err = <-errCh:
break
case <-ctx.Done():
err = ctx.Err()
}

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
easwars marked this conversation as resolved.
Show resolved Hide resolved
})
}
}

// TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
// accepting a request from a client that doesn't support ALPN.
func (s) TestTLS_DisabledALPNServer(t *testing.T) {
initialVal := envconfig.EnforceALPNEnabled
defer func() {
envconfig.EnforceALPNEnabled = initialVal
}()

tests := []struct {
name string
alpnEnforced bool
wantErr bool
}{
{
name: "enforced",
alpnEnforced: true,
wantErr: true,
},
{
name: "not_enforced",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
envconfig.EnforceALPNEnabled = tc.alpnEnforced

listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error starting server: %v", err)
}

errCh := make(chan error)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one needs a buffer of size 1 too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.


easwars marked this conversation as resolved.
Show resolved Hide resolved
go func() {
conn, err := listener.Accept()
if err != nil {
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
close(errCh)
return
}
defer conn.Close()
serverCfg := tls.Config{
Certificates: []tls.Certificate{serverCert},
NextProtos: []string{"h2"},
}
_, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn)
if gotErr := (err != nil); gotErr != tc.wantErr {
t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
}
close(errCh)
}()

serverAddr := listener.Addr().String()
clientCfg := &tls.Config{
Certificates: []tls.Certificate{serverCert},
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
RootCAs: certPool,
ServerName: serverName,
}
conn, err := tls.Dial("tcp", serverAddr, clientCfg)
if err != nil {
t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
}
defer conn.Close()

if err := <-errCh; err != nil {
t.Fatalf("Unexpected server error: %v", err)
}
easwars marked this conversation as resolved.
Show resolved Hide resolved
})
}
}
6 changes: 6 additions & 0 deletions internal/envconfig/envconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ var (
// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
// handshakes that can be performed.
ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
// EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
// should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
// option is present for backward compatibility. This option may be overridden
// by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
// or "false".
EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", false)
)

func boolFromEnv(envVar string, def bool) bool {
Expand Down