diff --git a/fs.go b/fs.go index a8bcd13f9d..bb3ee07893 100644 --- a/fs.go +++ b/fs.go @@ -2,6 +2,7 @@ package fasthttp import ( "bytes" + "crypto/rand" "errors" "fmt" "html" @@ -501,7 +502,8 @@ type fsHandler struct { cacheGzip map[string]*fsFile cacheLock sync.Mutex - smallFileReaderPool sync.Pool + smallFileReaderPool sync.Pool + smallRangeReaderPool sync.Pool } type fsFile struct { @@ -520,6 +522,9 @@ type fsFile struct { bigFiles []*bigFileReader bigFilesLock sync.Mutex + + rangeBigFiles []*bigRangeReader + rangeBigFilesLock sync.Mutex } func (ff *fsFile) NewReader() (io.Reader, error) { @@ -540,9 +545,8 @@ func (ff *fsFile) smallFileReader() (io.Reader, error) { } r := v.(*fsSmallFileReader) r.ff = ff - r.endPos = ff.contentLength - if r.startPos > 0 { - return nil, errors.New("bug: fsSmallFileReader with non-nil startPos found in the pool") + if r.offset > 0 { + panic("BUG: fsSmallFileReader with non-nil offset found in the pool") } return r, nil } @@ -580,7 +584,6 @@ func (ff *fsFile) bigFileReader() (io.Reader, error) { return &bigFileReader{ f: f, ff: ff, - r: f, }, nil } @@ -607,143 +610,175 @@ func (ff *fsFile) decReadersCount() { ff.h.cacheLock.Unlock() } +func (ff *fsFile) NewRangeReader() (io.Reader, error) { + if ff.isBig() { + r, err := ff.bigRangeReader() + if err != nil { + ff.decReadersCount() + } + return r, err + } + return ff.smallRangeReader(), nil +} + +func (ff *fsFile) smallRangeReader() io.Reader { + v := ff.h.smallRangeReaderPool.Get() + if v == nil { + v = &smallRangeReader{ + sta: make([]int, 64), + end: make([]int, 64), + buf: make([]byte, 4096), + } + } + r := v.(*smallRangeReader) + r.ff = ff + r.sta, r.end, r.buf = r.sta[:0], r.end[:0], r.buf[:0] + r.hf = true + r.he = false + r.hasCurRangeBodyHeader = false + r.cRange = 0 + r.boundary = randomBoundary() + return r +} + +func (ff *fsFile) bigRangeReader() (io.Reader, error) { + if ff.f == nil { + panic("BUG: ff.f must be non-nil in bigRangeReader") + } + var r io.Reader + + ff.rangeBigFilesLock.Lock() + n := len(ff.rangeBigFiles) + if n > 0 { + r = ff.rangeBigFiles[n-1] + ff.rangeBigFiles = ff.rangeBigFiles[:n-1] + } + ff.rangeBigFilesLock.Unlock() + + if r != nil { + return r, nil + } + f, err := os.Open(ff.f.Name()) + if err != nil { + return nil, fmt.Errorf("cannot open already opened file: %s", err) + } + rr := &bigRangeReader{ + ff: ff, + f: f, + sta: make([]int, 64), + end: make([]int, 64), + buf: make([]byte, 4096), + hf: true, + he: false, + hasCurRangeBodyHeader: false, + cRange: 0, + boundary: randomBoundary(), + } + rr.sta, rr.end, rr.buf = rr.sta[:0], rr.end[:0], rr.buf[:0] + return rr, nil +} + // bigFileReader attempts to trigger sendfile // for sending big files over the wire. type bigFileReader struct { f *os.File ff *fsFile - r io.Reader - lr io.LimitedReader -} - -func (r *bigFileReader) UpdateByteRange(startPos, endPos int) error { - if _, err := r.f.Seek(int64(startPos), 0); err != nil { - return err - } - r.r = &r.lr - r.lr.R = r.f - r.lr.N = int64(endPos - startPos + 1) - return nil } func (r *bigFileReader) Read(p []byte) (int, error) { - return r.r.Read(p) + return r.f.Read(p) } func (r *bigFileReader) WriteTo(w io.Writer) (int64, error) { if rf, ok := w.(io.ReaderFrom); ok { // fast path. Send file must be triggered - return rf.ReadFrom(r.r) + return rf.ReadFrom(r.f) } // slow path - return copyZeroAlloc(w, r.r) + return copyZeroAlloc(w, r.f) } func (r *bigFileReader) Close() error { - r.r = r.f n, err := r.f.Seek(0, 0) if err == nil { - if n == 0 { - ff := r.ff - ff.bigFilesLock.Lock() - ff.bigFiles = append(ff.bigFiles, r) - ff.bigFilesLock.Unlock() - } else { - _ = r.f.Close() - err = errors.New("bug: File.Seek(0,0) returned (non-zero, nil)") + if n != 0 { + panic("BUG: File.Seek(0,0) returned (non-zero, nil)") } + ff := r.ff + ff.bigFilesLock.Lock() + ff.bigFiles = append(ff.bigFiles, r) + ff.bigFilesLock.Unlock() } else { - _ = r.f.Close() + r.f.Close() } r.ff.decReadersCount() return err } type fsSmallFileReader struct { - ff *fsFile - startPos int - endPos int + ff *fsFile + offset int64 } func (r *fsSmallFileReader) Close() error { ff := r.ff ff.decReadersCount() r.ff = nil - r.startPos = 0 - r.endPos = 0 + r.offset = 0 ff.h.smallFileReaderPool.Put(r) return nil } -func (r *fsSmallFileReader) UpdateByteRange(startPos, endPos int) error { - r.startPos = startPos - r.endPos = endPos + 1 - return nil -} - func (r *fsSmallFileReader) Read(p []byte) (int, error) { - tailLen := r.endPos - r.startPos - if tailLen <= 0 { - return 0, io.EOF - } - if len(p) > tailLen { - p = p[:tailLen] - } - ff := r.ff + if ff.f != nil { - n, err := ff.f.ReadAt(p, int64(r.startPos)) - r.startPos += n + n, err := ff.f.ReadAt(p, r.offset) + r.offset += int64(n) return n, err } - - n := copy(p, ff.dirIndex[r.startPos:]) - r.startPos += n + if r.offset == int64(len(ff.dirIndex)) { + return 0, io.EOF + } + n := copy(p, ff.dirIndex[r.offset:]) + r.offset += int64(n) return n, nil } func (r *fsSmallFileReader) WriteTo(w io.Writer) (int64, error) { + if r.offset != 0 { + panic("BUG: no-zero offset! Read() mustn't be called before WriteTo()") + } ff := r.ff var n int var err error if ff.f == nil { - n, err = w.Write(ff.dirIndex[r.startPos:r.endPos]) + n, err = w.Write(ff.dirIndex) return int64(n), err } - if rf, ok := w.(io.ReaderFrom); ok { return rf.ReadFrom(r) } - - curPos := r.startPos bufv := copyBufPool.Get() buf := bufv.([]byte) for err == nil { - tailLen := r.endPos - curPos - if tailLen <= 0 { - break - } - if len(buf) > tailLen { - buf = buf[:tailLen] - } - n, err = ff.f.ReadAt(buf, int64(curPos)) + n, err = ff.f.ReadAt(buf, r.offset) nw, errw := w.Write(buf[:n]) - curPos += nw + r.offset += int64(nw) if errw == nil && nw != n { - errw = errors.New("bug: Write(p) returned (n, nil), where n != len(p)") + panic("BUG: Write(p) returned (n, nil), where n != len(p)") } if err == nil { err = errw } } copyBufPool.Put(bufv) - if err == io.EOF { err = nil } - return int64(curPos - r.startPos), err + return r.offset, err } func (h *fsHandler) cleanCache(pendingFiles []*fsFile) []*fsFile { @@ -916,7 +951,14 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) { return } - r, err := ff.NewReader() + var r io.Reader + var err error + if h.acceptByteRange && len(byteRange) > 0 { + r, err = ff.NewRangeReader() + } else { + r, err = ff.NewReader() + } + if err != nil { ctx.Logger().Printf("cannot obtain file reader for path=%q: %v", path, err) ctx.Error("Internal Server Error", StatusInternalServerError) @@ -937,23 +979,27 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) { if h.acceptByteRange { hdr.setNonSpecial(strAcceptRanges, strBytes) if len(byteRange) > 0 { - startPos, endPos, err := ParseByteRange(byteRange, contentLength) + staList, endList, err := ParseByteRanges(byteRange, contentLength) if err != nil { - _ = r.(io.Closer).Close() - ctx.Logger().Printf("cannot parse byte range %q for path=%q: %v", byteRange, path, err) + r.(io.Closer).Close() + ctx.Logger().Printf("Cannot parse byte range %q for path=%q,error=%s", byteRange, path, err) ctx.Error("Range Not Satisfiable", StatusRequestedRangeNotSatisfiable) return } - if err = r.(byteRangeUpdater).UpdateByteRange(startPos, endPos); err != nil { - _ = r.(io.Closer).Close() - ctx.Logger().Printf("cannot seek byte range %q for path=%q: %v", byteRange, path, err) + if err = r.(byteRangeHandler).ByteRangeUpdate(staList, endList); err != nil { + r.(io.Closer).Close() + ctx.Logger().Printf("Cannot seek byte range %q for path=%q, error=%s", byteRange, path, err) ctx.Error("Internal Server Error", StatusInternalServerError) return } - - hdr.SetContentRange(startPos, endPos, contentLength) - contentLength = endPos - startPos + 1 + switch { + case len(staList) == 1: + hdr.SetContentRange(staList[0], endList[0], contentLength) + case len(staList) > 1: + hdr.SetContentType(fmt.Sprintf("multipart/byteranges; boundary=%s", r.(byteRangeHandler).Boundary())) + } + contentLength = r.(byteRangeHandler).ByteRangeLength() statusCode = StatusPartialContent } } @@ -980,62 +1026,63 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) { ctx.SetStatusCode(statusCode) } -type byteRangeUpdater interface { - UpdateByteRange(startPos, endPos int) error -} - // ParseByteRange parses 'Range: bytes=...' header value. // // It follows https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 . -func ParseByteRange(byteRange []byte, contentLength int) (startPos, endPos int, err error) { +func ParseByteRanges(byteRange []byte, contentLength int) (startPos, endPos []int, err error) { b := byteRange if !bytes.HasPrefix(b, strBytes) { - return 0, 0, fmt.Errorf("unsupported range units: %q. Expecting %q", byteRange, strBytes) + return startPos, endPos, fmt.Errorf("unsupported range units: %q. Expecting %q", byteRange, strBytes) } - b = b[len(strBytes):] if len(b) == 0 || b[0] != '=' { - return 0, 0, fmt.Errorf("missing byte range in %q", byteRange) + return startPos, endPos, fmt.Errorf("missing byte range in %q", byteRange) } b = b[1:] - n := bytes.IndexByte(b, '-') - if n < 0 { - return 0, 0, fmt.Errorf("missing the end position of byte range in %q", byteRange) - } + var n, sta, end, v int - if n == 0 { - v, err := ParseUint(b[n+1:]) - if err != nil { - return 0, 0, err + for _, ra := range bytes.Split(b, s2b(",")) { + n = bytes.IndexByte(ra, '-') + if n < 0 { + return startPos, endPos, fmt.Errorf("missing the end position of byte range in %q", byteRange) } - startPos := contentLength - v - if startPos < 0 { - startPos = 0 + if n == 0 { + v, err = ParseUint(ra[n+1:]) + if err != nil { + return startPos, endPos, err + } + sta = contentLength - v + if sta < 0 { + sta = 0 + } + startPos = append(startPos, sta) + endPos = append(endPos, contentLength-1) + continue } - return startPos, contentLength - 1, nil - } - - if startPos, err = ParseUint(b[:n]); err != nil { - return 0, 0, err - } - if startPos >= contentLength { - return 0, 0, fmt.Errorf("the start position of byte range cannot exceed %d. byte range %q", contentLength-1, byteRange) - } - - b = b[n+1:] - if len(b) == 0 { - return startPos, contentLength - 1, nil - } - - if endPos, err = ParseUint(b); err != nil { - return 0, 0, err - } - if endPos >= contentLength { - endPos = contentLength - 1 - } - if endPos < startPos { - return 0, 0, fmt.Errorf("the start position of byte range cannot exceed the end position. byte range %q", byteRange) + if sta, err = ParseUint(ra[:n]); err != nil { + return startPos, endPos, err + } + if sta >= contentLength { + return startPos, endPos, fmt.Errorf("the start position of byte range cannot exceed %d. byte range %q", contentLength-1, byteRange) + } + ra = ra[n+1:] + if len(ra) == 0 { + startPos = append(startPos, sta) + endPos = append(endPos, contentLength-1) + continue + } + if end, err = ParseUint(ra); err != nil { + return startPos, endPos, err + } + if end >= contentLength { + end = contentLength - 1 + } + if end < sta { + return startPos, endPos, fmt.Errorf("the start position of byte range cannot exceed the end position. byte range %q", byteRange) + } + startPos = append(startPos, sta) + endPos = append(endPos, end) } return startPos, endPos, nil } @@ -1451,3 +1498,499 @@ func getFileLock(absPath string) *sync.Mutex { filelock := v.(*sync.Mutex) return filelock } + +type byteRangeHandler interface { + ByteRangeUpdate(sta, end []int) error + ByteRangeLength() int + Boundary() string + IsMultiRange() bool +} + +type smallRangeReader struct { + ff *fsFile + sta []int + end []int + + // handle multi range + buf []byte + hf bool + he bool + hasCurRangeBodyHeader bool + cRange int + boundary string +} + +func (r *smallRangeReader) ByteRangeUpdate(sta, end []int) error { + for i := range sta { + r.sta = append(r.sta, sta[i]) + r.end = append(r.end, end[i]+1) + } + return nil +} + +func (r *smallRangeReader) IsMultiRange() bool { + return len(r.sta) > 1 +} + +func (r *smallRangeReader) Boundary() string { + return r.boundary +} + +func (r *smallRangeReader) ByteRangeLength() int { + if !r.IsMultiRange() { + return r.end[0] - r.sta[0] + } + return multiRangeLength(r.sta, r.end, r.ff.contentLength, r.ff.contentType, r.boundary) +} + +func (r *smallRangeReader) Close() error { + ff := r.ff + ff.decReadersCount() + r.ff = nil + + r.sta, r.end, r.buf = r.sta[:0], r.end[:0], r.buf[:0] + r.cRange = 0 + r.hf = true + r.he = false + r.hasCurRangeBodyHeader = false + r.boundary = "" + ff.h.smallRangeReaderPool.Put(r) + return nil +} + +func (r *smallRangeReader) Read(p []byte) (int, error) { + ff := r.ff + var err error + cPos, cLen, n := 0, 0, 0 + + if r.cRange >= len(r.sta) && len(r.buf) == 0 && r.he { + return 0, io.EOF + } + + if r.IsMultiRange() { + cLen = len(p[cPos:]) + n = copy(p[cPos:], r.buf) + if len(r.buf) > cLen { + r.buf = r.buf[n:] + return n, nil + } + cPos += n + r.buf = r.buf[:0] + } + + for i := r.cRange; i < len(r.sta); i++ { + if r.sta[i] >= r.end[i] { + continue + } + + if r.IsMultiRange() && !r.hasCurRangeBodyHeader { + r.hasCurRangeBodyHeader = true + multiRangeBodyHeader(&r.buf, r.sta[i], r.end[i], r.ff.contentLength, r.ff.contentType, r.boundary, r.hf) + r.hf = false + cLen = len(p[cPos:]) + n = copy(p[cPos:], r.buf) + cPos += n + if len(r.buf) > cLen { + r.buf = r.buf[n:] + return cPos, nil + } + r.buf = r.buf[:0] + } + + cLen = len(p[cPos:]) + + // handle file + if ff.f != nil { + if r.end[i]-r.sta[i] > cLen { + n, err = ff.f.ReadAt(p[cPos:], int64(r.sta[i])) + r.sta[i] += n + return cPos + n, err + } + + // todo use pool + cBody := make([]byte, r.end[i]-r.sta[i]) + n, err = ff.f.ReadAt(cBody, int64(r.sta[i])) + if err != nil && err != io.EOF { + return cPos + n, err + } + n = copy(p[cPos:], cBody) + } else { + // handle dir + n = copy(p[cPos:], ff.dirIndex[r.sta[i]:r.end[i]]) + if r.end[i]-r.sta[i] > cLen { + r.sta[i] += n + return cPos + n, nil + } + } + + cPos += n + r.cRange = i + 1 + r.hasCurRangeBodyHeader = false + } + + if r.IsMultiRange() && !r.he { + multiRangeBodyEnd(&r.buf, r.boundary) + r.he = true + n = copy(p[cPos:], r.buf) + if len(r.buf) > len(p[cPos:]) { + r.buf = r.buf[n:] + return cPos + n, nil + } + cPos += n + r.buf = r.buf[:0] + } + return cPos, io.EOF +} + +func (r *smallRangeReader) WriteTo(w io.Writer) (int64, error) { + ff := r.ff + var err error + cPos, cLen, n, sum := 0, 0, 0, 0 + bufv := copyBufPool.Get() + buf := bufv.([]byte) + hf := true + if ff.f == nil { + for i := range r.sta { + + if r.IsMultiRange() { + buf = buf[:0] + if i > 0 { + hf = false + } + multiRangeBodyHeader(&buf, r.sta[i], r.end[i], ff.contentLength, ff.contentType, r.boundary, hf) + nw, errw := w.Write(buf) + if errw == nil && nw != len(buf) { + panic("BUG: buf returned(n, nil),where n != len(buf)") + } + sum += nw + if errw != nil { + return int64(sum), errw + } + } + + n, err = w.Write(ff.dirIndex[r.sta[i]:r.end[i]]) + sum += n + if err != nil { + return int64(sum), err + } + } + if r.IsMultiRange() { + buf = buf[:0] + multiRangeBodyEnd(&buf, r.boundary) + nw, errw := w.Write(buf) + if errw == nil && nw != len(buf) { + panic("BUG: buf returned (n, nil), where n != len(buf)") + } + sum += nw + err = errw + } + return int64(sum), err + } + + if rf, ok := w.(io.ReaderFrom); ok { + return rf.ReadFrom(r) + } + + for i := range r.sta { + + if r.IsMultiRange() { + buf = buf[:0] + if i > 0 { + hf = false + } + multiRangeBodyHeader(&buf, r.sta[i], r.end[i], ff.contentLength, ff.contentType, r.boundary, hf) + nw, errw := w.Write(buf) + if errw == nil && nw != len(buf) { + panic("BUG: buf returned(n, nil),where n != len(buf)") + } + sum += nw + if errw != nil { + return int64(sum), errw + } + } + + cPos = r.sta[i] + buf = buf[:4096] + for err == nil { + cLen = r.end[i] - cPos + if cLen <= 0 { + break + } + if len(buf) > cLen { + buf = buf[:cLen] + } + n, err = ff.f.ReadAt(buf, int64(cPos)) + nw, errw := w.Write(buf[:n]) + cPos += nw + sum += nw + if errw == nil && nw != n { + panic("BUG: Write(p) returned (n, nil), where n != len(p)") + } + if err == nil { + err = errw + } + } + } + if err == io.EOF { + err = nil + } + if r.IsMultiRange() { + buf = buf[:0] + multiRangeBodyEnd(&buf, r.boundary) + nw, errw := w.Write(buf) + if errw == nil && nw != len(buf) { + panic("BUG: buf returned (n, nil), where n != len(buf)") + } + sum += nw + if err == nil { + err = errw + } + } + return int64(sum), err +} + +type bigRangeReader struct { + ff *fsFile + f *os.File + sta []int + end []int + + // handle multi range + buf []byte + hf bool + he bool + hasCurRangeBodyHeader bool + cRange int + boundary string +} + +func (r *bigRangeReader) ByteRangeUpdate(sta, end []int) error { + for i := range sta { + r.sta = append(r.sta, sta[i]) + r.end = append(r.end, end[i]+1) + } + return nil +} + +func (r *bigRangeReader) IsMultiRange() bool { + return len(r.sta) > 1 +} + +func (r *bigRangeReader) Boundary() string { + return r.boundary +} + +func (r *bigRangeReader) ByteRangeLength() int { + if !r.IsMultiRange() { + return r.end[0] - r.sta[0] + } + return multiRangeLength(r.sta, r.end, r.ff.contentLength, r.ff.contentType, r.boundary) +} + +func (r *bigRangeReader) Close() error { + n, err := r.f.Seek(0, 0) + if err == nil { + if n != 0 { + panic("BUG: File.Seek(0,0) returned (non-zero, nil)") + } + ff := r.ff + ff.rangeBigFilesLock.Lock() + r.end, r.sta, r.buf = r.end[:0], r.sta[:0], r.buf[:0] + r.cRange = 0 + r.hf = true + r.he = false + r.hasCurRangeBodyHeader = false + + ff.rangeBigFiles = append(ff.rangeBigFiles, r) + ff.rangeBigFilesLock.Unlock() + } else { + r.f.Close() + } + r.ff.decReadersCount() + return err +} + +func (r *bigRangeReader) Read(p []byte) (int, error) { + if r.cRange >= len(r.sta) && len(r.buf) == 0 && r.he { + return 0, io.EOF + } + + ff := r.ff + var err error + cPos, cLen, n := 0, 0, 0 + + if r.IsMultiRange() { + cLen = len(p[cPos:]) + n = copy(p[cPos:], r.buf) + if len(r.buf) > cLen { + r.buf = r.buf[n:] + return n, nil + } + cPos += n + r.buf = r.buf[:0] + } + + for i := r.cRange; i < len(r.sta); i++ { + if r.sta[i] >= r.end[i] { + continue + } + + if r.IsMultiRange() && !r.hasCurRangeBodyHeader { + r.hasCurRangeBodyHeader = true + multiRangeBodyHeader(&r.buf, r.sta[i], r.end[i], r.ff.contentLength, r.ff.contentType, r.boundary, r.hf) + r.hf = false + cLen = len(p[cPos:]) + n = copy(p[cPos:], r.buf) + cPos += n + if len(r.buf) > cLen { + r.buf = r.buf[n:] + return cPos, nil + } + r.buf = r.buf[:0] + } + + cLen = len(p[cPos:]) + if r.end[i]-r.sta[i] > cLen { + n, err = r.f.ReadAt(p[cPos:], int64(r.sta[i])) + r.sta[i] += n + return cPos + n, err + } + + // todo use pool + cBody := make([]byte, r.end[i]-r.sta[i]) + n, err = ff.f.ReadAt(cBody, int64(r.sta[i])) + if err != nil && err != io.EOF { + return cPos + n, err + } + n = copy(p[cPos:], cBody) + cPos += n + r.cRange = i + 1 + r.hasCurRangeBodyHeader = false + } + + if r.IsMultiRange() && !r.he { + multiRangeBodyEnd(&r.buf, r.boundary) + r.he = true + n = copy(p[cPos:], r.buf) + if len(r.buf) > len(p[cPos:]) { + r.buf = r.buf[n:] + return cPos + n, nil + } + cPos += n + r.buf = r.buf[:0] + } + return cPos, io.EOF +} + +func (r *bigRangeReader) WriteTo(w io.Writer) (int64, error) { + if rf, ok := w.(io.ReaderFrom); ok { + return rf.ReadFrom(r) + } + + var err error + cPos, cLen, n, sum := 0, 0, 0, 0 + hf := true + + bufv := copyBufPool.Get() + buf := bufv.([]byte) + + for i := range r.sta { + + if r.IsMultiRange() { + buf = buf[:0] + if i > 0 { + hf = false + } + multiRangeBodyHeader(&buf, r.sta[i], r.end[i], r.ff.contentLength, r.ff.contentType, r.boundary, hf) + nw, errw := w.Write(buf) + if errw == nil && nw != len(buf) { + panic("BUG: buf returned(n, nil),where n != len(buf)") + } + sum += nw + if errw != nil { + return int64(sum), errw + } + } + + cPos = r.sta[i] + buf = buf[:4096] + for err == nil { + cLen = r.end[i] - cPos + if cLen <= 0 { + break + } + if len(buf) > cLen { + buf = buf[:cLen] + } + n, err = r.f.ReadAt(buf, int64(cPos)) + nw, errw := w.Write(buf[:n]) + cPos += nw + sum += nw + if errw == nil && nw != n { + panic("BUG: Write(p) returned (n, nil), where n != len(p)") + } + if err == nil { + err = errw + } + } + } + if err == io.EOF { + err = nil + } + if r.IsMultiRange() { + buf = buf[:0] + multiRangeBodyEnd(&buf, r.boundary) + nw, errw := w.Write(buf) + if err == nil && nw != len(buf) { + panic("BUG: buf returned (n, nil), where n != len(buf)") + } + sum += nw + err = errw + } + return int64(sum), err +} + +func multiRangeLength(sta, end []int, cl int, ct, bd string) int { + sum := 0 + hf := true + bufv := copyBufPool.Get() + buf := bufv.([]byte) + for i := range sta { + buf = buf[:0] + if i > 0 { + hf = false + } + multiRangeBodyHeader(&buf, sta[i], end[i], cl, ct, bd, hf) + sum += len(buf) + sum += end[i] - sta[i] + } + buf = buf[:0] + multiRangeBodyEnd(&buf, bd) + sum += len(buf) + copyBufPool.Put(bufv) + return sum +} + +func multiRangeBodyHeader(b *[]byte, sta, end, size int, ct, boundary string, hf bool) { + if !hf { + *b = append(*b, s2b("\r\n")...) + } + *b = append(*b, s2b(fmt.Sprintf("--%s\r\n", boundary))...) + *b = append(*b, s2b(fmt.Sprintf("%s: %s\r\n", HeaderContentRange, + fmt.Sprintf("bytes %d-%d/%d", sta, end-1, size)))...) + *b = append(*b, s2b(fmt.Sprintf("%s: %s\r\n", HeaderContentType, ct))...) + *b = append(*b, s2b("\r\n")...) +} + +func multiRangeBodyEnd(b *[]byte, boundary string) { + *b = append(*b, s2b(fmt.Sprintf("\r\n--%s--\r\n", boundary))...) +} + +func randomBoundary() string { + var buf [30]byte + _, err := io.ReadFull(rand.Reader, buf[:]) + if err != nil { + panic(err) + } + return fmt.Sprintf("%x", buf[:]) +} diff --git a/fs_test.go b/fs_test.go index 22a9b33c2b..053f3b0923 100644 --- a/fs_test.go +++ b/fs_test.go @@ -10,6 +10,7 @@ import ( "path" "runtime" "sort" + "strings" "testing" "time" ) @@ -301,7 +302,7 @@ func TestServeFileUncompressed(t *testing.T) { } } -func TestFSByteRangeConcurrent(t *testing.T) { +func TestFSSingleByteRangeConcurrent(t *testing.T) { // This test can't run parallel as files in / might by changed by other tests. stop := make(chan struct{}) @@ -319,8 +320,10 @@ func TestFSByteRangeConcurrent(t *testing.T) { for i := 0; i < concurrency; i++ { go func() { for j := 0; j < 5; j++ { - testFSByteRange(t, h, "/fs.go") - testFSByteRange(t, h, "/README.md") + testFSSingleByteRangeOfRead(t, h, "/fs.go") + testFSSingleByteRangeOfWriteTo(t, h, "/fs.go") + testFSSingleByteRangeOfRead(t, h, "/README.md") + testFSSingleByteRangeOfWriteTo(t, h, "/README.md") } ch <- struct{}{} }() @@ -348,28 +351,33 @@ func TestFSByteRangeSingleThread(t *testing.T) { } h := fs.NewRequestHandler() - testFSByteRange(t, h, "/fs.go") - testFSByteRange(t, h, "/README.md") + testFSSingleByteRangeOfRead(t, h, "/fs.go") + testFSSingleByteRangeOfWriteTo(t, h, "/fs.go") + testFSSingleByteRangeOfRead(t, h, "/README.md") + testFSSingleByteRangeOfWriteTo(t, h, "/README.md") } -func testFSByteRange(t *testing.T, h RequestHandler, filePath string) { +func testFSSingleByteRangeOfRead(t *testing.T, h RequestHandler, filePath string) { var ctx RequestCtx ctx.Init(&Request{}, nil, nil) expectedBody, err := getFileContents(filePath) if err != nil { - t.Fatalf("cannot read file %q: %v", filePath, err) + t.Fatalf("cannot read file %q: %s", filePath, err) } fileSize := len(expectedBody) - startPos := rand.Intn(fileSize) - endPos := rand.Intn(fileSize) - if endPos < startPos { - startPos, endPos = endPos, startPos + startPos, endPos := make([]int, 0), make([]int, 0) + start := rand.Intn(fileSize) + end := rand.Intn(fileSize) + if end < start { + start, end = end, start } + startPos = append(startPos, start) + endPos = append(endPos, end) ctx.Request.SetRequestURI(filePath) - ctx.Request.Header.SetByteRange(startPos, endPos) + ctx.Request.Header.SetByteRanges(startPos, endPos) h(&ctx) var resp Response @@ -381,23 +389,91 @@ func testFSByteRange(t *testing.T, h RequestHandler, filePath string) { if resp.StatusCode() != StatusPartialContent { t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusPartialContent, filePath) } + cr := resp.Header.Peek(HeaderContentRange) - expectedCR := fmt.Sprintf("bytes %d-%d/%d", startPos, endPos, fileSize) + expectedCR := fmt.Sprintf("bytes %d-%d/%d", start, end, fileSize) if string(cr) != expectedCR { t.Fatalf("unexpected content-range %q. Expecting %q. filePath=%q", cr, expectedCR, filePath) } body := resp.Body() - bodySize := endPos - startPos + 1 + bodySize := end - start + 1 if len(body) != bodySize { - t.Fatalf("unexpected body size %d. Expecting %d. filePath=%q, startPos=%d, endPos=%d", - len(body), bodySize, filePath, startPos, endPos) + t.Fatalf("unexpected body size %d. Expecting %d. filePath=%q, start=%d, end=%d", + len(body), bodySize, filePath, start, end) } - expectedBody = expectedBody[startPos : endPos+1] + expectedBody = expectedBody[start : end+1] if !bytes.Equal(body, expectedBody) { - t.Fatalf("unexpected body %q. Expecting %q. filePath=%q, startPos=%d, endPos=%d", - body, expectedBody, filePath, startPos, endPos) + t.Fatalf("unexpected body %q. Expecting %q. filePath=%q, start=%d, end=%d", + body, expectedBody, filePath, start, end) + } +} + +func testFSSingleByteRangeOfWriteTo(t *testing.T, h RequestHandler, filePath string) { + var ctx RequestCtx + ctx.Init(&Request{}, nil, nil) + + expectedBody, err := getFileContents(filePath) + if err != nil { + t.Fatalf("cannot read file %q: %s", filePath, err) + } + + fileSize := len(expectedBody) + startPos, endPos := make([]int, 0), make([]int, 0) + start := rand.Intn(fileSize) + end := rand.Intn(fileSize) + if end < start { + start, end = end, start + } + startPos = append(startPos, start) + endPos = append(endPos, end) + + ctx.Request.SetRequestURI(filePath) + ctx.Request.Header.SetByteRanges(startPos, endPos) + h(&ctx) + + bodySize := end - start + 1 + + // test WriteTo(w io.Writer) + if fileSize > maxSmallFileSize { + reader, ok := ctx.Response.bodyStream.(*bigRangeReader) + if !ok { + t.Fatal("expected bigRangeReader") + } + buf := bytes.NewBuffer(nil) + + n, err := reader.WriteTo(pureWriter{buf}) + if err != nil { + t.Fatal(err) + } + if n != int64(bodySize) { + t.Fatalf("expected %d bytes, got %d bytes", bodySize, n) + } + body1 := buf.String() + if body1 != b2s(expectedBody[start:end+1]) { + t.Fatalf("unexpected body %q. Expecting %q. filePath=%q, start=%d, end=%d", + body1, b2s(expectedBody[start:end+1]), filePath, start, end) + } + } else { + reader, ok := ctx.Response.bodyStream.(*smallRangeReader) + if !ok { + t.Fatal("expected smallRangeReader") + } + buf := bytes.NewBuffer(nil) + + n, err := reader.WriteTo(pureWriter{buf}) + if err != nil { + t.Fatal(err) + } + if n != int64(bodySize) { + t.Fatalf("expected %d bytes, got %d bytes", bodySize, n) + } + body1 := buf.String() + if body1 != b2s(expectedBody[start:end+1]) { + t.Fatalf("unexpected body %q. Expecting %q. filePath=%q, start=%d, end=%d", + body1, b2s(expectedBody[start:end+1]), filePath, start, end) + } } } @@ -411,36 +487,325 @@ func getFileContents(path string) ([]byte, error) { return io.ReadAll(f) } -func TestParseByteRangeSuccess(t *testing.T) { +func TestFSMultiByteRangeConcurrent(t *testing.T) { + t.Parallel() + + fs := &FS{ + Root: ".", + AcceptByteRange: true, + } + h := fs.NewRequestHandler() + + concurrency := 10 + ch := make(chan struct{}, concurrency) + for i := 0; i < concurrency; i++ { + go func() { + for j := 0; j < 5; j++ { + testFSMultiByteRangeOfRead(t, h, "/fs.go") + testFSMultiByteRangeOfWriteTo(t, h, "/fs.go") + } + ch <- struct{}{} + }() + } + + for i := 0; i < concurrency; i++ { + select { + case <-time.After(time.Second): + t.Fatalf("timeout") + case <-ch: + } + } +} + +func TestFSMultiByteRangeSingleThread(t *testing.T) { + t.Parallel() + + fs := &FS{ + Root: ".", + AcceptByteRange: true, + } + h := fs.NewRequestHandler() + + testFSMultiByteRangeOfRead(t, h, "/fs.go") + testFSMultiByteRangeOfWriteTo(t, h, "/fs.go") +} + +func testFSMultiByteRangeOfWriteTo(t *testing.T, h RequestHandler, filePath string) { + var ctx RequestCtx + ctx.Init(&Request{}, nil, nil) + + expectedBody, err := getFileContents(filePath) + if err != nil { + t.Fatalf("cannot read file %q: %s", filePath, err) + } + + num := rand.Intn(5) + 2 + + fileSize := len(expectedBody) + startPos, endPos := make([]int, 0), make([]int, 0) + + for i := 0; i < num; i++ { + start := rand.Intn(fileSize) + end := rand.Intn(fileSize) + if end < start { + start, end = end, start + } + startPos = append(startPos, start) + endPos = append(endPos, end) + } + + ctx.Request.SetRequestURI(filePath) + ctx.Request.Header.SetByteRanges(startPos, endPos) + h(&ctx) + + var body string + var boundary string + + if fileSize > maxSmallFileSize { + reader, ok := ctx.Response.bodyStream.(*bigRangeReader) + boundary = reader.Boundary() + if !ok { + t.Fatal("expected bigRangeReader") + } + buf := bytes.NewBuffer(nil) + + _, err := reader.WriteTo(pureWriter{buf}) + if err != nil { + t.Fatal(err) + } + body = buf.String() + } else { + reader, ok := ctx.Response.bodyStream.(*smallRangeReader) + boundary = reader.Boundary() + if !ok { + t.Fatal("expected smallRangeReader") + } + buf := bytes.NewBuffer(nil) + + _, err := reader.WriteTo(pureWriter{buf}) + if err != nil { + t.Fatal(err) + } + body = buf.String() + } + + singleBodys := make([]byte, 0) + + // compare with single range + for i := 0; i < num; i++ { + var ctx1 RequestCtx + ctx1.Init(&Request{}, nil, nil) + ctx1.Request.SetRequestURI(filePath) + ctx1.Request.Header.SetByteRanges([]int{startPos[i]}, []int{endPos[i]}) + h(&ctx1) + + var r1 Response + s1 := ctx1.Response.String() + + br1 := bufio.NewReader(bytes.NewBufferString(s1)) + if err1 := r1.Read(br1); err1 != nil { + t.Fatalf("unexpected error: %s. filePath=%q", err1, filePath) + } + if r1.StatusCode() != StatusPartialContent { + t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", r1.StatusCode(), StatusPartialContent, filePath) + } + + cr1 := r1.Header.Peek(HeaderContentRange) + expectedCR1 := fmt.Sprintf("bytes %d-%d/%d", startPos[i], endPos[i], fileSize) + if string(cr1) != expectedCR1 { + t.Fatalf("unexpected content-range %q. Expecting %q. filePath=%q", cr1, expectedCR1, filePath) + } + + body1 := r1.Body() + bodySize := endPos[i] - startPos[i] + 1 + if len(body1) != bodySize { + t.Fatalf("unexpected body size %d. Expecting %d. filePath=%q, startPos=%d, endPos=%d", + len(body), bodySize, filePath, startPos[i], endPos[i]) + } + + expectedBody1 := expectedBody[startPos[i] : endPos[i]+1] + if !bytes.Equal(body1, expectedBody1) { + t.Fatalf("unexpected body %q. Expecting %q. filePath=%q, startPos=%d, endPos=%d", + body1, expectedBody1, filePath, startPos[i], endPos[i]) + } + buf := make([]byte, 0) + first := true + if i > 0 { + first = false + } + ct1 := r1.Header.Peek(HeaderContentType) + multiRangeBodyHeader(&buf, startPos[i], endPos[i]+1, fileSize, string(ct1), boundary, first) + singleBodys = append(singleBodys, buf...) + singleBodys = append(singleBodys, body1...) + } + buf := make([]byte, 0) + multiRangeBodyEnd(&buf, boundary) + singleBodys = append(singleBodys, buf...) + if body != string(singleBodys) { + t.Fatalf("multipart ranges content is invalid") + } +} + +func testFSMultiByteRangeOfRead(t *testing.T, h RequestHandler, filePath string) { + var ctx RequestCtx + ctx.Init(&Request{}, nil, nil) + + expectedBody, err := getFileContents(filePath) + if err != nil { + t.Fatalf("cannot read file %q: %s", filePath, err) + } + + num := rand.Intn(5) + 2 + + fileSize := len(expectedBody) + startPos, endPos := make([]int, 0), make([]int, 0) + + for i := 0; i < num; i++ { + start := rand.Intn(fileSize) + end := rand.Intn(fileSize) + if end < start { + start, end = end, start + } + startPos = append(startPos, start) + endPos = append(endPos, end) + } + + ctx.Request.SetRequestURI(filePath) + ctx.Request.Header.SetByteRanges(startPos, endPos) + h(&ctx) + + var resp Response + s := ctx.Response.String() + + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s. filePath=%q", err, filePath) + } + if resp.StatusCode() != StatusPartialContent { + t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusPartialContent, filePath) + } + + ct := resp.Header.Peek(HeaderContentType) + expectedCT := "multipart/byteranges; boundary=" + if !strings.HasPrefix(string(ct), expectedCT) { + t.Fatalf("unexpected content-type %q. Expecting prefix %q. filePath=%q", ct, expectedCT, filePath) + } + + cl := resp.Header.Peek(HeaderContentLength) + + body := resp.Body() + if fmt.Sprintf("%d", len(body)) != b2s(cl) { + t.Fatalf("Content-Length error") + } + + boundary := string(ct)[len(expectedCT):] + + singleBodys := make([]byte, 0) + + // compare with single range + for i := 0; i < num; i++ { + var ctx1 RequestCtx + ctx1.Init(&Request{}, nil, nil) + ctx1.Request.SetRequestURI(filePath) + ctx1.Request.Header.SetByteRanges([]int{startPos[i]}, []int{endPos[i]}) + h(&ctx1) + + var r1 Response + s1 := ctx1.Response.String() + + br1 := bufio.NewReader(bytes.NewBufferString(s1)) + if err1 := r1.Read(br1); err1 != nil { + t.Fatalf("unexpected error: %s. filePath=%q", err1, filePath) + } + if r1.StatusCode() != StatusPartialContent { + t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", r1.StatusCode(), StatusPartialContent, filePath) + } + + cr1 := r1.Header.Peek(HeaderContentRange) + expectedCR1 := fmt.Sprintf("bytes %d-%d/%d", startPos[i], endPos[i], fileSize) + if string(cr1) != expectedCR1 { + t.Fatalf("unexpected content-range %q. Expecting %q. filePath=%q", cr1, expectedCR1, filePath) + } + + body1 := r1.Body() + bodySize := endPos[i] - startPos[i] + 1 + if len(body1) != bodySize { + t.Fatalf("unexpected body size %d. Expecting %d. filePath=%q, startPos=%d, endPos=%d", + len(body), bodySize, filePath, startPos[i], endPos[i]) + } + + expectedBody1 := expectedBody[startPos[i] : endPos[i]+1] + if !bytes.Equal(body1, expectedBody1) { + t.Fatalf("unexpected body %q. Expecting %q. filePath=%q, startPos=%d, endPos=%d", + body1, expectedBody1, filePath, startPos[i], endPos[i]) + } + buf := make([]byte, 0) + first := true + if i > 0 { + first = false + } + ct1 := r1.Header.Peek(HeaderContentType) + multiRangeBodyHeader(&buf, startPos[i], endPos[i]+1, fileSize, string(ct1), boundary, first) + singleBodys = append(singleBodys, buf...) + singleBodys = append(singleBodys, body1...) + } + buf := make([]byte, 0) + multiRangeBodyEnd(&buf, boundary) + singleBodys = append(singleBodys, buf...) + if string(body) != string(singleBodys) { + t.Fatalf("multipart ranges content is invalid") + } +} + +func TestParseByteSingleRangeSuccess(t *testing.T) { t.Parallel() - testParseByteRangeSuccess(t, "bytes=0-0", 1, 0, 0) - testParseByteRangeSuccess(t, "bytes=1234-6789", 6790, 1234, 6789) + testParseByteRangeSuccess(t, "bytes=0-0", 1, []int{0}, []int{0}) + testParseByteRangeSuccess(t, "bytes=1234-6789", 6790, []int{1234}, []int{6789}) + + testParseByteRangeSuccess(t, "bytes=123-", 456, []int{123}, []int{455}) + testParseByteRangeSuccess(t, "bytes=-1", 1, []int{0}, []int{0}) + testParseByteRangeSuccess(t, "bytes=-123", 456, []int{333}, []int{455}) + + // End position exceeding content-length. It should be updated to content-length-1. + // See https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 + testParseByteRangeSuccess(t, "bytes=1-2345", 234, []int{1}, []int{233}) + testParseByteRangeSuccess(t, "bytes=0-2345", 2345, []int{0}, []int{2344}) + + // Start position overflow. Whole range must be returned. + // See https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 + testParseByteRangeSuccess(t, "bytes=-567", 56, []int{0}, []int{55}) +} + +func TestParseByteMultiRangeSuccess(t *testing.T) { + t.Parallel() - testParseByteRangeSuccess(t, "bytes=123-", 456, 123, 455) - testParseByteRangeSuccess(t, "bytes=-1", 1, 0, 0) - testParseByteRangeSuccess(t, "bytes=-123", 456, 333, 455) + testParseByteRangeSuccess(t, "bytes=1234-6789,23-342", 6790, []int{1234, 23}, []int{6789, 342}) + testParseByteRangeSuccess(t, "bytes=123-,-123", 456, []int{123, 333}, []int{455, 455}) // End position exceeding content-length. It should be updated to content-length-1. // See https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 - testParseByteRangeSuccess(t, "bytes=1-2345", 234, 1, 233) - testParseByteRangeSuccess(t, "bytes=0-2345", 2345, 0, 2344) + testParseByteRangeSuccess(t, "bytes=1-2345,1-345", 234, []int{1, 1}, []int{233, 233}) + + testParseByteRangeSuccess(t, "bytes=0-2345,23-1234", 2345, []int{0, 23}, []int{2344, 1234}) // Start position overflow. Whole range must be returned. // See https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 - testParseByteRangeSuccess(t, "bytes=-567", 56, 0, 55) + testParseByteRangeSuccess(t, "bytes=-567,-765", 56, []int{0, 0}, []int{55, 55}) } -func testParseByteRangeSuccess(t *testing.T, v string, contentLength, startPos, endPos int) { - startPos1, endPos1, err := ParseByteRange([]byte(v), contentLength) +func testParseByteRangeSuccess(t *testing.T, v string, contentLength int, startPos, endPos []int) { + startPos1, endPos1, err := ParseByteRanges([]byte(v), contentLength) if err != nil { - t.Fatalf("unexpected error: %v. v=%q, contentLength=%d", err, v, contentLength) - } - if startPos1 != startPos { - t.Fatalf("unexpected startPos=%d. Expecting %d. v=%q, contentLength=%d", startPos1, startPos, v, contentLength) + t.Fatalf("unexpected error: %s. v=%q, contentLength=%d", err, v, contentLength) } - if endPos1 != endPos { - t.Fatalf("unexpected endPos=%d. Expectind %d. v=%q, contentLenght=%d", endPos1, endPos, v, contentLength) + for i := range startPos1 { + if startPos1[i] != startPos[i] { + t.Fatalf("unexpected startPos=%d. Expecting %d. v=%q, contentLength=%d", startPos1[i], startPos[i], v, contentLength) + } + if endPos1[i] != endPos[i] { + t.Fatalf("unexpected endPos=%d. Expectind %d. v=%q, contentLength=%d", endPos1[i], endPos[i], v, contentLength) + } } } @@ -461,9 +826,6 @@ func TestParseByteRangeError(t *testing.T) { testParseByteRangeError(t, "bytes=1-foobar", 123) testParseByteRangeError(t, "bytes=df-344", 545) - // multiple byte ranges - testParseByteRangeError(t, "bytes=1-2,4-6", 123) - // byte range exceeding contentLength testParseByteRangeError(t, "bytes=123-", 12) @@ -472,7 +834,7 @@ func TestParseByteRangeError(t *testing.T) { } func testParseByteRangeError(t *testing.T, v string, contentLength int) { - _, _, err := ParseByteRange([]byte(v), contentLength) + _, _, err := ParseByteRanges([]byte(v), contentLength) if err == nil { t.Fatalf("expecting error when parsing byte range %q", v) } diff --git a/header.go b/header.go index fe4f669553..82927678e4 100644 --- a/header.go +++ b/header.go @@ -113,24 +113,33 @@ func (h *ResponseHeader) SetContentRange(startPos, endPos, contentLength int) { // // - If startPos is negative, then 'bytes=-startPos' value is set. // - If endPos is negative, then 'bytes=startPos-' value is set. -func (h *RequestHeader) SetByteRange(startPos, endPos int) { +func (h *RequestHeader) SetByteRanges(startPos, endPos []int) { b := h.bufKV.value[:0] b = append(b, strBytes...) b = append(b, '=') - if startPos >= 0 { - b = AppendUint(b, startPos) - } else { - endPos = -startPos - } - b = append(b, '-') - if endPos >= 0 { - b = AppendUint(b, endPos) + for i := range startPos { + if i > 0 { + b = append(b, ',') + } + if startPos[i] >= 0 { + b = AppendUint(b, startPos[i]) + } else { + endPos[i] = -startPos[i] + } + b = append(b, '-') + if endPos[i] >= 0 { + b = AppendUint(b, endPos[i]) + } } h.bufKV.value = b h.setNonSpecial(strRange, h.bufKV.value) } +func (h *RequestHeader) SetByteRange(startPos, endPos int) { + h.SetByteRanges([]int{startPos}, []int{endPos}) +} + // StatusCode returns response status code. func (h *ResponseHeader) StatusCode() int { if h.statusCode == 0 { diff --git a/header_test.go b/header_test.go index 7d2942551d..31b594117f 100644 --- a/header_test.go +++ b/header_test.go @@ -974,7 +974,7 @@ func TestRequestHeaderSetByteRange(t *testing.T) { func testRequestHeaderSetByteRange(t *testing.T, startPos, endPos int, expectedV string) { var h RequestHeader - h.SetByteRange(startPos, endPos) + h.SetByteRanges([]int{startPos}, []int{endPos}) v := h.Peek(HeaderRange) if string(v) != expectedV { t.Fatalf("unexpected range: %q. Expecting %q. startPos=%d, endPos=%d", v, expectedV, startPos, endPos)