Skip to content

Commit

Permalink
add health check for load balance servers (#857)
Browse files Browse the repository at this point in the history
* add health check for loadbalance servers

* add lock

* fix lock for race condition

* refactor lock to atomic.Value

* optmize usage of atomic.Value

* fix goroutine leak & other optimizations

* validate config & fix typo

* fix leak and close & use duration format

* fix test

* apply suggestions
  • Loading branch information
samanhappy committed Nov 28, 2022
1 parent 59cc1a2 commit b80a9cf
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 18 deletions.
11 changes: 11 additions & 0 deletions doc/reference/filters.md
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,7 @@ Rules to revise request header.
| policy | string | Load balance policy, valid values are `roundRobin`, `random`, `weightedRandom`, `ipHash` ,and `headerHash` | Yes |
| headerHashKey | string | When `policy` is `headerHash`, this option is the name of a header whose value is used for hash calculation | No |
| stickySession | [proxy.StickySession](#proxyStickySessionSpec) | Sticky session spec | No |
| healthCheck | [proxy.HealthCheck](#proxyHealthCheckSpec) | Health check spec, note that healthCheck is not needed if you are using service registry | No |

### proxy.StickySessionSpec

Expand All @@ -1132,6 +1133,16 @@ Rules to revise request header.
| lbCookieName | string | Name of the cookie generated by load balancer, its value will be used as the session identifier for stickiness in `DurationBased` and `ApplicationBased` mode, default is `EG_SESSION` | No |
| lbCookieExpire | string | Expire duration of the cookie generated by load balancer, its value will be used as the session expire time for stickiness in `DurationBased` and `ApplicationBased` mode, default is 2 hours | No |

### proxy.HealthCheckSpec

| Name | Type | Description | Required |
| ------------- | ------ | ----------------------------------------------------------------------------------------------------------- | -------- |
| interval | string | Interval duration for health check, default is 60s | Yes |
| path | string | Path URL for server health check | No |
| timeout | string | Timeout duration for health check, default is 3s | No |
| fails | int | Consecutive fails count for assert fail, default is 1 | No |
| passes | int | Consecutive passes count for assert pass , default is 1 | No |

### proxy.MemoryCacheSpec

| Name | Type | Description | Required |
Expand Down
16 changes: 14 additions & 2 deletions pkg/filters/proxy/basepool.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ func (sps *BaseServerPoolSpec) Validate() error {
return fmt.Errorf(msgFmt, serversGotWeight, len(sps.Servers))
}

if sps.ServiceName != "" && sps.LoadBalance.HealthCheck != nil {
return fmt.Errorf("can not open health check for service discovery")
}

return nil
}

Expand Down Expand Up @@ -112,7 +116,10 @@ func (bsp *BaseServerPool) Init(super *supervisor.Supervisor, name string, spec

// LoadBalancer returns the load balancer of the server pool.
func (bsp *BaseServerPool) LoadBalancer() LoadBalancer {
return bsp.loadBalancer.Load().(LoadBalancer)
if v := bsp.loadBalancer.Load(); v != nil {
return v.(LoadBalancer)
}
return nil
}

func (bsp *BaseServerPool) createLoadBalancer(spec *LoadBalanceSpec, servers []*Server) {
Expand All @@ -125,7 +132,9 @@ func (bsp *BaseServerPool) createLoadBalancer(spec *LoadBalanceSpec, servers []*
}

lb := NewLoadBalancer(spec, servers)
bsp.loadBalancer.Store(lb)
if old := bsp.loadBalancer.Swap(lb); old != nil {
old.(LoadBalancer).Close()
}
}

func (bsp *BaseServerPool) useService(spec *BaseServerPoolSpec, instances map[string]*serviceregistry.ServiceInstanceSpec) {
Expand Down Expand Up @@ -162,4 +171,7 @@ func (bsp *BaseServerPool) useService(spec *BaseServerPoolSpec, instances map[st
func (bsp *BaseServerPool) close() {
close(bsp.done)
bsp.wg.Wait()
if lb := bsp.LoadBalancer(); lb != nil {
lb.Close()
}
}
151 changes: 135 additions & 16 deletions pkg/filters/proxy/loadbalance.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,22 @@ const (
StickySessionDefaultLBCookieExpire = time.Hour * 2
// KeyLen is the key length used by HMAC.
KeyLen = 8
// HealthCheckDefaultInterval is the default interval for health check
HealthCheckDefaultInterval = time.Second * 60
// HealthCheckDefaultTimeout is the default timeout for health check
HealthCheckDefaultTimeout = time.Second * 3
// HealthCheckDefaultFailThreshold is the default fail threshold for health check
HealthCheckDefaultFailThreshold = 1
// HealthCheckDefaultPassThreshold is the default pass threshold for health check
HealthCheckDefaultPassThreshold = 1
)

// LoadBalancer is the interface of an HTTP load balancer.
type LoadBalancer interface {
ChooseServer(req *httpprot.Request) *Server
ReturnServer(server *Server, req *httpprot.Request, resp *httpprot.Response)
HealthyServers() []*Server
Close()
}

// StickySessionSpec is the spec for sticky session.
Expand All @@ -78,11 +88,26 @@ type StickySessionSpec struct {
LBCookieExpire string `json:"lbCookieExpire" jsonschema:"omitempty,format=duration"`
}

// HealthCheckSpec is the spec for health check.
type HealthCheckSpec struct {
// Interval is the interval duration for health check.
Interval string `json:"interval" jsonschema:"omitempty,format=duration"`
// Path is the health check path for server
Path string `json:"path" jsonschema:"omitempty"`
// Timeout is the timeout duration for health check, default is 3.
Timeout string `json:"timeout" jsonschema:"omitempty,format=duration"`
// Fails is the consecutive fails count for assert fail, default is 1.
Fails int `json:"fails" jsonschema:"omitempty,minimum=1"`
// Passes is the consecutive passes count for assert pass, default is 1.
Passes int `json:"passes" jsonschema:"omitempty,minimum=1"`
}

// LoadBalanceSpec is the spec to create a load balancer.
type LoadBalanceSpec struct {
Policy string `json:"policy" jsonschema:"omitempty,enum=,enum=roundRobin,enum=random,enum=weightedRandom,enum=ipHash,enum=headerHash"`
HeaderHashKey string `json:"headerHashKey" jsonschema:"omitempty"`
StickySession *StickySessionSpec `json:"stickySession" jsonschema:"omitempty"`
HealthCheck *HealthCheckSpec `json:"healthCheck" jsonschema:"omitempty"`
}

// NewLoadBalancer creates a load balancer for servers according to spec.
Expand Down Expand Up @@ -126,30 +151,117 @@ func (h hasher) Sum64(data []byte) uint64 {
type BaseLoadBalancer struct {
spec *LoadBalanceSpec
Servers []*Server
healthyServers atomic.Value
consistentHash *consistent.Consistent
cookieExpire time.Duration
done chan bool
probeClient *http.Client
probeInterval time.Duration
probeTimeout time.Duration
}

// HealthyServers return healthy servers
func (blb *BaseLoadBalancer) HealthyServers() []*Server {
return blb.healthyServers.Load().([]*Server)
}

// init initializes load balancer
func (blb *BaseLoadBalancer) init(spec *LoadBalanceSpec, servers []*Server) {
blb.spec = spec
blb.Servers = servers
blb.healthyServers.Store(servers)

if spec.StickySession == nil || len(servers) == 0 {
blb.initStickySession(spec.StickySession, blb.HealthyServers())
blb.initHealthCheck(spec.HealthCheck, servers)
}

// initStickySession initializes for sticky session
func (blb *BaseLoadBalancer) initStickySession(spec *StickySessionSpec, servers []*Server) {
if spec == nil || len(servers) == 0 {
return
}

switch spec.StickySession.Mode {
switch spec.Mode {
case StickySessionModeCookieConsistentHash:
blb.initConsistentHash()
case StickySessionModeDurationBased, StickySessionModeApplicationBased:
blb.configLBCookie()
}
}

// initHealthCheck initializes for health check
func (blb *BaseLoadBalancer) initHealthCheck(spec *HealthCheckSpec, servers []*Server) {
if spec == nil || len(servers) == 0 {
return
}

blb.probeInterval, _ = time.ParseDuration(spec.Interval)
if blb.probeInterval <= 0 {
blb.probeInterval = HealthCheckDefaultInterval
}
blb.probeTimeout, _ = time.ParseDuration(spec.Timeout)
if blb.probeTimeout <= 0 {
blb.probeTimeout = HealthCheckDefaultTimeout
}
if spec.Fails == 0 {
spec.Fails = HealthCheckDefaultFailThreshold
}
if spec.Passes == 0 {
spec.Passes = HealthCheckDefaultPassThreshold
}
blb.probeClient = &http.Client{Timeout: blb.probeTimeout}
ticker := time.NewTicker(blb.probeInterval)
blb.done = make(chan bool)
go func() {
for {
select {
case <-blb.done:
ticker.Stop()
return
case <-ticker.C:
blb.probeServers()
}
}
}()
}

// probeServers checks health status of servers
func (blb *BaseLoadBalancer) probeServers() {
statusChange := false
healthyServers := make([]*Server, 0, len(blb.Servers))
for _, svr := range blb.Servers {
pass := blb.probeHTTP(svr.URL)
healthy, change := svr.recordHealth(pass, blb.spec.HealthCheck.Passes, blb.spec.HealthCheck.Fails)
if change {
statusChange = true
}
if healthy {
healthyServers = append(healthyServers, svr)
}
}
if statusChange {
blb.healthyServers.Store(healthyServers)
// init consistent hash in sticky session when servers change
blb.initStickySession(blb.spec.StickySession, blb.HealthyServers())
}
}

// probeHTTP checks http url status
func (blb *BaseLoadBalancer) probeHTTP(url string) bool {
if blb.spec.HealthCheck.Path != "" {
url += blb.spec.HealthCheck.Path
}
res, err := blb.probeClient.Get(url)
if err != nil || res.StatusCode > 500 {
return false
}
return true
}

// initConsistentHash initializes for consistent hash mode
func (blb *BaseLoadBalancer) initConsistentHash() {
members := make([]consistent.Member, len(blb.Servers))
for i, s := range blb.Servers {
members := make([]consistent.Member, len(blb.HealthyServers()))
for i, s := range blb.HealthyServers() {
members[i] = hashMember{server: s}
}

Expand Down Expand Up @@ -219,7 +331,7 @@ func (blb *BaseLoadBalancer) chooseServerByLBCookie(req *httpprot.Request) *Serv

key := signed[:KeyLen]
macBytes := signed[KeyLen:]
for _, s := range blb.Servers {
for _, s := range blb.HealthyServers() {
mac := hmac.New(sha256.New, key)
mac.Write([]byte(s.ID()))
expected := mac.Sum(nil)
Expand Down Expand Up @@ -259,6 +371,13 @@ func (blb *BaseLoadBalancer) ReturnServer(server *Server, req *httpprot.Request,
}
}

// Close closes resources
func (blb *BaseLoadBalancer) Close() {
if blb.done != nil {
close(blb.done)
}
}

// sign signs plain text byte array to encoded string
func sign(plain []byte) string {
signed := make([]byte, KeyLen+sha256.Size)
Expand Down Expand Up @@ -287,15 +406,15 @@ func newRandomLoadBalancer(spec *LoadBalanceSpec, servers []*Server) *randomLoad

// ChooseServer implements the LoadBalancer interface.
func (lb *randomLoadBalancer) ChooseServer(req *httpprot.Request) *Server {
if len(lb.Servers) == 0 {
if len(lb.HealthyServers()) == 0 {
return nil
}

if server := lb.BaseLoadBalancer.ChooseServer(req); server != nil {
return server
}

return lb.Servers[rand.Intn(len(lb.Servers))]
return lb.HealthyServers()[rand.Intn(len(lb.HealthyServers()))]
}

// roundRobinLoadBalancer does load balancing in a round robin manner.
Expand All @@ -312,7 +431,7 @@ func newRoundRobinLoadBalancer(spec *LoadBalanceSpec, servers []*Server) *roundR

// ChooseServer implements the LoadBalancer interface.
func (lb *roundRobinLoadBalancer) ChooseServer(req *httpprot.Request) *Server {
if len(lb.Servers) == 0 {
if len(lb.HealthyServers()) == 0 {
return nil
}

Expand All @@ -321,7 +440,7 @@ func (lb *roundRobinLoadBalancer) ChooseServer(req *httpprot.Request) *Server {
}

counter := atomic.AddUint64(&lb.counter, 1) - 1
return lb.Servers[int(counter)%len(lb.Servers)]
return lb.HealthyServers()[int(counter)%len(lb.HealthyServers())]
}

// WeightedRandomLoadBalancer does load balancing in a weighted random manner.
Expand All @@ -333,15 +452,15 @@ type WeightedRandomLoadBalancer struct {
func newWeightedRandomLoadBalancer(spec *LoadBalanceSpec, servers []*Server) *WeightedRandomLoadBalancer {
lb := &WeightedRandomLoadBalancer{}
lb.init(spec, servers)
for _, server := range lb.Servers {
for _, server := range lb.HealthyServers() {
lb.totalWeight += server.Weight
}
return lb
}

// ChooseServer implements the LoadBalancer interface.
func (lb *WeightedRandomLoadBalancer) ChooseServer(req *httpprot.Request) *Server {
if len(lb.Servers) == 0 {
if len(lb.HealthyServers()) == 0 {
return nil
}

Expand All @@ -350,7 +469,7 @@ func (lb *WeightedRandomLoadBalancer) ChooseServer(req *httpprot.Request) *Serve
}

randomWeight := rand.Intn(lb.totalWeight)
for _, server := range lb.Servers {
for _, server := range lb.HealthyServers() {
randomWeight -= server.Weight
if randomWeight < 0 {
return server
Expand All @@ -373,7 +492,7 @@ func newIPHashLoadBalancer(spec *LoadBalanceSpec, servers []*Server) *ipHashLoad

// ChooseServer implements the LoadBalancer interface.
func (lb *ipHashLoadBalancer) ChooseServer(req *httpprot.Request) *Server {
if len(lb.Servers) == 0 {
if len(lb.HealthyServers()) == 0 {
return nil
}

Expand All @@ -384,7 +503,7 @@ func (lb *ipHashLoadBalancer) ChooseServer(req *httpprot.Request) *Server {
ip := req.RealIP()
hash := fnv.New32()
hash.Write([]byte(ip))
return lb.Servers[hash.Sum32()%uint32(len(lb.Servers))]
return lb.HealthyServers()[hash.Sum32()%uint32(len(lb.HealthyServers()))]
}

// headerHashLoadBalancer does load balancing based on header hash.
Expand All @@ -402,7 +521,7 @@ func newHeaderHashLoadBalancer(spec *LoadBalanceSpec, servers []*Server) *header

// ChooseServer implements the LoadBalancer interface.
func (lb *headerHashLoadBalancer) ChooseServer(req *httpprot.Request) *Server {
if len(lb.Servers) == 0 {
if len(lb.HealthyServers()) == 0 {
return nil
}

Expand All @@ -413,5 +532,5 @@ func (lb *headerHashLoadBalancer) ChooseServer(req *httpprot.Request) *Server {
v := req.HTTPHeader().Get(lb.key)
hash := fnv.New32()
hash.Write([]byte(v))
return lb.Servers[hash.Sum32()%uint32(len(lb.Servers))]
return lb.HealthyServers()[hash.Sum32()%uint32(len(lb.HealthyServers()))]
}
21 changes: 21 additions & 0 deletions pkg/filters/proxy/loadbalance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"math/rand"
"net/http"
"testing"
"time"

"github.com/megaease/easegress/pkg/protocols/httpprot"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -263,3 +264,23 @@ func BenchmarkSign(b *testing.B) {
sign([]byte("192.168.1.2"))
}
}

func TestHealthCheck(t *testing.T) {
assert := assert.New(t)
servers := prepareServers(3)
lb := NewLoadBalancer(&LoadBalanceSpec{
Policy: LoadBalancePolicyRandom,
HealthCheck: &HealthCheckSpec{
Interval: "3s",
Fails: 2,
},
}, servers)

assert.Equal(len(servers), len(lb.HealthyServers()))

time.Sleep(5 * time.Second)
assert.Equal(len(servers), len(lb.HealthyServers()))

time.Sleep(5 * time.Second)
assert.Equal(0, len(lb.HealthyServers()))
}

0 comments on commit b80a9cf

Please sign in to comment.