Skip to content

Commit

Permalink
refactor(): back to unexported default headers
Browse files Browse the repository at this point in the history
  • Loading branch information
n33pm committed Apr 9, 2024
1 parent da49574 commit 277d4a5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
17 changes: 8 additions & 9 deletions middleware/realip.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"strings"
)

var DefaultRealIPHeaders = []string{
"True-Client-IP", // Cloudflare Enterprise plan
var defaultHeaders = []string{
"True-Client-IP", // Cloudflare Enterprise plan
"X-Real-IP",
"X-Forwarded-For",
}
Expand All @@ -32,7 +32,7 @@ var DefaultRealIPHeaders = []string{
// how you're using RemoteAddr, vulnerable to an attack of some sort).
func RealIP(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if rip := getRealIP(r, DefaultRealIPHeaders); rip != "" {
if rip := getRealIP(r, defaultHeaders); rip != "" {
r.RemoteAddr = rip
}
h.ServeHTTP(w, r)
Expand All @@ -45,12 +45,11 @@ func RealIP(h http.Handler) http.Handler {
// of parsing the custom headers.
//
// usage:
// r.Use(RealIPFromHeaders([]string{"CF-Connecting-IP"}))
// r.Use(RealIPFromHeaders(append(DefaultRealIPHeaders, "CF-Connecting-IP")))
func RealIPFromHeaders(realIPHeaders []string) func(http.Handler) http.Handler {
// r.Use(RealIPFromHeaders("CF-Connecting-IP"))
func RealIPFromHeaders(headers ...string) func(http.Handler) http.Handler {
f := func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if rip := getRealIP(r, realIPHeaders); rip != "" {
if rip := getRealIP(r, headers); rip != "" {
r.RemoteAddr = rip
}
h.ServeHTTP(w, r)
Expand All @@ -60,8 +59,8 @@ func RealIPFromHeaders(realIPHeaders []string) func(http.Handler) http.Handler {
return f
}

func getRealIP(r *http.Request, realIPHeaders []string) string {
for _, header := range realIPHeaders {
func getRealIP(r *http.Request, headers []string) string {
for _, header := range headers {
if ip := r.Header.Get(header); ip != "" {
ips := strings.Split(ip, ",")
if ips[0] == "" || net.ParseIP(ips[0]) == nil {
Expand Down
4 changes: 2 additions & 2 deletions middleware/realip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestCustomIPHeader(t *testing.T) {
w := httptest.NewRecorder()

r := chi.NewRouter()
r.Use(RealIPFromHeaders([]string{customHeaderKey}))
r.Use(RealIPFromHeaders(customHeaderKey))

realIP := ""
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -145,7 +145,7 @@ func TestCustomIPHeaderWithoutDefault(t *testing.T) {
w := httptest.NewRecorder()

r := chi.NewRouter()
r.Use(RealIPFromHeaders([]string{"X-CUSTOM-IP"}))
r.Use(RealIPFromHeaders("CF-Connecting-IP"))

realIP := ""
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit 277d4a5

Please sign in to comment.