Skip to content

Commit

Permalink
feat: add support for unix sockets
Browse files Browse the repository at this point in the history
This is an adaptation from GoogleCloudPlatform/cloud-sql-proxy#1182
  • Loading branch information
enocom committed May 26, 2022
1 parent 7069117 commit c7cab02
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 38 deletions.
34 changes: 32 additions & 2 deletions cmd/root.go
Expand Up @@ -128,6 +128,8 @@ without having to manage any client SSL certificates.`,
"Address on which to bind AlloyDB instance listeners.")
cmd.PersistentFlags().IntVarP(&c.conf.Port, "port", "p", 5432,
"Initial port to use for listeners. Subsequent listeners increment from this value.")
cmd.PersistentFlags().StringVarP(&c.conf.UnixSocket, "unix-socket", "u", "",
`Enables Unix sockets for all listeners using the provided directory.`)

c.Command = cmd
return c
Expand All @@ -138,6 +140,15 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error {
if len(args) == 0 {
return newBadCommandError("missing instance uri (e.g., /projects/$PROJECTS/locations/$LOCTION/clusters/$CLUSTER/instances/$INSTANCES)")
}
userHasSet := func(f string) bool {
return cmd.PersistentFlags().Lookup(f).Changed
}
if userHasSet("address") && userHasSet("unix-socket") {
return newBadCommandError("cannot specify --unix-socket and --address together")
}
if userHasSet("port") && userHasSet("unix-socket") {
return newBadCommandError("cannot specify --unix-socket and --port together")
}
// First, validate global config.
if ip := net.ParseIP(conf.Addr); ip == nil {
return newBadCommandError(fmt.Sprintf("not a valid IP address: %q", conf.Addr))
Expand Down Expand Up @@ -171,7 +182,18 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error {
return newBadCommandError(fmt.Sprintf("could not parse query: %q", res[1]))
}

if a, ok := q["address"]; ok {
a, aok := q["address"]
p, pok := q["port"]
u, uok := q["unix-socket"]

if aok && uok {
return newBadCommandError("cannot specify both address and unix-socket query params")
}
if pok && uok {
return newBadCommandError("cannot specify both port and unix-socket query params")
}

if aok {
if len(a) != 1 {
return newBadCommandError(fmt.Sprintf("address query param should be only one value: %q", a))
}
Expand All @@ -184,7 +206,7 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error {
ic.Addr = a[0]
}

if p, ok := q["port"]; ok {
if pok {
if len(p) != 1 {
return newBadCommandError(fmt.Sprintf("port query param should be only one value: %q", a))
}
Expand All @@ -197,6 +219,14 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error {
}
ic.Port = pp
}

if uok {
if len(u) != 1 {
return newBadCommandError(fmt.Sprintf("unix query param should be only one value: %q", a))
}
ic.UnixSocket = u[0]

}
}
ics = append(ics, ic)
}
Expand Down
43 changes: 43 additions & 0 deletions cmd/root_test.go
Expand Up @@ -137,6 +137,29 @@ func TestNewCommandArguments(t *testing.T) {
CredentialsFile: "/path/to/file",
}),
},
{
desc: "using the unix socket flag",
args: []string{"--unix-socket", "/path/to/dir/", "/projects/proj/locations/region/clusters/clust/instances/inst"},
want: withDefaults(&proxy.Config{
UnixSocket: "/path/to/dir/",
}),
},
{
desc: "using the (short) unix socket flag",
args: []string{"-u", "/path/to/dir/", "/projects/proj/locations/region/clusters/clust/instances/inst"},
want: withDefaults(&proxy.Config{
UnixSocket: "/path/to/dir/",
}),
},
{
desc: "using the unix socket query param",
args: []string{"/projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/path/to/dir/"},
want: withDefaults(&proxy.Config{
Instances: []proxy.InstanceConnConfig{{
UnixSocket: "/path/to/dir/",
}},
}),
},
}

for _, tc := range tcs {
Expand Down Expand Up @@ -210,6 +233,26 @@ func TestNewCommandWithErrors(t *testing.T) {
"--token", "my-token",
"--credentials-file", "/path/to/file", "/projects/proj/locations/region/clusters/clust/instances/inst"},
},
{
desc: "when the unix socket query param contains multiple values",
args: []string{"/projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/one&unix-socket=/two"},
},
{
desc: "using the unix socket flag with addr",
args: []string{"-u", "/path/to/dir/", "-a", "127.0.0.1", "/projects/proj/locations/region/clusters/clust/instances/inst"},
},
{
desc: "using the unix socket flag with port",
args: []string{"-u", "/path/to/dir/", "-p", "5432", "/projects/proj/locations/region/clusters/clust/instances/inst"},
},
{
desc: "using the unix socket and addr query params",
args: []string{"/projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/path&address=127.0.0.1"},
},
{
desc: "using the unix socket and port query params",
args: []string{"/projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/path&port=5000"},
},
}

for _, tc := range tcs {
Expand Down
112 changes: 97 additions & 15 deletions internal/proxy/proxy.go
Expand Up @@ -19,6 +19,10 @@ import (
"fmt"
"io"
"net"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"

Expand All @@ -37,6 +41,10 @@ type InstanceConnConfig struct {
Addr string
// Port is the port on which to bind a listener for the instance.
Port int
// UnixSocket is the directory where a Unix socket will be created,
// connected to the Cloud SQL instance. If set, takes precedence over Addr
// and Port.
UnixSocket string
}

// Config contains all the configuration provided by the caller.
Expand All @@ -54,6 +62,10 @@ type Config struct {
// increments from this value.
Port int

// UnixSocket is the directory where Unix sockets will be created,
// connected to any Instances. If set, takes precedence over Addr and Port.
UnixSocket string

// Instances are configuration for individual instances. Instance
// configuration takes precedence over global configuration.
Instances []InstanceConnConfig
Expand Down Expand Up @@ -95,6 +107,28 @@ func (c *portConfig) nextPort() int {
return p
}

var (
// Instance URI is in the format:
// '/projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>'
// Additionally, we have to support legacy "domain-scoped" projects (e.g. "google.com:PROJECT")
instURIRegex = regexp.MustCompile("projects/([^:]+(:[^:]+)?)/locations/([^:]+)/clusters/([^:]+)/instances/([^:]+)")
)

// UnixSocketDir returns a shorted instance connection name to prevent exceeding
// the Unix socket length.
func UnixSocketDir(dir, inst string) (string, error) {
m := instURIRegex.FindSubmatch([]byte(inst))
if m == nil {
return "", fmt.Errorf("invalid instance name: %v", inst)
}
project := string(m[1])
region := string(m[3])
cluster := string(m[4])
name := string(m[5])
shortName := strings.Join([]string{project, region, cluster, name}, ".")
return filepath.Join(dir, shortName), nil
}

// Client represents the state of the current instantiation of the proxy.
type Client struct {
cmd *cobra.Command
Expand All @@ -106,31 +140,79 @@ type Client struct {

// NewClient completes the initial setup required to get the proxy to a "steady" state.
func NewClient(ctx context.Context, d alloydb.Dialer, cmd *cobra.Command, conf *Config) (*Client, error) {
var mnts []*socketMount
pc := newPortConfig(conf.Port)
var mnts []*socketMount
for _, inst := range conf.Instances {
m := &socketMount{inst: inst.Name}
a := conf.Addr
if inst.Addr != "" {
a = inst.Addr
}
var np int
switch {
case inst.Port != 0:
np = inst.Port
default: // use next increment from conf.Port
np = pc.nextPort()
var (
// network is one of "tcp" or "unix"
network string
// address is either a TCP host port, or a Unix socket
address string
)
// IF
// a global Unix socket directory is NOT set AND
// an instance-level Unix socket is NOT set
// (e.g., I didn't set a Unix socket globally or for this instance)
// OR
// an instance-level TCP address or port IS set
// (e.g., I'm overriding any global settings to use TCP for this
// instance)
// use a TCP listener.
// Otherwise, use a Unix socket.
if (conf.UnixSocket == "" && inst.UnixSocket == "") ||
(inst.Addr != "" || inst.Port != 0) {
network = "tcp"

a := conf.Addr
if inst.Addr != "" {
a = inst.Addr
}

var np int
switch {
case inst.Port != 0:
np = inst.Port
case conf.Port != 0:
np = pc.nextPort()
default:
np = pc.nextPort()
}

address = net.JoinHostPort(a, fmt.Sprint(np))
} else {
network = "unix"

dir := conf.UnixSocket
if dir == "" {
dir = inst.UnixSocket
}
ud, err := UnixSocketDir(dir, inst.Name)
if err != nil {
return nil, err
}
// Create the parent directory that will hold the socket.
if _, err := os.Stat(ud); err != nil {
if err = os.Mkdir(ud, 0777); err != nil {
return nil, err
}
}
// use the Postgres-specific socket name
address = filepath.Join(ud, ".s.PGSQL.5432")
}
addr, err := m.listen(ctx, "tcp", net.JoinHostPort(a, fmt.Sprint(np)))

m := &socketMount{inst: inst.Name}
addr, err := m.listen(ctx, network, address)
if err != nil {
for _, m := range mnts {
m.close()
}
return nil, fmt.Errorf("[%v] Unable to mount socket: %v", inst.Name, err)
}

cmd.Printf("[%s] Listening on %s\n", inst.Name, addr.String())
mnts = append(mnts, m)
}

return &Client{mnts: mnts, cmd: cmd, dialer: d}, nil
}

Expand Down Expand Up @@ -210,9 +292,9 @@ type socketMount struct {
}

// listen causes a socketMount to create a Listener at the specified network address.
func (s *socketMount) listen(ctx context.Context, network string, host string) (net.Addr, error) {
func (s *socketMount) listen(ctx context.Context, network string, address string) (net.Addr, error) {
lc := net.ListenConfig{KeepAlive: 30 * time.Second}
l, err := lc.Listen(ctx, network, host)
l, err := lc.Listen(ctx, network, address)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit c7cab02

Please sign in to comment.