From c126d46683c1d06beccf66bb1874351532a3d4d7 Mon Sep 17 00:00:00 2001 From: Fufu Date: Fri, 1 Oct 2021 23:03:54 +0800 Subject: [PATCH] Fix: register static file routing with trailing slash --- app_test.go | 25 +++++++++++++++++++++++++ router.go | 5 +++++ router_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+) diff --git a/app_test.go b/app_test.go index aa950459c1..c3ec6df361 100644 --- a/app_test.go +++ b/app_test.go @@ -747,6 +747,31 @@ func Test_App_Static_Trailing_Slash(t *testing.T) { utils.AssertEqual(t, 404, resp.StatusCode, "Status code") utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "") utils.AssertEqual(t, MIMETextPlainCharsetUTF8, resp.Header.Get(HeaderContentType)) + + app.Static("/john/", "./.github") + + req = httptest.NewRequest(MethodGet, "/john/", nil) + resp, err = app.Test(req) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "") + utils.AssertEqual(t, MIMETextHTMLCharsetUTF8, resp.Header.Get(HeaderContentType)) + + req = httptest.NewRequest(MethodGet, "/john", nil) + resp, err = app.Test(req) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "") + utils.AssertEqual(t, MIMETextHTMLCharsetUTF8, resp.Header.Get(HeaderContentType)) + + app.Static("/john_without_index/", "./.github/testdata/fs/css") + + req = httptest.NewRequest(MethodGet, "/john_without_index/", nil) + resp, err = app.Test(req) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 404, resp.StatusCode, "Status code") + utils.AssertEqual(t, false, resp.Header.Get(HeaderContentLength) == "") + utils.AssertEqual(t, MIMETextPlainCharsetUTF8, resp.Header.Get(HeaderContentType)) } func Test_App_Static_Next(t *testing.T) { diff --git a/router.go b/router.go index 54ea5eddc1..730b70390d 100644 --- a/router.go +++ b/router.go @@ -313,6 +313,11 @@ func (app *App) registerStatic(prefix, root string, config ...Static) Router { // Fix this later } prefixLen := len(prefix) + if prefixLen > 1 && prefix[prefixLen-1:] == "/" { + // /john/ -> /john + prefixLen-- + prefix = prefix[:prefixLen] + } // Fileserver settings fs := &fasthttp.FS{ Root: root, diff --git a/router_test.go b/router_test.go index 0e257ca087..c33761d2f3 100644 --- a/router_test.go +++ b/router_test.go @@ -382,6 +382,27 @@ func Test_Route_Static_HasPrefix(t *testing.T) { utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) + app = New() + app.Static("/static/", dir, Static{ + Browse: true, + }) + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/style.css", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err = ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) + app = New() app.Static("/static", dir) @@ -400,6 +421,26 @@ func Test_Route_Static_HasPrefix(t *testing.T) { body, err = ioutil.ReadAll(resp.Body) utils.AssertEqual(t, nil, err, "app.Test(req)") utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) + + + app = New() + app.Static("/static/", dir) + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 404, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 404, resp.StatusCode, "Status code") + + resp, err = app.Test(httptest.NewRequest(MethodGet, "/static/style.css", nil)) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, 200, resp.StatusCode, "Status code") + + body, err = ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err, "app.Test(req)") + utils.AssertEqual(t, true, strings.Contains(app.getString(body), "color")) } //////////////////////////////////////////////