From b94870d8865b47e5b7bf3c843a4d48adb2b14e88 Mon Sep 17 00:00:00 2001 From: Fufu Date: Fri, 1 Oct 2021 18:51:48 +0800 Subject: [PATCH] Fix: static file routing path rewrite. (#1538) * Fix: static file routing path rewrite. * Add: static file routing test cases. * Update: change os.CreateTemp to ioutil.TempFile for go1.14 * Update: optimize test cases. --- app_test.go | 9 ++++++ router.go | 7 +++-- router_test.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 2 deletions(-) diff --git a/app_test.go b/app_test.go index 6ae4fc7605..aa950459c1 100644 --- a/app_test.go +++ b/app_test.go @@ -735,6 +735,15 @@ func Test_App_Static_Trailing_Slash(t *testing.T) { req := httptest.NewRequest(MethodGet, "/john/", nil) resp, err := app.Test(req) utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "") + utils.AssertEqual(t, MIMETextHTMLCharsetUTF8, resp.Header.Get(HeaderContentType)) + + app.Static("/john_without_index", "./.github/testdata/fs/css") + + req = httptest.NewRequest(MethodGet, "/john_without_index/", nil) + resp, err = app.Test(req) + utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, 404, resp.StatusCode, "Status code") utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "") utils.AssertEqual(t, MIMETextPlainCharsetUTF8, resp.Header.Get(HeaderContentType)) diff --git a/router.go b/router.go index 20ba79b193..54ea5eddc1 100644 --- a/router.go +++ b/router.go @@ -327,8 +327,11 @@ func (app *App) registerStatic(prefix, root string, config ...Static) Router { if len(path) >= prefixLen { if isStar && app.getString(path[0:prefixLen]) == prefix { path = append(path[0:0], '/') - } else if len(path) > 0 && path[len(path)-1] != '/' { - path = append(path[prefixLen:], '/') + } else { + path = path[prefixLen:] + if len(path) == 0 || path[len(path)-1] != '/' { + path = append(path, '/') + } } } if len(path) > 0 && path[0] != '/' { diff --git a/router_test.go b/router_test.go index f854f8c8ab..0e257ca087 100644 --- a/router_test.go +++ b/router_test.go @@ -12,6 +12,7 @@ import ( "fmt" "io/ioutil" "net/http/httptest" + "strings" "testing" "github.com/gofiber/fiber/v2/utils" @@ -323,6 +324,84 @@ func Test_Router_Handler_Catch_Error(t *testing.T) { utils.AssertEqual(t, StatusInternalServerError, c.Response.Header.StatusCode()) } +func Test_Route_Static_Root(t *testing.T) { + dir := "./.github/testdata/fs/css" + app := New() + app.Static("/", dir, Static{ + Browse: true, + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/style.css", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) + + app = New() + app.Static("/", dir) + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 404, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/style.css", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err = ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) +} + +func Test_Route_Static_HasPrefix(t *testing.T) { + dir := "./.github/testdata/fs/css" + app := New() + app.Static("/static", dir, Static{ + Browse: true, + }) + + resp, err := app.Test(httptest.NewRequest(MethodGet, "/static", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/style.css", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) + + app = New() + app.Static("/static", dir) + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 404, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 404, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/style.css", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err = ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) +} + ////////////////////////////////////////////// ///////////////// BENCHMARKS ///////////////// //////////////////////////////////////////////