From 267787eb245d9307cf78304c1ce34bdfb2aaf5ab Mon Sep 17 00:00:00 2001 From: shollyman Date: Thu, 24 Jun 2021 14:06:47 -0700 Subject: [PATCH 1/4] feat(bigquery): enable project autodetection, expose project ids further (#4312) PR supersedes: https://github.com/googleapis/google-cloud-go/pull/4076 Related: https://github.com/googleapis/google-cloud-go/issues/1294 With this change, project autodetection is enabled via use of a sentinel value, and the retained project identifier is now exposed on the Client and Job resources via the Project() function. --- bigquery/bigquery.go | 39 ++++++++++++++++++++++++++++++++++++ bigquery/integration_test.go | 18 +++++++++++++++++ bigquery/job.go | 5 +++++ 3 files changed, 62 insertions(+) diff --git a/bigquery/bigquery.go b/bigquery/bigquery.go index 0a32f02c3d52..28ea5446f1f3 100644 --- a/bigquery/bigquery.go +++ b/bigquery/bigquery.go @@ -16,6 +16,7 @@ package bigquery import ( "context" + "errors" "fmt" "io" "net/http" @@ -29,6 +30,7 @@ import ( bq "google.golang.org/api/bigquery/v2" "google.golang.org/api/googleapi" "google.golang.org/api/option" + "google.golang.org/api/transport" ) const ( @@ -56,8 +58,20 @@ type Client struct { bqs *bq.Service } +// DetectProjectID is a sentinel value that instructs NewClient to detect the +// project ID. It is given in place of the projectID argument. NewClient will +// use the project ID from the given credentials or the default credentials +// (https://developers.google.com/accounts/docs/application-default-credentials) +// if no credentials were provided. When providing credentials, not all +// options will allow NewClient to extract the project ID. Specifically a JWT +// does not have the project ID encoded. +const DetectProjectID = "*detect-project-id*" + // NewClient constructs a new Client which can perform BigQuery operations. // Operations performed via the client are billed to the specified GCP project. +// +// If the project ID is set to DetectProjectID, NewClient will attempt to detect +// the project ID from credentials. func NewClient(ctx context.Context, projectID string, opts ...option.ClientOption) (*Client, error) { o := []option.ClientOption{ option.WithScopes(Scope), @@ -68,6 +82,14 @@ func NewClient(ctx context.Context, projectID string, opts ...option.ClientOptio if err != nil { return nil, fmt.Errorf("bigquery: constructing client: %v", err) } + + if projectID == DetectProjectID { + projectID, err = detectProjectID(ctx, opts...) + if err != nil { + return nil, fmt.Errorf("failed to detect project: %v", err) + } + } + c := &Client{ projectID: projectID, bqs: bqs, @@ -75,6 +97,12 @@ func NewClient(ctx context.Context, projectID string, opts ...option.ClientOptio return c, nil } +// Project returns the project ID or number for this instance of the client, which may have +// either been explicitly specified or autodetected. +func (c *Client) Project() string { + return c.projectID +} + // Close closes any resources held by the client. // Close should be called when the client is no longer needed. // It need not be called at program exit. @@ -82,6 +110,17 @@ func (c *Client) Close() error { return nil } +func detectProjectID(ctx context.Context, opts ...option.ClientOption) (string, error) { + creds, err := transport.Creds(ctx, opts...) + if err != nil { + return "", fmt.Errorf("fetching creds: %v", err) + } + if creds.ProjectID == "" { + return "", errors.New("credentials did not provide a valid ProjectID") + } + return creds.ProjectID, nil +} + // Calls the Jobs.Insert RPC and returns a Job. func (c *Client) insertJob(ctx context.Context, job *bq.Job, media io.Reader) (*Job, error) { call := c.bqs.Jobs.Insert(c.projectID, job).Context(ctx) diff --git a/bigquery/integration_test.go b/bigquery/integration_test.go index 9cca940d876d..8b7b2a1e5ad6 100644 --- a/bigquery/integration_test.go +++ b/bigquery/integration_test.go @@ -228,6 +228,24 @@ func initTestState(client *Client, t time.Time) func() { } } +func TestIntegration_DetectProjectID(t *testing.T) { + ctx := context.Background() + testCreds := testutil.Credentials(ctx) + if testCreds == nil { + t.Skip("test credentials not present, skipping") + } + + if _, err := NewClient(ctx, DetectProjectID, option.WithCredentials(testCreds)); err != nil { + t.Errorf("test NewClient: %v", err) + } + + badTS := testutil.ErroringTokenSource{} + + if badClient, err := NewClient(ctx, DetectProjectID, option.WithTokenSource(badTS)); err == nil { + t.Errorf("expected error from bad token source, NewClient succeeded with project: %s", badClient.Project()) + } +} + func TestIntegration_TableCreate(t *testing.T) { // Check that creating a record field with an empty schema is an error. if client == nil { diff --git a/bigquery/job.go b/bigquery/job.go index 2d259f910b45..725775dfd68a 100644 --- a/bigquery/job.go +++ b/bigquery/job.go @@ -63,6 +63,11 @@ func (c *Client) JobFromIDLocation(ctx context.Context, id, location string) (j return bqToJob(bqjob, c) } +// Project returns the job's project. +func (j *Job) Project() string { + return j.projectID +} + // ID returns the job's ID. func (j *Job) ID() string { return j.jobID From b34783a4d7a8c88204e0f44bd411795d8267d811 Mon Sep 17 00:00:00 2001 From: Christopher Wilcox Date: Fri, 25 Jun 2021 11:19:16 -0700 Subject: [PATCH 2/4] feat(firestore): Add support for PartitionQuery (#4206) --- firestore/collgroupref.go | 148 +++++++++++++++++++++++++++++++++ firestore/collgroupref_test.go | 73 ++++++++++++++++ firestore/integration_test.go | 85 +++++++++++++++++++ firestore/order.go | 8 ++ 4 files changed, 314 insertions(+) create mode 100644 firestore/collgroupref_test.go diff --git a/firestore/collgroupref.go b/firestore/collgroupref.go index e43a7e648c54..c13ff1f160b8 100644 --- a/firestore/collgroupref.go +++ b/firestore/collgroupref.go @@ -14,6 +14,16 @@ package firestore +import ( + "context" + "errors" + "fmt" + "sort" + + "google.golang.org/api/iterator" + firestorepb "google.golang.org/genproto/googleapis/firestore/v1" +) + // A CollectionGroupRef is a reference to a group of collections sharing the // same ID. type CollectionGroupRef struct { @@ -36,3 +46,141 @@ func newCollectionGroupRef(c *Client, dbPath, collectionID string) *CollectionGr }, } } + +// GetPartitionedQueries returns a slice of Query objects, each containing a +// partition of a collection group. partitionCount must be a positive value and +// the number of returned partitions may be less than the requested number if +// providing the desired number would result in partitions with very few documents. +// +// If a Collection Group Query would return a large number of documents, this +// can help to subdivide the query to smaller working units that can be distributed. +func (cgr CollectionGroupRef) GetPartitionedQueries(ctx context.Context, partitionCount int) ([]Query, error) { + qp, err := cgr.getPartitions(ctx, partitionCount) + if err != nil { + return nil, err + } + queries := make([]Query, len(qp)) + for _, part := range qp { + queries = append(queries, part.toQuery()) + } + return queries, nil +} + +// getPartitions returns a slice of queryPartition objects, describing a start +// and end range to query a subsection of the collection group. partitionCount +// must be a positive value and the number of returned partitions may be less +// than the requested number if providing the desired number would result in +// partitions with very few documents. +func (cgr CollectionGroupRef) getPartitions(ctx context.Context, partitionCount int) ([]queryPartition, error) { + orderedQuery := cgr.query().OrderBy(DocumentID, Asc) + + if partitionCount <= 0 { + return nil, errors.New("a positive partitionCount must be provided") + } else if partitionCount == 1 { + return []queryPartition{{CollectionGroupQuery: orderedQuery}}, nil + } + + db := cgr.c.path() + ctx = withResourceHeader(ctx, db) + + // CollectionGroup Queries need to be ordered by __name__ ASC. + query, err := orderedQuery.toProto() + if err != nil { + return nil, err + } + structuredQuery := &firestorepb.PartitionQueryRequest_StructuredQuery{ + StructuredQuery: query, + } + + // Uses default PageSize + pbr := &firestorepb.PartitionQueryRequest{ + Parent: db + "/documents", + PartitionCount: int64(partitionCount), + QueryType: structuredQuery, + } + cursorReferences := make([]*firestorepb.Value, 0, partitionCount) + iter := cgr.c.c.PartitionQuery(ctx, pbr) + for { + cursor, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, fmt.Errorf("GetPartitions: %v", err) + } + cursorReferences = append(cursorReferences, cursor.GetValues()...) + } + + // From Proto documentation: + // To obtain a complete result set ordered with respect to the results of the + // query supplied to PartitionQuery, the results sets should be merged: + // cursor A, cursor B, cursor M, cursor Q, cursor U, cursor W + // Once we have exhausted the pages, the cursor values need to be sorted in + // lexicographical order by segment (areas between '/'). + sort.Sort(byFirestoreValue(cursorReferences)) + + queryPartitions := make([]queryPartition, 0, len(cursorReferences)) + previousCursor := "" + + for _, cursor := range cursorReferences { + cursorRef := cursor.GetReferenceValue() + + // remove the root path from the reference, as queries take cursors + // relative to a collection + cursorRef = cursorRef[len(orderedQuery.path)+1:] + + qp := queryPartition{ + CollectionGroupQuery: orderedQuery, + StartAt: previousCursor, + EndBefore: cursorRef, + } + queryPartitions = append(queryPartitions, qp) + previousCursor = cursorRef + } + + // In the case there were no partitions, we still add a single partition to + // the result, that covers the complete range. + lastPart := queryPartition{CollectionGroupQuery: orderedQuery} + if len(cursorReferences) > 0 { + cursorRef := cursorReferences[len(cursorReferences)-1].GetReferenceValue() + lastPart.StartAt = cursorRef[len(orderedQuery.path)+1:] + } + queryPartitions = append(queryPartitions, lastPart) + + return queryPartitions, nil +} + +// queryPartition provides a Collection Group Reference and start and end split +// points allowing for a section of a collection group to be queried. This is +// used by GetPartitions which, given a CollectionGroupReference returns smaller +// sub-queries or partitions +type queryPartition struct { + // CollectionGroupQuery is an ordered query on a CollectionGroupReference. + // This query must be ordered Asc on __name__. + // Example: client.CollectionGroup("collectionID").query().OrderBy(DocumentID, Asc) + CollectionGroupQuery Query + + // StartAt is a document reference value, relative to the collection, not + // a complete parent path. + // Example: "documents/collectionName/documentName" + StartAt string + + // EndBefore is a document reference value, relative to the collection, not + // a complete parent path. + // Example: "documents/collectionName/documentName" + EndBefore string +} + +// toQuery converts a queryPartition object to a Query object +func (qp queryPartition) toQuery() Query { + q := *qp.CollectionGroupQuery.query() + + // Remove the leading path before calling StartAt, EndBefore + if qp.StartAt != "" { + q = q.StartAt(qp.StartAt) + } + if qp.EndBefore != "" { + q = q.EndBefore(qp.EndBefore) + } + return q +} diff --git a/firestore/collgroupref_test.go b/firestore/collgroupref_test.go new file mode 100644 index 000000000000..bbcbfe4ebb62 --- /dev/null +++ b/firestore/collgroupref_test.go @@ -0,0 +1,73 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "context" + "testing" +) + +func TestCGR_TestQueryPartition_ToQuery(t *testing.T) { + cgr := newCollectionGroupRef(testClient, testClient.path(), "collectionID") + qp := queryPartition{ + CollectionGroupQuery: cgr.Query.OrderBy(DocumentID, Asc), + StartAt: "documents/start/at", + EndBefore: "documents/end/before", + } + + got := qp.toQuery() + + want := Query{ + c: testClient, + path: "projects/projectID/databases/(default)", + parentPath: "projects/projectID/databases/(default)/documents", + collectionID: "collectionID", + startVals: []interface{}{"documents/start/at"}, + endVals: []interface{}{"documents/end/before"}, + startBefore: true, + endBefore: true, + allDescendants: true, + orders: []order{{fieldPath: []string{"__name__"}, dir: 1}}, + } + + if !testEqual(got, want) { + t.Errorf("got %+v, want %+v", got, want) + } +} + +func TestCGR_TestGetPartitions(t *testing.T) { + cgr := newCollectionGroupRef(testClient, testClient.path(), "collectionID") + _, err := cgr.getPartitions(context.Background(), 0) + if err == nil { + t.Error("Expected an error when requested partition count is < 1") + } + + parts, err := cgr.getPartitions(context.Background(), 1) + if err != nil { + t.Error("Didn't expect an error when requested partition count is 1") + } + if len(parts) != 1 { + t.Fatal("Expected 1 queryPartition") + } + got := parts[0] + want := queryPartition{ + CollectionGroupQuery: cgr.Query.OrderBy(DocumentID, Asc), + StartAt: "", + EndBefore: "", + } + if !testEqual(got, want) { + t.Errorf("got %+v, want %+v", got, want) + } +} diff --git a/firestore/integration_test.go b/firestore/integration_test.go index e398150f84fa..ca841858f988 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -1601,3 +1601,88 @@ func TestDetectProjectID(t *testing.T) { t.Errorf("expected an error while using TokenSource that does not have a project ID") } } + +func TestIntegration_ColGroupRefPartitions(t *testing.T) { + h := testHelper{t} + coll := integrationColl(t) + ctx := context.Background() + + // Create a doc in the test collection so a collectionID is live for testing + doc := coll.NewDoc() + h.mustCreate(doc, integrationTestMap) + + for _, tc := range []struct { + collectionID string + expectedPartitionCount int + }{ + // Verify no failures if a collection doesn't exist + {collectionID: "does-not-exist", expectedPartitionCount: 1}, + // Verify a collectionID with a small number of results returns a partition + {collectionID: coll.collectionID, expectedPartitionCount: 1}, + } { + colGroup := iClient.CollectionGroup(tc.collectionID) + partitions, err := colGroup.getPartitions(ctx, 10) + if err != nil { + t.Fatalf("getPartitions: received unexpected error: %v", err) + } + if got, want := len(partitions), tc.expectedPartitionCount; got != want { + t.Errorf("Unexpected Partition Count: got %d, want %d", got, want) + } + } +} + +func TestIntegration_ColGroupRefPartitionsLarge(t *testing.T) { + // Create collection with enough documents to have multiple partitions. + coll := integrationColl(t) + collectionID := coll.collectionID + "largeCollection" + coll = iClient.Collection(collectionID) + + ctx := context.Background() + + documentCount := 2*128 + 127 // Minimum partition size is 128. + + // Create documents in a collection sufficient to trigger multiple partitions. + batch := iClient.Batch() + deleteBatch := iClient.Batch() + for i := 0; i < documentCount; i++ { + doc := coll.Doc(fmt.Sprintf("doc%d", i)) + batch.Create(doc, integrationTestMap) + deleteBatch.Delete(doc) + } + batch.Commit(ctx) + defer deleteBatch.Commit(ctx) + + // Verify that we retrieve 383 documents for the colGroup (128*2 + 127) + colGroup := iClient.CollectionGroup(collectionID) + docs, err := colGroup.Documents(ctx).GetAll() + if err != nil { + t.Fatalf("GetAll(): received unexpected error: %v", err) + } + if got, want := len(docs), documentCount; got != want { + t.Errorf("Unexpected number of documents in collection group: got %d, want %d", got, want) + } + + // Get partitions, allow up to 10 to come back, expect less will be returned. + partitions, err := colGroup.GetPartitionedQueries(ctx, 10) + if err != nil { + t.Fatalf("GetPartitionedQueries: received unexpected error: %v", err) + } + if len(partitions) < 2 { + t.Errorf("Unexpected Partition Count. Expected 2 or more: got %d, want 2+", len(partitions)) + } + + // Verify that we retrieve 383 documents across all partitions. (128*2 + 127) + totalCount := 0 + for _, query := range partitions { + + allDocs, err := query.Documents(ctx).GetAll() + if err != nil { + t.Fatalf("GetAll(): received unexpected error: %v", err) + } + totalCount += len(allDocs) + } + + if got, want := totalCount, documentCount; got != want { + t.Errorf("Unexpected number of documents across partitions: got %d, want %d", got, want) + } +} diff --git a/firestore/order.go b/firestore/order.go index e5ee1e09fb53..c495a141fd36 100644 --- a/firestore/order.go +++ b/firestore/order.go @@ -22,6 +22,7 @@ import ( "strings" tspb "github.com/golang/protobuf/ptypes/timestamp" + firestorepb "google.golang.org/genproto/googleapis/firestore/v1" pb "google.golang.org/genproto/googleapis/firestore/v1" ) @@ -214,3 +215,10 @@ func typeOrder(v *pb.Value) int { panic(fmt.Sprintf("bad value type: %v", v)) } } + +// byReferenceValue implements sort.Interface for []*firestorepb.Value +type byFirestoreValue []*firestorepb.Value + +func (a byFirestoreValue) Len() int { return len(a) } +func (a byFirestoreValue) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a byFirestoreValue) Less(i, j int) bool { return compareValues(a[i], a[j]) < 0 } From ae34396b1a2a970a0d871cd5496527294f3310d4 Mon Sep 17 00:00:00 2001 From: tmdiep Date: Sat, 26 Jun 2021 07:17:55 +1000 Subject: [PATCH 3/4] fix(pubsublite): wire user context to api clients (#4318) --- pubsublite/pscompat/integration_test.go | 60 ++++++++++++++++++++----- pubsublite/pscompat/publisher.go | 5 +-- pubsublite/pscompat/subscriber.go | 23 +++++----- pubsublite/pscompat/subscriber_test.go | 4 +- 4 files changed, 66 insertions(+), 26 deletions(-) diff --git a/pubsublite/pscompat/integration_test.go b/pubsublite/pscompat/integration_test.go index a22358ead8e7..ab14ad7b973b 100644 --- a/pubsublite/pscompat/integration_test.go +++ b/pubsublite/pscompat/integration_test.go @@ -30,7 +30,9 @@ import ( "cloud.google.com/go/pubsublite/internal/wire" "github.com/google/go-cmp/cmp/cmpopts" "golang.org/x/sync/errgroup" + "golang.org/x/xerrors" "google.golang.org/api/option" + "google.golang.org/grpc/codes" vkit "cloud.google.com/go/pubsublite/apiv1" pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" @@ -167,7 +169,7 @@ func partitionNumbers(partitionCount int) []int { func publishMessages(t *testing.T, settings PublishSettings, topic wire.TopicPath, msgs ...*pubsub.Message) { ctx := context.Background() - publisher := publisherClient(ctx, t, settings, topic) + publisher := publisherClient(context.Background(), t, settings, topic) defer publisher.Stop() var pubResults []*pubsub.PublishResult @@ -179,7 +181,7 @@ func publishMessages(t *testing.T, settings PublishSettings, topic wire.TopicPat func publishPrefixedMessages(t *testing.T, settings PublishSettings, topic wire.TopicPath, msgPrefix string, msgCount, msgSize int) []string { ctx := context.Background() - publisher := publisherClient(ctx, t, settings, topic) + publisher := publisherClient(context.Background(), t, settings, topic) defer publisher.Stop() orderingSender := test.NewOrderingSender() @@ -271,7 +273,7 @@ func receiveAllMessages(t *testing.T, msgTracker *test.MsgTracker, settings Rece } } - subscriber := subscriberClient(cctx, t, settings, subscription) + subscriber := subscriberClient(context.Background(), t, settings, subscription) if err := subscriber.Receive(cctx, messageReceiver); err != nil { t.Errorf("Receive() got err: %v", err) } @@ -298,7 +300,7 @@ func receiveAndVerifyMessage(t *testing.T, want *pubsub.Message, settings Receiv } } - subscriber := subscriberClient(cctx, t, settings, subscription) + subscriber := subscriberClient(context.Background(), t, settings, subscription) if err := subscriber.Receive(cctx, messageReceiver); err != nil { t.Errorf("Receive() got err: %v", err) } @@ -383,7 +385,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { } got.Nack() } - subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath) if gotErr := subscriber.Receive(cctx, messageReceiver1); !test.ErrorEqual(gotErr, errNackCalled) { t.Errorf("Receive() got err: (%v), want err: (%v)", gotErr, errNackCalled) } @@ -400,7 +402,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { } return fmt.Errorf("Received unexpected message: %q", truncateMsg(string(msg.Data))) } - subscriber = subscriberClient(cctx, t, customSettings, subscriptionPath) + subscriber = subscriberClient(context.Background(), t, customSettings, subscriptionPath) messageReceiver2 := func(ctx context.Context, got *pubsub.Message) { got.Nack() @@ -434,7 +436,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { got.Ack() stopSubscriber() } - subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath) // The message receiver stops the subscriber after receiving the first // message. However, the subscriber isn't guaranteed to immediately stop, so @@ -485,7 +487,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { // next test, which would receive an incorrect message. got.Ack() } - subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath) if err := subscriber.Receive(cctx, messageReceiver); err != nil { t.Errorf("Receive() got err: %v", err) @@ -539,6 +541,44 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { receiveAllMessages(t, msgTracker, recvSettings, subscriptionPath) }) + // Verifies that cancelling the context passed to NewPublisherClient can shut + // down the publisher. + t.Run("CancelPublisherContext", func(t *testing.T) { + cctx, cancel := context.WithCancel(context.Background()) + publisher := publisherClient(cctx, t, DefaultPublishSettings, topicPath) + + cancel() + + wantCode := codes.Canceled + result := publisher.Publish(ctx, &pubsub.Message{Data: []byte("cancel_publisher_context")}) + if _, err := result.Get(ctx); !test.ErrorHasCode(err, wantCode) { + t.Errorf("Publish() got err: %v, want code: %v", err, wantCode) + } + if err := xerrors.Unwrap(publisher.Error()); !test.ErrorHasCode(err, wantCode) { + t.Errorf("Error() got err: %v, want code: %v", err, wantCode) + } + publisher.Stop() + }) + + // Verifies that cancelling the context passed to NewSubscriberClient can shut + // down the subscriber. + t.Run("CancelSubscriberContext", func(t *testing.T) { + msg := &pubsub.Message{Data: []byte("cancel_subscriber_context")} + publishMessages(t, DefaultPublishSettings, topicPath, msg) + + cctx, cancel := context.WithCancel(context.Background()) + subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath) + + subsErr := subscriber.Receive(context.Background(), func(ctx context.Context, got *pubsub.Message) { + got.Ack() + cancel() + }) + + if err, wantCode := xerrors.Unwrap(subsErr), codes.Canceled; !test.ErrorHasCode(err, wantCode) { + t.Errorf("Receive() got err: %v, want code: %v", err, wantCode) + } + }) + // NOTE: This should be the last test case. // Verifies that increasing the number of topic partitions is handled // correctly by publishers. @@ -547,7 +587,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { const pollPeriod = 5 * time.Second pubSettings := DefaultPublishSettings pubSettings.configPollPeriod = pollPeriod // Poll updates more frequently - publisher := publisherClient(ctx, t, pubSettings, topicPath) + publisher := publisherClient(context.Background(), t, pubSettings, topicPath) defer publisher.Stop() // Update the number of partitions. @@ -661,7 +701,7 @@ func TestIntegration_PublishSubscribeMultiPartition(t *testing.T) { for i := 0; i < subscriberCount; i++ { // Subscribers must be started in a goroutine as Receive() blocks. g.Go(func() error { - subscriber := subscriberClient(cctx, t, DefaultReceiveSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, DefaultReceiveSettings, subscriptionPath) err := subscriber.Receive(cctx, messageReceiver) if err != nil { t.Errorf("Receive() got err: %v", err) diff --git a/pubsublite/pscompat/publisher.go b/pubsublite/pscompat/publisher.go index 6f4e9415a822..b9f8c82e19ab 100644 --- a/pubsublite/pscompat/publisher.go +++ b/pubsublite/pscompat/publisher.go @@ -82,10 +82,7 @@ func NewPublisherClientWithSettings(ctx context.Context, topic string, settings return nil, err } - // Note: ctx is not used to create the wire publisher, because if it is - // cancelled, the publisher will not be able to perform graceful shutdown - // (e.g. flush pending messages). - wirePub, err := wire.NewPublisher(context.Background(), settings.toWireSettings(), region, topic, opts...) + wirePub, err := wire.NewPublisher(ctx, settings.toWireSettings(), region, topic, opts...) if err != nil { return nil, err } diff --git a/pubsublite/pscompat/subscriber.go b/pubsublite/pscompat/subscriber.go index dafb25ee526f..d76cc9d290bf 100644 --- a/pubsublite/pscompat/subscriber.go +++ b/pubsublite/pscompat/subscriber.go @@ -72,7 +72,7 @@ func (ah *pslAckHandler) OnNack() { // wireSubscriberFactory is a factory for creating wire subscribers, which can // be overridden with a mock in unit tests. type wireSubscriberFactory interface { - New(wire.MessageReceiverFunc) (wire.Subscriber, error) + New(context.Context, wire.MessageReceiverFunc) (wire.Subscriber, error) } type wireSubscriberFactoryImpl struct { @@ -82,8 +82,8 @@ type wireSubscriberFactoryImpl struct { options []option.ClientOption } -func (f *wireSubscriberFactoryImpl) New(receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { - return wire.NewSubscriber(context.Background(), f.settings, receiver, f.region, f.subscription.String(), f.options...) +func (f *wireSubscriberFactoryImpl) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { + return wire.NewSubscriber(ctx, f.settings, receiver, f.region, f.subscription.String(), f.options...) } type messageReceiverFunc = func(context.Context, *pubsub.Message) @@ -103,8 +103,8 @@ type subscriberInstance struct { err error } -func newSubscriberInstance(ctx context.Context, factory wireSubscriberFactory, settings ReceiveSettings, receiver messageReceiverFunc) (*subscriberInstance, error) { - recvCtx, recvCancel := context.WithCancel(ctx) +func newSubscriberInstance(recvCtx, clientCtx context.Context, factory wireSubscriberFactory, settings ReceiveSettings, receiver messageReceiverFunc) (*subscriberInstance, error) { + recvCtx, recvCancel := context.WithCancel(recvCtx) subInstance := &subscriberInstance{ settings: settings, recvCtx: recvCtx, @@ -112,10 +112,11 @@ func newSubscriberInstance(ctx context.Context, factory wireSubscriberFactory, s receiver: receiver, } - // Note: ctx is not used to create the wire subscriber, because if it is - // cancelled, the subscriber will not be able to perform graceful shutdown - // (e.g. process acks and commit the final cursor offset). - wireSub, err := factory.New(subInstance.onMessage) + // Note: The context from Receive (recvCtx) should not be used, as when it is + // cancelled, the gRPC streams will be disconnected and the subscriber will + // not be able to process acks and commit the final cursor offset. Use the + // context from NewSubscriberClient (clientCtx) instead. + wireSub, err := factory.New(clientCtx, subInstance.onMessage) if err != nil { return nil, err } @@ -229,6 +230,7 @@ func (si *subscriberInstance) Wait(ctx context.Context) error { // See https://cloud.google.com/pubsub/lite/docs/subscribing for more // information about receiving messages. type SubscriberClient struct { + clientCtx context.Context settings ReceiveSettings wireSubFactory wireSubscriberFactory @@ -265,6 +267,7 @@ func NewSubscriberClientWithSettings(ctx context.Context, subscription string, s options: opts, } subClient := &SubscriberClient{ + clientCtx: ctx, settings: settings, wireSubFactory: factory, } @@ -303,7 +306,7 @@ func (s *SubscriberClient) Receive(ctx context.Context, f func(context.Context, defer s.setReceiveActive(false) // Initialize a subscriber instance. - subInstance, err := newSubscriberInstance(ctx, s.wireSubFactory, s.settings, f) + subInstance, err := newSubscriberInstance(ctx, s.clientCtx, s.wireSubFactory, s.settings, f) if err != nil { return err } diff --git a/pubsublite/pscompat/subscriber_test.go b/pubsublite/pscompat/subscriber_test.go index 429c9f551416..5c737f8ff424 100644 --- a/pubsublite/pscompat/subscriber_test.go +++ b/pubsublite/pscompat/subscriber_test.go @@ -113,7 +113,7 @@ func (ms *mockWireSubscriber) WaitStopped() error { type mockWireSubscriberFactory struct{} -func (f *mockWireSubscriberFactory) New(receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { +func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { return &mockWireSubscriber{ receiver: receiver, msgsC: make(chan *wire.ReceivedMessage, 10), @@ -122,7 +122,7 @@ func (f *mockWireSubscriberFactory) New(receiver wire.MessageReceiverFunc) (wire } func newTestSubscriberInstance(ctx context.Context, settings ReceiveSettings, receiver messageReceiverFunc) *subscriberInstance { - sub, _ := newSubscriberInstance(ctx, new(mockWireSubscriberFactory), settings, receiver) + sub, _ := newSubscriberInstance(ctx, context.Background(), new(mockWireSubscriberFactory), settings, receiver) return sub } From 12f3042716d51fb0d7a23071d00a20f9751bac91 Mon Sep 17 00:00:00 2001 From: shollyman Date: Mon, 28 Jun 2021 10:50:55 -0700 Subject: [PATCH 4/4] fix(bigquery): update streaming insert error test (#4321) More backend changes appear to be inducing two forms of error message, depending on which component intercepts the error first. This test captures both outcomes (we always get an error). Co-authored-by: Tyler Bui-Palsulich <26876514+tbpg@users.noreply.github.com> --- bigquery/integration_test.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bigquery/integration_test.go b/bigquery/integration_test.go index 8b7b2a1e5ad6..9bf2f91030e8 100644 --- a/bigquery/integration_test.go +++ b/bigquery/integration_test.go @@ -1378,9 +1378,11 @@ func TestIntegration_InsertErrors(t *testing.T) { if !ok { t.Errorf("Wanted googleapi.Error, got: %v", err) } - want := "Request payload size exceeds the limit" - if !strings.Contains(e.Message, want) { - t.Errorf("Error didn't contain expected message (%s): %s", want, e.Message) + if e.Code != http.StatusRequestEntityTooLarge { + want := "Request payload size exceeds the limit" + if !strings.Contains(e.Message, want) { + t.Errorf("Error didn't contain expected message (%s): %#v", want, e) + } } // Case 2: Very Large Request // Request so large it gets rejected by intermediate infra (3x 10MB rows)