Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read Limit Fix #537

Merged
merged 5 commits into from Aug 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 48 additions & 10 deletions conn.go
Expand Up @@ -260,10 +260,12 @@ type Conn struct {
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser

// Read fields
reader io.ReadCloser // the current reader returned to the application
readErr error
br *bufio.Reader
readRemaining int64 // bytes remaining in current frame.
reader io.ReadCloser // the current reader returned to the application
readErr error
br *bufio.Reader
// bytes remaining in current frame.
// set setReadRemaining to safely update this value and prevent overflow
readRemaining int64
readFinal bool // true the current message has more frames.
readLength int64 // Message size.
readLimit int64 // Maximum message size.
Expand Down Expand Up @@ -320,6 +322,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
return c
}

// setReadRemaining tracks the number of bytes remaining on the connection. If n
// overflows, an ErrReadLimit is returned.
func (c *Conn) setReadRemaining(n int64) error {
if n < 0 {
return ErrReadLimit
}

c.readRemaining = n
return nil
}

// Subprotocol returns the negotiated protocol for the connection.
func (c *Conn) Subprotocol() string {
return c.subprotocol
Expand Down Expand Up @@ -790,7 +803,7 @@ func (c *Conn) advanceFrame() (int, error) {
final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
mask := p[1]&maskBit != 0
c.readRemaining = int64(p[1] & 0x7f)
c.setReadRemaining(int64(p[1] & 0x7f))

c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
Expand Down Expand Up @@ -824,21 +837,37 @@ func (c *Conn) advanceFrame() (int, error) {
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
}

// 3. Read and parse frame length.
// 3. Read and parse frame length as per
// https://tools.ietf.org/html/rfc6455#section-5.2
//
// The length of the "Payload data", in bytes: if 0-125, that is the payload
// length.
// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
// integer are the payload length.
// - If 127, the following 8 bytes interpreted as
// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
// payload length. Multibyte length quantities are expressed in network byte
// order.

switch c.readRemaining {
case 126:
p, err := c.read(2)
if err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint16(p))

if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
return noFrame, err
}
case 127:
p, err := c.read(8)
if err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint64(p))

if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
return noFrame, err
}
}

// 4. Handle frame masking.
Expand All @@ -861,6 +890,12 @@ func (c *Conn) advanceFrame() (int, error) {
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {

c.readLength += c.readRemaining
// Don't allow readLength to overflow in the presence of a large readRemaining
// counter.
if c.readLength < 0 {
return noFrame, ErrReadLimit
}

if c.readLimit > 0 && c.readLength > c.readLimit {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
return noFrame, ErrReadLimit
Expand All @@ -874,7 +909,7 @@ func (c *Conn) advanceFrame() (int, error) {
var payload []byte
if c.readRemaining > 0 {
payload, err = c.read(int(c.readRemaining))
c.readRemaining = 0
c.setReadRemaining(0)
if err != nil {
return noFrame, err
}
Expand Down Expand Up @@ -947,6 +982,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c.readErr = hideTempErr(err)
break
}

if frameType == TextMessage || frameType == BinaryMessage {
c.messageReader = &messageReader{c}
c.reader = c.messageReader
Expand Down Expand Up @@ -987,7 +1023,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
if c.isServer {
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
}
c.readRemaining -= int64(n)
rem := c.readRemaining
rem -= int64(n)
c.setReadRemaining(rem)
if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF
}
Expand Down
115 changes: 88 additions & 27 deletions conn_test.go
Expand Up @@ -55,7 +55,10 @@ func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
}

func TestFraming(t *testing.T) {
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
frameSizes := []int{
0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
// 65536, 65537
}
var readChunkers = []struct {
name string
f func(io.Reader) io.Reader
Expand Down Expand Up @@ -120,6 +123,8 @@ func TestFraming(t *testing.T) {
t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
continue
}

t.Logf("frame size: %d", n)
rbuf, err := ioutil.ReadAll(r)
if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
Expand Down Expand Up @@ -458,37 +463,93 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
}

func TestReadLimit(t *testing.T) {
t.Run("Test ReadLimit is enforced", func(t *testing.T) {
const readLimit = 512
message := make([]byte, readLimit+1)

const readLimit = 512
message := make([]byte, readLimit+1)
var b1, b2 bytes.Buffer
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
rc := newTestConn(&b1, &b2, true)
rc.SetReadLimit(readLimit)

var b1, b2 bytes.Buffer
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
rc := newTestConn(&b1, &b2, true)
rc.SetReadLimit(readLimit)
// Send message at the limit with interleaved pong.
w, _ := wc.NextWriter(BinaryMessage)
w.Write(message[:readLimit-1])
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
w.Write(message[:1])
w.Close()

// Send message at the limit with interleaved pong.
w, _ := wc.NextWriter(BinaryMessage)
w.Write(message[:readLimit-1])
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
w.Write(message[:1])
w.Close()
// Send message larger than the limit.
wc.WriteMessage(BinaryMessage, message[:readLimit+1])

// Send message larger than the limit.
wc.WriteMessage(BinaryMessage, message[:readLimit+1])
op, _, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("1: NextReader() returned %d, %v", op, err)
}
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("2: NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err)
}
})

op, _, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("1: NextReader() returned %d, %v", op, err)
}
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("2: NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err)
}
t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) {
const readLimit = 1

var b1, b2 bytes.Buffer
rc := newTestConn(&b1, &b2, true)
rc.SetReadLimit(readLimit)

// First, send a non-final binary message
b1.Write([]byte("\x02\x81"))

// Mask key
b1.Write([]byte("\x00\x00\x00\x00"))

// First payload
b1.Write([]byte("A"))

// Next, send a negative-length, non-final continuation frame
b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00"))

// Mask key
b1.Write([]byte("\x00\x00\x00\x00"))

// Next, send a too long, final continuation frame
b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05"))

// Mask key
b1.Write([]byte("\x00\x00\x00\x00"))

// Too-long payload
b1.Write([]byte("BCDEF"))

op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("1: NextReader() returned %d, %v", op, err)
}

var buf [10]byte
var read int
n, err := r.Read(buf[:])
if err != nil && err != ErrReadLimit {
t.Fatalf("unexpected error testing read limit: %v", err)
}
read += n

n, err = r.Read(buf[:])
if err != nil && err != ErrReadLimit {
t.Fatalf("unexpected error testing read limit: %v", err)
}
read += n

if err == nil && read > readLimit {
t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read)
}
})
}

func TestAddrs(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions go.mod
@@ -1 +1,3 @@
module github.com/gorilla/websocket

go 1.12