Skip to content

Commit

Permalink
Improve filesystem support (Go 1.16+). Add field echo.Filesystem, met…
Browse files Browse the repository at this point in the history
…hods: echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS. Following methods will use echo.Filesystem to server files: echo.File, echo.Static, group.File, group.Static, Context.File
  • Loading branch information
aldas committed Jan 10, 2022
1 parent 6f6befe commit f182524
Show file tree
Hide file tree
Showing 15 changed files with 819 additions and 97 deletions.
25 changes: 0 additions & 25 deletions context.go
Expand Up @@ -9,8 +9,6 @@ import (
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
)
Expand Down Expand Up @@ -569,29 +567,6 @@ func (c *context) Stream(code int, contentType string, r io.Reader) (err error)
return
}

func (c *context) File(file string) (err error) {
f, err := os.Open(file)
if err != nil {
return NotFoundHandler(c)
}
defer f.Close()

fi, _ := f.Stat()
if fi.IsDir() {
file = filepath.Join(file, indexPage)
f, err = os.Open(file)
if err != nil {
return NotFoundHandler(c)
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
return
}
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f)
return
}

func (c *context) Attachment(file, name string) error {
return c.contentDisposition(file, name, "attachment")
}
Expand Down
33 changes: 33 additions & 0 deletions context_fs.go
@@ -0,0 +1,33 @@
//go:build !go1.16
// +build !go1.16

package echo

import (
"net/http"
"os"
"path/filepath"
)

func (c *context) File(file string) (err error) {
f, err := os.Open(file)
if err != nil {
return NotFoundHandler(c)
}
defer f.Close()

fi, _ := f.Stat()
if fi.IsDir() {
file = filepath.Join(file, indexPage)
f, err = os.Open(file)
if err != nil {
return NotFoundHandler(c)
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
return
}
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f)
return
}
47 changes: 47 additions & 0 deletions context_fs_go1.16.go
@@ -0,0 +1,47 @@
//go:build go1.16
// +build go1.16

package echo

import (
"errors"
"io"
"io/fs"
"net/http"
"path/filepath"
)

func (c *context) File(file string) error {
return fsFile(c, file, c.echo.Filesystem)
}

func (c *context) FileFS(file string, filesystem fs.FS) error {
return fsFile(c, file, filesystem)
}

func fsFile(c Context, file string, filesystem fs.FS) error {
f, err := filesystem.Open(file)
if err != nil {
return ErrNotFound
}
defer f.Close()

fi, _ := f.Stat()
if fi.IsDir() {
file = filepath.Join(file, indexPage)
f, err = filesystem.Open(file)
if err != nil {
return ErrNotFound
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
return err
}
}
ff, ok := f.(io.ReadSeeker)
if !ok {
return errors.New("file does not implement io.ReadSeeker")
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
return nil
}
135 changes: 135 additions & 0 deletions context_fs_go1.16_test.go
@@ -0,0 +1,135 @@
//go:build go1.16
// +build go1.16

package echo

import (
testify "github.com/stretchr/testify/assert"
"io/fs"
"net/http"
"net/http/httptest"
"os"
"testing"
)

func TestContext_File(t *testing.T) {
var testCases = []struct {
name string
whenFile string
whenFS fs.FS
expectStatus int
expectStartsWith []byte
expectError string
}{
{
name: "ok, from default file system",
whenFile: "_fixture/images/walle.png",
whenFS: nil,
expectStatus: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "ok, from custom file system",
whenFile: "walle.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, not existent file",
whenFile: "not.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: nil,
expectError: "code=404, message=Not Found",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
if tc.whenFS != nil {
e.Filesystem = tc.whenFS
}

handler := func(ec Context) error {
return ec.(*context).File(tc.whenFile)
}

req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

err := handler(c)

testify.Equal(t, tc.expectStatus, rec.Code)
if tc.expectError != "" {
testify.EqualError(t, err, tc.expectError)
} else {
testify.NoError(t, err)
}

body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
testify.Equal(t, tc.expectStartsWith, body)
})
}
}

func TestContext_FileFS(t *testing.T) {
var testCases = []struct {
name string
whenFile string
whenFS fs.FS
expectStatus int
expectStartsWith []byte
expectError string
}{
{
name: "ok",
whenFile: "walle.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, not existent file",
whenFile: "not.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: nil,
expectError: "code=404, message=Not Found",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()

handler := func(ec Context) error {
return ec.(*context).FileFS(tc.whenFile, tc.whenFS)
}

req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

err := handler(c)

testify.Equal(t, tc.expectStatus, rec.Code)
if tc.expectError != "" {
testify.EqualError(t, err, tc.expectError)
} else {
testify.NoError(t, err)
}

body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
testify.Equal(t, tc.expectStartsWith, body)
})
}
}
53 changes: 4 additions & 49 deletions echo.go
Expand Up @@ -47,9 +47,6 @@ import (
stdLog "log"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"reflect"
"runtime"
"sync"
Expand All @@ -66,6 +63,7 @@ import (
type (
// Echo is the top-level framework instance.
Echo struct {
filesystem
common
// startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get
// listener address info (on which interface/port was listener binded) without having data races.
Expand Down Expand Up @@ -319,8 +317,9 @@ var (
// New creates an instance of Echo.
func New() (e *Echo) {
e = &Echo{
Server: new(http.Server),
TLSServer: new(http.Server),
filesystem: createFilesystem(),
Server: new(http.Server),
TLSServer: new(http.Server),
AutoTLSManager: autocert.Manager{
Prompt: autocert.AcceptTOS,
},
Expand Down Expand Up @@ -499,50 +498,6 @@ func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middlew
return routes
}

// Static registers a new route with path prefix to serve static files from the
// provided root directory.
func (e *Echo) Static(prefix, root string) *Route {
if root == "" {
root = "." // For security we want to restrict to CWD.
}
return e.static(prefix, root, e.GET)
}

func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route {
h := func(c Context) error {
p, err := url.PathUnescape(c.Param("*"))
if err != nil {
return err
}

name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security
fi, err := os.Stat(name)
if err != nil {
// The access path does not exist
return NotFoundHandler(c)
}

// If the request is for a directory and does not end with "/"
p = c.Request().URL.Path // path must not be empty.
if fi.IsDir() && p[len(p)-1] != '/' {
// Redirect to ends with "/"
return c.Redirect(http.StatusMovedPermanently, p+"/")
}
return c.File(name)
}
// Handle added routes based on trailing slash:
// /prefix => exact route "/prefix" + any route "/prefix/*"
// /prefix/ => only any route "/prefix/*"
if prefix != "" {
if prefix[len(prefix)-1] == '/' {
// Only add any route for intentional trailing slash
return get(prefix+"*", h)
}
get(prefix, h)
}
return get(prefix+"/*", h)
}

func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route,
m ...MiddlewareFunc) *Route {
return get(path, func(c Context) error {
Expand Down

0 comments on commit f182524

Please sign in to comment.