diff --git a/admin.go b/admin.go index 0a7b9330798..157ae95f88c 100644 --- a/admin.go +++ b/admin.go @@ -42,6 +42,7 @@ import ( "github.com/caddyserver/certmagic" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) // AdminConfig configures Caddy's API endpoint, which is used @@ -192,6 +193,7 @@ func (admin AdminConfig) newAdminHandler(addr NetworkAddress, remote bool) admin } else { muxWrap.enforceHost = !addr.isWildcardInterface() muxWrap.allowedOrigins = admin.allowedOrigins(addr) + muxWrap.enforceOrigin = admin.EnforceOrigin } addRouteWithMetrics := func(pattern string, handlerLabel string, h http.Handler) { @@ -252,7 +254,7 @@ func (admin AdminConfig) newAdminHandler(addr NetworkAddress, remote bool) admin // will be used as the default origin. If admin.Origins is // empty, no origins will be allowed, effectively bricking the // endpoint for non-unix-socket endpoints, but whatever. -func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []string { +func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []*url.URL { uniqueOrigins := make(map[string]struct{}) for _, o := range admin.Origins { uniqueOrigins[o] = struct{}{} @@ -276,8 +278,23 @@ func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []string { uniqueOrigins[addr.JoinHostPort(0)] = struct{}{} } } - allowed := make([]string, 0, len(uniqueOrigins)) - for origin := range uniqueOrigins { + allowed := make([]*url.URL, 0, len(uniqueOrigins)) + for originStr := range uniqueOrigins { + var origin *url.URL + if strings.Contains(originStr, "://") { + var err error + origin, err = url.Parse(originStr) + if err != nil { + continue + } + origin.Path = "" + origin.RawPath = "" + origin.Fragment = "" + origin.RawFragment = "" + origin.RawQuery = "" + } else { + origin = &url.URL{Host: originStr} + } allowed = append(allowed, origin) } return allowed @@ -358,7 +375,7 @@ func replaceLocalAdminServer(cfg *Config) error { adminLogger.Info("admin endpoint started", zap.String("address", addr.String()), zap.Bool("enforce_origin", adminConfig.EnforceOrigin), - zap.Strings("origins", handler.allowedOrigins)) + zap.Array("origins", loggableURLArray(handler.allowedOrigins))) if !handler.enforceHost { adminLogger.Warn("admin endpoint on open interface; host checking disabled", @@ -650,10 +667,10 @@ type AdminRoute struct { type adminHandler struct { mux *http.ServeMux - // security for local/plaintext) endpoint, on by default + // security for local/plaintext endpoint enforceOrigin bool enforceHost bool - allowedOrigins []string + allowedOrigins []*url.URL // security for remote/encrypted endpoint remoteControl *RemoteAdmin @@ -779,8 +796,8 @@ func (h adminHandler) handleError(w http.ResponseWriter, r *http.Request, err er // rebinding attacks. func (h adminHandler) checkHost(r *http.Request) error { var allowed bool - for _, allowedHost := range h.allowedOrigins { - if r.Host == allowedHost { + for _, allowedOrigin := range h.allowedOrigins { + if r.Host == allowedOrigin.Host { allowed = true break } @@ -799,43 +816,45 @@ func (h adminHandler) checkHost(r *http.Request) error { // sites from issuing requests to our listener. It // returns the origin that was obtained from r. func (h adminHandler) checkOrigin(r *http.Request) (string, error) { - origin := h.getOriginHost(r) - if origin == "" { - return origin, APIError{ + originStr, origin := h.getOrigin(r) + if origin == nil { + return "", APIError{ HTTPStatus: http.StatusForbidden, - Err: fmt.Errorf("missing required Origin header"), + Err: fmt.Errorf("required Origin header is missing or invalid"), } } if !h.originAllowed(origin) { - return origin, APIError{ + return "", APIError{ HTTPStatus: http.StatusForbidden, - Err: fmt.Errorf("client is not allowed to access from origin %s", origin), + Err: fmt.Errorf("client is not allowed to access from origin '%s'", originStr), } } - return origin, nil + return origin.String(), nil } -func (h adminHandler) getOriginHost(r *http.Request) string { +func (h adminHandler) getOrigin(r *http.Request) (string, *url.URL) { origin := r.Header.Get("Origin") if origin == "" { origin = r.Header.Get("Referer") } originURL, err := url.Parse(origin) - if err == nil && originURL.Host != "" { - origin = originURL.Host - } - return origin + if err != nil { + return origin, nil + } + originURL.Path = "" + originURL.RawPath = "" + originURL.Fragment = "" + originURL.RawFragment = "" + originURL.RawQuery = "" + return origin, originURL } -func (h adminHandler) originAllowed(origin string) bool { +func (h adminHandler) originAllowed(origin *url.URL) bool { for _, allowedOrigin := range h.allowedOrigins { - originCopy := origin - if !strings.Contains(allowedOrigin, "://") { - // no scheme specified, so allow both - originCopy = strings.TrimPrefix(originCopy, "http://") - originCopy = strings.TrimPrefix(originCopy, "https://") + if allowedOrigin.Scheme != "" && origin.Scheme != allowedOrigin.Scheme { + continue } - if originCopy == allowedOrigin { + if origin.Host == allowedOrigin.Host { return true } } @@ -1189,6 +1208,18 @@ func decodeBase64DERCert(certStr string) (*x509.Certificate, error) { return x509.ParseCertificate(derBytes) } +type loggableURLArray []*url.URL + +func (ua loggableURLArray) MarshalLogArray(enc zapcore.ArrayEncoder) error { + if ua == nil { + return nil + } + for _, u := range ua { + enc.AppendString(u.String()) + } + return nil +} + var ( // DefaultAdminListen is the address for the local admin // listener, if none is specified at startup.