diff --git a/physical/consul.go b/physical/consul.go index 6835b500fd376..949db16178d2d 100644 --- a/physical/consul.go +++ b/physical/consul.go @@ -58,7 +58,7 @@ type ConsulBackend struct { sealedCheck *api.AgentCheckRegistration registrationLock int64 advertiseHost string - advertisePort int + advertisePort int64 consulClientConf *api.Config serviceName string running bool @@ -244,14 +244,6 @@ func (c *ConsulBackend) AdvertiseSealed(sealed bool) error { return nil } -func (c *ConsulBackend) setAdvertiseAddr(addr string) (err error) { - c.advertiseHost, c.advertisePort, err = c.parseAdvertiseAddr(addr) - if err != nil { - return err - } - return nil -} - func (c *ConsulBackend) RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string) (err error) { c.serviceLock.Lock() defer c.serviceLock.Unlock() @@ -270,7 +262,7 @@ func (c *ConsulBackend) RunServiceDiscovery(shutdownCh ShutdownChannel, advertis ID: serviceID, Name: c.serviceName, Tags: serviceTags(c.active), - Port: c.advertisePort, + Port: int(c.advertisePort), Address: c.advertiseHost, EnableTagOverride: false, } @@ -369,36 +361,36 @@ func (c *ConsulBackend) serviceID() string { return fmt.Sprintf("%s:%s:%d", c.serviceName, c.advertiseHost, c.advertisePort) } -func (c *ConsulBackend) parseAdvertiseAddr(addr string) (host string, port int, err error) { +func (c *ConsulBackend) setAdvertiseAddr(addr string) (err error) { if addr == "" { - return "", -1, fmt.Errorf("advertise address must not be empty") + return fmt.Errorf("advertise address must not be empty") } url, err := url.Parse(addr) if err != nil { - return "", -2, errwrap.Wrapf(fmt.Sprintf(`failed to parse advertise URL "%v": {{err}}`, addr), err) + return errwrap.Wrapf(fmt.Sprintf(`failed to parse advertise URL "%v": {{err}}`, addr), err) } var portStr string - host, portStr, err = net.SplitHostPort(url.Host) + c.advertiseHost, portStr, err = net.SplitHostPort(url.Host) if err != nil { if url.Scheme == "http" { portStr = "80" } else if url.Scheme == "https" { portStr = "443" } else if url.Scheme == "unix" { - portStr = "0" - host = url.Path + portStr = "-1" + c.advertiseHost = url.Path } else { - return "", -3, errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err) + return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err) } } - portNum, err := strconv.ParseInt(portStr, 10, 0) - if err != nil || portNum < 0 || portNum > 65535 { - return "", -4, errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err) + c.advertisePort, err = strconv.ParseInt(portStr, 10, 0) + if err != nil || c.advertisePort < -1 || c.advertisePort > 65535 { + return errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err) } - return host, int(portNum), nil + return nil } func setupTLSConfig(conf map[string]string) (*tls.Config, error) { diff --git a/physical/consul_test.go b/physical/consul_test.go index 364c238bbd0cf..33e1821aee5f4 100644 --- a/physical/consul_test.go +++ b/physical/consul_test.go @@ -221,11 +221,11 @@ func TestConsul_serviceTags(t *testing.T) { } } -func TestConsul_parseAdvertiseAddr(t *testing.T) { +func TestConsul_setAdvertiseAddr(t *testing.T) { tests := []struct { addr string host string - port int + port int64 pass bool }{ { @@ -249,7 +249,7 @@ func TestConsul_parseAdvertiseAddr(t *testing.T) { { addr: "unix:///tmp/.vault.addr.sock", host: "/tmp/.vault.addr.sock", - port: 0, + port: -1, pass: true, }, { @@ -263,7 +263,7 @@ func TestConsul_parseAdvertiseAddr(t *testing.T) { } for _, test := range tests { c := testConsulBackend(t) - host, port, err := c.parseAdvertiseAddr(test.addr) + err := c.setAdvertiseAddr(test.addr) if test.pass { if err != nil { t.Fatalf("bad: %v", err) @@ -276,12 +276,12 @@ func TestConsul_parseAdvertiseAddr(t *testing.T) { } } - if host != test.host { - t.Fatalf("bad: %v != %v", host, test.host) + if c.advertiseHost != test.host { + t.Fatalf("bad: %v != %v", c.advertiseHost, test.host) } - if port != test.port { - t.Fatalf("bad: %v != %v", port, test.port) + if c.advertisePort != test.port { + t.Fatalf("bad: %v != %v", c.advertisePort, test.port) } } }