diff --git a/spanner/batch.go b/spanner/batch.go index 413d8fd202b..374a963c48f 100644 --- a/spanner/batch.go +++ b/spanner/batch.go @@ -221,10 +221,7 @@ func (t *BatchReadOnlyTransaction) Cleanup(ctx context.Context) { } t.sh = nil sid, client := sh.getID(), sh.getClient() - err := runRetryable(ctx, func(ctx context.Context) error { - _, e := client.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: sid}) - return e - }) + err := client.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: sid}) if err != nil { log.Printf("Failed to delete session %v. Error: %v", sid, err) } diff --git a/spanner/client.go b/spanner/client.go index fc8443b0b0e..38c3f4968d8 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -25,8 +25,9 @@ import ( "cloud.google.com/go/internal/trace" "cloud.google.com/go/internal/version" + vkit "cloud.google.com/go/spanner/apiv1" + "cloud.google.com/go/spanner/internal/backoff" "google.golang.org/api/option" - gtransport "google.golang.org/api/transport/grpc" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -71,8 +72,7 @@ func validDatabaseName(db string) error { type Client struct { // rr must be accessed through atomic operations. rr uint32 - conns []*grpc.ClientConn - clients []sppb.SpannerClient + clients []*vkit.Client database string // Metadata to be sent with each request. @@ -170,19 +170,18 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf // TODO(deklerk): This should be replaced with a balancer with // config.NumChannels connections, instead of config.NumChannels - // clientconns. + // clients. for i := 0; i < config.NumChannels; i++ { - conn, err := gtransport.Dial(ctx, allOpts...) + client, err := vkit.NewClient(ctx, allOpts...) if err != nil { return nil, errDial(i, err) } - c.conns = append(c.conns, conn) - c.clients = append(c.clients, sppb.NewSpannerClient(conn)) + c.clients = append(c.clients, client) } // Prepare session pool. - config.SessionPoolConfig.getRPCClient = func() (sppb.SpannerClient, error) { - // TODO: support more loadbalancing options. + // TODO: support more loadbalancing options. + config.SessionPoolConfig.getRPCClient = func() (*vkit.Client, error) { return c.rrNext(), nil } config.SessionPoolConfig.sessionLabels = c.sessionLabels @@ -195,9 +194,9 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf return c, nil } -// rrNext returns the next available Cloud Spanner RPC client in a round-robin -// manner. -func (c *Client) rrNext() sppb.SpannerClient { +// rrNext returns the next available vkit Cloud Spanner RPC client in a +// round-robin manner. +func (c *Client) rrNext() *vkit.Client { return c.clients[atomic.AddUint32(&c.rr, 1)%uint32(len(c.clients))] } @@ -206,8 +205,8 @@ func (c *Client) Close() { if c.idleSessions != nil { c.idleSessions.close() } - for _, conn := range c.conns { - conn.Close() + for _, gpc := range c.clients { + gpc.Close() } } @@ -279,26 +278,20 @@ func (c *Client) BatchReadOnlyTransaction(ctx context.Context, tb TimestampBound sh = &sessionHandle{session: s} // Begin transaction. - err = runRetryable(contextWithOutgoingMetadata(ctx, sh.getMetadata()), func(ctx context.Context) error { - res, e := sh.getClient().BeginTransaction(ctx, &sppb.BeginTransactionRequest{ - Session: sh.getID(), - Options: &sppb.TransactionOptions{ - Mode: &sppb.TransactionOptions_ReadOnly_{ - ReadOnly: buildTransactionOptionsReadOnly(tb, true), - }, + res, err := sh.getClient().BeginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.BeginTransactionRequest{ + Session: sh.getID(), + Options: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadOnly_{ + ReadOnly: buildTransactionOptionsReadOnly(tb, true), }, - }) - if e != nil { - return e - } - tx = res.Id - if res.ReadTimestamp != nil { - rts = time.Unix(res.ReadTimestamp.Seconds, int64(res.ReadTimestamp.Nanos)) - } - return nil + }, }) if err != nil { - return nil, err + return nil, toSpannerError(err) + } + tx = res.Id + if res.ReadTimestamp != nil { + rts = time.Unix(res.ReadTimestamp.Seconds, int64(res.ReadTimestamp.Nanos)) } t := &BatchReadOnlyTransaction{ @@ -377,7 +370,7 @@ func (c *Client) ReadWriteTransaction(ctx context.Context, f func(context.Contex ts time.Time sh *sessionHandle ) - err = runRetryableNoWrap(ctx, func(ctx context.Context) error { + err = runWithRetryOnAborted(ctx, func(ctx context.Context) error { var ( err error t *ReadWriteTransaction @@ -402,8 +395,7 @@ func (c *Client) ReadWriteTransaction(ctx context.Context, f func(context.Contex trace.TracePrintf(ctx, map[string]interface{}{"transactionID": string(sh.getTransactionID())}, "Starting transaction attempt") if err = t.begin(ctx); err != nil { - // Mask error from begin operation as retryable error. - return errRetry(err) + return err } ts, err = t.runInTransaction(ctx, f) return err @@ -414,6 +406,43 @@ func (c *Client) ReadWriteTransaction(ctx context.Context, f func(context.Contex return ts, err } +func runWithRetryOnAborted(ctx context.Context, f func(context.Context) error) error { + var funcErr error + retryCount := 0 + for { + select { + case <-ctx.Done(): + // Do context check here so that even f() failed to do so (for + // example, gRPC implementation bug), the loop can still have a + // chance to exit as expected. + return errContextCanceled(ctx, funcErr) + default: + } + funcErr = f(ctx) + if funcErr == nil { + return nil + } + // Only retry on ABORTED. + if isAbortErr(funcErr) { + // Aborted, do exponential backoff and continue. + b, ok := extractRetryDelay(funcErr) + if !ok { + b = backoff.DefaultBackoff.Delay(retryCount) + } + trace.TracePrintf(ctx, nil, "Backing off after ABORTED for %s, then retrying", b) + select { + case <-ctx.Done(): + return errContextCanceled(ctx, funcErr) + case <-time.After(b): + } + retryCount++ + continue + } + // Error isn't ABORTED / no error, return immediately. + return funcErr + } +} + // applyOption controls the behavior of Client.Apply. type applyOption struct { // If atLeastOnce == true, Client.Apply will execute the mutations on Cloud diff --git a/spanner/client_test.go b/spanner/client_test.go index 54c52631f05..3df53669306 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -17,8 +17,15 @@ limitations under the License. package spanner import ( + "context" + "io" "strings" "testing" + + "cloud.google.com/go/spanner/internal/testutil" + "google.golang.org/api/iterator" + "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" ) // Test validDatabaseName() @@ -48,3 +55,355 @@ func TestReadOnlyTransactionClose(t *testing.T) { tx := c.ReadOnlyTransaction() tx.Close() } + +func TestClient_Single(t *testing.T) { + t.Parallel() + err := testSingleQuery(t, nil) + if err != nil { + t.Fatal(err) + } +} + +func TestClient_Single_Unavailable(t *testing.T) { + t.Parallel() + err := testSingleQuery(t, gstatus.Error(codes.Unavailable, "Temporary unavailable")) + if err != nil { + t.Fatal(err) + } +} + +func TestClient_Single_InvalidArgument(t *testing.T) { + t.Parallel() + err := testSingleQuery(t, gstatus.Error(codes.InvalidArgument, "Invalid argument")) + if err == nil { + t.Fatalf("missing expected error") + } else if gstatus.Code(err) != codes.InvalidArgument { + t.Fatal(err) + } +} + +func testSingleQuery(t *testing.T, serverError error) error { + config := ClientConfig{} + server, client := newSpannerInMemTestServerWithConfig(t, config) + defer server.teardown(client) + if serverError != nil { + server.testSpanner.SetError(serverError) + } + ctx := context.Background() + iter := client.Single().Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + var singerID, albumID int64 + var albumTitle string + if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { + return err + } + } + return nil +} + +func createSimulatedExecutionTimeWithTwoUnavailableErrors(method string) map[string]testutil.SimulatedExecutionTime { + errors := make([]error, 2) + errors[0] = gstatus.Error(codes.Unavailable, "Temporary unavailable") + errors[1] = gstatus.Error(codes.Unavailable, "Temporary unavailable") + executionTimes := make(map[string]testutil.SimulatedExecutionTime) + executionTimes[method] = testutil.SimulatedExecutionTime{ + Errors: errors, + } + return executionTimes +} + +func TestClient_ReadOnlyTransaction(t *testing.T) { + t.Parallel() + if err := testReadOnlyTransaction(t, make(map[string]testutil.SimulatedExecutionTime)); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadOnlyTransaction_UnavailableOnSessionCreate(t *testing.T) { + t.Parallel() + if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(testutil.MethodCreateSession)); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadOnlyTransaction_UnavailableOnBeginTransaction(t *testing.T) { + t.Parallel() + if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(testutil.MethodBeginTransaction)); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadOnlyTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) { + t.Parallel() + if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(testutil.MethodExecuteStreamingSql)); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndBeginTransaction(t *testing.T) { + t.Parallel() + exec := map[string]testutil.SimulatedExecutionTime{ + testutil.MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + } + if err := testReadOnlyTransaction(t, exec); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndInvalidArgumentOnBeginTransaction(t *testing.T) { + t.Parallel() + exec := map[string]testutil.SimulatedExecutionTime{ + testutil.MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.InvalidArgument, "Invalid argument")}}, + } + if err := testReadOnlyTransaction(t, exec); err == nil { + t.Fatalf("Missing expected exception") + } else if gstatus.Code(err) != codes.InvalidArgument { + t.Fatalf("Got unexpected exception: %v", err) + } +} + +func testReadOnlyTransaction(t *testing.T, executionTimes map[string]testutil.SimulatedExecutionTime) error { + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + for method, exec := range executionTimes { + server.testSpanner.PutExecutionTime(method, exec) + } + ctx := context.Background() + tx := client.ReadOnlyTransaction() + defer tx.Close() + iter := tx.Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + var singerID, albumID int64 + var albumTitle string + if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { + return err + } + } + return nil +} + +func TestClient_ReadWriteTransaction(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, make(map[string]testutil.SimulatedExecutionTime), 1); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, + }, 2); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, + }, 2); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_UnavailableOnBeginTransaction(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + }, 1); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_UnavailableOnBeginAndAbortOnCommit(t *testing.T) { + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, + }, 2); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + }, 1); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_UnavailableOnBeginAndExecuteStreamingSqlAndTwiceAbortOnCommit(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, + }, 3); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_AbortedOnExecuteStreamingSqlAndCommit(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, + testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, + }, 4); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransactionCommitAbortedAndUnavailable(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodCommitTransaction: { + Errors: []error{ + gstatus.Error(codes.Aborted, "Transaction aborted"), + gstatus.Error(codes.Unavailable, "Unavailable"), + }, + }, + }, 2); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransactionCommitAlreadyExists(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.AlreadyExists, "A row with this key already exists")}}, + }, 1); err != nil { + if gstatus.Code(err) != codes.AlreadyExists { + t.Fatalf("Got unexpected error %v, expected %v", err, codes.AlreadyExists) + } + } else { + t.Fatalf("Missing expected exception") + } +} + +func testReadWriteTransaction(t *testing.T, executionTimes map[string]testutil.SimulatedExecutionTime, expectedAttempts int) error { + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + for method, exec := range executionTimes { + server.testSpanner.PutExecutionTime(method, exec) + } + var attempts int + ctx := context.Background() + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + var singerID, albumID int64 + var albumTitle string + if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { + return err + } + } + return nil + }) + if err != nil { + return err + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + return nil +} + +func TestClient_ApplyAtLeastOnce(t *testing.T) { + t.Parallel() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + ms := []*Mutation{ + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}), + } + server.testSpanner.PutExecutionTime(testutil.MethodCommitTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}, + }) + _, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce()) + if err != nil { + t.Fatal(err) + } +} + +// PartitionedUpdate should not retry on aborted. +func TestClient_PartitionedUpdate(t *testing.T) { + t.Parallel() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + // PartitionedDML transactions are not committed. + server.testSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, + testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}, + }) + _, err := client.PartitionedUpdate(context.Background(), NewStatement(updateBarSetFoo)) + if err == nil { + t.Fatalf("Missing expected Aborted exception") + } else { + if gstatus.Code(err) != codes.Aborted { + t.Fatalf("Got unexpected error %v, expected Aborted", err) + } + } +} + +func TestReadWriteTransaction_ErrUnexpectedEOF(t *testing.T) { + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + var attempts int + ctx := context.Background() + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + var singerID, albumID int64 + var albumTitle string + if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { + return err + } + } + return io.ErrUnexpectedEOF + }) + if err != io.ErrUnexpectedEOF { + t.Fatalf("Missing expected error %v, got %v", io.ErrUnexpectedEOF, err) + } + if attempts != 1 { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, 1) + } +} diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go new file mode 100644 index 00000000000..963d33467de --- /dev/null +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -0,0 +1,711 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + emptypb "github.com/golang/protobuf/ptypes/empty" + structpb "github.com/golang/protobuf/ptypes/struct" + spannerpb "google.golang.org/genproto/googleapis/spanner/v1" +) + +import ( + "context" + "fmt" + "math/rand" + "sort" + "strings" + "sync" + "time" + + "github.com/golang/protobuf/ptypes/timestamp" + "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" +) + +// StatementResultType indicates the type of result returned by a SQL +// statement. +type StatementResultType int + +const ( + // StatementResultError indicates that the sql statement returns an error. + StatementResultError StatementResultType = 0 + // StatementResultResultSet indicates that the sql statement returns a + // result set. + StatementResultResultSet StatementResultType = 1 + // StatementResultUpdateCount indicates that the sql statement returns an + // update count. + StatementResultUpdateCount StatementResultType = 2 +) + +// The method names that can be used to register execution times and errors. +const ( + MethodBeginTransaction string = "BEGIN_TRANSACTION" + MethodCommitTransaction string = "COMMIT_TRANSACTION" + MethodCreateSession string = "CREATE_SESSION" + MethodDeleteSession string = "DELETE_SESSION" + MethodGetSession string = "GET_SESSION" + MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL" +) + +// StatementResult represents a mocked result on the test server. Th result can +// be either a ResultSet, an update count or an error. +type StatementResult struct { + Type StatementResultType + Err error + ResultSet *spannerpb.ResultSet + UpdateCount int64 +} + +// Converts a ResultSet to a PartialResultSet. This method is used to convert +// a mocked result to a PartialResultSet when one of the streaming methods are +// called. +func (s *StatementResult) toPartialResultSet() *spannerpb.PartialResultSet { + values := make([]*structpb.Value, + len(s.ResultSet.Rows)*len(s.ResultSet.Metadata.RowType.Fields)) + var idx int + for _, row := range s.ResultSet.Rows { + for colIdx := range s.ResultSet.Metadata.RowType.Fields { + values[idx] = row.Values[colIdx] + idx++ + } + } + return &spannerpb.PartialResultSet{ + Metadata: s.ResultSet.Metadata, + Values: values, + } +} + +func (s *StatementResult) updateCountToPartialResultSet(exact bool) *spannerpb.PartialResultSet { + return &spannerpb.PartialResultSet{ + Stats: s.convertUpdateCountToResultSet(exact).Stats, + } +} + +// Converts an update count to a ResultSet, as DML statements also return the +// update count as the statistics of a ResultSet. +func (s *StatementResult) convertUpdateCountToResultSet(exact bool) *spannerpb.ResultSet { + if exact { + return &spannerpb.ResultSet{ + Stats: &spannerpb.ResultSetStats{ + RowCount: &spannerpb.ResultSetStats_RowCountExact{ + RowCountExact: s.UpdateCount, + }, + }, + } + } + return &spannerpb.ResultSet{ + Stats: &spannerpb.ResultSetStats{ + RowCount: &spannerpb.ResultSetStats_RowCountLowerBound{ + RowCountLowerBound: s.UpdateCount, + }, + }, + } +} + +// SimulatedExecutionTime represents the time the execution of a method +// should take, and any errors that should be returned by the method. +type SimulatedExecutionTime struct { + MinimumExecutionTime time.Duration + RandomExecutionTime time.Duration + Errors []error + // Keep error after execution. The error will continue to be returned until + // it is cleared. + KeepError bool +} + +// InMemSpannerServer contains the SpannerServer interface plus a couple +// of specific methods for adding mocked results and resetting the server. +type InMemSpannerServer interface { + spannerpb.SpannerServer + + // Stops this server. + Stop() + + // Resets the in-mem server to its default state, deleting all sessions and + // transactions that have been created on the server. Mocked results are + // not deleted. + Reset() + + // Sets an error that will be returned by the next server call. The server + // call will also automatically clear the error. + SetError(err error) + + // Puts a mocked result on the server for a specific sql statement. The + // server does not parse the SQL string in any way, it is merely used as + // a key to the mocked result. The result will be used for all methods that + // expect a SQL statement, including (batch) DML methods. + PutStatementResult(sql string, result *StatementResult) error + + // Removes a mocked result on the server for a specific sql statement. + RemoveStatementResult(sql string) + + // Aborts the specified transaction . This method can be used to test + // transaction retry logic. + AbortTransaction(id []byte) + + // Puts a simulated execution time for one of the Spanner methods. + PutExecutionTime(method string, executionTime SimulatedExecutionTime) + // Freeze stalls all requests. + Freeze() + // Unfreeze restores processing requests. + Unfreeze() + + TotalSessionsCreated() uint + TotalSessionsDeleted() uint + + ReceivedRequests() chan interface{} + DumpSessions() map[string]bool + DumpPings() []string +} + +type inMemSpannerServer struct { + // Embed for forward compatibility. + // Tests will keep working if more methods are added + // in the future. + spannerpb.SpannerServer + + mu sync.Mutex + + // If set, all calls return this error. + err error + // The mock server creates session IDs using this counter. + sessionCounter uint64 + // The sessions that have been created on this mock server. + sessions map[string]*spannerpb.Session + // Last use times per session. + sessionLastUseTime map[string]time.Time + + // The mock server creates transaction IDs per session using these + // counters. + transactionCounters map[string]*uint64 + // The transactions that have been created on this mock server. + transactions map[string]*spannerpb.Transaction + // The transactions that have been (manually) aborted on the server. + abortedTransactions map[string]bool + // The transactions that are marked as PartitionedDMLTransaction + partitionedDmlTransactions map[string]bool + + // The mocked results for this server. + statementResults map[string]*StatementResult + // The simulated execution times per method. + executionTimes map[string]*SimulatedExecutionTime + // Server will stall on any requests. + freezed chan struct{} + + totalSessionsCreated uint + totalSessionsDeleted uint + receivedRequests chan interface{} + // Session ping history. + pings []string +} + +// NewInMemSpannerServer creates a new in-mem test server. +func NewInMemSpannerServer() InMemSpannerServer { + res := &inMemSpannerServer{} + res.initDefaults() + res.statementResults = make(map[string]*StatementResult) + res.executionTimes = make(map[string]*SimulatedExecutionTime) + res.receivedRequests = make(chan interface{}, 1000000) + // Produce a closed channel, so the default action of ready is to not block. + res.Freeze() + res.Unfreeze() + return res +} + +func (s *inMemSpannerServer) Stop() { + close(s.receivedRequests) +} + +// Resets the test server to its initial state, deleting all sessions and +// transactions that have been created on the server. This method will not +// remove mocked results. +func (s *inMemSpannerServer) Reset() { + close(s.receivedRequests) + s.receivedRequests = make(chan interface{}, 1000000) + s.initDefaults() +} + +func (s *inMemSpannerServer) SetError(err error) { + s.err = err +} + +// Registers a mocked result for a SQL statement on the server. +func (s *inMemSpannerServer) PutStatementResult(sql string, result *StatementResult) error { + s.statementResults[sql] = result + return nil +} + +func (s *inMemSpannerServer) RemoveStatementResult(sql string) { + delete(s.statementResults, sql) +} + +func (s *inMemSpannerServer) AbortTransaction(id []byte) { + s.abortedTransactions[string(id)] = true +} + +func (s *inMemSpannerServer) PutExecutionTime(method string, executionTime SimulatedExecutionTime) { + s.executionTimes[method] = &executionTime +} + +// Freeze stalls all requests. +func (s *inMemSpannerServer) Freeze() { + s.mu.Lock() + defer s.mu.Unlock() + s.freezed = make(chan struct{}) +} + +// Unfreeze restores processing requests. +func (s *inMemSpannerServer) Unfreeze() { + s.mu.Lock() + defer s.mu.Unlock() + close(s.freezed) +} + +// ready checks conditions before executing requests +func (s *inMemSpannerServer) ready() { + s.mu.Lock() + freezed := s.freezed + s.mu.Unlock() + // check if server should be freezed + <-freezed +} + +func (s *inMemSpannerServer) TotalSessionsCreated() uint { + return s.totalSessionsCreated +} + +func (s *inMemSpannerServer) TotalSessionsDeleted() uint { + return s.totalSessionsDeleted +} + +func (s *inMemSpannerServer) ReceivedRequests() chan interface{} { + return s.receivedRequests +} + +// DumpPings dumps the ping history. +func (s *inMemSpannerServer) DumpPings() []string { + s.mu.Lock() + defer s.mu.Unlock() + return append([]string(nil), s.pings...) +} + +// DumpSessions dumps the internal session table. +func (s *inMemSpannerServer) DumpSessions() map[string]bool { + s.mu.Lock() + defer s.mu.Unlock() + st := map[string]bool{} + for s := range s.sessions { + st[s] = true + } + return st +} + +func (s *inMemSpannerServer) initDefaults() { + s.sessionCounter = 0 + s.sessions = make(map[string]*spannerpb.Session) + s.sessionLastUseTime = make(map[string]time.Time) + s.transactions = make(map[string]*spannerpb.Transaction) + s.abortedTransactions = make(map[string]bool) + s.partitionedDmlTransactions = make(map[string]bool) + s.transactionCounters = make(map[string]*uint64) +} + +func (s *inMemSpannerServer) generateSessionName(database string) string { + s.mu.Lock() + defer s.mu.Unlock() + s.sessionCounter++ + return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter) +} + +func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + session := s.sessions[name] + if session == nil { + return nil, gstatus.Error(codes.NotFound, fmt.Sprintf("Session %s not found", name)) + } + return session, nil +} + +func (s *inMemSpannerServer) updateSessionLastUseTime(session string) { + s.mu.Lock() + defer s.mu.Unlock() + s.sessionLastUseTime[session] = time.Now() +} + +func getCurrentTimestamp() *timestamp.Timestamp { + t := time.Now() + return ×tamp.Timestamp{Seconds: t.Unix(), Nanos: int32(t.Nanosecond())} +} + +// Gets the transaction id from the transaction selector. If the selector +// specifies that a new transaction should be started, this method will start +// a new transaction and return the id of that transaction. +func (s *inMemSpannerServer) getTransactionID(session *spannerpb.Session, txSelector *spannerpb.TransactionSelector) []byte { + var res []byte + if txSelector.GetBegin() != nil { + // Start a new transaction. + res = s.beginTransaction(session, txSelector.GetBegin()).Id + } else if txSelector.GetId() != nil { + res = txSelector.GetId() + } + return res +} + +func (s *inMemSpannerServer) generateTransactionName(session string) string { + s.mu.Lock() + defer s.mu.Unlock() + counter, ok := s.transactionCounters[session] + if !ok { + counter = new(uint64) + s.transactionCounters[session] = counter + } + *counter++ + return fmt.Sprintf("%s/transactions/%d", session, *counter) +} + +func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, options *spannerpb.TransactionOptions) *spannerpb.Transaction { + id := s.generateTransactionName(session.Name) + res := &spannerpb.Transaction{ + Id: []byte(id), + ReadTimestamp: getCurrentTimestamp(), + } + s.mu.Lock() + s.transactions[id] = res + s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil + s.mu.Unlock() + return res +} + +func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) { + s.mu.Lock() + defer s.mu.Unlock() + tx, ok := s.transactions[string(id)] + if !ok { + return nil, gstatus.Error(codes.NotFound, "Transaction not found") + } + aborted, ok := s.abortedTransactions[string(id)] + if ok && aborted { + return nil, gstatus.Error(codes.Aborted, "Transaction has been aborted") + } + return tx, nil +} + +func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.transactions, string(tx.Id)) + delete(s.partitionedDmlTransactions, string(tx.Id)) +} + +func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, error) { + result, ok := s.statementResults[sql] + if !ok { + return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for statement %v", sql)) + } + return result, nil +} + +func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error { + s.receivedRequests <- req + s.ready() + if s.err != nil { + err := s.err + s.err = nil + return err + } + executionTime, ok := s.executionTimes[method] + if ok { + var randTime int64 + if executionTime.RandomExecutionTime > 0 { + randTime = rand.Int63n(int64(executionTime.RandomExecutionTime)) + } + totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime) + <-time.After(totalExecutionTime) + if executionTime.Errors != nil && len(executionTime.Errors) > 0 { + err := executionTime.Errors[0] + if !executionTime.KeepError { + executionTime.Errors = executionTime.Errors[1:] + } + return err + } + } + return nil +} + +func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) { + if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil { + return nil, err + } + if req.Database == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing database") + } + sessionName := s.generateSessionName(req.Database) + ts := getCurrentTimestamp() + session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} + s.mu.Lock() + s.totalSessionsCreated++ + s.sessions[sessionName] = session + s.mu.Unlock() + return session, nil +} + +func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) { + if err := s.simulateExecutionTime(MethodGetSession, req); err != nil { + return nil, err + } + s.mu.Lock() + s.pings = append(s.pings, req.Name) + s.mu.Unlock() + if req.Name == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Name) + if err != nil { + return nil, err + } + return session, nil +} + +func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) { + s.receivedRequests <- req + if req.Database == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing database") + } + expectedSessionName := req.Database + "/sessions/" + var sessions []*spannerpb.Session + s.mu.Lock() + for _, session := range s.sessions { + if strings.Index(session.Name, expectedSessionName) == 0 { + sessions = append(sessions, session) + } + } + s.mu.Unlock() + sort.Slice(sessions[:], func(i, j int) bool { + return sessions[i].Name < sessions[j].Name + }) + res := &spannerpb.ListSessionsResponse{Sessions: sessions} + return res, nil +} + +func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) { + if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil { + return nil, err + } + if req.Name == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + if _, err := s.findSession(req.Name); err != nil { + return nil, err + } + s.mu.Lock() + defer s.mu.Unlock() + s.totalSessionsDeleted++ + delete(s.sessions, req.Name) + return &emptypb.Empty{}, nil +} + +func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { + s.receivedRequests <- req + if req.Session == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return nil, err + } + var id []byte + s.updateSessionLastUseTime(session.Name) + if id = s.getTransactionID(session, req.Transaction); id != nil { + _, err = s.getTransactionByID(id) + if err != nil { + return nil, err + } + } + statementResult, err := s.getStatementResult(req.Sql) + if err != nil { + return nil, err + } + switch statementResult.Type { + case StatementResultError: + return nil, statementResult.Err + case StatementResultResultSet: + return statementResult.ResultSet, nil + case StatementResultUpdateCount: + return statementResult.convertUpdateCountToResultSet(!s.partitionedDmlTransactions[string(id)]), nil + } + return nil, gstatus.Error(codes.Internal, "Unknown result type") +} + +func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error { + if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil { + return err + } + if req.Session == "" { + return gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return err + } + s.updateSessionLastUseTime(session.Name) + var id []byte + if id = s.getTransactionID(session, req.Transaction); id != nil { + _, err = s.getTransactionByID(id) + if err != nil { + return err + } + } + statementResult, err := s.getStatementResult(req.Sql) + if err != nil { + return err + } + switch statementResult.Type { + case StatementResultError: + return statementResult.Err + case StatementResultResultSet: + part := statementResult.toPartialResultSet() + if err := stream.Send(part); err != nil { + return err + } + return nil + case StatementResultUpdateCount: + part := statementResult.updateCountToPartialResultSet(!s.partitionedDmlTransactions[string(id)]) + if err := stream.Send(part); err != nil { + return err + } + return nil + } + return gstatus.Error(codes.Internal, "Unknown result type") +} + +func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { + s.receivedRequests <- req + if req.Session == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return nil, err + } + s.updateSessionLastUseTime(session.Name) + var id []byte + if id = s.getTransactionID(session, req.Transaction); id != nil { + _, err = s.getTransactionByID(id) + if err != nil { + return nil, err + } + } + resp := &spannerpb.ExecuteBatchDmlResponse{} + resp.ResultSets = make([]*spannerpb.ResultSet, len(req.Statements)) + for idx, batchStatement := range req.Statements { + statementResult, err := s.getStatementResult(batchStatement.Sql) + if err != nil { + return nil, err + } + switch statementResult.Type { + case StatementResultError: + resp.Status = &status.Status{Code: int32(codes.Unknown)} + case StatementResultResultSet: + return nil, gstatus.Error(codes.InvalidArgument, fmt.Sprintf("Not an update statement: %v", batchStatement.Sql)) + case StatementResultUpdateCount: + resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!s.partitionedDmlTransactions[string(id)]) + resp.Status = &status.Status{Code: int32(codes.OK)} + } + } + return resp, nil +} + +func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) { + s.receivedRequests <- req + return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") +} + +func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { + s.receivedRequests <- req + return gstatus.Error(codes.Unimplemented, "Method not yet implemented") +} + +func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { + if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil { + return nil, err + } + if req.Session == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return nil, err + } + s.updateSessionLastUseTime(session.Name) + tx := s.beginTransaction(session, req.Options) + return tx, nil +} + +func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) { + if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil { + return nil, err + } + if req.Session == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return nil, err + } + s.updateSessionLastUseTime(session.Name) + var tx *spannerpb.Transaction + if req.GetSingleUseTransaction() != nil { + tx = s.beginTransaction(session, req.GetSingleUseTransaction()) + } else if req.GetTransactionId() != nil { + tx, err = s.getTransactionByID(req.GetTransactionId()) + if err != nil { + return nil, err + } + } else { + return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request") + } + s.removeTransaction(tx) + return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil +} + +func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { + s.receivedRequests <- req + if req.Session == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return nil, err + } + s.updateSessionLastUseTime(session.Name) + tx, err := s.getTransactionByID(req.TransactionId) + if err != nil { + return nil, err + } + s.removeTransaction(tx) + return &emptypb.Empty{}, nil +} + +func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { + s.receivedRequests <- req + return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") +} + +func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) { + s.receivedRequests <- req + return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") +} diff --git a/spanner/internal/testutil/inmem_spanner_server_test.go b/spanner/internal/testutil/inmem_spanner_server_test.go new file mode 100644 index 00000000000..d563ff4b617 --- /dev/null +++ b/spanner/internal/testutil/inmem_spanner_server_test.go @@ -0,0 +1,598 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "strconv" + + structpb "github.com/golang/protobuf/ptypes/struct" + spannerpb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc/codes" +) + +import ( + "context" + "flag" + "fmt" + "log" + "net" + "os" + "strings" + "testing" + + apiv1 "cloud.google.com/go/spanner/apiv1" + "google.golang.org/api/iterator" + "google.golang.org/api/option" + "google.golang.org/grpc" + gstatus "google.golang.org/grpc/status" +) + +// clientOpt is the option tests should use to connect to the test server. +// It is initialized by TestMain. +var serverAddress string +var clientOpt option.ClientOption +var testSpanner InMemSpannerServer + +// Mocked selectSQL statement. +const selectSQL = "SELECT FOO FROM BAR" +const selectRowCount int64 = 2 +const selectColCount int = 1 + +var selectValues = [...]int64{1, 2} + +// Mocked DML statement. +const updateSQL = "UPDATE FOO SET BAR=1 WHERE ID=ID" +const updateRowCount int64 = 2 + +func TestMain(m *testing.M) { + flag.Parse() + + testSpanner = NewInMemSpannerServer() + serv := grpc.NewServer() + spannerpb.RegisterSpannerServer(serv, testSpanner) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + log.Fatal(err) + } + go serv.Serve(lis) + + serverAddress = lis.Addr().String() + conn, err := grpc.Dial(serverAddress, grpc.WithInsecure()) + if err != nil { + log.Fatal(err) + } + clientOpt = option.WithGRPCConn(conn) + + os.Exit(m.Run()) +} + +// Resets the mock server to its default values and registers a mocked result +// for the statements "SELECT FOO FROM BAR" and +// "UPDATE FOO SET BAR=1 WHERE ID=ID". +func setup() { + testSpanner.Reset() + fields := make([]*spannerpb.StructType_Field, selectColCount) + fields[0] = &spannerpb.StructType_Field{ + Name: "FOO", + Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64}, + } + rowType := &spannerpb.StructType{ + Fields: fields, + } + metadata := &spannerpb.ResultSetMetadata{ + RowType: rowType, + } + rows := make([]*structpb.ListValue, selectRowCount) + for idx, value := range selectValues { + rowValue := make([]*structpb.Value, selectColCount) + rowValue[0] = &structpb.Value{ + Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(value, 10)}, + } + rows[idx] = &structpb.ListValue{ + Values: rowValue, + } + } + resultSet := &spannerpb.ResultSet{ + Metadata: metadata, + Rows: rows, + } + result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet} + testSpanner.PutStatementResult(selectSQL, result) + + updateResult := &StatementResult{Type: StatementResultUpdateCount, UpdateCount: updateRowCount} + testSpanner.PutStatementResult(updateSQL, updateResult) +} + +func TestSpannerCreateSession(t *testing.T) { + testSpanner.Reset() + var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var request = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + resp, err := c.CreateSession(context.Background(), request) + if err != nil { + t.Fatal(err) + } + if strings.Index(resp.Name, expectedName) != 0 { + t.Errorf("wrong name %s, should start with %s)", resp.Name, expectedName) + } +} + +func TestSpannerCreateSession_Unavailable(t *testing.T) { + testSpanner.Reset() + var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var request = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + testSpanner.SetError(gstatus.Error(codes.Unavailable, "Temporary unavailable")) + resp, err := c.CreateSession(context.Background(), request) + if err != nil { + t.Fatal(err) + } + if strings.Index(resp.Name, expectedName) != 0 { + t.Errorf("wrong name %s, should start with %s)", resp.Name, expectedName) + } +} + +func TestSpannerGetSession(t *testing.T) { + testSpanner.Reset() + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + createResp, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + var getRequest = &spannerpb.GetSessionRequest{ + Name: createResp.Name, + } + getResp, err := c.GetSession(context.Background(), getRequest) + if err != nil { + t.Fatal(err) + } + if getResp.Name != getRequest.Name { + t.Errorf("wrong name %s, expected %s)", getResp.Name, getRequest.Name) + } +} + +func TestSpannerListSessions(t *testing.T) { + testSpanner.Reset() + const expectedNumberOfSessions = 5 + var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + for i := 0; i < expectedNumberOfSessions; i++ { + _, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + } + var listRequest = &spannerpb.ListSessionsRequest{ + Database: formattedDatabase, + } + var sessionCount int + listResp := c.ListSessions(context.Background(), listRequest) + for { + session, err := listResp.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatal(err) + } + if strings.Index(session.Name, expectedName) != 0 { + t.Errorf("wrong name %s, should start with %s)", session.Name, expectedName) + } + sessionCount++ + } + if sessionCount != expectedNumberOfSessions { + t.Errorf("wrong number of sessions: %d, expected %d", sessionCount, expectedNumberOfSessions) + } +} + +func TestSpannerDeleteSession(t *testing.T) { + testSpanner.Reset() + const expectedNumberOfSessions = 5 + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + for i := 0; i < expectedNumberOfSessions; i++ { + _, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + } + var listRequest = &spannerpb.ListSessionsRequest{ + Database: formattedDatabase, + } + var sessionCount int + listResp := c.ListSessions(context.Background(), listRequest) + for { + session, err := listResp.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatal(err) + } + var deleteRequest = &spannerpb.DeleteSessionRequest{ + Name: session.Name, + } + c.DeleteSession(context.Background(), deleteRequest) + sessionCount++ + } + if sessionCount != expectedNumberOfSessions { + t.Errorf("wrong number of sessions: %d, expected %d", sessionCount, expectedNumberOfSessions) + } + // Re-list all sessions. This should now be empty. + listResp = c.ListSessions(context.Background(), listRequest) + _, err = listResp.Next() + if err != iterator.Done { + t.Errorf("expected empty session iterator") + } +} + +func TestSpannerExecuteSql(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + request := &spannerpb.ExecuteSqlRequest{ + Session: session.Name, + Sql: selectSQL, + Transaction: &spannerpb.TransactionSelector{ + Selector: &spannerpb.TransactionSelector_SingleUse{ + SingleUse: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadOnly_{ + ReadOnly: &spannerpb.TransactionOptions_ReadOnly{ + ReturnReadTimestamp: false, + TimestampBound: &spannerpb.TransactionOptions_ReadOnly_Strong{ + Strong: true, + }, + }, + }, + }, + }, + }, + Seqno: 1, + QueryMode: spannerpb.ExecuteSqlRequest_NORMAL, + } + response, err := c.ExecuteSql(context.Background(), request) + if err != nil { + t.Fatal(err) + } + var rowCount int64 + for _, row := range response.Rows { + if len(row.Values) != selectColCount { + t.Fatalf("unexpected number of columns: %d, expected %d", len(row.Values), selectColCount) + } + rowCount++ + } + if rowCount != selectRowCount { + t.Fatalf("unexpected number of rows: %d, expected %d", rowCount, selectRowCount) + } +} + +func TestSpannerExecuteSqlDml(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + request := &spannerpb.ExecuteSqlRequest{ + Session: session.Name, + Sql: updateSQL, + Transaction: &spannerpb.TransactionSelector{ + Selector: &spannerpb.TransactionSelector_Begin{ + Begin: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadWrite_{ + ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, + }, + }, + }, + }, + Seqno: 1, + QueryMode: spannerpb.ExecuteSqlRequest_NORMAL, + } + response, err := c.ExecuteSql(context.Background(), request) + if err != nil { + t.Fatal(err) + } + var rowCount int64 = response.Stats.GetRowCountExact() + if rowCount != updateRowCount { + t.Fatalf("unexpected number of rows updated: %d, expected %d", rowCount, updateRowCount) + } +} + +func TestSpannerExecuteStreamingSql(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + request := &spannerpb.ExecuteSqlRequest{ + Session: session.Name, + Sql: selectSQL, + Transaction: &spannerpb.TransactionSelector{ + Selector: &spannerpb.TransactionSelector_SingleUse{ + SingleUse: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadOnly_{ + ReadOnly: &spannerpb.TransactionOptions_ReadOnly{ + ReturnReadTimestamp: false, + TimestampBound: &spannerpb.TransactionOptions_ReadOnly_Strong{ + Strong: true, + }, + }, + }, + }, + }, + }, + Seqno: 1, + QueryMode: spannerpb.ExecuteSqlRequest_NORMAL, + } + response, err := c.ExecuteStreamingSql(context.Background(), request) + if err != nil { + t.Fatal(err) + } + partial, err := response.Recv() + if err != nil { + t.Fatal(err) + } + var rowIndex int64 + colCount := len(partial.Metadata.RowType.Fields) + if colCount != selectColCount { + t.Fatalf("unexpected number of columns: %d, expected %d", colCount, selectColCount) + } + for { + for col := 0; col < colCount; col++ { + val, err := strconv.ParseInt(partial.Values[rowIndex*int64(colCount)+int64(col)].GetStringValue(), 10, 64) + if err != nil { + t.Fatal(err) + } + if val != selectValues[rowIndex] { + t.Fatalf("Unexpected value at index %d. Expected %d, got %d", rowIndex, selectValues[rowIndex], val) + } + } + rowIndex++ + if rowIndex == selectRowCount { + break + } + } + if rowIndex != selectRowCount { + t.Fatalf("unexpected number of rows: %d, expected %d", rowIndex, selectRowCount) + } +} + +func TestSpannerExecuteBatchDml(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + statements := make([]*spannerpb.ExecuteBatchDmlRequest_Statement, 3) + for idx := 0; idx < len(statements); idx++ { + statements[idx] = &spannerpb.ExecuteBatchDmlRequest_Statement{Sql: updateSQL} + } + executeBatchDmlRequest := &spannerpb.ExecuteBatchDmlRequest{ + Session: session.Name, + Statements: statements, + Transaction: &spannerpb.TransactionSelector{ + Selector: &spannerpb.TransactionSelector_Begin{ + Begin: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadWrite_{ + ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, + }, + }, + }, + }, + Seqno: 1, + } + response, err := c.ExecuteBatchDml(context.Background(), executeBatchDmlRequest) + if err != nil { + t.Fatal(err) + } + var totalRowCount int64 + for _, res := range response.ResultSets { + var rowCount int64 = res.Stats.GetRowCountExact() + if rowCount != updateRowCount { + t.Fatalf("unexpected number of rows updated: %d, expected %d", rowCount, updateRowCount) + } + totalRowCount += rowCount + } + if totalRowCount != updateRowCount*int64(len(statements)) { + t.Fatalf("unexpected number of total rows updated: %d, expected %d", totalRowCount, updateRowCount*int64(len(statements))) + } +} + +func TestBeginTransaction(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + beginRequest := &spannerpb.BeginTransactionRequest{ + Session: session.Name, + Options: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadWrite_{ + ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, + }, + }, + } + tx, err := c.BeginTransaction(context.Background(), beginRequest) + if err != nil { + t.Fatal(err) + } + expectedName := fmt.Sprintf("%s/transactions/", session.Name) + if strings.Index(string(tx.Id), expectedName) != 0 { + t.Errorf("wrong name %s, should start with %s)", string(tx.Id), expectedName) + } +} + +func TestCommitTransaction(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + beginRequest := &spannerpb.BeginTransactionRequest{ + Session: session.Name, + Options: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadWrite_{ + ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, + }, + }, + } + tx, err := c.BeginTransaction(context.Background(), beginRequest) + if err != nil { + t.Fatal(err) + } + commitRequest := &spannerpb.CommitRequest{ + Session: session.Name, + Transaction: &spannerpb.CommitRequest_TransactionId{ + TransactionId: tx.Id, + }, + } + resp, err := c.Commit(context.Background(), commitRequest) + if err != nil { + t.Fatal(err) + } + if resp.CommitTimestamp == nil { + t.Fatalf("No commit timestamp returned") + } +} + +func TestRollbackTransaction(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + beginRequest := &spannerpb.BeginTransactionRequest{ + Session: session.Name, + Options: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadWrite_{ + ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, + }, + }, + } + tx, err := c.BeginTransaction(context.Background(), beginRequest) + if err != nil { + t.Fatal(err) + } + rollbackRequest := &spannerpb.RollbackRequest{ + Session: session.Name, + TransactionId: tx.Id, + } + err = c.Rollback(context.Background(), rollbackRequest) + if err != nil { + t.Fatal(err) + } +} diff --git a/spanner/mocked_inmem_server.go b/spanner/mocked_inmem_server.go new file mode 100644 index 00000000000..b98b17724ac --- /dev/null +++ b/spanner/mocked_inmem_server.go @@ -0,0 +1,179 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "context" + "fmt" + "net" + "strconv" + "testing" + + "cloud.google.com/go/spanner/internal/testutil" + structpb "github.com/golang/protobuf/ptypes/struct" + "google.golang.org/api/option" + spannerpb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc" +) + +// The SQL statements and results that are already mocked for this test server. +const selectFooFromBar = "SELECT FOO FROM BAR" +const selectFooFromBarRowCount int64 = 2 +const selectFooFromBarColCount int = 1 + +var selectFooFromBarResults = [...]int64{1, 2} + +const selectSingerIDAlbumIDAlbumTitleFromAlbums = "SELECT SingerId, AlbumId, AlbumTitle FROM Albums" +const selectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount int64 = 3 +const selectSingerIDAlbumIDAlbumTitleFromAlbumsColCount int = 3 + +const updateBarSetFoo = "UPDATE FOO SET BAR=1 WHERE BAZ=2" +const updateBarSetFooRowCount = 5 + +// An InMemSpannerServer with results for a number of SQL statements readily +// mocked. +type spannerInMemTestServer struct { + testSpanner testutil.InMemSpannerServer + server *grpc.Server +} + +// Create a spannerInMemTestServer with default configuration. +func newSpannerInMemTestServer(t *testing.T) (*spannerInMemTestServer, *Client) { + s := &spannerInMemTestServer{} + client := s.setup(t) + return s, client +} + +// Create a spannerInMemTestServer with the specified configuration. +func newSpannerInMemTestServerWithConfig(t *testing.T, config ClientConfig) (*spannerInMemTestServer, *Client) { + s := &spannerInMemTestServer{} + client := s.setupWithConfig(t, config) + return s, client +} + +func (s *spannerInMemTestServer) setup(t *testing.T) *Client { + return s.setupWithConfig(t, ClientConfig{}) +} + +func (s *spannerInMemTestServer) setupWithConfig(t *testing.T, config ClientConfig) *Client { + s.testSpanner = testutil.NewInMemSpannerServer() + s.setupFooResults() + s.setupSingersResults() + s.server = grpc.NewServer() + spannerpb.RegisterSpannerServer(s.server, s.testSpanner) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + go s.server.Serve(lis) + + serverAddress := lis.Addr().String() + ctx := context.Background() + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + client, err := NewClientWithConfig(ctx, formattedDatabase, config, + option.WithEndpoint(serverAddress), + option.WithGRPCDialOption(grpc.WithInsecure()), + option.WithoutAuthentication(), + ) + if err != nil { + t.Fatal(err) + } + return client +} + +func (s *spannerInMemTestServer) setupFooResults() { + fields := make([]*spannerpb.StructType_Field, selectFooFromBarColCount) + fields[0] = &spannerpb.StructType_Field{ + Name: "FOO", + Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64}, + } + rowType := &spannerpb.StructType{ + Fields: fields, + } + metadata := &spannerpb.ResultSetMetadata{ + RowType: rowType, + } + rows := make([]*structpb.ListValue, selectFooFromBarRowCount) + for idx, value := range selectFooFromBarResults { + rowValue := make([]*structpb.Value, selectFooFromBarColCount) + rowValue[0] = &structpb.Value{ + Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(value, 10)}, + } + rows[idx] = &structpb.ListValue{ + Values: rowValue, + } + } + resultSet := &spannerpb.ResultSet{ + Metadata: metadata, + Rows: rows, + } + result := &testutil.StatementResult{Type: testutil.StatementResultResultSet, ResultSet: resultSet} + s.testSpanner.PutStatementResult(selectFooFromBar, result) + s.testSpanner.PutStatementResult(updateBarSetFoo, &testutil.StatementResult{ + Type: testutil.StatementResultUpdateCount, + UpdateCount: updateBarSetFooRowCount, + }) +} + +func (s *spannerInMemTestServer) setupSingersResults() { + fields := make([]*spannerpb.StructType_Field, selectSingerIDAlbumIDAlbumTitleFromAlbumsColCount) + fields[0] = &spannerpb.StructType_Field{ + Name: "SingerId", + Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64}, + } + fields[1] = &spannerpb.StructType_Field{ + Name: "AlbumId", + Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64}, + } + fields[2] = &spannerpb.StructType_Field{ + Name: "AlbumTitle", + Type: &spannerpb.Type{Code: spannerpb.TypeCode_STRING}, + } + rowType := &spannerpb.StructType{ + Fields: fields, + } + metadata := &spannerpb.ResultSetMetadata{ + RowType: rowType, + } + rows := make([]*structpb.ListValue, selectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) + var idx int64 + for idx = 0; idx < selectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount; idx++ { + rowValue := make([]*structpb.Value, selectSingerIDAlbumIDAlbumTitleFromAlbumsColCount) + rowValue[0] = &structpb.Value{ + Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(idx+1, 10)}, + } + rowValue[1] = &structpb.Value{ + Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(idx*10+idx, 10)}, + } + rowValue[2] = &structpb.Value{ + Kind: &structpb.Value_StringValue{StringValue: fmt.Sprintf("Album title %d", idx)}, + } + rows[idx] = &structpb.ListValue{ + Values: rowValue, + } + } + resultSet := &spannerpb.ResultSet{ + Metadata: metadata, + Rows: rows, + } + result := &testutil.StatementResult{Type: testutil.StatementResultResultSet, ResultSet: resultSet} + s.testSpanner.PutStatementResult(selectSingerIDAlbumIDAlbumTitleFromAlbums, result) +} + +func (s *spannerInMemTestServer) teardown(client *Client) { + client.Close() + s.server.Stop() +} diff --git a/spanner/pdml.go b/spanner/pdml.go index 168c46a9ff1..c03c13a60e0 100644 --- a/spanner/pdml.go +++ b/spanner/pdml.go @@ -53,27 +53,20 @@ func (c *Client) PartitionedUpdate(ctx context.Context, statement Statement) (co defer s.delete(ctx) sh = &sessionHandle{session: s} // Begin transaction. - err = runRetryable(contextWithOutgoingMetadata(ctx, sh.getMetadata()), func(ctx context.Context) error { - res, e := sc.BeginTransaction(ctx, &sppb.BeginTransactionRequest{ - Session: sh.getID(), - Options: &sppb.TransactionOptions{ - Mode: &sppb.TransactionOptions_PartitionedDml_{PartitionedDml: &sppb.TransactionOptions_PartitionedDml{}}, - }, - }) - if e != nil { - return e - } - tx = res.Id - return nil + res, err := sc.BeginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.BeginTransactionRequest{ + Session: sh.getID(), + Options: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_PartitionedDml_{PartitionedDml: &sppb.TransactionOptions_PartitionedDml{}}, + }, }) if err != nil { return 0, toSpannerError(err) } - params, paramTypes, err := statement.convertParams() if err != nil { return 0, toSpannerError(err) } + tx = res.Id req := &sppb.ExecuteSqlRequest{ Session: sh.getID(), diff --git a/spanner/pdml_test.go b/spanner/pdml_test.go index 4c5a743ef14..743f5113cc1 100644 --- a/spanner/pdml_test.go +++ b/spanner/pdml_test.go @@ -16,28 +16,23 @@ package spanner import ( "context" - "io" "testing" - "cloud.google.com/go/spanner/internal/testutil" - sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" ) func TestMockPartitionedUpdate(t *testing.T) { t.Parallel() ctx := context.Background() - ms := testutil.NewMockCloudSpanner(t, trxTs) - ms.Serve() - mc := sppb.NewSpannerClient(dialMock(t, ms)) - client := &Client{database: "mockdb"} - client.clients = append(client.clients, mc) - stmt := NewStatement("UPDATE t SET x = 2 WHERE x = 1") + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + + stmt := NewStatement(updateBarSetFoo) rowCount, err := client.PartitionedUpdate(ctx, stmt) if err != nil { t.Fatal(err) } - want := int64(3) + want := int64(updateBarSetFooRowCount) if rowCount != want { t.Errorf("got %d, want %d", rowCount, want) } @@ -46,13 +41,10 @@ func TestMockPartitionedUpdate(t *testing.T) { func TestMockPartitionedUpdateWithQuery(t *testing.T) { t.Parallel() ctx := context.Background() - ms := testutil.NewMockCloudSpanner(t, trxTs) - ms.AddMsg(io.EOF, true) - ms.Serve() - mc := sppb.NewSpannerClient(dialMock(t, ms)) - client := &Client{database: "mockdb"} - client.clients = append(client.clients, mc) - stmt := NewStatement("SELECT t.key key, t.value value FROM t_mock t") + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + + stmt := NewStatement(selectFooFromBar) _, err := client.PartitionedUpdate(ctx, stmt) wantCode := codes.InvalidArgument if serr, ok := err.(*Error); !ok || serr.Code != wantCode { diff --git a/spanner/retry.go b/spanner/retry.go index 033529c5c80..c45e80b4d63 100644 --- a/spanner/retry.go +++ b/spanner/retry.go @@ -18,12 +18,9 @@ package spanner import ( "context" - "fmt" "strings" "time" - "cloud.google.com/go/internal/trace" - "cloud.google.com/go/spanner/internal/backoff" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" edpb "google.golang.org/genproto/googleapis/rpc/errdetails" @@ -35,16 +32,6 @@ const ( retryInfoKey = "google.rpc.retryinfo-bin" ) -// errRetry returns an unavailable error under error namespace EsOther. It is a -// generic retryable error that is used to mask and recover unretryable errors -// in a retry loop. -func errRetry(err error) error { - if se, ok := err.(*Error); ok { - return &Error{codes.Unavailable, fmt.Sprintf("generic Cloud Spanner retryable error: { %v }", se.Error()), se.trailers} - } - return spannerErrorf(codes.Unavailable, "generic Cloud Spanner retryable error: { %v }", err.Error()) -} - // isErrorClosing reports whether the error is generated by gRPC layer talking // to a closed server. func isErrorClosing(err error) bool { @@ -158,50 +145,3 @@ func extractRetryDelay(err error) (time.Duration, bool) { } return delay, true } - -// runRetryable keeps attempting to run f until one of the following happens: -// 1) f returns nil error or an unretryable error; -// 2) context is cancelled or timeout. -// -// TODO: consider using https://github.com/googleapis/gax-go/v2 once it -// becomes available internally. -func runRetryable(ctx context.Context, f func(context.Context) error) error { - return toSpannerError(runRetryableNoWrap(ctx, f)) -} - -// Like runRetryable, but doesn't wrap the returned error in a spanner.Error. -func runRetryableNoWrap(ctx context.Context, f func(context.Context) error) error { - var funcErr error - retryCount := 0 - for { - select { - case <-ctx.Done(): - // Do context check here so that even f() failed to do so (for - // example, gRPC implementation bug), the loop can still have a - // chance to exit as expected. - return errContextCanceled(ctx, funcErr) - default: - } - funcErr = f(ctx) - if funcErr == nil { - return nil - } - if isRetryable(funcErr) { - // Error is retryable, do exponential backoff and continue. - b, ok := extractRetryDelay(funcErr) - if !ok { - b = backoff.DefaultBackoff.Delay(retryCount) - } - trace.TracePrintf(ctx, nil, "Backing off for %s, then retrying", b) - select { - case <-ctx.Done(): - return errContextCanceled(ctx, funcErr) - case <-time.After(b): - } - retryCount++ - continue - } - // Error isn't retryable / no error, return immediately. - return funcErr - } -} diff --git a/spanner/retry_test.go b/spanner/retry_test.go index 64516f647db..979f164e54c 100644 --- a/spanner/retry_test.go +++ b/spanner/retry_test.go @@ -17,9 +17,6 @@ limitations under the License. package spanner import ( - "context" - "errors" - "fmt" "testing" "time" @@ -31,69 +28,6 @@ import ( "google.golang.org/grpc/status" ) -// Test if runRetryable loop deals with various errors correctly. -func TestRetry(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - responses := []error{ - status.Errorf(codes.Internal, "transport is closing"), - status.Errorf(codes.Unknown, "unexpected EOF"), - status.Errorf(codes.Internal, "unexpected EOF"), - status.Errorf(codes.Internal, "stream terminated by RST_STREAM with error code: 2"), - status.Errorf(codes.Unavailable, "service is currently unavailable"), - errRetry(fmt.Errorf("just retry it")), - } - err := runRetryable(context.Background(), func(ct context.Context) error { - var r error - if len(responses) > 0 { - r = responses[0] - responses = responses[1:] - } - return r - }) - if err != nil { - t.Errorf("runRetryable should be able to survive all retryable errors, but it returns %v", err) - } - // Unretryable errors - injErr := errors.New("this is unretryable") - err = runRetryable(context.Background(), func(ct context.Context) error { - return injErr - }) - if wantErr := toSpannerError(injErr); !testEqual(err, wantErr) { - t.Errorf("runRetryable returns error %v, want %v", err, wantErr) - } - // Timeout - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - retryErr := errRetry(fmt.Errorf("still retrying")) - err = runRetryable(ctx, func(ct context.Context) error { - // Expect to trigger timeout in retryable runner after 10 executions. - <-time.After(100 * time.Millisecond) - // Let retryable runner to retry so that timeout will eventually happen. - return retryErr - }) - // Check error code and error message. - if wantErrCode, wantErr := codes.DeadlineExceeded, errContextCanceled(ctx, retryErr); ErrCode(err) != wantErrCode || !testEqual(err, wantErr) { - t.Errorf("=\n<%v, %v>, want:\n<%v, %v>", ErrCode(err), err, wantErrCode, wantErr) - } - // Cancellation - ctx, cancel = context.WithCancel(context.Background()) - retries := 3 - retryErr = errRetry(fmt.Errorf("retry before cancel")) - err = runRetryable(ctx, func(ct context.Context) error { - retries-- - if retries == 0 { - cancel() - } - return retryErr - }) - // Check error code, error message, retry count. - if wantErrCode, wantErr := codes.Canceled, errContextCanceled(ctx, retryErr); ErrCode(err) != wantErrCode || !testEqual(err, wantErr) || retries != 0 { - t.Errorf("=\n<%v, %v, %v>, want:\n<%v, %v, %v>", ErrCode(err), err, retries, wantErrCode, wantErr, 0) - } -} - func TestRetryInfo(t *testing.T) { b, _ := proto.Marshal(&edpb.RetryInfo{ RetryDelay: ptypes.DurationProto(time.Second), @@ -101,7 +35,7 @@ func TestRetryInfo(t *testing.T) { trailers := map[string]string{ retryInfoKey: string(b), } - gotDelay, ok := extractRetryDelay(errRetry(toSpannerErrorWithMetadata(status.Errorf(codes.Aborted, ""), metadata.New(trailers)))) + gotDelay, ok := extractRetryDelay(toSpannerErrorWithMetadata(status.Errorf(codes.Aborted, ""), metadata.New(trailers))) if !ok || !testEqual(time.Second, gotDelay) { t.Errorf(" = <%t, %v>, want ", ok, gotDelay, time.Second) } diff --git a/spanner/session.go b/spanner/session.go index 64158fff41b..27bad1bf26d 100644 --- a/spanner/session.go +++ b/spanner/session.go @@ -28,6 +28,7 @@ import ( "time" "cloud.google.com/go/internal/trace" + vkit "cloud.google.com/go/spanner/apiv1" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -72,7 +73,7 @@ func (sh *sessionHandle) getID() string { // getClient gets the Cloud Spanner RPC client associated with the session ID // in sessionHandle. -func (sh *sessionHandle) getClient() sppb.SpannerClient { +func (sh *sessionHandle) getClient() *vkit.Client { sh.mu.Lock() defer sh.mu.Unlock() if sh.session == nil { @@ -121,7 +122,7 @@ func (sh *sessionHandle) destroy() { type session struct { // client is the RPC channel to Cloud Spanner. It is set only once during // session's creation. - client sppb.SpannerClient + client *vkit.Client // id is the unique id of the session in Cloud Spanner. It is set only once // during session's creation. id string @@ -183,11 +184,9 @@ func (s *session) String() string { func (s *session) ping() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - return runRetryable(ctx, func(ctx context.Context) error { - // s.getID is safe even when s is invalid. - _, err := s.client.GetSession(contextWithOutgoingMetadata(ctx, s.pool.md), &sppb.GetSessionRequest{Name: s.getID()}) - return err - }) + // s.getID is safe even when s is invalid. + _, err := s.client.GetSession(contextWithOutgoingMetadata(ctx, s.pool.md), &sppb.GetSessionRequest{Name: s.getID()}) + return err } // setHcIndex atomically sets the session's index in the healthcheck queue and @@ -290,13 +289,9 @@ func (s *session) destroy(isExpire bool) bool { } func (s *session) delete(ctx context.Context) { - // Ignore the error returned by runRetryable because even if we fail to - // explicitly destroy the session, it will be eventually garbage collected - // by Cloud Spanner. - err := runRetryable(ctx, func(ctx context.Context) error { - _, e := s.client.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: s.getID()}) - return e - }) + // Ignore the error because even if we fail to explicitly destroy the + // session, it will be eventually garbage collected by Cloud Spanner. + err := s.client.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: s.getID()}) if err != nil { log.Printf("Failed to delete session %v. Error: %v", s.getID(), err) } @@ -320,7 +315,7 @@ func (s *session) prepareForWrite(ctx context.Context) error { type SessionPoolConfig struct { // getRPCClient is the caller supplied method for getting a gRPC client to // Cloud Spanner, this makes session pool able to use client pooling. - getRPCClient func() (sppb.SpannerClient, error) + getRPCClient func() (*vkit.Client, error) // MaxOpened is the maximum number of opened sessions allowed by the session // pool. If the client tries to open a session and there are already @@ -539,6 +534,9 @@ func (p *sessionPool) createSession(ctx context.Context) (*session, error) { doneCreate(false) // Should return error directly because of the previous retries on // CreateSession RPC. + // If the error is a timeout, there is a chance that the session was + // created on the server but is not known to the session pool. This + // session will then be garbage collected by the server after 1 hour. return nil, err } s.pool = p @@ -547,23 +545,17 @@ func (p *sessionPool) createSession(ctx context.Context) (*session, error) { return s, nil } -func createSession(ctx context.Context, sc sppb.SpannerClient, db string, labels map[string]string, md metadata.MD) (*session, error) { +func createSession(ctx context.Context, sc *vkit.Client, db string, labels map[string]string, md metadata.MD) (*session, error) { var s *session - err := runRetryable(ctx, func(ctx context.Context) error { - sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{ - Database: db, - Session: &sppb.Session{Labels: labels}, - }) - if e != nil { - return e - } - // If no error, construct the new session. - s = &session{valid: true, client: sc, id: sid.Name, createTime: time.Now(), md: md} - return nil + sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{ + Database: db, + Session: &sppb.Session{Labels: labels}, }) - if err != nil { - return nil, err + if e != nil { + return nil, toSpannerError(e) } + // If no error, construct the new session. + s = &session{valid: true, client: sc, id: sid.Name, createTime: time.Now(), md: md} return s, nil } diff --git a/spanner/session_test.go b/spanner/session_test.go index d04e93d2e62..2f3f7dc32c4 100644 --- a/spanner/session_test.go +++ b/spanner/session_test.go @@ -23,13 +23,11 @@ import ( "fmt" "math/rand" "sync" - "sync/atomic" "testing" "time" + vkit "cloud.google.com/go/spanner/apiv1" "cloud.google.com/go/spanner/internal/testutil" - sppb "google.golang.org/genproto/googleapis/spanner/v1" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -37,8 +35,9 @@ import ( // TestSessionPoolConfigValidation tests session pool config validation. func TestSessionPoolConfigValidation(t *testing.T) { t.Parallel() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) - sc := testutil.NewMockCloudSpannerClient(t) for _, test := range []struct { spc SessionPoolConfig err error @@ -49,8 +48,8 @@ func TestSessionPoolConfigValidation(t *testing.T) { }, { SessionPoolConfig{ - getRPCClient: func() (sppb.SpannerClient, error) { - return sc, nil + getRPCClient: func() (*vkit.Client, error) { + return client.clients[0], nil }, MinOpened: 10, MaxOpened: 5, @@ -68,8 +67,9 @@ func TestSessionPoolConfigValidation(t *testing.T) { func TestSessionCreation(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + sp := client.idleSessions // Take three sessions from session pool, this should trigger session pool // to create three new sessions. @@ -87,7 +87,7 @@ func TestSessionCreation(t *testing.T) { if len(gotDs) != len(shs) { t.Fatalf("session pool created %v sessions, want %v", len(gotDs), len(shs)) } - if wantDs := mock.DumpSessions(); !testEqual(gotDs, wantDs) { + if wantDs := server.testSpanner.DumpSessions(); !testEqual(gotDs, wantDs) { t.Fatalf("session pool creates sessions %v, want %v", gotDs, wantDs) } // Verify that created sessions are recorded correctly in session pool. @@ -119,8 +119,12 @@ func TestTakeFromIdleList(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{MaxIdle: 10}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{MaxIdle: 10}, + }) + defer server.teardown(client) + sp := client.idleSessions // Take ten sessions from session pool and recycle them. shs := make([]*sessionHandle, 10) @@ -139,7 +143,7 @@ func TestTakeFromIdleList(t *testing.T) { } // Further session requests from session pool won't cause mockclient to // create more sessions. - wantSessions := mock.DumpSessions() + wantSessions := server.testSpanner.DumpSessions() // Take ten sessions from session pool again, this time all sessions should // come from idle list. gotSessions := map[string]bool{} @@ -165,8 +169,12 @@ func TestTakeWriteSessionFromIdleList(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{MaxIdle: 20}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{MaxIdle: 20}, + }) + defer server.teardown(client) + sp := client.idleSessions // Take ten sessions from session pool and recycle them. shs := make([]*sessionHandle, 10) @@ -185,7 +193,7 @@ func TestTakeWriteSessionFromIdleList(t *testing.T) { } // Further session requests from session pool won't cause mockclient to // create more sessions. - wantSessions := mock.DumpSessions() + wantSessions := server.testSpanner.DumpSessions() // Take ten sessions from session pool again, this time all sessions should // come from idle list. gotSessions := map[string]bool{} @@ -211,12 +219,16 @@ func TestTakeFromIdleListChecked(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{ - MaxIdle: 1, - HealthCheckInterval: 50 * time.Millisecond, - healthCheckSampleInterval: 10 * time.Millisecond, - }) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MaxIdle: 1, + HealthCheckInterval: 50 * time.Millisecond, + healthCheckSampleInterval: 10 * time.Millisecond, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Stop healthcheck workers to simulate slow pings. sp.hc.close() @@ -251,7 +263,7 @@ func TestTakeFromIdleListChecked(t *testing.T) { // The two back-to-back session requests shouldn't trigger any session // pings because sessionPool.Take // reschedules the next healthcheck. - if got, want := mock.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { + if got, want := server.testSpanner.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { t.Fatalf("%v - got ping session requests: %v, want %v", i, got, want) } sh.recycle() @@ -260,10 +272,10 @@ func TestTakeFromIdleListChecked(t *testing.T) { // Inject session error to server stub, and take the session from the // session pool, the old session should be destroyed and the session pool // will create a new session. - mock.GetSessionFn = func(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) { - mock.MockCloudSpannerClient.ReceivedRequests <- r - return nil, status.Errorf(codes.NotFound, "Session not found") - } + server.testSpanner.PutExecutionTime(testutil.MethodGetSession, + testutil.SimulatedExecutionTime{ + Errors: []error{status.Errorf(codes.NotFound, "Session not found")}, + }) // Delay to trigger sessionPool.Take to ping the session. // TODO(deklerk): get rid of this @@ -276,7 +288,7 @@ func TestTakeFromIdleListChecked(t *testing.T) { if err != nil { t.Fatalf("failed to get session: %v", err) } - ds := mock.DumpSessions() + ds := server.testSpanner.DumpSessions() if len(ds) != 1 { t.Fatalf("dumped sessions from mockclient: %v, want %v", ds, sh.getID()) } @@ -292,12 +304,16 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{ - MaxIdle: 1, - HealthCheckInterval: 50 * time.Millisecond, - healthCheckSampleInterval: 10 * time.Millisecond, - }) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MaxIdle: 1, + HealthCheckInterval: 50 * time.Millisecond, + healthCheckSampleInterval: 10 * time.Millisecond, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Stop healthcheck workers to simulate slow pings. sp.hc.close() @@ -330,7 +346,7 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { } // The two back-to-back session requests shouldn't trigger any session // pings because sessionPool.Take reschedules the next healthcheck. - if got, want := mock.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { + if got, want := server.testSpanner.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { t.Fatalf("%v - got ping session requests: %v, want %v", i, got, want) } sh.recycle() @@ -339,10 +355,10 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { // Inject session error to mockclient, and take the session from the // session pool, the old session should be destroyed and the session pool // will create a new session. - mock.GetSessionFn = func(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) { - mock.MockCloudSpannerClient.ReceivedRequests <- r - return nil, status.Errorf(codes.NotFound, "Session not found") - } + server.testSpanner.PutExecutionTime(testutil.MethodGetSession, + testutil.SimulatedExecutionTime{ + Errors: []error{status.Errorf(codes.NotFound, "Session not found")}, + }) // Delay to trigger sessionPool.Take to ping the session. // TOOD(deklerk) get rid of this @@ -352,7 +368,7 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { if err != nil { t.Fatalf("failed to get session: %v", err) } - ds := mock.DumpSessions() + ds := server.testSpanner.DumpSessions() if len(ds) != 1 { t.Fatalf("dumped sessions from mockclient: %v, want %v", ds, sh.getID()) } @@ -365,8 +381,14 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { func TestMaxOpenedSessions(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MaxOpened: 1}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MaxOpened: 1, + }, + }) + defer server.teardown(client) + sp := client.idleSessions sh1, err := sp.take(ctx) if err != nil { @@ -404,8 +426,14 @@ func TestMaxOpenedSessions(t *testing.T) { func TestMinOpenedSessions(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MinOpened: 1}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Take ten sessions from session pool and recycle them. var ss []*session @@ -441,20 +469,21 @@ func TestMinOpenedSessions(t *testing.T) { func TestMaxBurst(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{MaxBurst: 1}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MaxBurst: 1, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Will cause session creation RPC to be retried forever. - allowRequests := make(chan struct{}) - mock.CreateSessionFn = func(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) { - select { - case <-allowRequests: - return mock.MockCloudSpannerClient.CreateSession(c, r, opts...) - default: - mock.MockCloudSpannerClient.ReceivedRequests <- r - return nil, status.Errorf(codes.Unavailable, "try later") - } - } + server.testSpanner.PutExecutionTime(testutil.MethodCreateSession, + testutil.SimulatedExecutionTime{ + Errors: []error{status.Errorf(codes.Unavailable, "try later")}, + KeepError: true, + }) // This session request will never finish until the injected error is // cleared. @@ -483,7 +512,10 @@ func TestMaxBurst(t *testing.T) { } // Let the first session request succeed. - close(allowRequests) + server.testSpanner.Freeze() + server.testSpanner.PutExecutionTime(testutil.MethodCreateSession, testutil.SimulatedExecutionTime{}) + //close(allowRequests) + server.testSpanner.Unfreeze() // Now new session request can proceed because the first session request will eventually succeed. sh, err := sp.take(ctx) @@ -499,8 +531,15 @@ func TestMaxBurst(t *testing.T) { func TestSessionRecycle(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MinOpened: 1, MaxIdle: 5}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + MaxIdle: 5, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Test session is correctly recycled and reused. for i := 0; i < 20; i++ { @@ -530,8 +569,14 @@ func TestSessionDestroy(t *testing.T) { t.Skip("s.destroy(true) is flakey") t.Parallel() ctx := context.Background() - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MinOpened: 1}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + }, + }) + defer server.teardown(client) + sp := client.idleSessions <-time.After(10 * time.Millisecond) // maintainer will create one session, we wait for it create session to avoid flakiness in test sh, err := sp.take(ctx) @@ -587,11 +632,15 @@ func TestHcHeap(t *testing.T) { func TestHealthCheckScheduler(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{ - HealthCheckInterval: 50 * time.Millisecond, - healthCheckSampleInterval: 10 * time.Millisecond, - }) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + HealthCheckInterval: 50 * time.Millisecond, + healthCheckSampleInterval: 10 * time.Millisecond, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Create 50 sessions. ss := []string{} @@ -605,7 +654,7 @@ func TestHealthCheckScheduler(t *testing.T) { // Wait for 10-30 pings per session. waitFor(t, func() error { - dp := mock.DumpPings() + dp := server.testSpanner.DumpPings() gotPings := map[string]int64{} for _, p := range dp { gotPings[p]++ @@ -625,8 +674,15 @@ func TestHealthCheckScheduler(t *testing.T) { func TestWriteSessionsPrepared(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{WriteSessions: 0.5, MaxIdle: 20}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + WriteSessions: 0.5, + MaxIdle: 20, + }, + }) + defer server.teardown(client) + sp := client.idleSessions shs := make([]*sessionHandle, 10) var err error @@ -688,8 +744,16 @@ func TestWriteSessionsPrepared(t *testing.T) { func TestTakeFromWriteQueue(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MaxOpened: 1, WriteSessions: 1.0, MaxIdle: 1}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MaxOpened: 1, + WriteSessions: 1.0, + MaxIdle: 1, + }, + }) + defer server.teardown(client) + sp := client.idleSessions sh, err := sp.take(ctx) if err != nil { @@ -718,20 +782,15 @@ func TestTakeFromWriteQueue(t *testing.T) { func TestSessionHealthCheck(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{ - HealthCheckInterval: 50 * time.Millisecond, - healthCheckSampleInterval: 10 * time.Millisecond, - }) - defer cleanup() - - var requestShouldErr int64 // 0 == false, 1 == true - mock.GetSessionFn = func(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) { - if shouldErr := atomic.LoadInt64(&requestShouldErr); shouldErr == 1 { - mock.MockCloudSpannerClient.ReceivedRequests <- r - return nil, status.Errorf(codes.NotFound, "Session not found") - } - return mock.MockCloudSpannerClient.GetSession(c, r, opts...) - } + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + HealthCheckInterval: 50 * time.Millisecond, + healthCheckSampleInterval: 10 * time.Millisecond, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Test pinging sessions. sh, err := sp.take(ctx) @@ -741,7 +800,7 @@ func TestSessionHealthCheck(t *testing.T) { // Wait for healthchecker to send pings to session. waitFor(t, func() error { - pings := mock.DumpPings() + pings := server.testSpanner.DumpPings() if len(pings) == 0 || pings[0] != sh.getID() { return fmt.Errorf("healthchecker didn't send any ping to session %v", sh.getID()) } @@ -753,7 +812,14 @@ func TestSessionHealthCheck(t *testing.T) { t.Fatalf("cannot get session from session pool: %v", err) } - atomic.SwapInt64(&requestShouldErr, 1) + server.testSpanner.Freeze() + server.testSpanner.PutExecutionTime(testutil.MethodGetSession, + testutil.SimulatedExecutionTime{ + Errors: []error{status.Errorf(codes.NotFound, "Session not found")}, + KeepError: true, + }) + server.testSpanner.Unfreeze() + //atomic.SwapInt64(&requestShouldErr, 1) // Wait for healthcheck workers to find the broken session and tear it down. // TODO(deklerk): get rid of this @@ -764,7 +830,9 @@ func TestSessionHealthCheck(t *testing.T) { t.Fatalf("session(%v) is still alive, want it to be dropped by healthcheck workers", s) } - atomic.SwapInt64(&requestShouldErr, 0) + server.testSpanner.Freeze() + server.testSpanner.PutExecutionTime(testutil.MethodGetSession, testutil.SimulatedExecutionTime{}) + server.testSpanner.Unfreeze() // Test garbage collection. sh, err = sp.take(ctx) @@ -805,18 +873,17 @@ func TestStressSessionPool(t *testing.T) { cfg.HealthCheckInterval = 50 * time.Millisecond cfg.healthCheckSampleInterval = 10 * time.Millisecond cfg.HealthCheckWorkers = 50 - sc := testutil.NewMockCloudSpannerClient(t) - cfg.getRPCClient = func() (sppb.SpannerClient, error) { - return sc, nil - } - sp, _ := newSessionPool("mockdb", cfg, nil) - defer sp.hc.close() - defer sp.close() + + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: cfg, + }) + sp := client.idleSessions for i := 0; i < 100; i++ { wg.Add(1) // Schedule a test worker. - go func(idx int, pool *sessionPool, client sppb.SpannerClient) { + go func(idx int, pool *sessionPool, client *Client) { defer wg.Done() // Test worker iterates 1K times and tries different // session / session pool operations. @@ -868,7 +935,7 @@ func TestStressSessionPool(t *testing.T) { // recycle the session. sh.recycle() } - }(i, sp, sc) + }(i, sp, client) } wg.Wait() sp.hc.close() @@ -876,7 +943,7 @@ func TestStressSessionPool(t *testing.T) { // stable. idleSessions := map[string]bool{} hcSessions := map[string]bool{} - mockSessions := sc.DumpSessions() + mockSessions := server.testSpanner.DumpSessions() // Dump session pool's idle list. for sl := sp.idleList.Front(); sl != nil; sl = sl.Next() { s := sl.Value.(*session) @@ -912,14 +979,23 @@ func TestStressSessionPool(t *testing.T) { if !testEqual(idleSessions, hcSessions) { t.Fatalf("%v: sessions in idle list (%v) != sessions in healthcheck queue (%v)", ti, idleSessions, hcSessions) } - if !testEqual(hcSessions, mockSessions) { - t.Fatalf("%v: sessions in healthcheck queue (%v) != sessions in mockclient (%v)", ti, hcSessions, mockSessions) + // The server may contain more sessions than the health check queue. + // This can be caused by a timeout client side during a CreateSession + // request. The request may still be received and executed by the + // server, but the session pool will not register the session. + for id, b := range hcSessions { + if b && !mockSessions[id] { + t.Fatalf("%v: session in healthcheck queue (%v) was not found on server", ti, id) + } } sp.close() - mockSessions = sc.DumpSessions() - if len(mockSessions) != 0 { - t.Fatalf("Found live sessions: %v", mockSessions) + mockSessions = server.testSpanner.DumpSessions() + for id, b := range hcSessions { + if b && mockSessions[id] { + t.Fatalf("Found session from pool still live on server: %v", id) + } } + server.teardown(client) } } @@ -941,8 +1017,15 @@ func TestMaintainer(t *testing.T) { minOpened := uint64(5) maxIdle := uint64(4) - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MinOpened: minOpened, MaxIdle: maxIdle}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: minOpened, + MaxIdle: maxIdle, + }, + }) + defer server.teardown(client) + sp := client.idleSessions sampleInterval := sp.SessionPoolConfig.healthCheckSampleInterval @@ -1003,40 +1086,25 @@ func TestMaintainer(t *testing.T) { // // Historical context: This test also checks that a low // healthCheckSampleInterval does not prevent it from opening connections. +// The low healthCheckSampleInterval will however sometimes cause session +// creations to time out. That should not be considered a problem, but it +// could cause the test case to fail if it happens too often. // See: https://github.com/googleapis/google-cloud-go/issues/1259 func TestMaintainer_CreatesSessions(t *testing.T) { t.Parallel() - - rawServerStub := testutil.NewMockCloudSpannerClient(t) - serverClientMock := testutil.FuncMock{MockCloudSpannerClient: rawServerStub} - serverClientMock.CreateSessionFn = func(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) { - time.Sleep(10 * time.Millisecond) - return rawServerStub.CreateSession(c, r, opts...) - } spc := SessionPoolConfig{ MinOpened: 10, MaxIdle: 10, - healthCheckSampleInterval: time.Millisecond, - getRPCClient: func() (sppb.SpannerClient, error) { - return &serverClientMock, nil - }, - } - db := "mockdb" - sp, err := newSessionPool(db, spc, nil) - if err != nil { - t.Fatalf("cannot create session pool: %v", err) + healthCheckSampleInterval: 20 * time.Millisecond, } - client := Client{ - database: db, - idleSessions: sp, - } - defer func() { - client.Close() - sp.hc.close() - sp.close() - }() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: spc, + }) + defer server.teardown(client) + sp := client.idleSessions - timeoutAmt := 2 * time.Second + timeoutAmt := 4 * time.Second timeout := time.After(timeoutAmt) var numOpened uint64 loop: diff --git a/spanner/transaction.go b/spanner/transaction.go index 4a6a336de08..a53486b674f 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -22,7 +22,10 @@ import ( "sync/atomic" "time" + "github.com/googleapis/gax-go/v2" + "cloud.google.com/go/internal/trace" + vkit "cloud.google.com/go/spanner/apiv1" "google.golang.org/api/iterator" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc" @@ -364,24 +367,22 @@ func (t *ReadOnlyTransaction) begin(ctx context.Context) error { if err != nil { return err } - err = runRetryable(contextWithOutgoingMetadata(ctx, sh.getMetadata()), func(ctx context.Context) error { - res, e := sh.getClient().BeginTransaction(ctx, &sppb.BeginTransactionRequest{ - Session: sh.getID(), - Options: &sppb.TransactionOptions{ - Mode: &sppb.TransactionOptions_ReadOnly_{ - ReadOnly: buildTransactionOptionsReadOnly(t.getTimestampBound(), true), - }, + res, err := sh.getClient().BeginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.BeginTransactionRequest{ + Session: sh.getID(), + Options: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadOnly_{ + ReadOnly: buildTransactionOptionsReadOnly(t.getTimestampBound(), true), }, - }) - if e != nil { - return e - } + }, + }) + if err == nil { tx = res.Id if res.ReadTimestamp != nil { rts = time.Unix(res.ReadTimestamp.Seconds, int64(res.ReadTimestamp.Nanos)) } - return nil - }) + } else { + err = toSpannerError(err) + } t.mu.Lock() // defer function will be executed with t.mu being held. @@ -798,27 +799,19 @@ func (t *ReadWriteTransaction) release(err error) { } } -func beginTransaction(ctx context.Context, sid string, client sppb.SpannerClient) (transactionID, error) { - var tx transactionID - err := runRetryable(ctx, func(ctx context.Context) error { - res, e := client.BeginTransaction(ctx, &sppb.BeginTransactionRequest{ - Session: sid, - Options: &sppb.TransactionOptions{ - Mode: &sppb.TransactionOptions_ReadWrite_{ - ReadWrite: &sppb.TransactionOptions_ReadWrite{}, - }, +func beginTransaction(ctx context.Context, sid string, client *vkit.Client) (transactionID, error) { + res, err := client.BeginTransaction(ctx, &sppb.BeginTransactionRequest{ + Session: sid, + Options: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadWrite_{ + ReadWrite: &sppb.TransactionOptions_ReadWrite{}, }, - }) - if e != nil { - return e - } - tx = res.Id - return nil + }, }) if err != nil { return nil, err } - return tx, nil + return res.Id, nil } // begin starts a read-write transacton on Cloud Spanner, it is always called @@ -857,23 +850,21 @@ func (t *ReadWriteTransaction) commit(ctx context.Context) (time.Time, error) { if sid == "" || client == nil { return ts, errSessionClosed(t.sh) } - err = runRetryable(contextWithOutgoingMetadata(ctx, t.sh.getMetadata()), func(ctx context.Context) error { - var trailer metadata.MD - res, e := client.Commit(ctx, &sppb.CommitRequest{ - Session: sid, - Transaction: &sppb.CommitRequest_TransactionId{ - TransactionId: t.tx, - }, - Mutations: mPb, - }, grpc.Trailer(&trailer)) - if e != nil { - return toSpannerErrorWithMetadata(e, trailer) - } - if tstamp := res.GetCommitTimestamp(); tstamp != nil { - ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) - } - return nil - }) + + var trailer metadata.MD + res, e := client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata()), &sppb.CommitRequest{ + Session: sid, + Transaction: &sppb.CommitRequest_TransactionId{ + TransactionId: t.tx, + }, + Mutations: mPb, + }, gax.WithGRPCOptions(grpc.Trailer(&trailer))) + if e != nil { + return ts, toSpannerErrorWithMetadata(e, trailer) + } + if tstamp := res.GetCommitTimestamp(); tstamp != nil { + ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) + } if shouldDropSession(err) { t.sh.destroy() } @@ -893,12 +884,9 @@ func (t *ReadWriteTransaction) rollback(ctx context.Context) { if sid == "" || client == nil { return } - err := runRetryable(contextWithOutgoingMetadata(ctx, t.sh.getMetadata()), func(ctx context.Context) error { - _, e := client.Rollback(ctx, &sppb.RollbackRequest{ - Session: sid, - TransactionId: t.tx, - }) - return e + err := client.Rollback(contextWithOutgoingMetadata(ctx, t.sh.getMetadata()), &sppb.RollbackRequest{ + Session: sid, + TransactionId: t.tx, }) if shouldDropSession(err) { t.sh.destroy() @@ -920,7 +908,6 @@ func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(cont // Retry the transaction using the same session on ABORT error. // Cloud Spanner will create the new transaction with the previous // one's wound-wait priority. - err = errRetry(err) return ts, err } // Not going to commit, according to API spec, should rollback the @@ -956,19 +943,21 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta // Malformed mutation found, just return the error. return ts, err } - err = runRetryable(ctx, func(ct context.Context) error { - var e error - var trailers metadata.MD + + var trailers metadata.MD + // Retry-loop for aborted transactions. + // TODO: Replace with generic retryer. + for { if sh == nil || sh.getID() == "" || sh.getClient() == nil { // No usable session for doing the commit, take one from pool. - sh, e = t.sp.take(ctx) - if e != nil { + sh, err = t.sp.take(ctx) + if err != nil { // sessionPool.Take already retries for session // creations/retrivals. - return e + return ts, err } } - res, e := sh.getClient().Commit(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.CommitRequest{ + res, err := sh.getClient().Commit(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.CommitRequest{ Session: sh.getID(), Transaction: &sppb.CommitRequest_SingleUseTransaction{ SingleUseTransaction: &sppb.TransactionOptions{ @@ -978,28 +967,24 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta }, }, Mutations: mPb, - }, grpc.Trailer(&trailers)) - if e != nil { - if isAbortErr(e) { - // Mask ABORT error as retryable, because aborted transactions - // are allowed to be retried. - return errRetry(toSpannerErrorWithMetadata(e, trailers)) - } - if shouldDropSession(e) { + }, gax.WithGRPCOptions(grpc.Trailer(&trailers))) + if err != nil && !isAbortErr(err) { + if shouldDropSession(err) { // Discard the bad session. sh.destroy() } - return e - } - if tstamp := res.GetCommitTimestamp(); tstamp != nil { - ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) + return ts, toSpannerError(err) + } else if err == nil { + if tstamp := res.GetCommitTimestamp(); tstamp != nil { + ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) + } + break } - return nil - }) + } if sh != nil { sh.recycle() } - return ts, err + return ts, toSpannerError(err) } // isAbortedErr returns true if the error indicates that an gRPC call is diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index 9698aab742f..9c10ac4d819 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -27,16 +27,16 @@ import ( "cloud.google.com/go/spanner/internal/testutil" sppb "google.golang.org/genproto/googleapis/spanner/v1" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" ) // Single can only be used once. func TestSingle(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) txn := client.Single() defer txn.Close() @@ -50,7 +50,7 @@ func TestSingle(t *testing.T) { } // Only one CreateSessionRequest is sent. - if _, err := shouldHaveReceived(mock, []interface{}{&sppb.CreateSessionRequest{}}); err != nil { + if _, err := shouldHaveReceived(server.testSpanner, []interface{}{&sppb.CreateSessionRequest{}}); err != nil { t.Fatal(err) } } @@ -59,23 +59,18 @@ func TestSingle(t *testing.T) { func TestReadOnlyTransaction_RecoverFromFailure(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) txn := client.ReadOnlyTransaction() defer txn.Close() - // First request will fail, which should trigger a retry. - errUsr := errors.New("error") - firstCall := true - mock.BeginTransactionFn = func(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) { - if firstCall { - mock.MockCloudSpannerClient.ReceivedRequests <- r - firstCall = false - return nil, errUsr - } - return mock.MockCloudSpannerClient.BeginTransaction(c, r, opts...) - } + // First request will fail. + errUsr := gstatus.Error(codes.Unknown, "error") + server.testSpanner.PutExecutionTime(testutil.MethodBeginTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{errUsr}, + }) _, _, e := txn.acquire(ctx) if wantErr := toSpannerError(errUsr); !testEqual(e, wantErr) { @@ -91,8 +86,8 @@ func TestReadOnlyTransaction_RecoverFromFailure(t *testing.T) { func TestReadOnlyTransaction_UseAfterClose(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, _, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) txn := client.ReadOnlyTransaction() txn.Close() @@ -107,12 +102,12 @@ func TestReadOnlyTransaction_UseAfterClose(t *testing.T) { func TestReadOnlyTransaction_Concurrent(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) txn := client.ReadOnlyTransaction() defer txn.Close() - mock.Freeze() + server.testSpanner.Freeze() var ( sh1 *sessionHandle sh2 *sessionHandle @@ -135,7 +130,7 @@ func TestReadOnlyTransaction_Concurrent(t *testing.T) { // TODO(deklerk): Get rid of this. <-time.After(100 * time.Millisecond) - mock.Unfreeze() + server.testSpanner.Unfreeze() wg.Wait() if sh1.session.id != sh2.session.id { t.Fatalf("Expected acquire to get same session handle, got %v and %v.", sh1, sh2) @@ -148,8 +143,8 @@ func TestReadOnlyTransaction_Concurrent(t *testing.T) { func TestApply_Single(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) ms := []*Mutation{ Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), @@ -159,7 +154,7 @@ func TestApply_Single(t *testing.T) { t.Fatalf("applyAtLeastOnce retry on abort, got %v, want nil.", e) } - if _, err := shouldHaveReceived(mock, []interface{}{ + if _, err := shouldHaveReceived(server.testSpanner, []interface{}{ &sppb.CreateSessionRequest{}, &sppb.CommitRequest{}, }); err != nil { @@ -171,20 +166,15 @@ func TestApply_Single(t *testing.T) { func TestApply_RetryOnAbort(t *testing.T) { ctx := context.Background() t.Parallel() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) // First commit will fail, and the retry will begin a new transaction. errAbrt := spannerErrorf(codes.Aborted, "") - firstCommitCall := true - mock.CommitFn = func(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) { - if firstCommitCall { - mock.MockCloudSpannerClient.ReceivedRequests <- r - firstCommitCall = false - return nil, errAbrt - } - return mock.MockCloudSpannerClient.Commit(c, r, opts...) - } + server.testSpanner.PutExecutionTime(testutil.MethodCommitTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{errAbrt}, + }) ms := []*Mutation{ Insert("Accounts", []string{"AccountId"}, []interface{}{int64(1)}), @@ -194,7 +184,7 @@ func TestApply_RetryOnAbort(t *testing.T) { t.Fatalf("ReadWriteTransaction retry on abort, got %v, want nil.", e) } - if _, err := shouldHaveReceived(mock, []interface{}{ + if _, err := shouldHaveReceived(server.testSpanner, []interface{}{ &sppb.CreateSessionRequest{}, &sppb.BeginTransactionRequest{}, &sppb.CommitRequest{}, // First commit fails. @@ -209,18 +199,18 @@ func TestApply_RetryOnAbort(t *testing.T) { func TestTransaction_NotFound(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) wantErr := spannerErrorf(codes.NotFound, "Session not found") - mock.BeginTransactionFn = func(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) { - mock.MockCloudSpannerClient.ReceivedRequests <- r - return nil, wantErr - } - mock.CommitFn = func(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) { - mock.MockCloudSpannerClient.ReceivedRequests <- r - return nil, wantErr - } + server.testSpanner.PutExecutionTime(testutil.MethodBeginTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{wantErr, wantErr, wantErr}, + }) + server.testSpanner.PutExecutionTime(testutil.MethodCommitTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{wantErr, wantErr, wantErr}, + }) txn := client.ReadOnlyTransaction() defer txn.Close() @@ -253,8 +243,8 @@ func TestTransaction_NotFound(t *testing.T) { func TestReadWriteTransaction_ErrorReturned(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) want := errors.New("an error") _, got := client.ReadWriteTransaction(ctx, func(context.Context, *ReadWriteTransaction) error { @@ -263,7 +253,7 @@ func TestReadWriteTransaction_ErrorReturned(t *testing.T) { if got != want { t.Fatalf("got %+v, want %+v", got, want) } - requests := drainRequests(mock) + requests := drainRequestsFromServer(server.testSpanner) if err := compareRequests([]interface{}{ &sppb.CreateSessionRequest{}, &sppb.BeginTransactionRequest{}, @@ -287,27 +277,27 @@ func TestReadWriteTransaction_ErrorReturned(t *testing.T) { func TestBatchDML_WithMultipleDML(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { - if _, err = tx.Update(ctx, Statement{SQL: "SELECT * FROM whatever"}); err != nil { + if _, err = tx.Update(ctx, Statement{SQL: updateBarSetFoo}); err != nil { return err } - if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: "SELECT * FROM whatever"}, {SQL: "SELECT * FROM whatever"}}); err != nil { + if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateBarSetFoo}, {SQL: updateBarSetFoo}}); err != nil { return err } - if _, err = tx.Update(ctx, Statement{SQL: "SELECT * FROM whatever"}); err != nil { + if _, err = tx.Update(ctx, Statement{SQL: updateBarSetFoo}); err != nil { return err } - _, err = tx.BatchUpdate(ctx, []Statement{{SQL: "SELECT * FROM whatever"}}) + _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateBarSetFoo}}) return err }) if err != nil { t.Fatal(err) } - gotReqs, err := shouldHaveReceived(mock, []interface{}{ + gotReqs, err := shouldHaveReceived(server.testSpanner, []interface{}{ &sppb.CreateSessionRequest{}, &sppb.BeginTransactionRequest{}, &sppb.ExecuteSqlRequest{}, @@ -339,8 +329,8 @@ func TestBatchDML_WithMultipleDML(t *testing.T) { // // Note: this in-place modifies serverClientMock by popping items off the // ReceivedRequests channel. -func shouldHaveReceived(mock *testutil.FuncMock, want []interface{}) ([]interface{}, error) { - got := drainRequests(mock) +func shouldHaveReceived(server testutil.InMemSpannerServer, want []interface{}) ([]interface{}, error) { + got := drainRequestsFromServer(server) return got, compareRequests(want, got) } @@ -368,12 +358,12 @@ func compareRequests(want []interface{}, got []interface{}) error { return nil } -func drainRequests(mock *testutil.FuncMock) []interface{} { +func drainRequestsFromServer(server testutil.InMemSpannerServer) []interface{} { var reqs []interface{} loop: for { select { - case req := <-mock.ReceivedRequests: + case req := <-server.ReceivedRequests(): reqs = append(reqs, req) default: break loop @@ -381,30 +371,3 @@ loop: } return reqs } - -// serverClientMock sets up a client configured to a NewMockCloudSpannerClient -// that is wrapped with a function-injectable wrapper. -// -// Note: be sure to call cleanup! -func serverClientMock(t *testing.T, spc SessionPoolConfig) (_ *Client, _ *sessionPool, _ *testutil.FuncMock, cleanup func()) { - rawServerStub := testutil.NewMockCloudSpannerClient(t) - serverClientMock := testutil.FuncMock{MockCloudSpannerClient: rawServerStub} - spc.getRPCClient = func() (sppb.SpannerClient, error) { - return &serverClientMock, nil - } - db := "mockdb" - sp, err := newSessionPool(db, spc, nil) - if err != nil { - t.Fatalf("cannot create session pool: %v", err) - } - client := Client{ - database: db, - idleSessions: sp, - } - cleanup = func() { - client.Close() - sp.hc.close() - sp.close() - } - return &client, sp, &serverClientMock, cleanup -}