Skip to content

Commit

Permalink
admin: Enforce and refactor origin checking
Browse files Browse the repository at this point in the history
Using URLs seems a little cleaner and more correct

cf: https://caddy.community/t/protect-admin-endpoint/15114

(This used to work. Something must have changed recently.)
  • Loading branch information
mholt committed Feb 15, 2022
1 parent 1d0425b commit 40b5443
Showing 1 changed file with 58 additions and 27 deletions.
85 changes: 58 additions & 27 deletions admin.go
Expand Up @@ -42,6 +42,7 @@ import (
"github.com/caddyserver/certmagic"
"github.com/prometheus/client_golang/prometheus"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)

// AdminConfig configures Caddy's API endpoint, which is used
Expand Down Expand Up @@ -192,6 +193,7 @@ func (admin AdminConfig) newAdminHandler(addr NetworkAddress, remote bool) admin
} else {
muxWrap.enforceHost = !addr.isWildcardInterface()
muxWrap.allowedOrigins = admin.allowedOrigins(addr)
muxWrap.enforceOrigin = admin.EnforceOrigin
}

addRouteWithMetrics := func(pattern string, handlerLabel string, h http.Handler) {
Expand Down Expand Up @@ -252,7 +254,7 @@ func (admin AdminConfig) newAdminHandler(addr NetworkAddress, remote bool) admin
// will be used as the default origin. If admin.Origins is
// empty, no origins will be allowed, effectively bricking the
// endpoint for non-unix-socket endpoints, but whatever.
func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []string {
func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []*url.URL {
uniqueOrigins := make(map[string]struct{})
for _, o := range admin.Origins {
uniqueOrigins[o] = struct{}{}
Expand All @@ -276,8 +278,23 @@ func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []string {
uniqueOrigins[addr.JoinHostPort(0)] = struct{}{}
}
}
allowed := make([]string, 0, len(uniqueOrigins))
for origin := range uniqueOrigins {
allowed := make([]*url.URL, 0, len(uniqueOrigins))
for originStr := range uniqueOrigins {
var origin *url.URL
if strings.Contains(originStr, "://") {
var err error
origin, err = url.Parse(originStr)
if err != nil {
continue
}
origin.Path = ""
origin.RawPath = ""
origin.Fragment = ""
origin.RawFragment = ""
origin.RawQuery = ""
} else {
origin = &url.URL{Host: originStr}
}
allowed = append(allowed, origin)
}
return allowed
Expand Down Expand Up @@ -358,7 +375,7 @@ func replaceLocalAdminServer(cfg *Config) error {
adminLogger.Info("admin endpoint started",
zap.String("address", addr.String()),
zap.Bool("enforce_origin", adminConfig.EnforceOrigin),
zap.Strings("origins", handler.allowedOrigins))
zap.Array("origins", loggableURLArray(handler.allowedOrigins)))

if !handler.enforceHost {
adminLogger.Warn("admin endpoint on open interface; host checking disabled",
Expand Down Expand Up @@ -650,10 +667,10 @@ type AdminRoute struct {
type adminHandler struct {
mux *http.ServeMux

// security for local/plaintext) endpoint, on by default
// security for local/plaintext endpoint
enforceOrigin bool
enforceHost bool
allowedOrigins []string
allowedOrigins []*url.URL

// security for remote/encrypted endpoint
remoteControl *RemoteAdmin
Expand Down Expand Up @@ -779,8 +796,8 @@ func (h adminHandler) handleError(w http.ResponseWriter, r *http.Request, err er
// rebinding attacks.
func (h adminHandler) checkHost(r *http.Request) error {
var allowed bool
for _, allowedHost := range h.allowedOrigins {
if r.Host == allowedHost {
for _, allowedOrigin := range h.allowedOrigins {
if r.Host == allowedOrigin.Host {
allowed = true
break
}
Expand All @@ -799,43 +816,45 @@ func (h adminHandler) checkHost(r *http.Request) error {
// sites from issuing requests to our listener. It
// returns the origin that was obtained from r.
func (h adminHandler) checkOrigin(r *http.Request) (string, error) {
origin := h.getOriginHost(r)
if origin == "" {
return origin, APIError{
originStr, origin := h.getOrigin(r)
if origin == nil {
return "", APIError{
HTTPStatus: http.StatusForbidden,
Err: fmt.Errorf("missing required Origin header"),
Err: fmt.Errorf("required Origin header is missing or invalid"),
}
}
if !h.originAllowed(origin) {
return origin, APIError{
return "", APIError{
HTTPStatus: http.StatusForbidden,
Err: fmt.Errorf("client is not allowed to access from origin %s", origin),
Err: fmt.Errorf("client is not allowed to access from origin '%s'", originStr),
}
}
return origin, nil
return origin.String(), nil
}

func (h adminHandler) getOriginHost(r *http.Request) string {
func (h adminHandler) getOrigin(r *http.Request) (string, *url.URL) {
origin := r.Header.Get("Origin")
if origin == "" {
origin = r.Header.Get("Referer")
}
originURL, err := url.Parse(origin)
if err == nil && originURL.Host != "" {
origin = originURL.Host
}
return origin
if err != nil {
return origin, nil
}
originURL.Path = ""
originURL.RawPath = ""
originURL.Fragment = ""
originURL.RawFragment = ""
originURL.RawQuery = ""
return origin, originURL
}

func (h adminHandler) originAllowed(origin string) bool {
func (h adminHandler) originAllowed(origin *url.URL) bool {
for _, allowedOrigin := range h.allowedOrigins {
originCopy := origin
if !strings.Contains(allowedOrigin, "://") {
// no scheme specified, so allow both
originCopy = strings.TrimPrefix(originCopy, "http://")
originCopy = strings.TrimPrefix(originCopy, "https://")
if allowedOrigin.Scheme != "" && origin.Scheme != allowedOrigin.Scheme {
continue
}
if originCopy == allowedOrigin {
if origin.Host == allowedOrigin.Host {
return true
}
}
Expand Down Expand Up @@ -1189,6 +1208,18 @@ func decodeBase64DERCert(certStr string) (*x509.Certificate, error) {
return x509.ParseCertificate(derBytes)
}

type loggableURLArray []*url.URL

func (ua loggableURLArray) MarshalLogArray(enc zapcore.ArrayEncoder) error {
if ua == nil {
return nil
}
for _, u := range ua {
enc.AppendString(u.String())
}
return nil
}

var (
// DefaultAdminListen is the address for the local admin
// listener, if none is specified at startup.
Expand Down

0 comments on commit 40b5443

Please sign in to comment.