Skip to content

Commit

Permalink
feat: Allow to skip TLS when connecting to local instance
Browse files Browse the repository at this point in the history
  • Loading branch information
efirs committed May 24, 2023
1 parent 0262afd commit 583a482
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 5 deletions.
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ type Driver struct {

SkipSchemaValidation bool `json:"skip_schema_validation,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
SkipLocalTLS bool `json:"skip_local_tls"`
}
31 changes: 31 additions & 0 deletions driver/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tigrisdata/tigris-client-go/config"
)

Expand Down Expand Up @@ -212,3 +213,33 @@ func TestDriverConfigProto(t *testing.T) {
})
}
}

func TestLocalURL(t *testing.T) {
// Test cases with local URLs that should return true
localURLs := []string{
"localhost:8080",
"127.0.0.1:8000",
"http://localhost:3000",
"http://127.0.0.1:5000",
"[::1]:8080",
"http://[::1]:8000",
}

for _, url := range localURLs {
require.True(t, localURL(url))
}

// Test cases with non-local URLs that should return false
nonLocalURLs := []string{
"example.com",
"http://example.com",
"www.google.com",
"http://www.google.com",
"127.0.0.1.5",
"localhost123",
}

for _, url := range nonLocalURLs {
require.False(t, localURL(url))
}
}
12 changes: 11 additions & 1 deletion driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,8 @@ func initConfig(lCfg *config.Driver) (*config.Driver, error) {
// Retain only host:port for connection
cfg.URL = u.Host

if cfg.TLS == nil && (cfg.ClientID != "" || cfg.ClientSecret != "" || cfg.Token != "" || u.Scheme == "https" || sec) {
if cfg.TLS == nil && (!cfg.SkipLocalTLS || !localURL(cfg.URL)) && (cfg.ClientID != "" || cfg.ClientSecret != "" ||
cfg.Token != "" || u.Scheme == "https" || sec) {
cfg.TLS = &tls.Config{MinVersion: tls.VersionTLS12}
}

Expand Down Expand Up @@ -594,3 +595,12 @@ func (c *driver) Close() error {

return c.driverWithOptions.Close()
}

func localURL(url string) bool {
return strings.HasPrefix(url, "localhost:") ||
strings.HasPrefix(url, "127.0.0.1:") ||
strings.HasPrefix(url, "http://localhost:") ||
strings.HasPrefix(url, "http://127.0.0.1:") ||
strings.HasPrefix(url, "[::1]") ||
strings.HasPrefix(url, "http://[::1]:")
}
6 changes: 3 additions & 3 deletions driver/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ func newGRPCClient(ctx context.Context, config *config.Driver) (driverWithOption
),
}

if config.TLS != nil || tokenSource != nil {
if (config.SkipLocalTLS && localURL(config.URL)) || (config.TLS == nil && tokenSource == nil) {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(config.TLS)))

if tokenSource != nil {
opts = append(opts, grpc.WithPerRPCCredentials(oauth.TokenSource{TokenSource: tokenSource}))
}
} else {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}

conn, err := grpc.DialContext(ctx, config.URL, opts...)
Expand Down
3 changes: 2 additions & 1 deletion driver/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ func newHTTPClient(_ context.Context, config *config.Driver) (driverWithOptions,
return nil
}

c, err := apiHTTP.NewClientWithResponses(config.URL, apiHTTP.WithHTTPClient(httpClient), apiHTTP.WithRequestEditorFn(hf))
c, err := apiHTTP.NewClientWithResponses(config.URL, apiHTTP.WithHTTPClient(httpClient),
apiHTTP.WithRequestEditorFn(hf))
if err != nil {
return nil, nil, nil, err
}
Expand Down

0 comments on commit 583a482

Please sign in to comment.