Skip to content

Commit

Permalink
Merge branch 'main' into refactor/jsoniter
Browse files Browse the repository at this point in the history
  • Loading branch information
ankitsridhar16 committed Jun 1, 2023
2 parents 1458250 + 12e3c25 commit 4d2b9a8
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 48 deletions.
4 changes: 4 additions & 0 deletions driver/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ func TestDriverConfig(t *testing.T) {
{name: "https_localhost_with_port_token", url: "https://localhost:555?token=tkn1", cfg: &config.Driver{URL: "localhost:555", Protocol: HTTP, TLS: cTLS, Token: "tkn1"}},
{name: "https_ip_token", url: "https://127.0.0.1?token=tkn1", cfg: &config.Driver{URL: "127.0.0.1", Protocol: HTTP, TLS: cTLS, Token: "tkn1"}},
{name: "https_ip_with_port_token", url: "https://127.0.0.1:777?token=tkn1", cfg: &config.Driver{URL: "127.0.0.1:777", Protocol: HTTP, TLS: cTLS, Token: "tkn1"}},
{name: "unix_socket", url: "/var/lib/tigris/unix.sock", cfg: &config.Driver{URL: "/var/lib/tigris/unix.sock", Protocol: DefaultProtocol}},
{name: "unix_socket", url: "/var/lib/tigris/unix.sock", cfg: &config.Driver{URL: "/var/lib/tigris/unix.sock", Protocol: DefaultProtocol}},
{name: "unix_socket_scheme", url: "unix://localhost:/var/lib/tigris/unix.sock", cfg: &config.Driver{URL: "/var/lib/tigris/unix.sock", Protocol: DefaultProtocol}},
{name: "unix_socket_relative", url: "./tigris/unix.sock", cfg: &config.Driver{URL: "./tigris/unix.sock", Protocol: DefaultProtocol}},
}

for _, v := range cases {
Expand Down
26 changes: 18 additions & 8 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ type Database interface {
options ...*CreateCollectionOptions) error

// CreateOrUpdateCollections creates batch of collections
CreateOrUpdateCollections(ctx context.Context, schema []Schema, options ...*CreateCollectionOptions) (*CreateOrUpdateCollectionsResponse, error)
CreateOrUpdateCollections(ctx context.Context, schema []Schema,
options ...*CreateCollectionOptions) (*CreateOrUpdateCollectionsResponse, error)

// DropCollection deletes the collection and all documents it contains.
DropCollection(ctx context.Context, collection string, options ...*CollectionOptions) error
Expand Down Expand Up @@ -170,7 +171,8 @@ type driver struct {
cfg *config.Driver
}

func (c *driver) CreateProject(ctx context.Context, project string, options ...*CreateProjectOptions) (*CreateProjectResponse, error) {
func (c *driver) CreateProject(ctx context.Context, project string, options ...*CreateProjectOptions,
) (*CreateProjectResponse, error) {
opts, err := validateOptionsParam(options, &CreateProjectOptions{})
if err != nil {
return nil, err
Expand All @@ -179,7 +181,8 @@ func (c *driver) CreateProject(ctx context.Context, project string, options ...*
return c.createProjectWithOptions(ctx, project, opts.(*CreateProjectOptions))
}

func (c *driver) DescribeDatabase(ctx context.Context, project string, options ...*DescribeProjectOptions) (*DescribeDatabaseResponse, error) {
func (c *driver) DescribeDatabase(ctx context.Context, project string, options ...*DescribeProjectOptions,
) (*DescribeDatabaseResponse, error) {
opts, err := validateOptionsParam(options, &DescribeProjectOptions{})
if err != nil {
return nil, err
Expand All @@ -188,7 +191,8 @@ func (c *driver) DescribeDatabase(ctx context.Context, project string, options .
return c.describeProjectWithOptions(ctx, project, opts.(*DescribeProjectOptions))
}

func (c *driver) DeleteProject(ctx context.Context, project string, options ...*DeleteProjectOptions) (*DeleteProjectResponse, error) {
func (c *driver) DeleteProject(ctx context.Context, project string, options ...*DeleteProjectOptions,
) (*DeleteProjectResponse, error) {
opts, err := validateOptionsParam(options, &DeleteProjectOptions{})
if err != nil {
return nil, err
Expand Down Expand Up @@ -487,7 +491,7 @@ func initConfig(lCfg *config.Driver) (*config.Driver, error) {
sURL := cfg.URL

noScheme := !strings.Contains(sURL, "://")
if noScheme {
if noScheme && sURL[0] != '/' && sURL[0] != '.' {
if DefaultProtocol == "" {
sURL = strings.ToLower(GRPC) + "://" + sURL
} else {
Expand All @@ -500,7 +504,9 @@ func initConfig(lCfg *config.Driver) (*config.Driver, error) {
return nil, err
}

if noScheme {
unix := u.Scheme == "unix" || isUnixSock(cfg.URL)

if noScheme || unix {
u.Scheme = ""
}

Expand All @@ -513,7 +519,11 @@ func initConfig(lCfg *config.Driver) (*config.Driver, error) {
initSecrets(u, &cfg)

// Retain only host:port for connection
cfg.URL = u.Host
if !unix {
cfg.URL = u.Host
} else {
cfg.URL = u.Path
}

if cfg.TLS == nil && (!cfg.SkipLocalTLS || !localURL(cfg.URL)) && (cfg.ClientID != "" || cfg.ClientSecret != "" ||
cfg.Token != "" || u.Scheme == "https" || sec) {
Expand Down Expand Up @@ -601,7 +611,7 @@ func (c *driver) Close() error {
}

func localURL(url string) bool {
return strings.HasPrefix(url, "localhost:") ||
return isUnixSock(url) || strings.HasPrefix(url, "localhost:") ||
strings.HasPrefix(url, "127.0.0.1:") ||
strings.HasPrefix(url, "http://localhost:") ||
strings.HasPrefix(url, "http://127.0.0.1:") ||
Expand Down
48 changes: 22 additions & 26 deletions driver/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
jsoniter "github.com/json-iterator/go"
"io"
"net"
"strings"
"unsafe"

Expand All @@ -30,6 +31,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/credentials/local"
"google.golang.org/grpc/credentials/oauth"
meta "google.golang.org/grpc/metadata"
)
Expand Down Expand Up @@ -67,27 +69,13 @@ func GRPCError(err error) error {
return &Error{api.FromStatusError(err)}
}

// this is a hack to skip TLS on local connection
// when the token is set.
type localPerRPCCred struct {
parent credentials.PerRPCCredentials
}

func (l *localPerRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return l.parent.GetRequestMetadata(ctx, uri...)
}

func (*localPerRPCCred) RequireTransportSecurity() bool {
return false
}

// newGRPCClient return Driver interface implementation using GRPC transport protocol.
func newGRPCClient(ctx context.Context, config *config.Driver) (driverWithOptions, Management, Observability, error) {
if !strings.Contains(config.URL, ":") {
config.URL = fmt.Sprintf("%s:%d", config.URL, DefaultGRPCPort)
func newGRPCClient(ctx context.Context, cfg *config.Driver) (driverWithOptions, Management, Observability, error) {
if !strings.Contains(cfg.URL, ":") && !isUnixSock(cfg.URL) {
cfg.URL = fmt.Sprintf("%s:%d", cfg.URL, DefaultGRPCPort)
}

tokenSource, _, _ := configAuth(config)
tokenSource, _, _ := configAuth(cfg)

opts := []grpc.DialOption{
grpc.FailOnNonTempDialError(true),
Expand All @@ -100,25 +88,32 @@ func newGRPCClient(ctx context.Context, config *config.Driver) (driverWithOption
),
}

if isUnixSock(cfg.URL) {
dialer := func(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", cfg.URL)
}

opts = append(opts, grpc.WithContextDialer(dialer))
}

if tokenSource != nil {
var perRPCCreds credentials.PerRPCCredentials = oauth.TokenSource{TokenSource: tokenSource}

if config.SkipLocalTLS && localURL(config.URL) {
perRPCCreds = &localPerRPCCred{oauth.TokenSource{TokenSource: tokenSource}}
}

opts = append(opts, grpc.WithPerRPCCredentials(perRPCCreds))
}

transportCreds := insecure.NewCredentials()

if (!config.SkipLocalTLS || !localURL(config.URL)) && (config.TLS != nil || tokenSource != nil) {
transportCreds = credentials.NewTLS(config.TLS)
if cfg.SkipLocalTLS && localURL(cfg.URL) {
transportCreds = local.NewCredentials()
} else if cfg.TLS != nil || tokenSource != nil {
transportCreds = credentials.NewTLS(cfg.TLS)
}

opts = append(opts, grpc.WithTransportCredentials(transportCreds))

conn, err := grpc.DialContext(ctx, config.URL, opts...)
conn, err := grpc.DialContext(ctx, cfg.URL, opts...)
if err != nil {
return nil, nil, nil, GRPCError(err)
}
Expand All @@ -131,7 +126,7 @@ func newGRPCClient(ctx context.Context, config *config.Driver) (driverWithOption
o11y: api.NewObservabilityClient(conn),
health: api.NewHealthAPIClient(conn),
search: api.NewSearchClient(conn),
cfg: config,
cfg: cfg,
}

return drv, drv, drv, nil
Expand All @@ -141,6 +136,7 @@ func (c *grpcDriver) Close() error {
if c.conn == nil {
return nil
}

return GRPCError(c.conn.Close())
}

Expand Down
37 changes: 23 additions & 14 deletions driver/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,50 +210,59 @@ func setHTTPTxCtx(ctx context.Context, txCtx *api.TransactionCtx, cookies []*htt
}

// newHTTPClient return Driver interface implementation using HTTP transport protocol.
func newHTTPClient(_ context.Context, config *config.Driver) (driverWithOptions, Management, Observability, error) {
if !strings.Contains(config.URL, ":") {
if config.TLS != nil {
config.URL = fmt.Sprintf("%s:%d", config.URL, DefaultHTTPSPort)
func newHTTPClient(_ context.Context, cfg *config.Driver) (driverWithOptions, Management, Observability, error) {
u := cfg.URL

if isUnixSock(u) {
u = "localhost"
} else if !strings.Contains(u, ":") {
if cfg.TLS != nil {
u = fmt.Sprintf("%s:%d", u, DefaultHTTPSPort)
} else {
config.URL = fmt.Sprintf("%s:%d", config.URL, DefaultHTTPPort)
u = fmt.Sprintf("%s:%d", u, DefaultHTTPPort)
}
}

if config.TLS != nil {
config.URL = "https://" + config.URL
if cfg.TLS != nil {
u = "https://" + u
} else {
config.URL = "http://" + config.URL
u = "http://" + u
}

_, httpClient, tokenURL := configAuth(config)
_, httpClient, tokenURL := configAuth(cfg)

if httpClient == nil {
httpClient = &http.Client{Transport: &http.Transport{TLSClientConfig: config.TLS}}
httpClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: cfg.TLS,
DialContext: getUnixHTTPDialer(cfg.URL),
},
}
}

hf := func(ctx context.Context, req *http.Request) error {
if err := setHeaders(ctx, req); err != nil {
return err
}

if config.SkipSchemaValidation {
if cfg.SkipSchemaValidation {
req.Header[api.HeaderSchemaSignOff] = []string{"true"}
}

if config.DisableSearch {
if cfg.DisableSearch {
req.Header[api.HeaderDisableSearch] = []string{"true"}
}

return nil
}

c, err := apiHTTP.NewClientWithResponses(config.URL, apiHTTP.WithHTTPClient(httpClient),
c, err := apiHTTP.NewClientWithResponses(u, apiHTTP.WithHTTPClient(httpClient),
apiHTTP.WithRequestEditorFn(hf))
if err != nil {
return nil, nil, nil, err
}

drv := &httpDriver{api: c, tokenURL: tokenURL, cfg: config}
drv := &httpDriver{api: c, tokenURL: tokenURL, cfg: cfg}

return drv, drv, drv, nil
}
Expand Down
18 changes: 18 additions & 0 deletions driver/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package driver
import (
"context"
"fmt"
"net"
"net/http"
"strings"

Expand Down Expand Up @@ -85,10 +86,27 @@ type txWithOptions interface {
Rollback(ctx context.Context) error
}

func isUnixSock(url string) bool {
return len(url) > 0 && (url[0] == '/' || url[0] == '.')
}

func getUnixHTTPDialer(url string) func(_ context.Context, _, _ string) (net.Conn, error) {
var dialer func(_ context.Context, _, _ string) (net.Conn, error)

if url[0] == '/' || url[0] == '.' {
dialer = func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", url)
}
}

return dialer
}

// func configAuth(config *config.Driver) (*clientcredentials.Config, context.Context) {.
func configAuth(cfg *config.Driver) (oauth2.TokenSource, *http.Client, string) {
tr := &http.Transport{
TLSClientConfig: cfg.TLS,
DialContext: getUnixHTTPDialer(cfg.URL),
}

tokenURL := cfg.URL + "/v1/auth/token"
Expand Down

0 comments on commit 4d2b9a8

Please sign in to comment.