From 9f33eb1c967625fd881833adc9e2f436d5251b6c Mon Sep 17 00:00:00 2001 From: Olav Loite Date: Mon, 13 May 2019 18:58:27 +0200 Subject: [PATCH] use gapic client and rely on gax for retries This change contains the following global changes: 1. spanner.Client uses the generated gapic client for gRPC calls, instead of a gRPC connection and a spannerpb.SpannerClient. 2. The gapic client uses the default gax retry logic. 3. Most custom retry logic has been removed, except: * retry on aborted transactions * retry for resumableStreamDecoder.next() The change also includes an in-memory Spanner server for test purposes. The server requires the user to mock the result of queries and update statements. Sessions and transactions are handled automatically. It also allows the user to register specific errors to be returned for each gRPC function. This test server makes it easier to develop test cases that verify the behavior of the client library for an entire transaction for situations that cannot easily be created in an integration test using a real Cloud Spanner instance, such as aborted transactions or temporary retryable errors. The test cases can use the standard Spanner client withouth the need to mock any of the server functions, other than specifying the results for queries and updates. Fixes #1418 and #1384 Change-Id: If0a8bbed50b512b32d73a8ef7ad74cdb1192294b Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/41131 Reviewed-by: kokoro Reviewed-by: Jean de Klerk --- spanner/batch.go | 5 +- spanner/client.go | 97 ++- spanner/client_test.go | 359 +++++++++ .../internal/testutil/inmem_spanner_server.go | 711 ++++++++++++++++++ .../testutil/inmem_spanner_server_test.go | 598 +++++++++++++++ spanner/mocked_inmem_server.go | 179 +++++ spanner/pdml.go | 19 +- spanner/pdml_test.go | 26 +- spanner/retry.go | 60 -- spanner/retry_test.go | 68 +- spanner/session.go | 50 +- spanner/session_test.go | 328 ++++---- spanner/transaction.go | 135 ++-- spanner/transaction_test.go | 141 ++-- 14 files changed, 2258 insertions(+), 518 deletions(-) create mode 100644 spanner/internal/testutil/inmem_spanner_server.go create mode 100644 spanner/internal/testutil/inmem_spanner_server_test.go create mode 100644 spanner/mocked_inmem_server.go diff --git a/spanner/batch.go b/spanner/batch.go index 413d8fd202b..374a963c48f 100644 --- a/spanner/batch.go +++ b/spanner/batch.go @@ -221,10 +221,7 @@ func (t *BatchReadOnlyTransaction) Cleanup(ctx context.Context) { } t.sh = nil sid, client := sh.getID(), sh.getClient() - err := runRetryable(ctx, func(ctx context.Context) error { - _, e := client.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: sid}) - return e - }) + err := client.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: sid}) if err != nil { log.Printf("Failed to delete session %v. Error: %v", sid, err) } diff --git a/spanner/client.go b/spanner/client.go index fc8443b0b0e..38c3f4968d8 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -25,8 +25,9 @@ import ( "cloud.google.com/go/internal/trace" "cloud.google.com/go/internal/version" + vkit "cloud.google.com/go/spanner/apiv1" + "cloud.google.com/go/spanner/internal/backoff" "google.golang.org/api/option" - gtransport "google.golang.org/api/transport/grpc" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -71,8 +72,7 @@ func validDatabaseName(db string) error { type Client struct { // rr must be accessed through atomic operations. rr uint32 - conns []*grpc.ClientConn - clients []sppb.SpannerClient + clients []*vkit.Client database string // Metadata to be sent with each request. @@ -170,19 +170,18 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf // TODO(deklerk): This should be replaced with a balancer with // config.NumChannels connections, instead of config.NumChannels - // clientconns. + // clients. for i := 0; i < config.NumChannels; i++ { - conn, err := gtransport.Dial(ctx, allOpts...) + client, err := vkit.NewClient(ctx, allOpts...) if err != nil { return nil, errDial(i, err) } - c.conns = append(c.conns, conn) - c.clients = append(c.clients, sppb.NewSpannerClient(conn)) + c.clients = append(c.clients, client) } // Prepare session pool. - config.SessionPoolConfig.getRPCClient = func() (sppb.SpannerClient, error) { - // TODO: support more loadbalancing options. + // TODO: support more loadbalancing options. + config.SessionPoolConfig.getRPCClient = func() (*vkit.Client, error) { return c.rrNext(), nil } config.SessionPoolConfig.sessionLabels = c.sessionLabels @@ -195,9 +194,9 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf return c, nil } -// rrNext returns the next available Cloud Spanner RPC client in a round-robin -// manner. -func (c *Client) rrNext() sppb.SpannerClient { +// rrNext returns the next available vkit Cloud Spanner RPC client in a +// round-robin manner. +func (c *Client) rrNext() *vkit.Client { return c.clients[atomic.AddUint32(&c.rr, 1)%uint32(len(c.clients))] } @@ -206,8 +205,8 @@ func (c *Client) Close() { if c.idleSessions != nil { c.idleSessions.close() } - for _, conn := range c.conns { - conn.Close() + for _, gpc := range c.clients { + gpc.Close() } } @@ -279,26 +278,20 @@ func (c *Client) BatchReadOnlyTransaction(ctx context.Context, tb TimestampBound sh = &sessionHandle{session: s} // Begin transaction. - err = runRetryable(contextWithOutgoingMetadata(ctx, sh.getMetadata()), func(ctx context.Context) error { - res, e := sh.getClient().BeginTransaction(ctx, &sppb.BeginTransactionRequest{ - Session: sh.getID(), - Options: &sppb.TransactionOptions{ - Mode: &sppb.TransactionOptions_ReadOnly_{ - ReadOnly: buildTransactionOptionsReadOnly(tb, true), - }, + res, err := sh.getClient().BeginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.BeginTransactionRequest{ + Session: sh.getID(), + Options: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadOnly_{ + ReadOnly: buildTransactionOptionsReadOnly(tb, true), }, - }) - if e != nil { - return e - } - tx = res.Id - if res.ReadTimestamp != nil { - rts = time.Unix(res.ReadTimestamp.Seconds, int64(res.ReadTimestamp.Nanos)) - } - return nil + }, }) if err != nil { - return nil, err + return nil, toSpannerError(err) + } + tx = res.Id + if res.ReadTimestamp != nil { + rts = time.Unix(res.ReadTimestamp.Seconds, int64(res.ReadTimestamp.Nanos)) } t := &BatchReadOnlyTransaction{ @@ -377,7 +370,7 @@ func (c *Client) ReadWriteTransaction(ctx context.Context, f func(context.Contex ts time.Time sh *sessionHandle ) - err = runRetryableNoWrap(ctx, func(ctx context.Context) error { + err = runWithRetryOnAborted(ctx, func(ctx context.Context) error { var ( err error t *ReadWriteTransaction @@ -402,8 +395,7 @@ func (c *Client) ReadWriteTransaction(ctx context.Context, f func(context.Contex trace.TracePrintf(ctx, map[string]interface{}{"transactionID": string(sh.getTransactionID())}, "Starting transaction attempt") if err = t.begin(ctx); err != nil { - // Mask error from begin operation as retryable error. - return errRetry(err) + return err } ts, err = t.runInTransaction(ctx, f) return err @@ -414,6 +406,43 @@ func (c *Client) ReadWriteTransaction(ctx context.Context, f func(context.Contex return ts, err } +func runWithRetryOnAborted(ctx context.Context, f func(context.Context) error) error { + var funcErr error + retryCount := 0 + for { + select { + case <-ctx.Done(): + // Do context check here so that even f() failed to do so (for + // example, gRPC implementation bug), the loop can still have a + // chance to exit as expected. + return errContextCanceled(ctx, funcErr) + default: + } + funcErr = f(ctx) + if funcErr == nil { + return nil + } + // Only retry on ABORTED. + if isAbortErr(funcErr) { + // Aborted, do exponential backoff and continue. + b, ok := extractRetryDelay(funcErr) + if !ok { + b = backoff.DefaultBackoff.Delay(retryCount) + } + trace.TracePrintf(ctx, nil, "Backing off after ABORTED for %s, then retrying", b) + select { + case <-ctx.Done(): + return errContextCanceled(ctx, funcErr) + case <-time.After(b): + } + retryCount++ + continue + } + // Error isn't ABORTED / no error, return immediately. + return funcErr + } +} + // applyOption controls the behavior of Client.Apply. type applyOption struct { // If atLeastOnce == true, Client.Apply will execute the mutations on Cloud diff --git a/spanner/client_test.go b/spanner/client_test.go index 54c52631f05..3df53669306 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -17,8 +17,15 @@ limitations under the License. package spanner import ( + "context" + "io" "strings" "testing" + + "cloud.google.com/go/spanner/internal/testutil" + "google.golang.org/api/iterator" + "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" ) // Test validDatabaseName() @@ -48,3 +55,355 @@ func TestReadOnlyTransactionClose(t *testing.T) { tx := c.ReadOnlyTransaction() tx.Close() } + +func TestClient_Single(t *testing.T) { + t.Parallel() + err := testSingleQuery(t, nil) + if err != nil { + t.Fatal(err) + } +} + +func TestClient_Single_Unavailable(t *testing.T) { + t.Parallel() + err := testSingleQuery(t, gstatus.Error(codes.Unavailable, "Temporary unavailable")) + if err != nil { + t.Fatal(err) + } +} + +func TestClient_Single_InvalidArgument(t *testing.T) { + t.Parallel() + err := testSingleQuery(t, gstatus.Error(codes.InvalidArgument, "Invalid argument")) + if err == nil { + t.Fatalf("missing expected error") + } else if gstatus.Code(err) != codes.InvalidArgument { + t.Fatal(err) + } +} + +func testSingleQuery(t *testing.T, serverError error) error { + config := ClientConfig{} + server, client := newSpannerInMemTestServerWithConfig(t, config) + defer server.teardown(client) + if serverError != nil { + server.testSpanner.SetError(serverError) + } + ctx := context.Background() + iter := client.Single().Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + var singerID, albumID int64 + var albumTitle string + if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { + return err + } + } + return nil +} + +func createSimulatedExecutionTimeWithTwoUnavailableErrors(method string) map[string]testutil.SimulatedExecutionTime { + errors := make([]error, 2) + errors[0] = gstatus.Error(codes.Unavailable, "Temporary unavailable") + errors[1] = gstatus.Error(codes.Unavailable, "Temporary unavailable") + executionTimes := make(map[string]testutil.SimulatedExecutionTime) + executionTimes[method] = testutil.SimulatedExecutionTime{ + Errors: errors, + } + return executionTimes +} + +func TestClient_ReadOnlyTransaction(t *testing.T) { + t.Parallel() + if err := testReadOnlyTransaction(t, make(map[string]testutil.SimulatedExecutionTime)); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadOnlyTransaction_UnavailableOnSessionCreate(t *testing.T) { + t.Parallel() + if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(testutil.MethodCreateSession)); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadOnlyTransaction_UnavailableOnBeginTransaction(t *testing.T) { + t.Parallel() + if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(testutil.MethodBeginTransaction)); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadOnlyTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) { + t.Parallel() + if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(testutil.MethodExecuteStreamingSql)); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndBeginTransaction(t *testing.T) { + t.Parallel() + exec := map[string]testutil.SimulatedExecutionTime{ + testutil.MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + } + if err := testReadOnlyTransaction(t, exec); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndInvalidArgumentOnBeginTransaction(t *testing.T) { + t.Parallel() + exec := map[string]testutil.SimulatedExecutionTime{ + testutil.MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.InvalidArgument, "Invalid argument")}}, + } + if err := testReadOnlyTransaction(t, exec); err == nil { + t.Fatalf("Missing expected exception") + } else if gstatus.Code(err) != codes.InvalidArgument { + t.Fatalf("Got unexpected exception: %v", err) + } +} + +func testReadOnlyTransaction(t *testing.T, executionTimes map[string]testutil.SimulatedExecutionTime) error { + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + for method, exec := range executionTimes { + server.testSpanner.PutExecutionTime(method, exec) + } + ctx := context.Background() + tx := client.ReadOnlyTransaction() + defer tx.Close() + iter := tx.Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + var singerID, albumID int64 + var albumTitle string + if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { + return err + } + } + return nil +} + +func TestClient_ReadWriteTransaction(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, make(map[string]testutil.SimulatedExecutionTime), 1); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, + }, 2); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, + }, 2); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_UnavailableOnBeginTransaction(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + }, 1); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_UnavailableOnBeginAndAbortOnCommit(t *testing.T) { + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, + }, 2); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + }, 1); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_UnavailableOnBeginAndExecuteStreamingSqlAndTwiceAbortOnCommit(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, + }, 3); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_AbortedOnExecuteStreamingSqlAndCommit(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, + testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, + }, 4); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransactionCommitAbortedAndUnavailable(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodCommitTransaction: { + Errors: []error{ + gstatus.Error(codes.Aborted, "Transaction aborted"), + gstatus.Error(codes.Unavailable, "Unavailable"), + }, + }, + }, 2); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransactionCommitAlreadyExists(t *testing.T) { + t.Parallel() + if err := testReadWriteTransaction(t, map[string]testutil.SimulatedExecutionTime{ + testutil.MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.AlreadyExists, "A row with this key already exists")}}, + }, 1); err != nil { + if gstatus.Code(err) != codes.AlreadyExists { + t.Fatalf("Got unexpected error %v, expected %v", err, codes.AlreadyExists) + } + } else { + t.Fatalf("Missing expected exception") + } +} + +func testReadWriteTransaction(t *testing.T, executionTimes map[string]testutil.SimulatedExecutionTime, expectedAttempts int) error { + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + for method, exec := range executionTimes { + server.testSpanner.PutExecutionTime(method, exec) + } + var attempts int + ctx := context.Background() + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + var singerID, albumID int64 + var albumTitle string + if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { + return err + } + } + return nil + }) + if err != nil { + return err + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + return nil +} + +func TestClient_ApplyAtLeastOnce(t *testing.T) { + t.Parallel() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + ms := []*Mutation{ + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}), + } + server.testSpanner.PutExecutionTime(testutil.MethodCommitTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}, + }) + _, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce()) + if err != nil { + t.Fatal(err) + } +} + +// PartitionedUpdate should not retry on aborted. +func TestClient_PartitionedUpdate(t *testing.T) { + t.Parallel() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + // PartitionedDML transactions are not committed. + server.testSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, + testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}, + }) + _, err := client.PartitionedUpdate(context.Background(), NewStatement(updateBarSetFoo)) + if err == nil { + t.Fatalf("Missing expected Aborted exception") + } else { + if gstatus.Code(err) != codes.Aborted { + t.Fatalf("Got unexpected error %v, expected Aborted", err) + } + } +} + +func TestReadWriteTransaction_ErrUnexpectedEOF(t *testing.T) { + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + var attempts int + ctx := context.Background() + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(selectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + var singerID, albumID int64 + var albumTitle string + if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { + return err + } + } + return io.ErrUnexpectedEOF + }) + if err != io.ErrUnexpectedEOF { + t.Fatalf("Missing expected error %v, got %v", io.ErrUnexpectedEOF, err) + } + if attempts != 1 { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, 1) + } +} diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go new file mode 100644 index 00000000000..963d33467de --- /dev/null +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -0,0 +1,711 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + emptypb "github.com/golang/protobuf/ptypes/empty" + structpb "github.com/golang/protobuf/ptypes/struct" + spannerpb "google.golang.org/genproto/googleapis/spanner/v1" +) + +import ( + "context" + "fmt" + "math/rand" + "sort" + "strings" + "sync" + "time" + + "github.com/golang/protobuf/ptypes/timestamp" + "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" +) + +// StatementResultType indicates the type of result returned by a SQL +// statement. +type StatementResultType int + +const ( + // StatementResultError indicates that the sql statement returns an error. + StatementResultError StatementResultType = 0 + // StatementResultResultSet indicates that the sql statement returns a + // result set. + StatementResultResultSet StatementResultType = 1 + // StatementResultUpdateCount indicates that the sql statement returns an + // update count. + StatementResultUpdateCount StatementResultType = 2 +) + +// The method names that can be used to register execution times and errors. +const ( + MethodBeginTransaction string = "BEGIN_TRANSACTION" + MethodCommitTransaction string = "COMMIT_TRANSACTION" + MethodCreateSession string = "CREATE_SESSION" + MethodDeleteSession string = "DELETE_SESSION" + MethodGetSession string = "GET_SESSION" + MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL" +) + +// StatementResult represents a mocked result on the test server. Th result can +// be either a ResultSet, an update count or an error. +type StatementResult struct { + Type StatementResultType + Err error + ResultSet *spannerpb.ResultSet + UpdateCount int64 +} + +// Converts a ResultSet to a PartialResultSet. This method is used to convert +// a mocked result to a PartialResultSet when one of the streaming methods are +// called. +func (s *StatementResult) toPartialResultSet() *spannerpb.PartialResultSet { + values := make([]*structpb.Value, + len(s.ResultSet.Rows)*len(s.ResultSet.Metadata.RowType.Fields)) + var idx int + for _, row := range s.ResultSet.Rows { + for colIdx := range s.ResultSet.Metadata.RowType.Fields { + values[idx] = row.Values[colIdx] + idx++ + } + } + return &spannerpb.PartialResultSet{ + Metadata: s.ResultSet.Metadata, + Values: values, + } +} + +func (s *StatementResult) updateCountToPartialResultSet(exact bool) *spannerpb.PartialResultSet { + return &spannerpb.PartialResultSet{ + Stats: s.convertUpdateCountToResultSet(exact).Stats, + } +} + +// Converts an update count to a ResultSet, as DML statements also return the +// update count as the statistics of a ResultSet. +func (s *StatementResult) convertUpdateCountToResultSet(exact bool) *spannerpb.ResultSet { + if exact { + return &spannerpb.ResultSet{ + Stats: &spannerpb.ResultSetStats{ + RowCount: &spannerpb.ResultSetStats_RowCountExact{ + RowCountExact: s.UpdateCount, + }, + }, + } + } + return &spannerpb.ResultSet{ + Stats: &spannerpb.ResultSetStats{ + RowCount: &spannerpb.ResultSetStats_RowCountLowerBound{ + RowCountLowerBound: s.UpdateCount, + }, + }, + } +} + +// SimulatedExecutionTime represents the time the execution of a method +// should take, and any errors that should be returned by the method. +type SimulatedExecutionTime struct { + MinimumExecutionTime time.Duration + RandomExecutionTime time.Duration + Errors []error + // Keep error after execution. The error will continue to be returned until + // it is cleared. + KeepError bool +} + +// InMemSpannerServer contains the SpannerServer interface plus a couple +// of specific methods for adding mocked results and resetting the server. +type InMemSpannerServer interface { + spannerpb.SpannerServer + + // Stops this server. + Stop() + + // Resets the in-mem server to its default state, deleting all sessions and + // transactions that have been created on the server. Mocked results are + // not deleted. + Reset() + + // Sets an error that will be returned by the next server call. The server + // call will also automatically clear the error. + SetError(err error) + + // Puts a mocked result on the server for a specific sql statement. The + // server does not parse the SQL string in any way, it is merely used as + // a key to the mocked result. The result will be used for all methods that + // expect a SQL statement, including (batch) DML methods. + PutStatementResult(sql string, result *StatementResult) error + + // Removes a mocked result on the server for a specific sql statement. + RemoveStatementResult(sql string) + + // Aborts the specified transaction . This method can be used to test + // transaction retry logic. + AbortTransaction(id []byte) + + // Puts a simulated execution time for one of the Spanner methods. + PutExecutionTime(method string, executionTime SimulatedExecutionTime) + // Freeze stalls all requests. + Freeze() + // Unfreeze restores processing requests. + Unfreeze() + + TotalSessionsCreated() uint + TotalSessionsDeleted() uint + + ReceivedRequests() chan interface{} + DumpSessions() map[string]bool + DumpPings() []string +} + +type inMemSpannerServer struct { + // Embed for forward compatibility. + // Tests will keep working if more methods are added + // in the future. + spannerpb.SpannerServer + + mu sync.Mutex + + // If set, all calls return this error. + err error + // The mock server creates session IDs using this counter. + sessionCounter uint64 + // The sessions that have been created on this mock server. + sessions map[string]*spannerpb.Session + // Last use times per session. + sessionLastUseTime map[string]time.Time + + // The mock server creates transaction IDs per session using these + // counters. + transactionCounters map[string]*uint64 + // The transactions that have been created on this mock server. + transactions map[string]*spannerpb.Transaction + // The transactions that have been (manually) aborted on the server. + abortedTransactions map[string]bool + // The transactions that are marked as PartitionedDMLTransaction + partitionedDmlTransactions map[string]bool + + // The mocked results for this server. + statementResults map[string]*StatementResult + // The simulated execution times per method. + executionTimes map[string]*SimulatedExecutionTime + // Server will stall on any requests. + freezed chan struct{} + + totalSessionsCreated uint + totalSessionsDeleted uint + receivedRequests chan interface{} + // Session ping history. + pings []string +} + +// NewInMemSpannerServer creates a new in-mem test server. +func NewInMemSpannerServer() InMemSpannerServer { + res := &inMemSpannerServer{} + res.initDefaults() + res.statementResults = make(map[string]*StatementResult) + res.executionTimes = make(map[string]*SimulatedExecutionTime) + res.receivedRequests = make(chan interface{}, 1000000) + // Produce a closed channel, so the default action of ready is to not block. + res.Freeze() + res.Unfreeze() + return res +} + +func (s *inMemSpannerServer) Stop() { + close(s.receivedRequests) +} + +// Resets the test server to its initial state, deleting all sessions and +// transactions that have been created on the server. This method will not +// remove mocked results. +func (s *inMemSpannerServer) Reset() { + close(s.receivedRequests) + s.receivedRequests = make(chan interface{}, 1000000) + s.initDefaults() +} + +func (s *inMemSpannerServer) SetError(err error) { + s.err = err +} + +// Registers a mocked result for a SQL statement on the server. +func (s *inMemSpannerServer) PutStatementResult(sql string, result *StatementResult) error { + s.statementResults[sql] = result + return nil +} + +func (s *inMemSpannerServer) RemoveStatementResult(sql string) { + delete(s.statementResults, sql) +} + +func (s *inMemSpannerServer) AbortTransaction(id []byte) { + s.abortedTransactions[string(id)] = true +} + +func (s *inMemSpannerServer) PutExecutionTime(method string, executionTime SimulatedExecutionTime) { + s.executionTimes[method] = &executionTime +} + +// Freeze stalls all requests. +func (s *inMemSpannerServer) Freeze() { + s.mu.Lock() + defer s.mu.Unlock() + s.freezed = make(chan struct{}) +} + +// Unfreeze restores processing requests. +func (s *inMemSpannerServer) Unfreeze() { + s.mu.Lock() + defer s.mu.Unlock() + close(s.freezed) +} + +// ready checks conditions before executing requests +func (s *inMemSpannerServer) ready() { + s.mu.Lock() + freezed := s.freezed + s.mu.Unlock() + // check if server should be freezed + <-freezed +} + +func (s *inMemSpannerServer) TotalSessionsCreated() uint { + return s.totalSessionsCreated +} + +func (s *inMemSpannerServer) TotalSessionsDeleted() uint { + return s.totalSessionsDeleted +} + +func (s *inMemSpannerServer) ReceivedRequests() chan interface{} { + return s.receivedRequests +} + +// DumpPings dumps the ping history. +func (s *inMemSpannerServer) DumpPings() []string { + s.mu.Lock() + defer s.mu.Unlock() + return append([]string(nil), s.pings...) +} + +// DumpSessions dumps the internal session table. +func (s *inMemSpannerServer) DumpSessions() map[string]bool { + s.mu.Lock() + defer s.mu.Unlock() + st := map[string]bool{} + for s := range s.sessions { + st[s] = true + } + return st +} + +func (s *inMemSpannerServer) initDefaults() { + s.sessionCounter = 0 + s.sessions = make(map[string]*spannerpb.Session) + s.sessionLastUseTime = make(map[string]time.Time) + s.transactions = make(map[string]*spannerpb.Transaction) + s.abortedTransactions = make(map[string]bool) + s.partitionedDmlTransactions = make(map[string]bool) + s.transactionCounters = make(map[string]*uint64) +} + +func (s *inMemSpannerServer) generateSessionName(database string) string { + s.mu.Lock() + defer s.mu.Unlock() + s.sessionCounter++ + return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter) +} + +func (s *inMemSpannerServer) findSession(name string) (*spannerpb.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + session := s.sessions[name] + if session == nil { + return nil, gstatus.Error(codes.NotFound, fmt.Sprintf("Session %s not found", name)) + } + return session, nil +} + +func (s *inMemSpannerServer) updateSessionLastUseTime(session string) { + s.mu.Lock() + defer s.mu.Unlock() + s.sessionLastUseTime[session] = time.Now() +} + +func getCurrentTimestamp() *timestamp.Timestamp { + t := time.Now() + return ×tamp.Timestamp{Seconds: t.Unix(), Nanos: int32(t.Nanosecond())} +} + +// Gets the transaction id from the transaction selector. If the selector +// specifies that a new transaction should be started, this method will start +// a new transaction and return the id of that transaction. +func (s *inMemSpannerServer) getTransactionID(session *spannerpb.Session, txSelector *spannerpb.TransactionSelector) []byte { + var res []byte + if txSelector.GetBegin() != nil { + // Start a new transaction. + res = s.beginTransaction(session, txSelector.GetBegin()).Id + } else if txSelector.GetId() != nil { + res = txSelector.GetId() + } + return res +} + +func (s *inMemSpannerServer) generateTransactionName(session string) string { + s.mu.Lock() + defer s.mu.Unlock() + counter, ok := s.transactionCounters[session] + if !ok { + counter = new(uint64) + s.transactionCounters[session] = counter + } + *counter++ + return fmt.Sprintf("%s/transactions/%d", session, *counter) +} + +func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, options *spannerpb.TransactionOptions) *spannerpb.Transaction { + id := s.generateTransactionName(session.Name) + res := &spannerpb.Transaction{ + Id: []byte(id), + ReadTimestamp: getCurrentTimestamp(), + } + s.mu.Lock() + s.transactions[id] = res + s.partitionedDmlTransactions[id] = options.GetPartitionedDml() != nil + s.mu.Unlock() + return res +} + +func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) { + s.mu.Lock() + defer s.mu.Unlock() + tx, ok := s.transactions[string(id)] + if !ok { + return nil, gstatus.Error(codes.NotFound, "Transaction not found") + } + aborted, ok := s.abortedTransactions[string(id)] + if ok && aborted { + return nil, gstatus.Error(codes.Aborted, "Transaction has been aborted") + } + return tx, nil +} + +func (s *inMemSpannerServer) removeTransaction(tx *spannerpb.Transaction) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.transactions, string(tx.Id)) + delete(s.partitionedDmlTransactions, string(tx.Id)) +} + +func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, error) { + result, ok := s.statementResults[sql] + if !ok { + return nil, gstatus.Error(codes.Internal, fmt.Sprintf("No result found for statement %v", sql)) + } + return result, nil +} + +func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error { + s.receivedRequests <- req + s.ready() + if s.err != nil { + err := s.err + s.err = nil + return err + } + executionTime, ok := s.executionTimes[method] + if ok { + var randTime int64 + if executionTime.RandomExecutionTime > 0 { + randTime = rand.Int63n(int64(executionTime.RandomExecutionTime)) + } + totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime) + <-time.After(totalExecutionTime) + if executionTime.Errors != nil && len(executionTime.Errors) > 0 { + err := executionTime.Errors[0] + if !executionTime.KeepError { + executionTime.Errors = executionTime.Errors[1:] + } + return err + } + } + return nil +} + +func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) { + if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil { + return nil, err + } + if req.Database == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing database") + } + sessionName := s.generateSessionName(req.Database) + ts := getCurrentTimestamp() + session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} + s.mu.Lock() + s.totalSessionsCreated++ + s.sessions[sessionName] = session + s.mu.Unlock() + return session, nil +} + +func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) { + if err := s.simulateExecutionTime(MethodGetSession, req); err != nil { + return nil, err + } + s.mu.Lock() + s.pings = append(s.pings, req.Name) + s.mu.Unlock() + if req.Name == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Name) + if err != nil { + return nil, err + } + return session, nil +} + +func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest) (*spannerpb.ListSessionsResponse, error) { + s.receivedRequests <- req + if req.Database == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing database") + } + expectedSessionName := req.Database + "/sessions/" + var sessions []*spannerpb.Session + s.mu.Lock() + for _, session := range s.sessions { + if strings.Index(session.Name, expectedSessionName) == 0 { + sessions = append(sessions, session) + } + } + s.mu.Unlock() + sort.Slice(sessions[:], func(i, j int) bool { + return sessions[i].Name < sessions[j].Name + }) + res := &spannerpb.ListSessionsResponse{Sessions: sessions} + return res, nil +} + +func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) { + if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil { + return nil, err + } + if req.Name == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + if _, err := s.findSession(req.Name); err != nil { + return nil, err + } + s.mu.Lock() + defer s.mu.Unlock() + s.totalSessionsDeleted++ + delete(s.sessions, req.Name) + return &emptypb.Empty{}, nil +} + +func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) { + s.receivedRequests <- req + if req.Session == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return nil, err + } + var id []byte + s.updateSessionLastUseTime(session.Name) + if id = s.getTransactionID(session, req.Transaction); id != nil { + _, err = s.getTransactionByID(id) + if err != nil { + return nil, err + } + } + statementResult, err := s.getStatementResult(req.Sql) + if err != nil { + return nil, err + } + switch statementResult.Type { + case StatementResultError: + return nil, statementResult.Err + case StatementResultResultSet: + return statementResult.ResultSet, nil + case StatementResultUpdateCount: + return statementResult.convertUpdateCountToResultSet(!s.partitionedDmlTransactions[string(id)]), nil + } + return nil, gstatus.Error(codes.Internal, "Unknown result type") +} + +func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error { + if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil { + return err + } + if req.Session == "" { + return gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return err + } + s.updateSessionLastUseTime(session.Name) + var id []byte + if id = s.getTransactionID(session, req.Transaction); id != nil { + _, err = s.getTransactionByID(id) + if err != nil { + return err + } + } + statementResult, err := s.getStatementResult(req.Sql) + if err != nil { + return err + } + switch statementResult.Type { + case StatementResultError: + return statementResult.Err + case StatementResultResultSet: + part := statementResult.toPartialResultSet() + if err := stream.Send(part); err != nil { + return err + } + return nil + case StatementResultUpdateCount: + part := statementResult.updateCountToPartialResultSet(!s.partitionedDmlTransactions[string(id)]) + if err := stream.Send(part); err != nil { + return err + } + return nil + } + return gstatus.Error(codes.Internal, "Unknown result type") +} + +func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { + s.receivedRequests <- req + if req.Session == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return nil, err + } + s.updateSessionLastUseTime(session.Name) + var id []byte + if id = s.getTransactionID(session, req.Transaction); id != nil { + _, err = s.getTransactionByID(id) + if err != nil { + return nil, err + } + } + resp := &spannerpb.ExecuteBatchDmlResponse{} + resp.ResultSets = make([]*spannerpb.ResultSet, len(req.Statements)) + for idx, batchStatement := range req.Statements { + statementResult, err := s.getStatementResult(batchStatement.Sql) + if err != nil { + return nil, err + } + switch statementResult.Type { + case StatementResultError: + resp.Status = &status.Status{Code: int32(codes.Unknown)} + case StatementResultResultSet: + return nil, gstatus.Error(codes.InvalidArgument, fmt.Sprintf("Not an update statement: %v", batchStatement.Sql)) + case StatementResultUpdateCount: + resp.ResultSets[idx] = statementResult.convertUpdateCountToResultSet(!s.partitionedDmlTransactions[string(id)]) + resp.Status = &status.Status{Code: int32(codes.OK)} + } + } + return resp, nil +} + +func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadRequest) (*spannerpb.ResultSet, error) { + s.receivedRequests <- req + return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") +} + +func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error { + s.receivedRequests <- req + return gstatus.Error(codes.Unimplemented, "Method not yet implemented") +} + +func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) { + if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil { + return nil, err + } + if req.Session == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return nil, err + } + s.updateSessionLastUseTime(session.Name) + tx := s.beginTransaction(session, req.Options) + return tx, nil +} + +func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) { + if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil { + return nil, err + } + if req.Session == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return nil, err + } + s.updateSessionLastUseTime(session.Name) + var tx *spannerpb.Transaction + if req.GetSingleUseTransaction() != nil { + tx = s.beginTransaction(session, req.GetSingleUseTransaction()) + } else if req.GetTransactionId() != nil { + tx, err = s.getTransactionByID(req.GetTransactionId()) + if err != nil { + return nil, err + } + } else { + return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request") + } + s.removeTransaction(tx) + return &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}, nil +} + +func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.RollbackRequest) (*emptypb.Empty, error) { + s.receivedRequests <- req + if req.Session == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return nil, err + } + s.updateSessionLastUseTime(session.Name) + tx, err := s.getTransactionByID(req.TransactionId) + if err != nil { + return nil, err + } + s.removeTransaction(tx) + return &emptypb.Empty{}, nil +} + +func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { + s.receivedRequests <- req + return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") +} + +func (s *inMemSpannerServer) PartitionRead(ctx context.Context, req *spannerpb.PartitionReadRequest) (*spannerpb.PartitionResponse, error) { + s.receivedRequests <- req + return nil, gstatus.Error(codes.Unimplemented, "Method not yet implemented") +} diff --git a/spanner/internal/testutil/inmem_spanner_server_test.go b/spanner/internal/testutil/inmem_spanner_server_test.go new file mode 100644 index 00000000000..d563ff4b617 --- /dev/null +++ b/spanner/internal/testutil/inmem_spanner_server_test.go @@ -0,0 +1,598 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "strconv" + + structpb "github.com/golang/protobuf/ptypes/struct" + spannerpb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc/codes" +) + +import ( + "context" + "flag" + "fmt" + "log" + "net" + "os" + "strings" + "testing" + + apiv1 "cloud.google.com/go/spanner/apiv1" + "google.golang.org/api/iterator" + "google.golang.org/api/option" + "google.golang.org/grpc" + gstatus "google.golang.org/grpc/status" +) + +// clientOpt is the option tests should use to connect to the test server. +// It is initialized by TestMain. +var serverAddress string +var clientOpt option.ClientOption +var testSpanner InMemSpannerServer + +// Mocked selectSQL statement. +const selectSQL = "SELECT FOO FROM BAR" +const selectRowCount int64 = 2 +const selectColCount int = 1 + +var selectValues = [...]int64{1, 2} + +// Mocked DML statement. +const updateSQL = "UPDATE FOO SET BAR=1 WHERE ID=ID" +const updateRowCount int64 = 2 + +func TestMain(m *testing.M) { + flag.Parse() + + testSpanner = NewInMemSpannerServer() + serv := grpc.NewServer() + spannerpb.RegisterSpannerServer(serv, testSpanner) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + log.Fatal(err) + } + go serv.Serve(lis) + + serverAddress = lis.Addr().String() + conn, err := grpc.Dial(serverAddress, grpc.WithInsecure()) + if err != nil { + log.Fatal(err) + } + clientOpt = option.WithGRPCConn(conn) + + os.Exit(m.Run()) +} + +// Resets the mock server to its default values and registers a mocked result +// for the statements "SELECT FOO FROM BAR" and +// "UPDATE FOO SET BAR=1 WHERE ID=ID". +func setup() { + testSpanner.Reset() + fields := make([]*spannerpb.StructType_Field, selectColCount) + fields[0] = &spannerpb.StructType_Field{ + Name: "FOO", + Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64}, + } + rowType := &spannerpb.StructType{ + Fields: fields, + } + metadata := &spannerpb.ResultSetMetadata{ + RowType: rowType, + } + rows := make([]*structpb.ListValue, selectRowCount) + for idx, value := range selectValues { + rowValue := make([]*structpb.Value, selectColCount) + rowValue[0] = &structpb.Value{ + Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(value, 10)}, + } + rows[idx] = &structpb.ListValue{ + Values: rowValue, + } + } + resultSet := &spannerpb.ResultSet{ + Metadata: metadata, + Rows: rows, + } + result := &StatementResult{Type: StatementResultResultSet, ResultSet: resultSet} + testSpanner.PutStatementResult(selectSQL, result) + + updateResult := &StatementResult{Type: StatementResultUpdateCount, UpdateCount: updateRowCount} + testSpanner.PutStatementResult(updateSQL, updateResult) +} + +func TestSpannerCreateSession(t *testing.T) { + testSpanner.Reset() + var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var request = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + resp, err := c.CreateSession(context.Background(), request) + if err != nil { + t.Fatal(err) + } + if strings.Index(resp.Name, expectedName) != 0 { + t.Errorf("wrong name %s, should start with %s)", resp.Name, expectedName) + } +} + +func TestSpannerCreateSession_Unavailable(t *testing.T) { + testSpanner.Reset() + var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var request = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + testSpanner.SetError(gstatus.Error(codes.Unavailable, "Temporary unavailable")) + resp, err := c.CreateSession(context.Background(), request) + if err != nil { + t.Fatal(err) + } + if strings.Index(resp.Name, expectedName) != 0 { + t.Errorf("wrong name %s, should start with %s)", resp.Name, expectedName) + } +} + +func TestSpannerGetSession(t *testing.T) { + testSpanner.Reset() + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + createResp, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + var getRequest = &spannerpb.GetSessionRequest{ + Name: createResp.Name, + } + getResp, err := c.GetSession(context.Background(), getRequest) + if err != nil { + t.Fatal(err) + } + if getResp.Name != getRequest.Name { + t.Errorf("wrong name %s, expected %s)", getResp.Name, getRequest.Name) + } +} + +func TestSpannerListSessions(t *testing.T) { + testSpanner.Reset() + const expectedNumberOfSessions = 5 + var expectedName = fmt.Sprintf("projects/%s/instances/%s/databases/%s/sessions/", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + for i := 0; i < expectedNumberOfSessions; i++ { + _, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + } + var listRequest = &spannerpb.ListSessionsRequest{ + Database: formattedDatabase, + } + var sessionCount int + listResp := c.ListSessions(context.Background(), listRequest) + for { + session, err := listResp.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatal(err) + } + if strings.Index(session.Name, expectedName) != 0 { + t.Errorf("wrong name %s, should start with %s)", session.Name, expectedName) + } + sessionCount++ + } + if sessionCount != expectedNumberOfSessions { + t.Errorf("wrong number of sessions: %d, expected %d", sessionCount, expectedNumberOfSessions) + } +} + +func TestSpannerDeleteSession(t *testing.T) { + testSpanner.Reset() + const expectedNumberOfSessions = 5 + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + for i := 0; i < expectedNumberOfSessions; i++ { + _, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + } + var listRequest = &spannerpb.ListSessionsRequest{ + Database: formattedDatabase, + } + var sessionCount int + listResp := c.ListSessions(context.Background(), listRequest) + for { + session, err := listResp.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatal(err) + } + var deleteRequest = &spannerpb.DeleteSessionRequest{ + Name: session.Name, + } + c.DeleteSession(context.Background(), deleteRequest) + sessionCount++ + } + if sessionCount != expectedNumberOfSessions { + t.Errorf("wrong number of sessions: %d, expected %d", sessionCount, expectedNumberOfSessions) + } + // Re-list all sessions. This should now be empty. + listResp = c.ListSessions(context.Background(), listRequest) + _, err = listResp.Next() + if err != iterator.Done { + t.Errorf("expected empty session iterator") + } +} + +func TestSpannerExecuteSql(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + request := &spannerpb.ExecuteSqlRequest{ + Session: session.Name, + Sql: selectSQL, + Transaction: &spannerpb.TransactionSelector{ + Selector: &spannerpb.TransactionSelector_SingleUse{ + SingleUse: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadOnly_{ + ReadOnly: &spannerpb.TransactionOptions_ReadOnly{ + ReturnReadTimestamp: false, + TimestampBound: &spannerpb.TransactionOptions_ReadOnly_Strong{ + Strong: true, + }, + }, + }, + }, + }, + }, + Seqno: 1, + QueryMode: spannerpb.ExecuteSqlRequest_NORMAL, + } + response, err := c.ExecuteSql(context.Background(), request) + if err != nil { + t.Fatal(err) + } + var rowCount int64 + for _, row := range response.Rows { + if len(row.Values) != selectColCount { + t.Fatalf("unexpected number of columns: %d, expected %d", len(row.Values), selectColCount) + } + rowCount++ + } + if rowCount != selectRowCount { + t.Fatalf("unexpected number of rows: %d, expected %d", rowCount, selectRowCount) + } +} + +func TestSpannerExecuteSqlDml(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + request := &spannerpb.ExecuteSqlRequest{ + Session: session.Name, + Sql: updateSQL, + Transaction: &spannerpb.TransactionSelector{ + Selector: &spannerpb.TransactionSelector_Begin{ + Begin: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadWrite_{ + ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, + }, + }, + }, + }, + Seqno: 1, + QueryMode: spannerpb.ExecuteSqlRequest_NORMAL, + } + response, err := c.ExecuteSql(context.Background(), request) + if err != nil { + t.Fatal(err) + } + var rowCount int64 = response.Stats.GetRowCountExact() + if rowCount != updateRowCount { + t.Fatalf("unexpected number of rows updated: %d, expected %d", rowCount, updateRowCount) + } +} + +func TestSpannerExecuteStreamingSql(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + request := &spannerpb.ExecuteSqlRequest{ + Session: session.Name, + Sql: selectSQL, + Transaction: &spannerpb.TransactionSelector{ + Selector: &spannerpb.TransactionSelector_SingleUse{ + SingleUse: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadOnly_{ + ReadOnly: &spannerpb.TransactionOptions_ReadOnly{ + ReturnReadTimestamp: false, + TimestampBound: &spannerpb.TransactionOptions_ReadOnly_Strong{ + Strong: true, + }, + }, + }, + }, + }, + }, + Seqno: 1, + QueryMode: spannerpb.ExecuteSqlRequest_NORMAL, + } + response, err := c.ExecuteStreamingSql(context.Background(), request) + if err != nil { + t.Fatal(err) + } + partial, err := response.Recv() + if err != nil { + t.Fatal(err) + } + var rowIndex int64 + colCount := len(partial.Metadata.RowType.Fields) + if colCount != selectColCount { + t.Fatalf("unexpected number of columns: %d, expected %d", colCount, selectColCount) + } + for { + for col := 0; col < colCount; col++ { + val, err := strconv.ParseInt(partial.Values[rowIndex*int64(colCount)+int64(col)].GetStringValue(), 10, 64) + if err != nil { + t.Fatal(err) + } + if val != selectValues[rowIndex] { + t.Fatalf("Unexpected value at index %d. Expected %d, got %d", rowIndex, selectValues[rowIndex], val) + } + } + rowIndex++ + if rowIndex == selectRowCount { + break + } + } + if rowIndex != selectRowCount { + t.Fatalf("unexpected number of rows: %d, expected %d", rowIndex, selectRowCount) + } +} + +func TestSpannerExecuteBatchDml(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + statements := make([]*spannerpb.ExecuteBatchDmlRequest_Statement, 3) + for idx := 0; idx < len(statements); idx++ { + statements[idx] = &spannerpb.ExecuteBatchDmlRequest_Statement{Sql: updateSQL} + } + executeBatchDmlRequest := &spannerpb.ExecuteBatchDmlRequest{ + Session: session.Name, + Statements: statements, + Transaction: &spannerpb.TransactionSelector{ + Selector: &spannerpb.TransactionSelector_Begin{ + Begin: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadWrite_{ + ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, + }, + }, + }, + }, + Seqno: 1, + } + response, err := c.ExecuteBatchDml(context.Background(), executeBatchDmlRequest) + if err != nil { + t.Fatal(err) + } + var totalRowCount int64 + for _, res := range response.ResultSets { + var rowCount int64 = res.Stats.GetRowCountExact() + if rowCount != updateRowCount { + t.Fatalf("unexpected number of rows updated: %d, expected %d", rowCount, updateRowCount) + } + totalRowCount += rowCount + } + if totalRowCount != updateRowCount*int64(len(statements)) { + t.Fatalf("unexpected number of total rows updated: %d, expected %d", totalRowCount, updateRowCount*int64(len(statements))) + } +} + +func TestBeginTransaction(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + beginRequest := &spannerpb.BeginTransactionRequest{ + Session: session.Name, + Options: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadWrite_{ + ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, + }, + }, + } + tx, err := c.BeginTransaction(context.Background(), beginRequest) + if err != nil { + t.Fatal(err) + } + expectedName := fmt.Sprintf("%s/transactions/", session.Name) + if strings.Index(string(tx.Id), expectedName) != 0 { + t.Errorf("wrong name %s, should start with %s)", string(tx.Id), expectedName) + } +} + +func TestCommitTransaction(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + beginRequest := &spannerpb.BeginTransactionRequest{ + Session: session.Name, + Options: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadWrite_{ + ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, + }, + }, + } + tx, err := c.BeginTransaction(context.Background(), beginRequest) + if err != nil { + t.Fatal(err) + } + commitRequest := &spannerpb.CommitRequest{ + Session: session.Name, + Transaction: &spannerpb.CommitRequest_TransactionId{ + TransactionId: tx.Id, + }, + } + resp, err := c.Commit(context.Background(), commitRequest) + if err != nil { + t.Fatal(err) + } + if resp.CommitTimestamp == nil { + t.Fatalf("No commit timestamp returned") + } +} + +func TestRollbackTransaction(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + beginRequest := &spannerpb.BeginTransactionRequest{ + Session: session.Name, + Options: &spannerpb.TransactionOptions{ + Mode: &spannerpb.TransactionOptions_ReadWrite_{ + ReadWrite: &spannerpb.TransactionOptions_ReadWrite{}, + }, + }, + } + tx, err := c.BeginTransaction(context.Background(), beginRequest) + if err != nil { + t.Fatal(err) + } + rollbackRequest := &spannerpb.RollbackRequest{ + Session: session.Name, + TransactionId: tx.Id, + } + err = c.Rollback(context.Background(), rollbackRequest) + if err != nil { + t.Fatal(err) + } +} diff --git a/spanner/mocked_inmem_server.go b/spanner/mocked_inmem_server.go new file mode 100644 index 00000000000..b98b17724ac --- /dev/null +++ b/spanner/mocked_inmem_server.go @@ -0,0 +1,179 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "context" + "fmt" + "net" + "strconv" + "testing" + + "cloud.google.com/go/spanner/internal/testutil" + structpb "github.com/golang/protobuf/ptypes/struct" + "google.golang.org/api/option" + spannerpb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc" +) + +// The SQL statements and results that are already mocked for this test server. +const selectFooFromBar = "SELECT FOO FROM BAR" +const selectFooFromBarRowCount int64 = 2 +const selectFooFromBarColCount int = 1 + +var selectFooFromBarResults = [...]int64{1, 2} + +const selectSingerIDAlbumIDAlbumTitleFromAlbums = "SELECT SingerId, AlbumId, AlbumTitle FROM Albums" +const selectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount int64 = 3 +const selectSingerIDAlbumIDAlbumTitleFromAlbumsColCount int = 3 + +const updateBarSetFoo = "UPDATE FOO SET BAR=1 WHERE BAZ=2" +const updateBarSetFooRowCount = 5 + +// An InMemSpannerServer with results for a number of SQL statements readily +// mocked. +type spannerInMemTestServer struct { + testSpanner testutil.InMemSpannerServer + server *grpc.Server +} + +// Create a spannerInMemTestServer with default configuration. +func newSpannerInMemTestServer(t *testing.T) (*spannerInMemTestServer, *Client) { + s := &spannerInMemTestServer{} + client := s.setup(t) + return s, client +} + +// Create a spannerInMemTestServer with the specified configuration. +func newSpannerInMemTestServerWithConfig(t *testing.T, config ClientConfig) (*spannerInMemTestServer, *Client) { + s := &spannerInMemTestServer{} + client := s.setupWithConfig(t, config) + return s, client +} + +func (s *spannerInMemTestServer) setup(t *testing.T) *Client { + return s.setupWithConfig(t, ClientConfig{}) +} + +func (s *spannerInMemTestServer) setupWithConfig(t *testing.T, config ClientConfig) *Client { + s.testSpanner = testutil.NewInMemSpannerServer() + s.setupFooResults() + s.setupSingersResults() + s.server = grpc.NewServer() + spannerpb.RegisterSpannerServer(s.server, s.testSpanner) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + go s.server.Serve(lis) + + serverAddress := lis.Addr().String() + ctx := context.Background() + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + client, err := NewClientWithConfig(ctx, formattedDatabase, config, + option.WithEndpoint(serverAddress), + option.WithGRPCDialOption(grpc.WithInsecure()), + option.WithoutAuthentication(), + ) + if err != nil { + t.Fatal(err) + } + return client +} + +func (s *spannerInMemTestServer) setupFooResults() { + fields := make([]*spannerpb.StructType_Field, selectFooFromBarColCount) + fields[0] = &spannerpb.StructType_Field{ + Name: "FOO", + Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64}, + } + rowType := &spannerpb.StructType{ + Fields: fields, + } + metadata := &spannerpb.ResultSetMetadata{ + RowType: rowType, + } + rows := make([]*structpb.ListValue, selectFooFromBarRowCount) + for idx, value := range selectFooFromBarResults { + rowValue := make([]*structpb.Value, selectFooFromBarColCount) + rowValue[0] = &structpb.Value{ + Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(value, 10)}, + } + rows[idx] = &structpb.ListValue{ + Values: rowValue, + } + } + resultSet := &spannerpb.ResultSet{ + Metadata: metadata, + Rows: rows, + } + result := &testutil.StatementResult{Type: testutil.StatementResultResultSet, ResultSet: resultSet} + s.testSpanner.PutStatementResult(selectFooFromBar, result) + s.testSpanner.PutStatementResult(updateBarSetFoo, &testutil.StatementResult{ + Type: testutil.StatementResultUpdateCount, + UpdateCount: updateBarSetFooRowCount, + }) +} + +func (s *spannerInMemTestServer) setupSingersResults() { + fields := make([]*spannerpb.StructType_Field, selectSingerIDAlbumIDAlbumTitleFromAlbumsColCount) + fields[0] = &spannerpb.StructType_Field{ + Name: "SingerId", + Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64}, + } + fields[1] = &spannerpb.StructType_Field{ + Name: "AlbumId", + Type: &spannerpb.Type{Code: spannerpb.TypeCode_INT64}, + } + fields[2] = &spannerpb.StructType_Field{ + Name: "AlbumTitle", + Type: &spannerpb.Type{Code: spannerpb.TypeCode_STRING}, + } + rowType := &spannerpb.StructType{ + Fields: fields, + } + metadata := &spannerpb.ResultSetMetadata{ + RowType: rowType, + } + rows := make([]*structpb.ListValue, selectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) + var idx int64 + for idx = 0; idx < selectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount; idx++ { + rowValue := make([]*structpb.Value, selectSingerIDAlbumIDAlbumTitleFromAlbumsColCount) + rowValue[0] = &structpb.Value{ + Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(idx+1, 10)}, + } + rowValue[1] = &structpb.Value{ + Kind: &structpb.Value_StringValue{StringValue: strconv.FormatInt(idx*10+idx, 10)}, + } + rowValue[2] = &structpb.Value{ + Kind: &structpb.Value_StringValue{StringValue: fmt.Sprintf("Album title %d", idx)}, + } + rows[idx] = &structpb.ListValue{ + Values: rowValue, + } + } + resultSet := &spannerpb.ResultSet{ + Metadata: metadata, + Rows: rows, + } + result := &testutil.StatementResult{Type: testutil.StatementResultResultSet, ResultSet: resultSet} + s.testSpanner.PutStatementResult(selectSingerIDAlbumIDAlbumTitleFromAlbums, result) +} + +func (s *spannerInMemTestServer) teardown(client *Client) { + client.Close() + s.server.Stop() +} diff --git a/spanner/pdml.go b/spanner/pdml.go index 168c46a9ff1..c03c13a60e0 100644 --- a/spanner/pdml.go +++ b/spanner/pdml.go @@ -53,27 +53,20 @@ func (c *Client) PartitionedUpdate(ctx context.Context, statement Statement) (co defer s.delete(ctx) sh = &sessionHandle{session: s} // Begin transaction. - err = runRetryable(contextWithOutgoingMetadata(ctx, sh.getMetadata()), func(ctx context.Context) error { - res, e := sc.BeginTransaction(ctx, &sppb.BeginTransactionRequest{ - Session: sh.getID(), - Options: &sppb.TransactionOptions{ - Mode: &sppb.TransactionOptions_PartitionedDml_{PartitionedDml: &sppb.TransactionOptions_PartitionedDml{}}, - }, - }) - if e != nil { - return e - } - tx = res.Id - return nil + res, err := sc.BeginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.BeginTransactionRequest{ + Session: sh.getID(), + Options: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_PartitionedDml_{PartitionedDml: &sppb.TransactionOptions_PartitionedDml{}}, + }, }) if err != nil { return 0, toSpannerError(err) } - params, paramTypes, err := statement.convertParams() if err != nil { return 0, toSpannerError(err) } + tx = res.Id req := &sppb.ExecuteSqlRequest{ Session: sh.getID(), diff --git a/spanner/pdml_test.go b/spanner/pdml_test.go index 4c5a743ef14..743f5113cc1 100644 --- a/spanner/pdml_test.go +++ b/spanner/pdml_test.go @@ -16,28 +16,23 @@ package spanner import ( "context" - "io" "testing" - "cloud.google.com/go/spanner/internal/testutil" - sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" ) func TestMockPartitionedUpdate(t *testing.T) { t.Parallel() ctx := context.Background() - ms := testutil.NewMockCloudSpanner(t, trxTs) - ms.Serve() - mc := sppb.NewSpannerClient(dialMock(t, ms)) - client := &Client{database: "mockdb"} - client.clients = append(client.clients, mc) - stmt := NewStatement("UPDATE t SET x = 2 WHERE x = 1") + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + + stmt := NewStatement(updateBarSetFoo) rowCount, err := client.PartitionedUpdate(ctx, stmt) if err != nil { t.Fatal(err) } - want := int64(3) + want := int64(updateBarSetFooRowCount) if rowCount != want { t.Errorf("got %d, want %d", rowCount, want) } @@ -46,13 +41,10 @@ func TestMockPartitionedUpdate(t *testing.T) { func TestMockPartitionedUpdateWithQuery(t *testing.T) { t.Parallel() ctx := context.Background() - ms := testutil.NewMockCloudSpanner(t, trxTs) - ms.AddMsg(io.EOF, true) - ms.Serve() - mc := sppb.NewSpannerClient(dialMock(t, ms)) - client := &Client{database: "mockdb"} - client.clients = append(client.clients, mc) - stmt := NewStatement("SELECT t.key key, t.value value FROM t_mock t") + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + + stmt := NewStatement(selectFooFromBar) _, err := client.PartitionedUpdate(ctx, stmt) wantCode := codes.InvalidArgument if serr, ok := err.(*Error); !ok || serr.Code != wantCode { diff --git a/spanner/retry.go b/spanner/retry.go index 033529c5c80..c45e80b4d63 100644 --- a/spanner/retry.go +++ b/spanner/retry.go @@ -18,12 +18,9 @@ package spanner import ( "context" - "fmt" "strings" "time" - "cloud.google.com/go/internal/trace" - "cloud.google.com/go/spanner/internal/backoff" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" edpb "google.golang.org/genproto/googleapis/rpc/errdetails" @@ -35,16 +32,6 @@ const ( retryInfoKey = "google.rpc.retryinfo-bin" ) -// errRetry returns an unavailable error under error namespace EsOther. It is a -// generic retryable error that is used to mask and recover unretryable errors -// in a retry loop. -func errRetry(err error) error { - if se, ok := err.(*Error); ok { - return &Error{codes.Unavailable, fmt.Sprintf("generic Cloud Spanner retryable error: { %v }", se.Error()), se.trailers} - } - return spannerErrorf(codes.Unavailable, "generic Cloud Spanner retryable error: { %v }", err.Error()) -} - // isErrorClosing reports whether the error is generated by gRPC layer talking // to a closed server. func isErrorClosing(err error) bool { @@ -158,50 +145,3 @@ func extractRetryDelay(err error) (time.Duration, bool) { } return delay, true } - -// runRetryable keeps attempting to run f until one of the following happens: -// 1) f returns nil error or an unretryable error; -// 2) context is cancelled or timeout. -// -// TODO: consider using https://github.com/googleapis/gax-go/v2 once it -// becomes available internally. -func runRetryable(ctx context.Context, f func(context.Context) error) error { - return toSpannerError(runRetryableNoWrap(ctx, f)) -} - -// Like runRetryable, but doesn't wrap the returned error in a spanner.Error. -func runRetryableNoWrap(ctx context.Context, f func(context.Context) error) error { - var funcErr error - retryCount := 0 - for { - select { - case <-ctx.Done(): - // Do context check here so that even f() failed to do so (for - // example, gRPC implementation bug), the loop can still have a - // chance to exit as expected. - return errContextCanceled(ctx, funcErr) - default: - } - funcErr = f(ctx) - if funcErr == nil { - return nil - } - if isRetryable(funcErr) { - // Error is retryable, do exponential backoff and continue. - b, ok := extractRetryDelay(funcErr) - if !ok { - b = backoff.DefaultBackoff.Delay(retryCount) - } - trace.TracePrintf(ctx, nil, "Backing off for %s, then retrying", b) - select { - case <-ctx.Done(): - return errContextCanceled(ctx, funcErr) - case <-time.After(b): - } - retryCount++ - continue - } - // Error isn't retryable / no error, return immediately. - return funcErr - } -} diff --git a/spanner/retry_test.go b/spanner/retry_test.go index 64516f647db..979f164e54c 100644 --- a/spanner/retry_test.go +++ b/spanner/retry_test.go @@ -17,9 +17,6 @@ limitations under the License. package spanner import ( - "context" - "errors" - "fmt" "testing" "time" @@ -31,69 +28,6 @@ import ( "google.golang.org/grpc/status" ) -// Test if runRetryable loop deals with various errors correctly. -func TestRetry(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - responses := []error{ - status.Errorf(codes.Internal, "transport is closing"), - status.Errorf(codes.Unknown, "unexpected EOF"), - status.Errorf(codes.Internal, "unexpected EOF"), - status.Errorf(codes.Internal, "stream terminated by RST_STREAM with error code: 2"), - status.Errorf(codes.Unavailable, "service is currently unavailable"), - errRetry(fmt.Errorf("just retry it")), - } - err := runRetryable(context.Background(), func(ct context.Context) error { - var r error - if len(responses) > 0 { - r = responses[0] - responses = responses[1:] - } - return r - }) - if err != nil { - t.Errorf("runRetryable should be able to survive all retryable errors, but it returns %v", err) - } - // Unretryable errors - injErr := errors.New("this is unretryable") - err = runRetryable(context.Background(), func(ct context.Context) error { - return injErr - }) - if wantErr := toSpannerError(injErr); !testEqual(err, wantErr) { - t.Errorf("runRetryable returns error %v, want %v", err, wantErr) - } - // Timeout - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - retryErr := errRetry(fmt.Errorf("still retrying")) - err = runRetryable(ctx, func(ct context.Context) error { - // Expect to trigger timeout in retryable runner after 10 executions. - <-time.After(100 * time.Millisecond) - // Let retryable runner to retry so that timeout will eventually happen. - return retryErr - }) - // Check error code and error message. - if wantErrCode, wantErr := codes.DeadlineExceeded, errContextCanceled(ctx, retryErr); ErrCode(err) != wantErrCode || !testEqual(err, wantErr) { - t.Errorf("=\n<%v, %v>, want:\n<%v, %v>", ErrCode(err), err, wantErrCode, wantErr) - } - // Cancellation - ctx, cancel = context.WithCancel(context.Background()) - retries := 3 - retryErr = errRetry(fmt.Errorf("retry before cancel")) - err = runRetryable(ctx, func(ct context.Context) error { - retries-- - if retries == 0 { - cancel() - } - return retryErr - }) - // Check error code, error message, retry count. - if wantErrCode, wantErr := codes.Canceled, errContextCanceled(ctx, retryErr); ErrCode(err) != wantErrCode || !testEqual(err, wantErr) || retries != 0 { - t.Errorf("=\n<%v, %v, %v>, want:\n<%v, %v, %v>", ErrCode(err), err, retries, wantErrCode, wantErr, 0) - } -} - func TestRetryInfo(t *testing.T) { b, _ := proto.Marshal(&edpb.RetryInfo{ RetryDelay: ptypes.DurationProto(time.Second), @@ -101,7 +35,7 @@ func TestRetryInfo(t *testing.T) { trailers := map[string]string{ retryInfoKey: string(b), } - gotDelay, ok := extractRetryDelay(errRetry(toSpannerErrorWithMetadata(status.Errorf(codes.Aborted, ""), metadata.New(trailers)))) + gotDelay, ok := extractRetryDelay(toSpannerErrorWithMetadata(status.Errorf(codes.Aborted, ""), metadata.New(trailers))) if !ok || !testEqual(time.Second, gotDelay) { t.Errorf(" = <%t, %v>, want ", ok, gotDelay, time.Second) } diff --git a/spanner/session.go b/spanner/session.go index 64158fff41b..27bad1bf26d 100644 --- a/spanner/session.go +++ b/spanner/session.go @@ -28,6 +28,7 @@ import ( "time" "cloud.google.com/go/internal/trace" + vkit "cloud.google.com/go/spanner/apiv1" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -72,7 +73,7 @@ func (sh *sessionHandle) getID() string { // getClient gets the Cloud Spanner RPC client associated with the session ID // in sessionHandle. -func (sh *sessionHandle) getClient() sppb.SpannerClient { +func (sh *sessionHandle) getClient() *vkit.Client { sh.mu.Lock() defer sh.mu.Unlock() if sh.session == nil { @@ -121,7 +122,7 @@ func (sh *sessionHandle) destroy() { type session struct { // client is the RPC channel to Cloud Spanner. It is set only once during // session's creation. - client sppb.SpannerClient + client *vkit.Client // id is the unique id of the session in Cloud Spanner. It is set only once // during session's creation. id string @@ -183,11 +184,9 @@ func (s *session) String() string { func (s *session) ping() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - return runRetryable(ctx, func(ctx context.Context) error { - // s.getID is safe even when s is invalid. - _, err := s.client.GetSession(contextWithOutgoingMetadata(ctx, s.pool.md), &sppb.GetSessionRequest{Name: s.getID()}) - return err - }) + // s.getID is safe even when s is invalid. + _, err := s.client.GetSession(contextWithOutgoingMetadata(ctx, s.pool.md), &sppb.GetSessionRequest{Name: s.getID()}) + return err } // setHcIndex atomically sets the session's index in the healthcheck queue and @@ -290,13 +289,9 @@ func (s *session) destroy(isExpire bool) bool { } func (s *session) delete(ctx context.Context) { - // Ignore the error returned by runRetryable because even if we fail to - // explicitly destroy the session, it will be eventually garbage collected - // by Cloud Spanner. - err := runRetryable(ctx, func(ctx context.Context) error { - _, e := s.client.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: s.getID()}) - return e - }) + // Ignore the error because even if we fail to explicitly destroy the + // session, it will be eventually garbage collected by Cloud Spanner. + err := s.client.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: s.getID()}) if err != nil { log.Printf("Failed to delete session %v. Error: %v", s.getID(), err) } @@ -320,7 +315,7 @@ func (s *session) prepareForWrite(ctx context.Context) error { type SessionPoolConfig struct { // getRPCClient is the caller supplied method for getting a gRPC client to // Cloud Spanner, this makes session pool able to use client pooling. - getRPCClient func() (sppb.SpannerClient, error) + getRPCClient func() (*vkit.Client, error) // MaxOpened is the maximum number of opened sessions allowed by the session // pool. If the client tries to open a session and there are already @@ -539,6 +534,9 @@ func (p *sessionPool) createSession(ctx context.Context) (*session, error) { doneCreate(false) // Should return error directly because of the previous retries on // CreateSession RPC. + // If the error is a timeout, there is a chance that the session was + // created on the server but is not known to the session pool. This + // session will then be garbage collected by the server after 1 hour. return nil, err } s.pool = p @@ -547,23 +545,17 @@ func (p *sessionPool) createSession(ctx context.Context) (*session, error) { return s, nil } -func createSession(ctx context.Context, sc sppb.SpannerClient, db string, labels map[string]string, md metadata.MD) (*session, error) { +func createSession(ctx context.Context, sc *vkit.Client, db string, labels map[string]string, md metadata.MD) (*session, error) { var s *session - err := runRetryable(ctx, func(ctx context.Context) error { - sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{ - Database: db, - Session: &sppb.Session{Labels: labels}, - }) - if e != nil { - return e - } - // If no error, construct the new session. - s = &session{valid: true, client: sc, id: sid.Name, createTime: time.Now(), md: md} - return nil + sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{ + Database: db, + Session: &sppb.Session{Labels: labels}, }) - if err != nil { - return nil, err + if e != nil { + return nil, toSpannerError(e) } + // If no error, construct the new session. + s = &session{valid: true, client: sc, id: sid.Name, createTime: time.Now(), md: md} return s, nil } diff --git a/spanner/session_test.go b/spanner/session_test.go index d04e93d2e62..2f3f7dc32c4 100644 --- a/spanner/session_test.go +++ b/spanner/session_test.go @@ -23,13 +23,11 @@ import ( "fmt" "math/rand" "sync" - "sync/atomic" "testing" "time" + vkit "cloud.google.com/go/spanner/apiv1" "cloud.google.com/go/spanner/internal/testutil" - sppb "google.golang.org/genproto/googleapis/spanner/v1" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -37,8 +35,9 @@ import ( // TestSessionPoolConfigValidation tests session pool config validation. func TestSessionPoolConfigValidation(t *testing.T) { t.Parallel() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) - sc := testutil.NewMockCloudSpannerClient(t) for _, test := range []struct { spc SessionPoolConfig err error @@ -49,8 +48,8 @@ func TestSessionPoolConfigValidation(t *testing.T) { }, { SessionPoolConfig{ - getRPCClient: func() (sppb.SpannerClient, error) { - return sc, nil + getRPCClient: func() (*vkit.Client, error) { + return client.clients[0], nil }, MinOpened: 10, MaxOpened: 5, @@ -68,8 +67,9 @@ func TestSessionPoolConfigValidation(t *testing.T) { func TestSessionCreation(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) + sp := client.idleSessions // Take three sessions from session pool, this should trigger session pool // to create three new sessions. @@ -87,7 +87,7 @@ func TestSessionCreation(t *testing.T) { if len(gotDs) != len(shs) { t.Fatalf("session pool created %v sessions, want %v", len(gotDs), len(shs)) } - if wantDs := mock.DumpSessions(); !testEqual(gotDs, wantDs) { + if wantDs := server.testSpanner.DumpSessions(); !testEqual(gotDs, wantDs) { t.Fatalf("session pool creates sessions %v, want %v", gotDs, wantDs) } // Verify that created sessions are recorded correctly in session pool. @@ -119,8 +119,12 @@ func TestTakeFromIdleList(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{MaxIdle: 10}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{MaxIdle: 10}, + }) + defer server.teardown(client) + sp := client.idleSessions // Take ten sessions from session pool and recycle them. shs := make([]*sessionHandle, 10) @@ -139,7 +143,7 @@ func TestTakeFromIdleList(t *testing.T) { } // Further session requests from session pool won't cause mockclient to // create more sessions. - wantSessions := mock.DumpSessions() + wantSessions := server.testSpanner.DumpSessions() // Take ten sessions from session pool again, this time all sessions should // come from idle list. gotSessions := map[string]bool{} @@ -165,8 +169,12 @@ func TestTakeWriteSessionFromIdleList(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{MaxIdle: 20}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{MaxIdle: 20}, + }) + defer server.teardown(client) + sp := client.idleSessions // Take ten sessions from session pool and recycle them. shs := make([]*sessionHandle, 10) @@ -185,7 +193,7 @@ func TestTakeWriteSessionFromIdleList(t *testing.T) { } // Further session requests from session pool won't cause mockclient to // create more sessions. - wantSessions := mock.DumpSessions() + wantSessions := server.testSpanner.DumpSessions() // Take ten sessions from session pool again, this time all sessions should // come from idle list. gotSessions := map[string]bool{} @@ -211,12 +219,16 @@ func TestTakeFromIdleListChecked(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{ - MaxIdle: 1, - HealthCheckInterval: 50 * time.Millisecond, - healthCheckSampleInterval: 10 * time.Millisecond, - }) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MaxIdle: 1, + HealthCheckInterval: 50 * time.Millisecond, + healthCheckSampleInterval: 10 * time.Millisecond, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Stop healthcheck workers to simulate slow pings. sp.hc.close() @@ -251,7 +263,7 @@ func TestTakeFromIdleListChecked(t *testing.T) { // The two back-to-back session requests shouldn't trigger any session // pings because sessionPool.Take // reschedules the next healthcheck. - if got, want := mock.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { + if got, want := server.testSpanner.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { t.Fatalf("%v - got ping session requests: %v, want %v", i, got, want) } sh.recycle() @@ -260,10 +272,10 @@ func TestTakeFromIdleListChecked(t *testing.T) { // Inject session error to server stub, and take the session from the // session pool, the old session should be destroyed and the session pool // will create a new session. - mock.GetSessionFn = func(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) { - mock.MockCloudSpannerClient.ReceivedRequests <- r - return nil, status.Errorf(codes.NotFound, "Session not found") - } + server.testSpanner.PutExecutionTime(testutil.MethodGetSession, + testutil.SimulatedExecutionTime{ + Errors: []error{status.Errorf(codes.NotFound, "Session not found")}, + }) // Delay to trigger sessionPool.Take to ping the session. // TODO(deklerk): get rid of this @@ -276,7 +288,7 @@ func TestTakeFromIdleListChecked(t *testing.T) { if err != nil { t.Fatalf("failed to get session: %v", err) } - ds := mock.DumpSessions() + ds := server.testSpanner.DumpSessions() if len(ds) != 1 { t.Fatalf("dumped sessions from mockclient: %v, want %v", ds, sh.getID()) } @@ -292,12 +304,16 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { ctx := context.Background() // Make sure maintainer keeps the idle sessions. - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{ - MaxIdle: 1, - HealthCheckInterval: 50 * time.Millisecond, - healthCheckSampleInterval: 10 * time.Millisecond, - }) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MaxIdle: 1, + HealthCheckInterval: 50 * time.Millisecond, + healthCheckSampleInterval: 10 * time.Millisecond, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Stop healthcheck workers to simulate slow pings. sp.hc.close() @@ -330,7 +346,7 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { } // The two back-to-back session requests shouldn't trigger any session // pings because sessionPool.Take reschedules the next healthcheck. - if got, want := mock.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { + if got, want := server.testSpanner.DumpPings(), ([]string{wantSid}); !testEqual(got, want) { t.Fatalf("%v - got ping session requests: %v, want %v", i, got, want) } sh.recycle() @@ -339,10 +355,10 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { // Inject session error to mockclient, and take the session from the // session pool, the old session should be destroyed and the session pool // will create a new session. - mock.GetSessionFn = func(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) { - mock.MockCloudSpannerClient.ReceivedRequests <- r - return nil, status.Errorf(codes.NotFound, "Session not found") - } + server.testSpanner.PutExecutionTime(testutil.MethodGetSession, + testutil.SimulatedExecutionTime{ + Errors: []error{status.Errorf(codes.NotFound, "Session not found")}, + }) // Delay to trigger sessionPool.Take to ping the session. // TOOD(deklerk) get rid of this @@ -352,7 +368,7 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { if err != nil { t.Fatalf("failed to get session: %v", err) } - ds := mock.DumpSessions() + ds := server.testSpanner.DumpSessions() if len(ds) != 1 { t.Fatalf("dumped sessions from mockclient: %v, want %v", ds, sh.getID()) } @@ -365,8 +381,14 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { func TestMaxOpenedSessions(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MaxOpened: 1}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MaxOpened: 1, + }, + }) + defer server.teardown(client) + sp := client.idleSessions sh1, err := sp.take(ctx) if err != nil { @@ -404,8 +426,14 @@ func TestMaxOpenedSessions(t *testing.T) { func TestMinOpenedSessions(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MinOpened: 1}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Take ten sessions from session pool and recycle them. var ss []*session @@ -441,20 +469,21 @@ func TestMinOpenedSessions(t *testing.T) { func TestMaxBurst(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{MaxBurst: 1}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MaxBurst: 1, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Will cause session creation RPC to be retried forever. - allowRequests := make(chan struct{}) - mock.CreateSessionFn = func(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) { - select { - case <-allowRequests: - return mock.MockCloudSpannerClient.CreateSession(c, r, opts...) - default: - mock.MockCloudSpannerClient.ReceivedRequests <- r - return nil, status.Errorf(codes.Unavailable, "try later") - } - } + server.testSpanner.PutExecutionTime(testutil.MethodCreateSession, + testutil.SimulatedExecutionTime{ + Errors: []error{status.Errorf(codes.Unavailable, "try later")}, + KeepError: true, + }) // This session request will never finish until the injected error is // cleared. @@ -483,7 +512,10 @@ func TestMaxBurst(t *testing.T) { } // Let the first session request succeed. - close(allowRequests) + server.testSpanner.Freeze() + server.testSpanner.PutExecutionTime(testutil.MethodCreateSession, testutil.SimulatedExecutionTime{}) + //close(allowRequests) + server.testSpanner.Unfreeze() // Now new session request can proceed because the first session request will eventually succeed. sh, err := sp.take(ctx) @@ -499,8 +531,15 @@ func TestMaxBurst(t *testing.T) { func TestSessionRecycle(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MinOpened: 1, MaxIdle: 5}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + MaxIdle: 5, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Test session is correctly recycled and reused. for i := 0; i < 20; i++ { @@ -530,8 +569,14 @@ func TestSessionDestroy(t *testing.T) { t.Skip("s.destroy(true) is flakey") t.Parallel() ctx := context.Background() - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MinOpened: 1}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + }, + }) + defer server.teardown(client) + sp := client.idleSessions <-time.After(10 * time.Millisecond) // maintainer will create one session, we wait for it create session to avoid flakiness in test sh, err := sp.take(ctx) @@ -587,11 +632,15 @@ func TestHcHeap(t *testing.T) { func TestHealthCheckScheduler(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{ - HealthCheckInterval: 50 * time.Millisecond, - healthCheckSampleInterval: 10 * time.Millisecond, - }) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + HealthCheckInterval: 50 * time.Millisecond, + healthCheckSampleInterval: 10 * time.Millisecond, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Create 50 sessions. ss := []string{} @@ -605,7 +654,7 @@ func TestHealthCheckScheduler(t *testing.T) { // Wait for 10-30 pings per session. waitFor(t, func() error { - dp := mock.DumpPings() + dp := server.testSpanner.DumpPings() gotPings := map[string]int64{} for _, p := range dp { gotPings[p]++ @@ -625,8 +674,15 @@ func TestHealthCheckScheduler(t *testing.T) { func TestWriteSessionsPrepared(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{WriteSessions: 0.5, MaxIdle: 20}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + WriteSessions: 0.5, + MaxIdle: 20, + }, + }) + defer server.teardown(client) + sp := client.idleSessions shs := make([]*sessionHandle, 10) var err error @@ -688,8 +744,16 @@ func TestWriteSessionsPrepared(t *testing.T) { func TestTakeFromWriteQueue(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MaxOpened: 1, WriteSessions: 1.0, MaxIdle: 1}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MaxOpened: 1, + WriteSessions: 1.0, + MaxIdle: 1, + }, + }) + defer server.teardown(client) + sp := client.idleSessions sh, err := sp.take(ctx) if err != nil { @@ -718,20 +782,15 @@ func TestTakeFromWriteQueue(t *testing.T) { func TestSessionHealthCheck(t *testing.T) { t.Parallel() ctx := context.Background() - _, sp, mock, cleanup := serverClientMock(t, SessionPoolConfig{ - HealthCheckInterval: 50 * time.Millisecond, - healthCheckSampleInterval: 10 * time.Millisecond, - }) - defer cleanup() - - var requestShouldErr int64 // 0 == false, 1 == true - mock.GetSessionFn = func(c context.Context, r *sppb.GetSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) { - if shouldErr := atomic.LoadInt64(&requestShouldErr); shouldErr == 1 { - mock.MockCloudSpannerClient.ReceivedRequests <- r - return nil, status.Errorf(codes.NotFound, "Session not found") - } - return mock.MockCloudSpannerClient.GetSession(c, r, opts...) - } + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + HealthCheckInterval: 50 * time.Millisecond, + healthCheckSampleInterval: 10 * time.Millisecond, + }, + }) + defer server.teardown(client) + sp := client.idleSessions // Test pinging sessions. sh, err := sp.take(ctx) @@ -741,7 +800,7 @@ func TestSessionHealthCheck(t *testing.T) { // Wait for healthchecker to send pings to session. waitFor(t, func() error { - pings := mock.DumpPings() + pings := server.testSpanner.DumpPings() if len(pings) == 0 || pings[0] != sh.getID() { return fmt.Errorf("healthchecker didn't send any ping to session %v", sh.getID()) } @@ -753,7 +812,14 @@ func TestSessionHealthCheck(t *testing.T) { t.Fatalf("cannot get session from session pool: %v", err) } - atomic.SwapInt64(&requestShouldErr, 1) + server.testSpanner.Freeze() + server.testSpanner.PutExecutionTime(testutil.MethodGetSession, + testutil.SimulatedExecutionTime{ + Errors: []error{status.Errorf(codes.NotFound, "Session not found")}, + KeepError: true, + }) + server.testSpanner.Unfreeze() + //atomic.SwapInt64(&requestShouldErr, 1) // Wait for healthcheck workers to find the broken session and tear it down. // TODO(deklerk): get rid of this @@ -764,7 +830,9 @@ func TestSessionHealthCheck(t *testing.T) { t.Fatalf("session(%v) is still alive, want it to be dropped by healthcheck workers", s) } - atomic.SwapInt64(&requestShouldErr, 0) + server.testSpanner.Freeze() + server.testSpanner.PutExecutionTime(testutil.MethodGetSession, testutil.SimulatedExecutionTime{}) + server.testSpanner.Unfreeze() // Test garbage collection. sh, err = sp.take(ctx) @@ -805,18 +873,17 @@ func TestStressSessionPool(t *testing.T) { cfg.HealthCheckInterval = 50 * time.Millisecond cfg.healthCheckSampleInterval = 10 * time.Millisecond cfg.HealthCheckWorkers = 50 - sc := testutil.NewMockCloudSpannerClient(t) - cfg.getRPCClient = func() (sppb.SpannerClient, error) { - return sc, nil - } - sp, _ := newSessionPool("mockdb", cfg, nil) - defer sp.hc.close() - defer sp.close() + + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: cfg, + }) + sp := client.idleSessions for i := 0; i < 100; i++ { wg.Add(1) // Schedule a test worker. - go func(idx int, pool *sessionPool, client sppb.SpannerClient) { + go func(idx int, pool *sessionPool, client *Client) { defer wg.Done() // Test worker iterates 1K times and tries different // session / session pool operations. @@ -868,7 +935,7 @@ func TestStressSessionPool(t *testing.T) { // recycle the session. sh.recycle() } - }(i, sp, sc) + }(i, sp, client) } wg.Wait() sp.hc.close() @@ -876,7 +943,7 @@ func TestStressSessionPool(t *testing.T) { // stable. idleSessions := map[string]bool{} hcSessions := map[string]bool{} - mockSessions := sc.DumpSessions() + mockSessions := server.testSpanner.DumpSessions() // Dump session pool's idle list. for sl := sp.idleList.Front(); sl != nil; sl = sl.Next() { s := sl.Value.(*session) @@ -912,14 +979,23 @@ func TestStressSessionPool(t *testing.T) { if !testEqual(idleSessions, hcSessions) { t.Fatalf("%v: sessions in idle list (%v) != sessions in healthcheck queue (%v)", ti, idleSessions, hcSessions) } - if !testEqual(hcSessions, mockSessions) { - t.Fatalf("%v: sessions in healthcheck queue (%v) != sessions in mockclient (%v)", ti, hcSessions, mockSessions) + // The server may contain more sessions than the health check queue. + // This can be caused by a timeout client side during a CreateSession + // request. The request may still be received and executed by the + // server, but the session pool will not register the session. + for id, b := range hcSessions { + if b && !mockSessions[id] { + t.Fatalf("%v: session in healthcheck queue (%v) was not found on server", ti, id) + } } sp.close() - mockSessions = sc.DumpSessions() - if len(mockSessions) != 0 { - t.Fatalf("Found live sessions: %v", mockSessions) + mockSessions = server.testSpanner.DumpSessions() + for id, b := range hcSessions { + if b && mockSessions[id] { + t.Fatalf("Found session from pool still live on server: %v", id) + } } + server.teardown(client) } } @@ -941,8 +1017,15 @@ func TestMaintainer(t *testing.T) { minOpened := uint64(5) maxIdle := uint64(4) - _, sp, _, cleanup := serverClientMock(t, SessionPoolConfig{MinOpened: minOpened, MaxIdle: maxIdle}) - defer cleanup() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: minOpened, + MaxIdle: maxIdle, + }, + }) + defer server.teardown(client) + sp := client.idleSessions sampleInterval := sp.SessionPoolConfig.healthCheckSampleInterval @@ -1003,40 +1086,25 @@ func TestMaintainer(t *testing.T) { // // Historical context: This test also checks that a low // healthCheckSampleInterval does not prevent it from opening connections. +// The low healthCheckSampleInterval will however sometimes cause session +// creations to time out. That should not be considered a problem, but it +// could cause the test case to fail if it happens too often. // See: https://github.com/googleapis/google-cloud-go/issues/1259 func TestMaintainer_CreatesSessions(t *testing.T) { t.Parallel() - - rawServerStub := testutil.NewMockCloudSpannerClient(t) - serverClientMock := testutil.FuncMock{MockCloudSpannerClient: rawServerStub} - serverClientMock.CreateSessionFn = func(c context.Context, r *sppb.CreateSessionRequest, opts ...grpc.CallOption) (*sppb.Session, error) { - time.Sleep(10 * time.Millisecond) - return rawServerStub.CreateSession(c, r, opts...) - } spc := SessionPoolConfig{ MinOpened: 10, MaxIdle: 10, - healthCheckSampleInterval: time.Millisecond, - getRPCClient: func() (sppb.SpannerClient, error) { - return &serverClientMock, nil - }, - } - db := "mockdb" - sp, err := newSessionPool(db, spc, nil) - if err != nil { - t.Fatalf("cannot create session pool: %v", err) + healthCheckSampleInterval: 20 * time.Millisecond, } - client := Client{ - database: db, - idleSessions: sp, - } - defer func() { - client.Close() - sp.hc.close() - sp.close() - }() + server, client := newSpannerInMemTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: spc, + }) + defer server.teardown(client) + sp := client.idleSessions - timeoutAmt := 2 * time.Second + timeoutAmt := 4 * time.Second timeout := time.After(timeoutAmt) var numOpened uint64 loop: diff --git a/spanner/transaction.go b/spanner/transaction.go index 4a6a336de08..a53486b674f 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -22,7 +22,10 @@ import ( "sync/atomic" "time" + "github.com/googleapis/gax-go/v2" + "cloud.google.com/go/internal/trace" + vkit "cloud.google.com/go/spanner/apiv1" "google.golang.org/api/iterator" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc" @@ -364,24 +367,22 @@ func (t *ReadOnlyTransaction) begin(ctx context.Context) error { if err != nil { return err } - err = runRetryable(contextWithOutgoingMetadata(ctx, sh.getMetadata()), func(ctx context.Context) error { - res, e := sh.getClient().BeginTransaction(ctx, &sppb.BeginTransactionRequest{ - Session: sh.getID(), - Options: &sppb.TransactionOptions{ - Mode: &sppb.TransactionOptions_ReadOnly_{ - ReadOnly: buildTransactionOptionsReadOnly(t.getTimestampBound(), true), - }, + res, err := sh.getClient().BeginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.BeginTransactionRequest{ + Session: sh.getID(), + Options: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadOnly_{ + ReadOnly: buildTransactionOptionsReadOnly(t.getTimestampBound(), true), }, - }) - if e != nil { - return e - } + }, + }) + if err == nil { tx = res.Id if res.ReadTimestamp != nil { rts = time.Unix(res.ReadTimestamp.Seconds, int64(res.ReadTimestamp.Nanos)) } - return nil - }) + } else { + err = toSpannerError(err) + } t.mu.Lock() // defer function will be executed with t.mu being held. @@ -798,27 +799,19 @@ func (t *ReadWriteTransaction) release(err error) { } } -func beginTransaction(ctx context.Context, sid string, client sppb.SpannerClient) (transactionID, error) { - var tx transactionID - err := runRetryable(ctx, func(ctx context.Context) error { - res, e := client.BeginTransaction(ctx, &sppb.BeginTransactionRequest{ - Session: sid, - Options: &sppb.TransactionOptions{ - Mode: &sppb.TransactionOptions_ReadWrite_{ - ReadWrite: &sppb.TransactionOptions_ReadWrite{}, - }, +func beginTransaction(ctx context.Context, sid string, client *vkit.Client) (transactionID, error) { + res, err := client.BeginTransaction(ctx, &sppb.BeginTransactionRequest{ + Session: sid, + Options: &sppb.TransactionOptions{ + Mode: &sppb.TransactionOptions_ReadWrite_{ + ReadWrite: &sppb.TransactionOptions_ReadWrite{}, }, - }) - if e != nil { - return e - } - tx = res.Id - return nil + }, }) if err != nil { return nil, err } - return tx, nil + return res.Id, nil } // begin starts a read-write transacton on Cloud Spanner, it is always called @@ -857,23 +850,21 @@ func (t *ReadWriteTransaction) commit(ctx context.Context) (time.Time, error) { if sid == "" || client == nil { return ts, errSessionClosed(t.sh) } - err = runRetryable(contextWithOutgoingMetadata(ctx, t.sh.getMetadata()), func(ctx context.Context) error { - var trailer metadata.MD - res, e := client.Commit(ctx, &sppb.CommitRequest{ - Session: sid, - Transaction: &sppb.CommitRequest_TransactionId{ - TransactionId: t.tx, - }, - Mutations: mPb, - }, grpc.Trailer(&trailer)) - if e != nil { - return toSpannerErrorWithMetadata(e, trailer) - } - if tstamp := res.GetCommitTimestamp(); tstamp != nil { - ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) - } - return nil - }) + + var trailer metadata.MD + res, e := client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata()), &sppb.CommitRequest{ + Session: sid, + Transaction: &sppb.CommitRequest_TransactionId{ + TransactionId: t.tx, + }, + Mutations: mPb, + }, gax.WithGRPCOptions(grpc.Trailer(&trailer))) + if e != nil { + return ts, toSpannerErrorWithMetadata(e, trailer) + } + if tstamp := res.GetCommitTimestamp(); tstamp != nil { + ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) + } if shouldDropSession(err) { t.sh.destroy() } @@ -893,12 +884,9 @@ func (t *ReadWriteTransaction) rollback(ctx context.Context) { if sid == "" || client == nil { return } - err := runRetryable(contextWithOutgoingMetadata(ctx, t.sh.getMetadata()), func(ctx context.Context) error { - _, e := client.Rollback(ctx, &sppb.RollbackRequest{ - Session: sid, - TransactionId: t.tx, - }) - return e + err := client.Rollback(contextWithOutgoingMetadata(ctx, t.sh.getMetadata()), &sppb.RollbackRequest{ + Session: sid, + TransactionId: t.tx, }) if shouldDropSession(err) { t.sh.destroy() @@ -920,7 +908,6 @@ func (t *ReadWriteTransaction) runInTransaction(ctx context.Context, f func(cont // Retry the transaction using the same session on ABORT error. // Cloud Spanner will create the new transaction with the previous // one's wound-wait priority. - err = errRetry(err) return ts, err } // Not going to commit, according to API spec, should rollback the @@ -956,19 +943,21 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta // Malformed mutation found, just return the error. return ts, err } - err = runRetryable(ctx, func(ct context.Context) error { - var e error - var trailers metadata.MD + + var trailers metadata.MD + // Retry-loop for aborted transactions. + // TODO: Replace with generic retryer. + for { if sh == nil || sh.getID() == "" || sh.getClient() == nil { // No usable session for doing the commit, take one from pool. - sh, e = t.sp.take(ctx) - if e != nil { + sh, err = t.sp.take(ctx) + if err != nil { // sessionPool.Take already retries for session // creations/retrivals. - return e + return ts, err } } - res, e := sh.getClient().Commit(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.CommitRequest{ + res, err := sh.getClient().Commit(contextWithOutgoingMetadata(ctx, sh.getMetadata()), &sppb.CommitRequest{ Session: sh.getID(), Transaction: &sppb.CommitRequest_SingleUseTransaction{ SingleUseTransaction: &sppb.TransactionOptions{ @@ -978,28 +967,24 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta }, }, Mutations: mPb, - }, grpc.Trailer(&trailers)) - if e != nil { - if isAbortErr(e) { - // Mask ABORT error as retryable, because aborted transactions - // are allowed to be retried. - return errRetry(toSpannerErrorWithMetadata(e, trailers)) - } - if shouldDropSession(e) { + }, gax.WithGRPCOptions(grpc.Trailer(&trailers))) + if err != nil && !isAbortErr(err) { + if shouldDropSession(err) { // Discard the bad session. sh.destroy() } - return e - } - if tstamp := res.GetCommitTimestamp(); tstamp != nil { - ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) + return ts, toSpannerError(err) + } else if err == nil { + if tstamp := res.GetCommitTimestamp(); tstamp != nil { + ts = time.Unix(tstamp.Seconds, int64(tstamp.Nanos)) + } + break } - return nil - }) + } if sh != nil { sh.recycle() } - return ts, err + return ts, toSpannerError(err) } // isAbortedErr returns true if the error indicates that an gRPC call is diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index 9698aab742f..9c10ac4d819 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -27,16 +27,16 @@ import ( "cloud.google.com/go/spanner/internal/testutil" sppb "google.golang.org/genproto/googleapis/spanner/v1" - "google.golang.org/grpc" "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" ) // Single can only be used once. func TestSingle(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) txn := client.Single() defer txn.Close() @@ -50,7 +50,7 @@ func TestSingle(t *testing.T) { } // Only one CreateSessionRequest is sent. - if _, err := shouldHaveReceived(mock, []interface{}{&sppb.CreateSessionRequest{}}); err != nil { + if _, err := shouldHaveReceived(server.testSpanner, []interface{}{&sppb.CreateSessionRequest{}}); err != nil { t.Fatal(err) } } @@ -59,23 +59,18 @@ func TestSingle(t *testing.T) { func TestReadOnlyTransaction_RecoverFromFailure(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) txn := client.ReadOnlyTransaction() defer txn.Close() - // First request will fail, which should trigger a retry. - errUsr := errors.New("error") - firstCall := true - mock.BeginTransactionFn = func(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) { - if firstCall { - mock.MockCloudSpannerClient.ReceivedRequests <- r - firstCall = false - return nil, errUsr - } - return mock.MockCloudSpannerClient.BeginTransaction(c, r, opts...) - } + // First request will fail. + errUsr := gstatus.Error(codes.Unknown, "error") + server.testSpanner.PutExecutionTime(testutil.MethodBeginTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{errUsr}, + }) _, _, e := txn.acquire(ctx) if wantErr := toSpannerError(errUsr); !testEqual(e, wantErr) { @@ -91,8 +86,8 @@ func TestReadOnlyTransaction_RecoverFromFailure(t *testing.T) { func TestReadOnlyTransaction_UseAfterClose(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, _, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) txn := client.ReadOnlyTransaction() txn.Close() @@ -107,12 +102,12 @@ func TestReadOnlyTransaction_UseAfterClose(t *testing.T) { func TestReadOnlyTransaction_Concurrent(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) txn := client.ReadOnlyTransaction() defer txn.Close() - mock.Freeze() + server.testSpanner.Freeze() var ( sh1 *sessionHandle sh2 *sessionHandle @@ -135,7 +130,7 @@ func TestReadOnlyTransaction_Concurrent(t *testing.T) { // TODO(deklerk): Get rid of this. <-time.After(100 * time.Millisecond) - mock.Unfreeze() + server.testSpanner.Unfreeze() wg.Wait() if sh1.session.id != sh2.session.id { t.Fatalf("Expected acquire to get same session handle, got %v and %v.", sh1, sh2) @@ -148,8 +143,8 @@ func TestReadOnlyTransaction_Concurrent(t *testing.T) { func TestApply_Single(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) ms := []*Mutation{ Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}), @@ -159,7 +154,7 @@ func TestApply_Single(t *testing.T) { t.Fatalf("applyAtLeastOnce retry on abort, got %v, want nil.", e) } - if _, err := shouldHaveReceived(mock, []interface{}{ + if _, err := shouldHaveReceived(server.testSpanner, []interface{}{ &sppb.CreateSessionRequest{}, &sppb.CommitRequest{}, }); err != nil { @@ -171,20 +166,15 @@ func TestApply_Single(t *testing.T) { func TestApply_RetryOnAbort(t *testing.T) { ctx := context.Background() t.Parallel() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) // First commit will fail, and the retry will begin a new transaction. errAbrt := spannerErrorf(codes.Aborted, "") - firstCommitCall := true - mock.CommitFn = func(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) { - if firstCommitCall { - mock.MockCloudSpannerClient.ReceivedRequests <- r - firstCommitCall = false - return nil, errAbrt - } - return mock.MockCloudSpannerClient.Commit(c, r, opts...) - } + server.testSpanner.PutExecutionTime(testutil.MethodCommitTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{errAbrt}, + }) ms := []*Mutation{ Insert("Accounts", []string{"AccountId"}, []interface{}{int64(1)}), @@ -194,7 +184,7 @@ func TestApply_RetryOnAbort(t *testing.T) { t.Fatalf("ReadWriteTransaction retry on abort, got %v, want nil.", e) } - if _, err := shouldHaveReceived(mock, []interface{}{ + if _, err := shouldHaveReceived(server.testSpanner, []interface{}{ &sppb.CreateSessionRequest{}, &sppb.BeginTransactionRequest{}, &sppb.CommitRequest{}, // First commit fails. @@ -209,18 +199,18 @@ func TestApply_RetryOnAbort(t *testing.T) { func TestTransaction_NotFound(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) wantErr := spannerErrorf(codes.NotFound, "Session not found") - mock.BeginTransactionFn = func(c context.Context, r *sppb.BeginTransactionRequest, opts ...grpc.CallOption) (*sppb.Transaction, error) { - mock.MockCloudSpannerClient.ReceivedRequests <- r - return nil, wantErr - } - mock.CommitFn = func(c context.Context, r *sppb.CommitRequest, opts ...grpc.CallOption) (*sppb.CommitResponse, error) { - mock.MockCloudSpannerClient.ReceivedRequests <- r - return nil, wantErr - } + server.testSpanner.PutExecutionTime(testutil.MethodBeginTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{wantErr, wantErr, wantErr}, + }) + server.testSpanner.PutExecutionTime(testutil.MethodCommitTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{wantErr, wantErr, wantErr}, + }) txn := client.ReadOnlyTransaction() defer txn.Close() @@ -253,8 +243,8 @@ func TestTransaction_NotFound(t *testing.T) { func TestReadWriteTransaction_ErrorReturned(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) want := errors.New("an error") _, got := client.ReadWriteTransaction(ctx, func(context.Context, *ReadWriteTransaction) error { @@ -263,7 +253,7 @@ func TestReadWriteTransaction_ErrorReturned(t *testing.T) { if got != want { t.Fatalf("got %+v, want %+v", got, want) } - requests := drainRequests(mock) + requests := drainRequestsFromServer(server.testSpanner) if err := compareRequests([]interface{}{ &sppb.CreateSessionRequest{}, &sppb.BeginTransactionRequest{}, @@ -287,27 +277,27 @@ func TestReadWriteTransaction_ErrorReturned(t *testing.T) { func TestBatchDML_WithMultipleDML(t *testing.T) { t.Parallel() ctx := context.Background() - client, _, mock, cleanup := serverClientMock(t, SessionPoolConfig{}) - defer cleanup() + server, client := newSpannerInMemTestServer(t) + defer server.teardown(client) _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { - if _, err = tx.Update(ctx, Statement{SQL: "SELECT * FROM whatever"}); err != nil { + if _, err = tx.Update(ctx, Statement{SQL: updateBarSetFoo}); err != nil { return err } - if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: "SELECT * FROM whatever"}, {SQL: "SELECT * FROM whatever"}}); err != nil { + if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateBarSetFoo}, {SQL: updateBarSetFoo}}); err != nil { return err } - if _, err = tx.Update(ctx, Statement{SQL: "SELECT * FROM whatever"}); err != nil { + if _, err = tx.Update(ctx, Statement{SQL: updateBarSetFoo}); err != nil { return err } - _, err = tx.BatchUpdate(ctx, []Statement{{SQL: "SELECT * FROM whatever"}}) + _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateBarSetFoo}}) return err }) if err != nil { t.Fatal(err) } - gotReqs, err := shouldHaveReceived(mock, []interface{}{ + gotReqs, err := shouldHaveReceived(server.testSpanner, []interface{}{ &sppb.CreateSessionRequest{}, &sppb.BeginTransactionRequest{}, &sppb.ExecuteSqlRequest{}, @@ -339,8 +329,8 @@ func TestBatchDML_WithMultipleDML(t *testing.T) { // // Note: this in-place modifies serverClientMock by popping items off the // ReceivedRequests channel. -func shouldHaveReceived(mock *testutil.FuncMock, want []interface{}) ([]interface{}, error) { - got := drainRequests(mock) +func shouldHaveReceived(server testutil.InMemSpannerServer, want []interface{}) ([]interface{}, error) { + got := drainRequestsFromServer(server) return got, compareRequests(want, got) } @@ -368,12 +358,12 @@ func compareRequests(want []interface{}, got []interface{}) error { return nil } -func drainRequests(mock *testutil.FuncMock) []interface{} { +func drainRequestsFromServer(server testutil.InMemSpannerServer) []interface{} { var reqs []interface{} loop: for { select { - case req := <-mock.ReceivedRequests: + case req := <-server.ReceivedRequests(): reqs = append(reqs, req) default: break loop @@ -381,30 +371,3 @@ loop: } return reqs } - -// serverClientMock sets up a client configured to a NewMockCloudSpannerClient -// that is wrapped with a function-injectable wrapper. -// -// Note: be sure to call cleanup! -func serverClientMock(t *testing.T, spc SessionPoolConfig) (_ *Client, _ *sessionPool, _ *testutil.FuncMock, cleanup func()) { - rawServerStub := testutil.NewMockCloudSpannerClient(t) - serverClientMock := testutil.FuncMock{MockCloudSpannerClient: rawServerStub} - spc.getRPCClient = func() (sppb.SpannerClient, error) { - return &serverClientMock, nil - } - db := "mockdb" - sp, err := newSessionPool(db, spc, nil) - if err != nil { - t.Fatalf("cannot create session pool: %v", err) - } - client := Client{ - database: db, - idleSessions: sp, - } - cleanup = func() { - client.Close() - sp.hc.close() - sp.close() - } - return &client, sp, &serverClientMock, cleanup -}