diff --git a/driver/driver.go b/driver/driver.go index 1bd93ec..ea5ee3e 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -480,6 +480,10 @@ func initConfig(lCfg *config.Driver) (*config.Driver, error) { cfg.URL = DefaultURL } + if os.Getenv(EnvSkipLocalTLS) != "" { + cfg.SkipLocalTLS = true + } + sURL := cfg.URL noScheme := !strings.Contains(sURL, "://") diff --git a/driver/grpc.go b/driver/grpc.go index f373a29..6f42a03 100644 --- a/driver/grpc.go +++ b/driver/grpc.go @@ -67,6 +67,20 @@ func GRPCError(err error) error { return &Error{api.FromStatusError(err)} } +// this is a hack to skip TLS on local connection +// when the token is set. +type localPerRPCCred struct { + parent credentials.PerRPCCredentials +} + +func (l *localPerRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return l.parent.GetRequestMetadata(ctx, uri...) +} + +func (*localPerRPCCred) RequireTransportSecurity() bool { + return false +} + // newGRPCClient return Driver interface implementation using GRPC transport protocol. func newGRPCClient(ctx context.Context, config *config.Driver) (driverWithOptions, Management, Observability, error) { if !strings.Contains(config.URL, ":") { @@ -86,16 +100,24 @@ func newGRPCClient(ctx context.Context, config *config.Driver) (driverWithOption ), } - if (config.SkipLocalTLS && localURL(config.URL)) || (config.TLS == nil && tokenSource == nil) { - opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) - } else { - opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(config.TLS))) + if tokenSource != nil { + var perRPCCreds credentials.PerRPCCredentials = oauth.TokenSource{TokenSource: tokenSource} - if tokenSource != nil { - opts = append(opts, grpc.WithPerRPCCredentials(oauth.TokenSource{TokenSource: tokenSource})) + if config.SkipLocalTLS && localURL(config.URL) { + perRPCCreds = &localPerRPCCred{oauth.TokenSource{TokenSource: tokenSource}} } + + opts = append(opts, grpc.WithPerRPCCredentials(perRPCCreds)) } + transportCreds := insecure.NewCredentials() + + if (!config.SkipLocalTLS || !localURL(config.URL)) && (config.TLS != nil || tokenSource != nil) { + transportCreds = credentials.NewTLS(config.TLS) + } + + opts = append(opts, grpc.WithTransportCredentials(transportCreds)) + conn, err := grpc.DialContext(ctx, config.URL, opts...) if err != nil { return nil, nil, nil, GRPCError(err) diff --git a/driver/types.go b/driver/types.go index fe8a912..80c5fd7 100644 --- a/driver/types.go +++ b/driver/types.go @@ -34,6 +34,7 @@ const ( EnvURI = "TIGRIS_URI" EnvProject = "TIGRIS_PROJECT" EnvDBBranch = "TIGRIS_DB_BRANCH" + EnvSkipLocalTLS = "TIGRIS_SKIP_LOCAL_TLS" ClientVersion = "v1.0.0" UserAgent = "tigris-client-go/" + ClientVersion