diff --git a/conn.go b/conn.go index ca46d2f7..2296bf37 100644 --- a/conn.go +++ b/conn.go @@ -13,6 +13,7 @@ import ( "math/rand" "net" "strconv" + "strings" "sync" "time" "unicode/utf8" @@ -794,47 +795,69 @@ func (c *Conn) advanceFrame() (int, error) { } // 2. Read and parse first two bytes of frame header. + // To aid debugging, collect and report all errors in the first two bytes + // of the header. + + var errors []string p, err := c.read(2) if err != nil { return noFrame, err } - final := p[0]&finalBit != 0 frameType := int(p[0] & 0xf) + final := p[0]&finalBit != 0 + rsv1 := p[0]&rsv1Bit != 0 + rsv2 := p[0]&rsv2Bit != 0 + rsv3 := p[0]&rsv3Bit != 0 mask := p[1]&maskBit != 0 c.setReadRemaining(int64(p[1] & 0x7f)) c.readDecompress = false - if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { - c.readDecompress = true - p[0] &^= rsv1Bit + if rsv1 { + if c.newDecompressionReader != nil { + c.readDecompress = true + } else { + errors = append(errors, "RSV1 set") + } } - if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { - return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) + if rsv2 { + errors = append(errors, "RSV2 set") + } + + if rsv3 { + errors = append(errors, "RSV3 set") } switch frameType { case CloseMessage, PingMessage, PongMessage: if c.readRemaining > maxControlFramePayloadSize { - return noFrame, c.handleProtocolError("control frame length > 125") + errors = append(errors, "len > 125 for control") } if !final { - return noFrame, c.handleProtocolError("control frame not final") + errors = append(errors, "FIN not set on control") } case TextMessage, BinaryMessage: if !c.readFinal { - return noFrame, c.handleProtocolError("message start before final message frame") + errors = append(errors, "data before FIN") } c.readFinal = final case continuationFrame: if c.readFinal { - return noFrame, c.handleProtocolError("continuation after final message frame") + errors = append(errors, "continuation after FIN") } c.readFinal = final default: - return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) + errors = append(errors, "bad opcode "+strconv.Itoa(frameType)) + } + + if mask != c.isServer { + errors = append(errors, "bad MASK") + } + + if len(errors) > 0 { + return noFrame, c.handleProtocolError(strings.Join(errors, ", ")) } // 3. Read and parse frame length as per @@ -872,10 +895,6 @@ func (c *Conn) advanceFrame() (int, error) { // 4. Handle frame masking. - if mask != c.isServer { - return noFrame, c.handleProtocolError("incorrect mask flag") - } - if mask { c.readMaskPos = 0 p, err := c.read(len(c.readMaskKey)) @@ -935,7 +954,7 @@ func (c *Conn) advanceFrame() (int, error) { if len(payload) >= 2 { closeCode = int(binary.BigEndian.Uint16(payload)) if !isValidReceivedCloseCode(closeCode) { - return noFrame, c.handleProtocolError("invalid close code") + return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode)) } closeText = string(payload[2:]) if !utf8.ValidString(closeText) { @@ -952,7 +971,11 @@ func (c *Conn) advanceFrame() (int, error) { } func (c *Conn) handleProtocolError(message string) error { - c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) + data := FormatCloseMessage(CloseProtocolError, message) + if len(data) > maxControlFramePayloadSize { + data = data[:maxControlFramePayloadSize] + } + c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) return errors.New("websocket: " + message) }