Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return write buffer to pool on write error #427

Merged
merged 1 commit into from Sep 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 24 additions & 23 deletions conn.go
Expand Up @@ -451,7 +451,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return err
}

func (c *Conn) prepWrite(messageType int) error {
// beginMessage prepares a connection and message writer for a new message.
func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
// Close previous writer if not already closed by the application. It's
// probably better to return an error in this situation, but we cannot
// change this without breaking existing applications.
Expand All @@ -471,6 +472,10 @@ func (c *Conn) prepWrite(messageType int) error {
return err
}

mw.c = c
mw.frameType = messageType
mw.pos = maxFrameHeaderSize

if c.writeBuf == nil {
wpd, ok := c.writePool.Get().(writePoolData)
if ok {
Expand All @@ -491,16 +496,11 @@ func (c *Conn) prepWrite(messageType int) error {
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
// PongMessage) are supported.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
if err := c.prepWrite(messageType); err != nil {
var mw messageWriter
if err := c.beginMessage(&mw, messageType); err != nil {
return nil, err
}

mw := &messageWriter{
c: c,
frameType: messageType,
pos: maxFrameHeaderSize,
}
c.writer = mw
c.writer = &mw
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
w := c.newCompressionWriter(c.writer, c.compressionLevel)
mw.compress = true
Expand All @@ -517,10 +517,16 @@ type messageWriter struct {
err error
}

func (w *messageWriter) fatal(err error) error {
func (w *messageWriter) endMessage(err error) error {
if w.err != nil {
w.err = err
w.c.writer = nil
return err
}
c := w.c
w.err = err
c.writer = nil
if c.writePool != nil {
c.writePool.Put(writePoolData{buf: c.writeBuf})
c.writeBuf = nil
}
return err
}
Expand All @@ -534,7 +540,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
// Check for invalid control frames.
if isControl(w.frameType) &&
(!final || length > maxControlFramePayloadSize) {
return w.fatal(errInvalidControlFrame)
return w.endMessage(errInvalidControlFrame)
}

b0 := byte(w.frameType)
Expand Down Expand Up @@ -579,7 +585,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
if len(extra) > 0 {
return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
}
}

Expand All @@ -600,15 +606,11 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
c.isWriting = false

if err != nil {
return w.fatal(err)
return w.endMessage(err)
}

if final {
c.writer = nil
if c.writePool != nil {
c.writePool.Put(writePoolData{buf: c.writeBuf})
c.writeBuf = nil
}
w.endMessage(errWriteClosed)
return nil
}

Expand Down Expand Up @@ -709,7 +711,6 @@ func (w *messageWriter) Close() error {
if err := w.flushFrame(true, nil); err != nil {
return err
}
w.err = errWriteClosed
return nil
}

Expand Down Expand Up @@ -742,10 +743,10 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
// Fast path with no allocations and single frame.

if err := c.prepWrite(messageType); err != nil {
var mw messageWriter
if err := c.beginMessage(&mw, messageType); err != nil {
return err
}
mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
n := copy(c.writeBuf[mw.pos:], data)
mw.pos += n
data = data[n:]
Expand Down
60 changes: 57 additions & 3 deletions conn_test.go
Expand Up @@ -196,11 +196,16 @@ func (p *simpleBufferPool) Put(v interface{}) {
}

func TestWriteBufferPool(t *testing.T) {
const message = "Now is the time for all good people to come to the aid of the party."

var buf bytes.Buffer
var pool simpleBufferPool
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
rc := newTestConn(&buf, nil, false)

// Specify writeBufferSize smaller than message size to ensure that pooling
// works with fragmented messages.
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)

if wc.writeBuf != nil {
t.Fatal("writeBuf not nil after create")
}
Expand All @@ -218,8 +223,6 @@ func TestWriteBufferPool(t *testing.T) {

writeBufAddr := &wc.writeBuf[0]

const message = "Hello World!"

if _, err := io.WriteString(w, message); err != nil {
t.Fatalf("io.WriteString(w, message) returned %v", err)
}
Expand Down Expand Up @@ -269,6 +272,7 @@ func TestWriteBufferPool(t *testing.T) {
}
}

// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
func TestWriteBufferPoolSync(t *testing.T) {
var buf bytes.Buffer
var pool sync.Pool
Expand All @@ -290,6 +294,56 @@ func TestWriteBufferPoolSync(t *testing.T) {
}
}

// errorWriter is an io.Writer than returns an error on all writes.
type errorWriter struct{}

func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("Error!") }

// TestWriteBufferPoolError ensures that buffer is returned to pool after error
// on write.
func TestWriteBufferPoolError(t *testing.T) {

// Part 1: Test NextWriter/Write/Close

var pool simpleBufferPool
wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)

w, err := wc.NextWriter(TextMessage)
if err != nil {
t.Fatalf("wc.NextWriter() returned %v", err)
}

if wc.writeBuf == nil {
t.Fatal("writeBuf is nil after NextWriter")
}

writeBufAddr := &wc.writeBuf[0]

if _, err := io.WriteString(w, "Hello"); err != nil {
t.Fatalf("io.WriteString(w, message) returned %v", err)
}

if err := w.Close(); err == nil {
t.Fatalf("w.Close() did not return error")
}

if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
t.Fatal("writeBuf not returned to pool")
}

// Part 2: Test WriteMessage

wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)

if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
t.Fatalf("wc.WriteMessage did not return error")
}

if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
t.Fatal("writeBuf not returned to pool")
}
}

func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
const bufSize = 512

Expand Down