Skip to content

Commit

Permalink
Refactor in preparation for specifying cluster listeners
Browse files Browse the repository at this point in the history
  • Loading branch information
jefferai committed Aug 12, 2016
1 parent b978c5c commit 1c87bc7
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 49 deletions.
20 changes: 19 additions & 1 deletion command/server.go
Expand Up @@ -334,6 +334,8 @@ func (c *ServerCommand) Run(args []string) int {
}
}

clusterAddrs := []string{}

// Initialize the listeners
lns := make([]net.Listener, 0, len(config.Listeners))
for i, lnConfig := range config.Listeners {
Expand Down Expand Up @@ -364,6 +366,22 @@ func (c *ServerCommand) Run(args []string) int {
relSlice = append(relSlice, reloadFunc)
c.ReloadFuncs["listener|"+lnConfig.Type] = relSlice
}

if lnConfig.Type == "tcp" {
tcpAddr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
c.Ui.Error("Failed to parse tcp listener")
return 1
}
ipStr := tcpAddr.IP.String()
if len(tcpAddr.IP) == net.IPv6len {
ipStr = fmt.Sprintf("[%s]", ipStr)
}
clusterAddrs = append(clusterAddrs, fmt.Sprintf("%s:%d", ipStr, tcpAddr.Port+1))
}
}
if !disableClustering {
c.logger.Printf("[TRACE] cluster listeners will be started on %v", clusterAddrs)
}

// Make sure we close all listeners from this point on
Expand Down Expand Up @@ -428,7 +446,7 @@ func (c *ServerCommand) Run(args []string) int {

// This needs to happen before we first unseal, so before we trigger dev
// mode if it's set
core.SetClusterListenerSetupFunc(vault.WrapListenersForClustering(lns, handler, c.logger))
core.SetClusterListenerSetupFunc(vault.WrapListenersForClustering(clusterAddrs, handler, c.logger))

// If we're in dev mode, then initialize the core
if dev {
Expand Down
23 changes: 4 additions & 19 deletions vault/cluster.go
Expand Up @@ -537,7 +537,7 @@ func (c *Core) ForwardRequest(req *http.Request) (*http.Response, error) {
// handler, creates a new handler that handles forwarded requests, and returns
// the cluster setup function that creates the new listners and assigns to the
// new handler
func WrapListenersForClustering(lns []net.Listener, handler http.Handler, logger *log.Logger) func() ([]net.Listener, http.Handler, error) {
func WrapListenersForClustering(addrs []string, handler http.Handler, logger *log.Logger) func() ([]net.Listener, http.Handler, error) {
// This mux handles cluster functions (right now, only forwarded requests)
mux := http.NewServeMux()
mux.HandleFunc("/cluster/forwarded-request", func(w http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -571,25 +571,10 @@ func WrapListenersForClustering(lns []net.Listener, handler http.Handler, logger
})

return func() ([]net.Listener, http.Handler, error) {
ret := make([]net.Listener, 0, len(lns))
ret := make([]net.Listener, 0, len(addrs))
// Loop over the existing listeners and start listeners on appropriate ports
for _, ln := range lns {
tcpAddr, ok := ln.Addr().(*net.TCPAddr)
if !ok {
if logger != nil {
logger.Printf("[TRACE] http/WrapClusterListener: %s not a candidate for cluster request handling", ln.Addr().String())
}
continue
}
if logger != nil {
logger.Printf("[TRACE] http/WrapClusterListener: %s is a candidate for cluster request handling at addr %s and port %d", tcpAddr.String(), tcpAddr.IP.String(), tcpAddr.Port+1)
}

ipStr := tcpAddr.IP.String()
if len(tcpAddr.IP) == net.IPv6len {
ipStr = fmt.Sprintf("[%s]", ipStr)
}
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", ipStr, tcpAddr.Port+1))
for _, addr := range addrs {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, nil, err
}
Expand Down
11 changes: 1 addition & 10 deletions vault/cluster_test.go
Expand Up @@ -9,7 +9,6 @@ import (
"net"
"net/http"
"os"
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -118,18 +117,10 @@ func TestClusterHAFetching(t *testing.T) {
}

// Make sure the certificate meets expectations
cert, err := x509.ParseCertificate(cluster.Certificate)
_, err = x509.ParseCertificate(cluster.Certificate)
if err != nil {
t.Fatal("error parsing local cluster certificate: %v", err)
}

// Make sure the cert pool is as expected
if len(c.localClusterCertPool.Subjects()) != 1 {
t.Fatal("unexpected local cluster cert pool length")
}
if !reflect.DeepEqual(cert.RawSubject, c.localClusterCertPool.Subjects()[0]) {
t.Fatal("cert pool subject does not match expected")
}
}

func TestCluster_ListenForRequests(t *testing.T) {
Expand Down
31 changes: 12 additions & 19 deletions vault/testing.go
Expand Up @@ -640,29 +640,22 @@ func TestCluster(t *testing.T, handlers []http.Handler, base *CoreConfig, unseal
//
// Clustering setup
//
c1SetupFunc := func() []net.Listener {
ret := make([]net.Listener, len(c1lns))
for i, ln := range c1lns {
ret[i] = ln.Listener
clusterAddrGen := func(lns []*TestListener) []string {
ret := make([]string, len(lns))
for i, ln := range lns {
curAddr := ln.Address
ipStr := curAddr.IP.String()
if len(curAddr.IP) == net.IPv6len {
ipStr = fmt.Sprintf("[%s]", ipStr)
}
ret[i] = fmt.Sprintf("%s:%d", ipStr, curAddr.Port+1)
}
return ret
}
c2.SetClusterListenerSetupFunc(WrapListenersForClustering(func() []net.Listener {
ret := make([]net.Listener, len(c2lns))
for i, ln := range c2lns {
ret[i] = ln.Listener
}
return ret
}(), handlers[1], logger))
c3.SetClusterListenerSetupFunc(WrapListenersForClustering(func() []net.Listener {
ret := make([]net.Listener, len(c3lns))
for i, ln := range c3lns {
ret[i] = ln.Listener
}
return ret
}(), handlers[2], logger))

key, root := TestCoreInitClusterListenerSetup(t, c1, WrapListenersForClustering(c1SetupFunc(), handlers[0], logger))
c2.SetClusterListenerSetupFunc(WrapListenersForClustering(clusterAddrGen(c2lns), handlers[1], logger))
c3.SetClusterListenerSetupFunc(WrapListenersForClustering(clusterAddrGen(c3lns), handlers[2], logger))
key, root := TestCoreInitClusterListenerSetup(t, c1, WrapListenersForClustering(clusterAddrGen(c1lns), handlers[0], logger))
if _, err := c1.Unseal(TestKeyCopy(key)); err != nil {
t.Fatalf("unseal err: %s", err)
}
Expand Down

0 comments on commit 1c87bc7

Please sign in to comment.