Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(bigquery/storage/managedstorage): improve internal locking #6304

Merged
merged 10 commits into from
Jul 7, 2022
190 changes: 97 additions & 93 deletions bigquery/storage/managedwriter/managed_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,94 +255,93 @@ func (ms *ManagedStream) append(requestCtx context.Context, pw *pendingWrite, op
r = settings.Retry()
}

var arc *storagepb.BigQueryWrite_AppendRowsClient
var ch chan *pendingWrite
var err error

for {
// critical section: Things that need to happen inside the critical section:
//
// * Getting the stream connection (in case of reconnects)
// * Issuing the append request
// * Adding the pending write to the channel to keep ordering correct on response
ms.mu.Lock()
appendErr, numRows := func() (error, int64) {
shollyman marked this conversation as resolved.
Show resolved Hide resolved
// critical section: Things that need to happen inside the critical section:
//
// * Getting the stream connection (in case of reconnects)
// * Issuing the append request
// * Adding the pending write to the channel to keep ordering correct on response
ms.mu.Lock()
defer ms.mu.Unlock()

var arc *storagepb.BigQueryWrite_AppendRowsClient
var ch chan *pendingWrite
var err error

// Don't both calling/retrying if this append's context is already expired.
if err = requestCtx.Err(); err != nil {
return err, 0
}

// Don't both calling/retrying if this append's context is already expired.
if err = requestCtx.Err(); err != nil {
return err
}
// If an updated schema is present, we need to reconnect the stream and update the reference
// schema for the stream.
reconnect := false
if pw.newSchema != nil && !proto.Equal(pw.newSchema, ms.schemaDescriptor) {
reconnect = true
ms.schemaDescriptor = proto.Clone(pw.newSchema).(*descriptorpb.DescriptorProto)
}
arc, ch, err = ms.getStream(arc, reconnect)
if err != nil {
return err, 0
}

// If an updated schema is present, we need to reconnect the stream and update the reference
// schema for the stream.
reconnect := false
if pw.newSchema != nil && !proto.Equal(pw.newSchema, ms.schemaDescriptor) {
reconnect = true
ms.schemaDescriptor = proto.Clone(pw.newSchema).(*descriptorpb.DescriptorProto)
}
arc, ch, err = ms.getStream(arc, reconnect)
if err != nil {
return err
}
// Resolve the special work for the first append on a stream.
var req *storagepb.AppendRowsRequest
ms.streamSetup.Do(func() {
reqCopy := proto.Clone(pw.request).(*storagepb.AppendRowsRequest)
reqCopy.WriteStream = ms.streamSettings.streamID
reqCopy.GetProtoRows().WriterSchema = &storagepb.ProtoSchema{
ProtoDescriptor: ms.schemaDescriptor,
}
if ms.streamSettings.TraceID != "" {
reqCopy.TraceId = ms.streamSettings.TraceID
}
req = reqCopy
})

// Resolve the special work for the first append on a stream.
var req *storagepb.AppendRowsRequest
ms.streamSetup.Do(func() {
reqCopy := proto.Clone(pw.request).(*storagepb.AppendRowsRequest)
reqCopy.WriteStream = ms.streamSettings.streamID
reqCopy.GetProtoRows().WriterSchema = &storagepb.ProtoSchema{
ProtoDescriptor: ms.schemaDescriptor,
if req != nil {
// First append in a new connection needs properties like schema and stream name set.
err = (*arc).Send(req)
} else {
// Subsequent requests need no modification.
err = (*arc).Send(pw.request)
}
if ms.streamSettings.TraceID != "" {
reqCopy.TraceId = ms.streamSettings.TraceID
if err != nil {
return err, 0
}
req = reqCopy
})

if req != nil {
// First append in a new connection needs properties like schema and stream name set.
err = (*arc).Send(req)
} else {
// Subsequent requests need no modification.
err = (*arc).Send(pw.request)
}
if err == nil {
// Compute numRows, once we pass ownership to the channel the request may be
// cleared.
numRows := int64(len(pw.request.GetProtoRows().Rows.GetSerializedRows()))
ch <- pw
// We've passed ownership of the pending write to the channel.
// It's now responsible for marking the request done, we're done
// with the critical section.
ms.mu.Unlock()
return nil, numRows
}()

// Record stats and return.
recordStat(ms.ctx, AppendRequests, 1)
recordStat(ms.ctx, AppendRequestBytes, int64(pw.reqSize))
recordStat(ms.ctx, AppendRequestRows, numRows)
return nil
}
// Unlock the mutex for error cases.
ms.mu.Unlock()

// Append yielded an error. Retry by continuing or return.
status := grpcstatus.Convert(err)
if status != nil {
ctx, _ := tag.New(ms.ctx, tag.Insert(keyError, status.Code().String()))
recordStat(ctx, AppendRequestErrors, 1)
}
bo, shouldRetry := r.Retry(err)
if shouldRetry {
if err := gax.Sleep(ms.ctx, bo); err != nil {
return err
if appendErr != nil {
// Append yielded an error. Retry by continuing or return.
status := grpcstatus.Convert(appendErr)
if status != nil {
ctx, _ := tag.New(ms.ctx, tag.Insert(keyError, status.Code().String()))
recordStat(ctx, AppendRequestErrors, 1)
}
continue
bo, shouldRetry := r.Retry(appendErr)
if shouldRetry {
if err := gax.Sleep(ms.ctx, bo); err != nil {
return err
}
continue
}
// We've got a non-retriable error, so propagate that up. and mark the write done.
ms.mu.Lock()
ms.err = appendErr
pw.markDone(NoStreamOffset, appendErr, ms.fc)
ms.mu.Unlock()
return appendErr
}
// We've got a non-retriable error, so propagate that up. and mark the write done.
ms.mu.Lock()
ms.err = err
pw.markDone(NoStreamOffset, err, ms.fc)
ms.mu.Unlock()
return err
recordStat(ms.ctx, AppendRequests, 1)
recordStat(ms.ctx, AppendRequestBytes, int64(pw.reqSize))
recordStat(ms.ctx, AppendRequestRows, numRows)
return nil
}
}

Expand All @@ -351,28 +350,33 @@ func (ms *ManagedStream) Close() error {

var arc *storagepb.BigQueryWrite_AppendRowsClient

// Critical section: get connection, close, mark closed.
ms.mu.Lock()
arc, ch, err := ms.getStream(arc, false)
if err != nil {
return err
}
if ms.arc == nil {
return fmt.Errorf("no stream exists")
}
err = (*arc).CloseSend()
if err == nil {
closeErr := func() error {
shollyman marked this conversation as resolved.
Show resolved Hide resolved
// Critical section: get connection, close, mark closed.
ms.mu.Lock()
defer ms.mu.Unlock()
arc, ch, err := ms.getStream(arc, false)
if err != nil {
return err
}
if ms.arc == nil {
return fmt.Errorf("no stream exists")
}
err = (*arc).CloseSend()
if err != nil {
ms.err = err
return err
}
// Mark the stream closed, then return
close(ch)
}
ms.err = io.EOF

// Done with the critical section.
ms.mu.Unlock()
// Propagate cancellation.
ms.err = io.EOF
return nil
}()
// Cancel the underlying context for the stream as well.
if ms.cancel != nil {
ms.cancel()
ms.cancel = nil
}
return err
return closeErr
}

// AppendRows sends the append requests to the service, and returns a single AppendResult for tracking
Expand Down
92 changes: 92 additions & 0 deletions bigquery/storage/managedwriter/managed_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package managedwriter

import (
"context"
"errors"
"runtime"
"testing"
"time"
Expand Down Expand Up @@ -94,6 +95,7 @@ type testAppendRowsClient struct {
requests []*storagepb.AppendRowsRequest
sendF func(*storagepb.AppendRowsRequest) error
recvF func() (*storagepb.AppendRowsResponse, error)
closeF func() error
}

func (tarc *testAppendRowsClient) Send(req *storagepb.AppendRowsRequest) error {
Expand All @@ -104,6 +106,10 @@ func (tarc *testAppendRowsClient) Recv() (*storagepb.AppendRowsResponse, error)
return tarc.recvF()
}

func (tarc *testAppendRowsClient) CloseSend() error {
return tarc.closeF()
}

// openTestArc handles wiring in a test AppendRowsClient into a managedstream by providing the open function.
func openTestArc(testARC *testAppendRowsClient, sendF func(req *storagepb.AppendRowsRequest) error, recvF func() (*storagepb.AppendRowsResponse, error)) func(s string, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
sF := func(req *storagepb.AppendRowsRequest) error {
Expand All @@ -123,6 +129,9 @@ func openTestArc(testARC *testAppendRowsClient, sendF func(req *storagepb.Append
}
testARC.sendF = sF
testARC.recvF = rF
testARC.closeF = func() error {
return nil
}
return func(s string, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
testARC.openCount = testARC.openCount + 1
return testARC, nil
Expand Down Expand Up @@ -291,6 +300,89 @@ func TestManagedStream_AppendWithDeadline(t *testing.T) {

}

func TestManagedStream_AppendDeadlocks(t *testing.T) {
// Ensure we don't deadlock by issing two appends.
testCases := []struct {
desc string
openErrors []error
ctx context.Context
respErr error
}{
{
desc: "no errors",
openErrors: []error{nil, nil},
ctx: context.Background(),
respErr: nil,
},
{
desc: "cancelled caller context",
openErrors: []error{nil, nil},
ctx: func() context.Context {
cctx, cancel := context.WithCancel(context.Background())
cancel()
return cctx
}(),
respErr: context.Canceled,
},
{
desc: "expired caller context",
openErrors: []error{nil, nil},
ctx: func() context.Context {
cctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
time.Sleep(2 * time.Millisecond)
return cctx
}(),
respErr: context.DeadlineExceeded,
},
{
desc: "errored getstream",
openErrors: []error{status.Errorf(codes.ResourceExhausted, "some error"), status.Errorf(codes.ResourceExhausted, "some error")},
ctx: context.Background(),
respErr: status.Errorf(codes.ResourceExhausted, "some error"),
},
}

for _, tc := range testCases {
openF := openTestArc(&testAppendRowsClient{}, nil, nil)
ms := &ManagedStream{
ctx: context.Background(),
open: func(s string, opts ...gax.CallOption) (storagepb.BigQueryWrite_AppendRowsClient, error) {
if len(tc.openErrors) == 0 {
panic("out of open errors")
}
curErr := tc.openErrors[0]
tc.openErrors = tc.openErrors[1:]
if curErr == nil {
return openF(s, opts...)
}
return nil, curErr
},
streamSettings: &streamSettings{
streamID: "foo",
},
}

// first append
pw := newPendingWrite([][]byte{[]byte("foo")})
gotErr := ms.append(tc.ctx, pw)
if !errors.Is(gotErr, tc.respErr) {
t.Errorf("%s first response: got %v, want %v", tc.desc, gotErr, tc.respErr)
}
// second append
pw = newPendingWrite([][]byte{[]byte("bar")})
gotErr = ms.append(tc.ctx, pw)
if !errors.Is(gotErr, tc.respErr) {
t.Errorf("%s second response: got %v, want %v", tc.desc, gotErr, tc.respErr)
}

// Issue two closes, to ensure we're not deadlocking there either.
ms.Close()
ms.Close()
}

}

func TestManagedStream_LeakingGoroutines(t *testing.T) {
ctx := context.Background()

Expand Down
8 changes: 7 additions & 1 deletion bigquery/storage/managedwriter/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package managedwriter

import (
"context"
"errors"
"time"

"github.com/googleapis/gax-go/v2"
Expand All @@ -31,7 +33,11 @@ func (r *defaultRetryer) Retry(err error) (pause time.Duration, shouldRetry bool
// retry predicates in addition to statuscode-based.
s, ok := status.FromError(err)
if !ok {
// non-status based errors as retryable
// Treat context errors as non-retriable.
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return r.bo.Pause(), false
}
// Any other non-status based errors treated as retryable.
return r.bo.Pause(), true
}
switch s.Code() {
Expand Down