Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve protocol error messages #754

Merged
merged 3 commits into from Jan 2, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 40 additions & 17 deletions conn.go
Expand Up @@ -13,6 +13,7 @@ import (
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
Expand Down Expand Up @@ -794,47 +795,69 @@ func (c *Conn) advanceFrame() (int, error) {
}

// 2. Read and parse first two bytes of frame header.
// To aid debugging, collect and report all errors in the first two bytes
// of the header.

var errors []string

p, err := c.read(2)
if err != nil {
return noFrame, err
}

final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
final := p[0]&finalBit != 0
rsv1 := p[0]&rsv1Bit != 0
rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f))

c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
c.readDecompress = true
p[0] &^= rsv1Bit
if rsv1 {
if c.newDecompressionReader != nil {
c.readDecompress = true
} else {
errors = append(errors, "RSV1 set")
}
}

if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
if rsv2 {
errors = append(errors, "RSV2 set")
}

if rsv3 {
errors = append(errors, "RSV3 set")
}

switch frameType {
case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize {
return noFrame, c.handleProtocolError("control frame length > 125")
errors = append(errors, "len > 125 for control")
}
if !final {
return noFrame, c.handleProtocolError("control frame not final")
errors = append(errors, "FIN not set on control")
}
case TextMessage, BinaryMessage:
if !c.readFinal {
return noFrame, c.handleProtocolError("message start before final message frame")
errors = append(errors, "data before FIN")
}
c.readFinal = final
case continuationFrame:
if c.readFinal {
return noFrame, c.handleProtocolError("continuation after final message frame")
errors = append(errors, "continuation after FIN")
}
c.readFinal = final
default:
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
}

if mask != c.isServer {
errors = append(errors, "bad MASK")
}

if len(errors) > 0 {
return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
}

// 3. Read and parse frame length as per
Expand Down Expand Up @@ -872,10 +895,6 @@ func (c *Conn) advanceFrame() (int, error) {

// 4. Handle frame masking.

if mask != c.isServer {
return noFrame, c.handleProtocolError("incorrect mask flag")
}

if mask {
c.readMaskPos = 0
p, err := c.read(len(c.readMaskKey))
Expand Down Expand Up @@ -935,7 +954,7 @@ func (c *Conn) advanceFrame() (int, error) {
if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) {
return noFrame, c.handleProtocolError("invalid close code")
return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
}
closeText = string(payload[2:])
if !utf8.ValidString(closeText) {
Expand All @@ -952,7 +971,11 @@ func (c *Conn) advanceFrame() (int, error) {
}

func (c *Conn) handleProtocolError(message string) error {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
data := FormatCloseMessage(CloseProtocolError, message)
if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize]
}
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}

Expand Down