From 962cae0bc1a4cdfd91b741c863598d57cbb9f485 Mon Sep 17 00:00:00 2001 From: Fufu Date: Fri, 24 Sep 2021 22:06:02 +0800 Subject: [PATCH 1/4] Fix: static file routing path rewrite. --- router.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/router.go b/router.go index 20ba79b193..54ea5eddc1 100644 --- a/router.go +++ b/router.go @@ -327,8 +327,11 @@ func (app *App) registerStatic(prefix, root string, config ...Static) Router { if len(path) >= prefixLen { if isStar && app.getString(path[0:prefixLen]) == prefix { path = append(path[0:0], '/') - } else if len(path) > 0 && path[len(path)-1] != '/' { - path = append(path[prefixLen:], '/') + } else { + path = path[prefixLen:] + if len(path) == 0 || path[len(path)-1] != '/' { + path = append(path, '/') + } } } if len(path) > 0 && path[0] != '/' { From ac37ea9a39c8cecc68ae699c3f6424401f24094f Mon Sep 17 00:00:00 2001 From: Fufu Date: Sat, 25 Sep 2021 11:32:58 +0800 Subject: [PATCH 2/4] Add: static file routing test cases. --- router_test.go | 109 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/router_test.go b/router_test.go index f854f8c8ab..0503a45b5a 100644 --- a/router_test.go +++ b/router_test.go @@ -10,8 +10,11 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/ioutil" "net/http/httptest" + "os" + "path/filepath" "testing" "github.com/gofiber/fiber/v2/utils" @@ -323,6 +326,112 @@ func Test_Router_Handler_Catch_Error(t *testing.T) { utils.AssertEqual(t, StatusInternalServerError, c.Response.Header.StatusCode()) } +func Test_Route_Static_Root(t *testing.T) { + rootDir, _ := os.Getwd() + f, err := os.CreateTemp(rootDir, "") + if err != nil { + t.Error(err) + } + defer func() { + _ = os.Remove(f.Name()) + }() + + _, err = f.WriteString("Fiber") + utils.AssertEqual(t, nil, err) + _ = f.Close() + + dir, filename := filepath.Split(f.Name()) + + app := New() + app.Static("/", dir, Static{ + Browse: true, + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/"+filename, nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, "Fiber", app.getString(body)) + + app = New() + app.Static("/", dir) + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 404, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/"+filename, nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err = io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, "Fiber", app.getString(body)) +} + +func Test_Route_Static_HasPrefix(t *testing.T) { + rootDir, _ := os.Getwd() + f, err := os.CreateTemp(rootDir, "") + if err != nil { + t.Error(err) + } + defer func() { + _ = os.Remove(f.Name()) + }() + + _, err = f.WriteString("Fiber") + utils.AssertEqual(t, nil, err) + _ = f.Close() + + dir, filename := filepath.Split(f.Name()) + + app := New() + app.Static("/static", dir, Static{ + Browse: true, + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/static", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/"+filename, nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err := io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, "Fiber", app.getString(body)) + + app = New() + app.Static("/static", dir) + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 404, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 404, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/"+filename, nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err = io.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, "Fiber", app.getString(body)) +} + ////////////////////////////////////////////// ///////////////// BENCHMARKS ///////////////// ////////////////////////////////////////////// From 64c006d429d7e3800328f04b8cb645be12a98641 Mon Sep 17 00:00:00 2001 From: Fufu Date: Fri, 1 Oct 2021 12:58:52 +0800 Subject: [PATCH 3/4] Update: change os.CreateTemp to ioutil.TempFile for go1.14 --- router_test.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/router_test.go b/router_test.go index 0503a45b5a..867565fd66 100644 --- a/router_test.go +++ b/router_test.go @@ -10,7 +10,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "io/ioutil" "net/http/httptest" "os" @@ -328,7 +327,7 @@ func Test_Router_Handler_Catch_Error(t *testing.T) { func Test_Route_Static_Root(t *testing.T) { rootDir, _ := os.Getwd() - f, err := os.CreateTemp(rootDir, "") + f, err := ioutil.TempFile(rootDir, "") if err != nil { t.Error(err) } @@ -355,7 +354,7 @@ func Test_Route_Static_Root(t *testing.T) { utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") - body, err := io.ReadAll(resp.Body) + body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, "Fiber", app.getString(body)) @@ -370,14 +369,14 @@ func Test_Route_Static_Root(t *testing.T) { utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") - body, err = io.ReadAll(resp.Body) + body, err = ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, "Fiber", app.getString(body)) } func Test_Route_Static_HasPrefix(t *testing.T) { rootDir, _ := os.Getwd() - f, err := os.CreateTemp(rootDir, "") + f, err := ioutil.TempFile(rootDir, "") if err != nil { t.Error(err) } @@ -408,7 +407,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) { utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") - body, err := io.ReadAll(resp.Body) + body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, "Fiber", app.getString(body)) @@ -427,7 +426,7 @@ func Test_Route_Static_HasPrefix(t *testing.T) { utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") - body, err = io.ReadAll(resp.Body) + body, err = ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, "Fiber", app.getString(body)) } From fe11a690d0d37308310987a3b75c2026d2899f7d Mon Sep 17 00:00:00 2001 From: Fufu Date: Fri, 1 Oct 2021 18:33:44 +0800 Subject: [PATCH 4/4] Update: optimize test cases. --- app_test.go | 9 +++++++++ router_test.go | 51 +++++++++++--------------------------------------- 2 files changed, 20 insertions(+), 40 deletions(-) diff --git a/app_test.go b/app_test.go index 6ae4fc7605..aa950459c1 100644 --- a/app_test.go +++ b/app_test.go @@ -735,6 +735,15 @@ func Test_App_Static_Trailing_Slash(t *testing.T) { req := httptest.NewRequest(MethodGet, "/john/", nil) resp, err := app.Test(req) utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "") + utils.AssertEqual(t, MIMETextHTMLCharsetUTF8, resp.Header.Get(HeaderContentType)) + + app.Static("/john_without_index", "./.github/testdata/fs/css") + + req = httptest.NewRequest(MethodGet, "/john_without_index/", nil) + resp, err = app.Test(req) + utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 404, resp.StatusCode, "Status code") utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "") utils.AssertEqual(t, MIMETextPlainCharsetUTF8, resp.Header.Get(HeaderContentType)) diff --git a/router_test.go b/router_test.go index 867565fd66..0e257ca087 100644 --- a/router_test.go +++ b/router_test.go @@ -12,8 +12,7 @@ import ( "fmt" "io/ioutil" "net/http/httptest" - "os" - "path/filepath" + "strings" "testing" "github.com/gofiber/fiber/v2/utils" @@ -326,21 +325,7 @@ func Test_Router_Handler_Catch_Error(t *testing.T) { } func Test_Route_Static_Root(t *testing.T) { - rootDir, _ := os.Getwd() - f, err := ioutil.TempFile(rootDir, "") - if err != nil { - t.Error(err) - } - defer func() { - _ = os.Remove(f.Name()) - }() - - _, err = f.WriteString("Fiber") - utils.AssertEqual(t, nil, err) - _ = f.Close() - - dir, filename := filepath.Split(f.Name()) - + dir := "./.github/testdata/fs/css" app := New() app.Static("/", dir, Static{ Browse: true, @@ -350,13 +335,13 @@ func Test_Route_Static_Root(t *testing.T) { utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") - resp, err = app.Test(httptest.NewRequest(MethodGet, "/"+filename, nil)) + resp, err = app.Test(httptest.NewRequest(MethodGet, "/style.css", nil)) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "Fiber", app.getString(body)) + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) app = New() app.Static("/", dir) @@ -365,31 +350,17 @@ func Test_Route_Static_Root(t *testing.T) { utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 404, resp.StatusCode, "Status code") - resp, err = app.Test(httptest.NewRequest(MethodGet, "/"+filename, nil)) + resp, err = app.Test(httptest.NewRequest(MethodGet, "/style.css", nil)) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") body, err = ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "Fiber", app.getString(body)) + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) } func Test_Route_Static_HasPrefix(t *testing.T) { - rootDir, _ := os.Getwd() - f, err := ioutil.TempFile(rootDir, "") - if err != nil { - t.Error(err) - } - defer func() { - _ = os.Remove(f.Name()) - }() - - _, err = f.WriteString("Fiber") - utils.AssertEqual(t, nil, err) - _ = f.Close() - - dir, filename := filepath.Split(f.Name()) - + dir := "./.github/testdata/fs/css" app := New() app.Static("/static", dir, Static{ Browse: true, @@ -403,13 +374,13 @@ func Test_Route_Static_HasPrefix(t *testing.T) { utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") - resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/"+filename, nil)) + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/style.css", nil)) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") body, err := ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "Fiber", app.getString(body)) + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) app = New() app.Static("/static", dir) @@ -422,13 +393,13 @@ func Test_Route_Static_HasPrefix(t *testing.T) { utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 404, resp.StatusCode, "Status code") - resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/"+filename, nil)) + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/style.css", nil)) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 200, resp.StatusCode, "Status code") body, err = ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") - utils.AssertEqual(t, "Fiber", app.getString(body)) + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) } //////////////////////////////////////////////