diff --git a/bigquery/storage/managedwriter/managed_stream.go b/bigquery/storage/managedwriter/managed_stream.go index 2131d9337f1..b7682d58115 100644 --- a/bigquery/storage/managedwriter/managed_stream.go +++ b/bigquery/storage/managedwriter/managed_stream.go @@ -16,7 +16,6 @@ package managedwriter import ( "context" - "errors" "fmt" "io" "sync" @@ -316,9 +315,7 @@ func (ms *ManagedStream) lockingAppend(requestCtx context.Context, pw *pendingWr err = (*arc).Send(pw.request) } if err != nil { - // Transient connection loss. If we got io.EOF from a send, we want subsequent appends to - // reconnect the network connection for the stream. - if errors.Is(err, io.EOF) { + if shouldReconnect(err) { ms.reconnect = true } return 0, err diff --git a/bigquery/storage/managedwriter/retry.go b/bigquery/storage/managedwriter/retry.go index e598a2d806b..7ace796c163 100644 --- a/bigquery/storage/managedwriter/retry.go +++ b/bigquery/storage/managedwriter/retry.go @@ -17,6 +17,7 @@ package managedwriter import ( "context" "errors" + "io" "time" "github.com/googleapis/gax-go/v2" @@ -47,3 +48,19 @@ func (r *defaultRetryer) Retry(err error) (pause time.Duration, shouldRetry bool return r.bo.Pause(), false } } + +// shouldReconnect is akin to a retry predicate, in that it evaluates whether we should force +// our bidi stream to close/reopen based on the responses error. Errors here signal that no +// further appends will succeed. +func shouldReconnect(err error) bool { + var knownErrors = []error{ + io.EOF, + status.Error(codes.Unavailable, "the connection is draining"), // errStreamDrain in gRPC transport + } + for _, ke := range knownErrors { + if errors.Is(err, ke) { + return true + } + } + return false +} diff --git a/bigquery/storage/managedwriter/retry_test.go b/bigquery/storage/managedwriter/retry_test.go new file mode 100644 index 00000000000..ca4272339c1 --- /dev/null +++ b/bigquery/storage/managedwriter/retry_test.go @@ -0,0 +1,64 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package managedwriter + +import ( + "fmt" + "io" + "testing" + + "github.com/googleapis/gax-go/v2/apierror" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestManagedStream_ShouldReconnect(t *testing.T) { + + testCases := []struct { + err error + want bool + }{ + { + err: fmt.Errorf("random error"), + want: false, + }, + { + err: io.EOF, + want: true, + }, + { + err: status.Error(codes.Unavailable, "nope"), + want: false, + }, + { + err: status.Error(codes.Unavailable, "the connection is draining"), + want: true, + }, + { + err: func() error { + // wrap the underlying error in a gax apierror + ai, _ := apierror.FromError(status.Error(codes.Unavailable, "the connection is draining")) + return ai + }(), + want: true, + }, + } + + for _, tc := range testCases { + if got := shouldReconnect(tc.err); got != tc.want { + t.Errorf("got %t, want %t for error: %+v", got, tc.want, tc.err) + } + } +}