Skip to content

Commit

Permalink
reverse_proxy: Finish adding support for SRV-based backends (#3179)
Browse files Browse the repository at this point in the history
  • Loading branch information
mholt committed Mar 23, 2020
1 parent 82b7fd7 commit 3945dd6
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 38 deletions.
30 changes: 26 additions & 4 deletions modules/caddyhttp/reverseproxy/caddyfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,36 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return net.JoinHostPort(host, port), nil
}

// appendUpstream creates an upstream for address and adds
// it to the list. If the address starts with "srv+" it is
// treated as a SRV-based upstream, and any port will be
// dropped.
appendUpstream := func(address string) error {
isSRV := strings.HasPrefix(address, "srv+")
if isSRV {
address = strings.TrimPrefix(address, "srv+")
}
dialAddr, err := upstreamDialAddress(address)
if err != nil {
return err
}
if isSRV {
if host, _, err := net.SplitHostPort(dialAddr); err == nil {
dialAddr = host
}
h.Upstreams = append(h.Upstreams, &Upstream{LookupSRV: dialAddr})
} else {
h.Upstreams = append(h.Upstreams, &Upstream{Dial: dialAddr})
}
return nil
}

for d.Next() {
for _, up := range d.RemainingArgs() {
dialAddr, err := upstreamDialAddress(up)
err := appendUpstream(up)
if err != nil {
return err
}
h.Upstreams = append(h.Upstreams, &Upstream{Dial: dialAddr})
}

for d.NextBlock(0) {
Expand All @@ -194,11 +217,10 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return d.ArgErr()
}
for _, up := range args {
dialAddr, err := upstreamDialAddress(up)
err := appendUpstream(up)
if err != nil {
return err
}
h.Upstreams = append(h.Upstreams, &Upstream{Dial: dialAddr})
}

case "lb_policy":
Expand Down
72 changes: 48 additions & 24 deletions modules/caddyhttp/reverseproxy/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package reverseproxy
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
"sync/atomic"

Expand Down Expand Up @@ -75,14 +77,16 @@ type Upstream struct {
// backends is down. Also be aware of open proxy vulnerabilities.
Dial string `json:"dial,omitempty"`

// If DNS SRV records are used for service discovery with this
// upstream, specify the DNS name for which to look up SRV
// records here, instead of specifying a dial address.
LookupSRV string `json:"lookup_srv,omitempty"`

// The maximum number of simultaneous requests to allow to
// this upstream. If set, overrides the global passive health
// check UnhealthyRequestCount value.
MaxRequests int `json:"max_requests,omitempty"`

// TODO:...
SRV bool

// TODO: This could be really useful, to bind requests
// with certain properties to specific backends
// HeaderAffinity string
Expand Down Expand Up @@ -121,6 +125,47 @@ func (u *Upstream) Full() bool {
return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests
}

// fillDialInfo returns a filled DialInfo for upstream u, using the request
// context. If the upstream has a SRV lookup configured, that is done and a
// returned address is chosen; otherwise, the upstream's regular dial address
// field is used. Note that the returned value is not a pointer.
func (u *Upstream) fillDialInfo(r *http.Request) (DialInfo, error) {
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
var addr caddy.ParsedAddress

if u.LookupSRV != "" {
// perform DNS lookup for SRV records and choose one
srvName := repl.ReplaceAll(u.LookupSRV, "")
_, records, err := net.DefaultResolver.LookupSRV(r.Context(), "", "", srvName)
if err != nil {
return DialInfo{}, err
}
addr.Network = "tcp"
addr.Host = records[0].Target
addr.StartPort, addr.EndPort = uint(records[0].Port), uint(records[0].Port)
} else {
// use provided dial address
var err error
dial := repl.ReplaceAll(u.Dial, "")
addr, err = caddy.ParseNetworkAddress(dial)
if err != nil {
return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", u.Dial, dial, err)
}
if numPorts := addr.PortRangeSize(); numPorts != 1 {
return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d",
u.Dial, dial, numPorts)
}
}

return DialInfo{
Upstream: u,
Network: addr.Network,
Address: addr.JoinHostPort(0),
Host: addr.Host,
Port: strconv.Itoa(int(addr.StartPort)),
}, nil
}

// upstreamHost is the basic, in-memory representation
// of the state of a remote host. It implements the
// Host interface.
Expand Down Expand Up @@ -207,27 +252,6 @@ func (di DialInfo) String() string {
return caddy.JoinNetworkAddress(di.Network, di.Host, di.Port)
}

// fillDialInfo returns a filled DialInfo for the given upstream, using
// the given Replacer. Note that the returned value is not a pointer.
func fillDialInfo(upstream *Upstream, repl *caddy.Replacer) (DialInfo, error) {
dial := repl.ReplaceAll(upstream.Dial, "")
addr, err := caddy.ParseNetworkAddress(dial)
if err != nil {
return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", upstream.Dial, dial, err)
}
if numPorts := addr.PortRangeSize(); numPorts != 1 {
return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d",
upstream.Dial, dial, numPorts)
}
return DialInfo{
Upstream: upstream,
Network: addr.Network,
Address: addr.JoinHostPort(0),
Host: addr.Host,
Port: strconv.Itoa(int(addr.StartPort)),
}, nil
}

// GetDialInfo gets the upstream dialing info out of the context,
// and returns true if there was a valid value; false otherwise.
func GetDialInfo(ctx context.Context) (DialInfo, bool) {
Expand Down
9 changes: 0 additions & 9 deletions modules/caddyhttp/reverseproxy/httptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"net"
"net/http"
"reflect"
"strconv"
"time"

"github.com/caddyserver/caddy/v2"
Expand Down Expand Up @@ -96,14 +95,6 @@ func (h *HTTPTransport) newTransport() (*http.Transport, error) {
if dialInfo, ok := GetDialInfo(ctx); ok {
network = dialInfo.Network
address = dialInfo.Address
// TODO: experimental SRV lookups
if dialInfo.Upstream.SRV {
_, addrs, err := net.DefaultResolver.LookupSRV(ctx, "", "", dialInfo.Host)
if err != nil {
return nil, err
}
address = net.JoinHostPort(addrs[0].Target, strconv.Itoa(int(addrs[0].Port)))
}
}
conn, err := dialer.DialContext(ctx, network, address)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion modules/caddyhttp/reverseproxy/reverseproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
// the dial address may vary per-request if placeholders are
// used, so perform those replacements here; the resulting
// DialInfo struct should have valid network address syntax
dialInfo, err := fillDialInfo(upstream, repl)
dialInfo, err := upstream.fillDialInfo(r)
if err != nil {
return fmt.Errorf("making dial info: %v", err)
}
Expand Down

0 comments on commit 3945dd6

Please sign in to comment.