Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

buffer: Use a double-buffering scheme to prevent data races #943

Merged
merged 6 commits into from
Apr 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 54 additions & 0 deletions benchmark_test.go
Expand Up @@ -317,3 +317,57 @@ func BenchmarkExecContext(b *testing.B) {
})
}
}

// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes.
// "size=" means size of each blobs.
func BenchmarkQueryRawBytes(b *testing.B) {
var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000}
db := initDB(b,
"DROP TABLE IF EXISTS bench_rawbytes",
"CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)",
)
defer db.Close()

blob := make([]byte, sizes[len(sizes)-1])
for i := range blob {
blob[i] = 42
}
for i := 0; i < 100; i++ {
_, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob)
if err != nil {
b.Fatal(err)
}
}

for _, s := range sizes {
b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) {
db.SetMaxIdleConns(0)
db.SetMaxIdleConns(1)
b.ReportAllocs()
b.ResetTimer()

for j := 0; j < b.N; j++ {
rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s)
if err != nil {
b.Fatal(err)
}
nrows := 0
for rows.Next() {
var buf sql.RawBytes
err := rows.Scan(&buf)
if err != nil {
b.Fatal(err)
}
if len(buf) != s {
b.Fatalf("size mismatch: expected %v, got %v", s, len(buf))
}
nrows++
}
rows.Close()
if nrows != 100 {
b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows)
}
}
})
}
}
48 changes: 35 additions & 13 deletions buffer.go
Expand Up @@ -15,47 +15,69 @@ import (
)

const defaultBufSize = 4096
const maxCachedBufSize = 256 * 1024

// A buffer which is used for both reading and writing.
// This is possible since communication on each connection is synchronous.
// In other words, we can't write and read simultaneously on the same connection.
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
// This buffer is backed by two byte slices in a double-buffering scheme
type buffer struct {
buf []byte // buf is a byte buffer who's length and capacity are equal.
nc net.Conn
idx int
length int
timeout time.Duration
dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer
flipcnt uint // flipccnt is the current buffer counter for double-buffering
}

// newBuffer allocates and returns a new buffer.
func newBuffer(nc net.Conn) buffer {
fg := make([]byte, defaultBufSize)
return buffer{
buf: make([]byte, defaultBufSize),
nc: nc,
buf: fg,
nc: nc,
dbuf: [2][]byte{fg, nil},
}
}

// flip replaces the active buffer with the background buffer
// this is a delayed flip that simply increases the buffer counter;
// the actual flip will be performed the next time we call `buffer.fill`
func (b *buffer) flip() {
b.flipcnt += 1
}

// fill reads into the buffer until at least _need_ bytes are in it
func (b *buffer) fill(need int) error {
n := b.length
// fill data into its double-buffering target: if we've called
// flip on this buffer, we'll be copying to the background buffer,
// and then filling it with network data; otherwise we'll just move
// the contents of the current buffer to the front before filling it
dest := b.dbuf[b.flipcnt&1]

// grow buffer if necessary to fit the whole packet.
if need > len(dest) {
// Round up to the next multiple of the default size
dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
vmg marked this conversation as resolved.
Show resolved Hide resolved

// move existing data to the beginning
if n > 0 && b.idx > 0 {
copy(b.buf[0:n], b.buf[b.idx:])
// if the allocated buffer is not too large, move it to backing storage
// to prevent extra allocations on applications that perform large reads
if len(dest) <= maxCachedBufSize {
b.dbuf[b.flipcnt&1] = dest
}
}

// grow buffer if necessary
// TODO: let the buffer shrink again at some point
// Maybe keep the org buf slice and swap back?
if need > len(b.buf) {
// Round up to the next multiple of the default size
newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
copy(newBuf, b.buf)
b.buf = newBuf
// if we're filling the fg buffer, move the existing data to the start of it.
// if we're filling the bg buffer, copy over the data
if n > 0 {
copy(dest[:n], b.buf[b.idx:])
}

b.buf = dest
b.idx = 0

for {
Expand Down
55 changes: 55 additions & 0 deletions driver_test.go
Expand Up @@ -2938,3 +2938,58 @@ func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
// This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value()
})
}

// TestRawBytesAreNotModified checks for a race condition that arises when a query context
// is canceled while a user is calling rows.Scan. This is a more stringent test than the one
// proposed in https://github.com/golang/go/issues/23519. Here we're explicitly using
// `sql.RawBytes` to check the contents of our internal buffers are not modified after an implicit
// call to `Rows.Close`, so Context cancellation should **not** invalidate the backing buffers.
func TestRawBytesAreNotModified(t *testing.T) {
const blob = "abcdefghijklmnop"
const contextRaceIterations = 20
const blobSize = defaultBufSize * 3 / 4 // Second row overwrites first row.
const insertRows = 4

var sqlBlobs = [2]string{
strings.Repeat(blob, blobSize/len(blob)),
strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)),
}

runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8")
for i := 0; i < insertRows; i++ {
dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1])
}

for i := 0; i < contextRaceIterations; i++ {
func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`)
if err != nil {
t.Fatal(err)
}

var b int
var raw sql.RawBytes
for rows.Next() {
if err := rows.Scan(&b, &raw); err != nil {
t.Fatal(err)
}

before := string(raw)
// Ensure cancelling the query does not corrupt the contents of `raw`
cancel()
time.Sleep(time.Microsecond * 100)
after := string(raw)

if before != after {
t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i)
}
}
rows.Close()
}()
}
})
}
7 changes: 7 additions & 0 deletions rows.go
Expand Up @@ -111,6 +111,13 @@ func (rows *mysqlRows) Close() (err error) {
return err
}

// flip the buffer for this connection if we need to drain it.
// note that for a successful query (i.e. one where rows.next()
// has been called until it returns false), `rows.mc` will be nil
// by the time the user calls `(*Rows).Close`, so we won't reach this
// see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47
mc.buf.flip()

// Remove unread packets from stream
if !rows.rs.done {
err = mc.readUntilEOF()
Expand Down