Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve filesystem support #2064

Merged
merged 5 commits into from Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 0 additions & 25 deletions context.go
Expand Up @@ -9,8 +9,6 @@ import (
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
)
Expand Down Expand Up @@ -569,29 +567,6 @@ func (c *context) Stream(code int, contentType string, r io.Reader) (err error)
return
}

func (c *context) File(file string) (err error) {
f, err := os.Open(file)
if err != nil {
return NotFoundHandler(c)
}
defer f.Close()

fi, _ := f.Stat()
if fi.IsDir() {
file = filepath.Join(file, indexPage)
f, err = os.Open(file)
if err != nil {
return NotFoundHandler(c)
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
return
}
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f)
return
}

func (c *context) Attachment(file, name string) error {
return c.contentDisposition(file, name, "attachment")
}
Expand Down
33 changes: 33 additions & 0 deletions context_fs.go
@@ -0,0 +1,33 @@
//go:build !go1.16
// +build !go1.16

package echo

import (
"net/http"
"os"
"path/filepath"
)

func (c *context) File(file string) (err error) {
f, err := os.Open(file)
if err != nil {
return NotFoundHandler(c)
}
defer f.Close()

fi, _ := f.Stat()
if fi.IsDir() {
file = filepath.Join(file, indexPage)
f, err = os.Open(file)
if err != nil {
return NotFoundHandler(c)
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
return
}
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f)
return
}
52 changes: 52 additions & 0 deletions context_fs_go1.16.go
@@ -0,0 +1,52 @@
//go:build go1.16
// +build go1.16

package echo

import (
"errors"
"io"
"io/fs"
"net/http"
"path/filepath"
)

func (c *context) File(file string) error {
return fsFile(c, file, c.echo.Filesystem)
}

// FileFS serves file from given file system.
//
// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
// including `assets/images` as their prefix.
func (c *context) FileFS(file string, filesystem fs.FS) error {
return fsFile(c, file, filesystem)
}

func fsFile(c Context, file string, filesystem fs.FS) error {
f, err := filesystem.Open(file)
if err != nil {
return ErrNotFound
}
defer f.Close()

fi, _ := f.Stat()
if fi.IsDir() {
file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect.
f, err = filesystem.Open(file)
if err != nil {
return ErrNotFound
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
return err
}
}
ff, ok := f.(io.ReadSeeker)
if !ok {
return errors.New("file does not implement io.ReadSeeker")
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
return nil
}
135 changes: 135 additions & 0 deletions context_fs_go1.16_test.go
@@ -0,0 +1,135 @@
//go:build go1.16
// +build go1.16

package echo

import (
"github.com/stretchr/testify/assert"
"io/fs"
"net/http"
"net/http/httptest"
"os"
"testing"
)

func TestContext_File(t *testing.T) {
var testCases = []struct {
name string
whenFile string
whenFS fs.FS
expectStatus int
expectStartsWith []byte
expectError string
}{
{
name: "ok, from default file system",
whenFile: "_fixture/images/walle.png",
whenFS: nil,
expectStatus: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "ok, from custom file system",
whenFile: "walle.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, not existent file",
whenFile: "not.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: nil,
expectError: "code=404, message=Not Found",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
if tc.whenFS != nil {
e.Filesystem = tc.whenFS
}

handler := func(ec Context) error {
return ec.(*context).File(tc.whenFile)
}

req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

err := handler(c)

assert.Equal(t, tc.expectStatus, rec.Code)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}

body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, body)
})
}
}

func TestContext_FileFS(t *testing.T) {
var testCases = []struct {
name string
whenFile string
whenFS fs.FS
expectStatus int
expectStartsWith []byte
expectError string
}{
{
name: "ok",
whenFile: "walle.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, not existent file",
whenFile: "not.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: nil,
expectError: "code=404, message=Not Found",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()

handler := func(ec Context) error {
return ec.(*context).FileFS(tc.whenFile, tc.whenFS)
}

req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

err := handler(c)

assert.Equal(t, tc.expectStatus, rec.Code)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}

body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, body)
})
}
}
53 changes: 4 additions & 49 deletions echo.go
Expand Up @@ -47,9 +47,6 @@ import (
stdLog "log"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"reflect"
"runtime"
"sync"
Expand All @@ -66,6 +63,7 @@ import (
type (
// Echo is the top-level framework instance.
Echo struct {
filesystem
common
// startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get
// listener address info (on which interface/port was listener binded) without having data races.
Expand Down Expand Up @@ -319,8 +317,9 @@ var (
// New creates an instance of Echo.
func New() (e *Echo) {
e = &Echo{
Server: new(http.Server),
TLSServer: new(http.Server),
filesystem: createFilesystem(),
Server: new(http.Server),
TLSServer: new(http.Server),
AutoTLSManager: autocert.Manager{
Prompt: autocert.AcceptTOS,
},
Expand Down Expand Up @@ -499,50 +498,6 @@ func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middlew
return routes
}

// Static registers a new route with path prefix to serve static files from the
// provided root directory.
func (e *Echo) Static(prefix, root string) *Route {
if root == "" {
root = "." // For security we want to restrict to CWD.
}
return e.static(prefix, root, e.GET)
}

func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route {
h := func(c Context) error {
p, err := url.PathUnescape(c.Param("*"))
if err != nil {
return err
}

name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security
fi, err := os.Stat(name)
if err != nil {
// The access path does not exist
return NotFoundHandler(c)
}

// If the request is for a directory and does not end with "/"
p = c.Request().URL.Path // path must not be empty.
if fi.IsDir() && p[len(p)-1] != '/' {
// Redirect to ends with "/"
return c.Redirect(http.StatusMovedPermanently, p+"/")
}
return c.File(name)
}
// Handle added routes based on trailing slash:
// /prefix => exact route "/prefix" + any route "/prefix/*"
// /prefix/ => only any route "/prefix/*"
if prefix != "" {
if prefix[len(prefix)-1] == '/' {
// Only add any route for intentional trailing slash
return get(prefix+"*", h)
}
get(prefix, h)
}
return get(prefix+"/*", h)
}

func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route,
m ...MiddlewareFunc) *Route {
return get(path, func(c Context) error {
Expand Down