diff --git a/bigquery/bigquery.go b/bigquery/bigquery.go index 0a32f02c3d52..28ea5446f1f3 100644 --- a/bigquery/bigquery.go +++ b/bigquery/bigquery.go @@ -16,6 +16,7 @@ package bigquery import ( "context" + "errors" "fmt" "io" "net/http" @@ -29,6 +30,7 @@ import ( bq "google.golang.org/api/bigquery/v2" "google.golang.org/api/googleapi" "google.golang.org/api/option" + "google.golang.org/api/transport" ) const ( @@ -56,8 +58,20 @@ type Client struct { bqs *bq.Service } +// DetectProjectID is a sentinel value that instructs NewClient to detect the +// project ID. It is given in place of the projectID argument. NewClient will +// use the project ID from the given credentials or the default credentials +// (https://developers.google.com/accounts/docs/application-default-credentials) +// if no credentials were provided. When providing credentials, not all +// options will allow NewClient to extract the project ID. Specifically a JWT +// does not have the project ID encoded. +const DetectProjectID = "*detect-project-id*" + // NewClient constructs a new Client which can perform BigQuery operations. // Operations performed via the client are billed to the specified GCP project. +// +// If the project ID is set to DetectProjectID, NewClient will attempt to detect +// the project ID from credentials. func NewClient(ctx context.Context, projectID string, opts ...option.ClientOption) (*Client, error) { o := []option.ClientOption{ option.WithScopes(Scope), @@ -68,6 +82,14 @@ func NewClient(ctx context.Context, projectID string, opts ...option.ClientOptio if err != nil { return nil, fmt.Errorf("bigquery: constructing client: %v", err) } + + if projectID == DetectProjectID { + projectID, err = detectProjectID(ctx, opts...) + if err != nil { + return nil, fmt.Errorf("failed to detect project: %v", err) + } + } + c := &Client{ projectID: projectID, bqs: bqs, @@ -75,6 +97,12 @@ func NewClient(ctx context.Context, projectID string, opts ...option.ClientOptio return c, nil } +// Project returns the project ID or number for this instance of the client, which may have +// either been explicitly specified or autodetected. +func (c *Client) Project() string { + return c.projectID +} + // Close closes any resources held by the client. // Close should be called when the client is no longer needed. // It need not be called at program exit. @@ -82,6 +110,17 @@ func (c *Client) Close() error { return nil } +func detectProjectID(ctx context.Context, opts ...option.ClientOption) (string, error) { + creds, err := transport.Creds(ctx, opts...) + if err != nil { + return "", fmt.Errorf("fetching creds: %v", err) + } + if creds.ProjectID == "" { + return "", errors.New("credentials did not provide a valid ProjectID") + } + return creds.ProjectID, nil +} + // Calls the Jobs.Insert RPC and returns a Job. func (c *Client) insertJob(ctx context.Context, job *bq.Job, media io.Reader) (*Job, error) { call := c.bqs.Jobs.Insert(c.projectID, job).Context(ctx) diff --git a/bigquery/integration_test.go b/bigquery/integration_test.go index 9cca940d876d..9bf2f91030e8 100644 --- a/bigquery/integration_test.go +++ b/bigquery/integration_test.go @@ -228,6 +228,24 @@ func initTestState(client *Client, t time.Time) func() { } } +func TestIntegration_DetectProjectID(t *testing.T) { + ctx := context.Background() + testCreds := testutil.Credentials(ctx) + if testCreds == nil { + t.Skip("test credentials not present, skipping") + } + + if _, err := NewClient(ctx, DetectProjectID, option.WithCredentials(testCreds)); err != nil { + t.Errorf("test NewClient: %v", err) + } + + badTS := testutil.ErroringTokenSource{} + + if badClient, err := NewClient(ctx, DetectProjectID, option.WithTokenSource(badTS)); err == nil { + t.Errorf("expected error from bad token source, NewClient succeeded with project: %s", badClient.Project()) + } +} + func TestIntegration_TableCreate(t *testing.T) { // Check that creating a record field with an empty schema is an error. if client == nil { @@ -1360,9 +1378,11 @@ func TestIntegration_InsertErrors(t *testing.T) { if !ok { t.Errorf("Wanted googleapi.Error, got: %v", err) } - want := "Request payload size exceeds the limit" - if !strings.Contains(e.Message, want) { - t.Errorf("Error didn't contain expected message (%s): %s", want, e.Message) + if e.Code != http.StatusRequestEntityTooLarge { + want := "Request payload size exceeds the limit" + if !strings.Contains(e.Message, want) { + t.Errorf("Error didn't contain expected message (%s): %#v", want, e) + } } // Case 2: Very Large Request // Request so large it gets rejected by intermediate infra (3x 10MB rows) diff --git a/bigquery/job.go b/bigquery/job.go index 2d259f910b45..725775dfd68a 100644 --- a/bigquery/job.go +++ b/bigquery/job.go @@ -63,6 +63,11 @@ func (c *Client) JobFromIDLocation(ctx context.Context, id, location string) (j return bqToJob(bqjob, c) } +// Project returns the job's project. +func (j *Job) Project() string { + return j.projectID +} + // ID returns the job's ID. func (j *Job) ID() string { return j.jobID diff --git a/firestore/collgroupref.go b/firestore/collgroupref.go index e43a7e648c54..c13ff1f160b8 100644 --- a/firestore/collgroupref.go +++ b/firestore/collgroupref.go @@ -14,6 +14,16 @@ package firestore +import ( + "context" + "errors" + "fmt" + "sort" + + "google.golang.org/api/iterator" + firestorepb "google.golang.org/genproto/googleapis/firestore/v1" +) + // A CollectionGroupRef is a reference to a group of collections sharing the // same ID. type CollectionGroupRef struct { @@ -36,3 +46,141 @@ func newCollectionGroupRef(c *Client, dbPath, collectionID string) *CollectionGr }, } } + +// GetPartitionedQueries returns a slice of Query objects, each containing a +// partition of a collection group. partitionCount must be a positive value and +// the number of returned partitions may be less than the requested number if +// providing the desired number would result in partitions with very few documents. +// +// If a Collection Group Query would return a large number of documents, this +// can help to subdivide the query to smaller working units that can be distributed. +func (cgr CollectionGroupRef) GetPartitionedQueries(ctx context.Context, partitionCount int) ([]Query, error) { + qp, err := cgr.getPartitions(ctx, partitionCount) + if err != nil { + return nil, err + } + queries := make([]Query, len(qp)) + for _, part := range qp { + queries = append(queries, part.toQuery()) + } + return queries, nil +} + +// getPartitions returns a slice of queryPartition objects, describing a start +// and end range to query a subsection of the collection group. partitionCount +// must be a positive value and the number of returned partitions may be less +// than the requested number if providing the desired number would result in +// partitions with very few documents. +func (cgr CollectionGroupRef) getPartitions(ctx context.Context, partitionCount int) ([]queryPartition, error) { + orderedQuery := cgr.query().OrderBy(DocumentID, Asc) + + if partitionCount <= 0 { + return nil, errors.New("a positive partitionCount must be provided") + } else if partitionCount == 1 { + return []queryPartition{{CollectionGroupQuery: orderedQuery}}, nil + } + + db := cgr.c.path() + ctx = withResourceHeader(ctx, db) + + // CollectionGroup Queries need to be ordered by __name__ ASC. + query, err := orderedQuery.toProto() + if err != nil { + return nil, err + } + structuredQuery := &firestorepb.PartitionQueryRequest_StructuredQuery{ + StructuredQuery: query, + } + + // Uses default PageSize + pbr := &firestorepb.PartitionQueryRequest{ + Parent: db + "/documents", + PartitionCount: int64(partitionCount), + QueryType: structuredQuery, + } + cursorReferences := make([]*firestorepb.Value, 0, partitionCount) + iter := cgr.c.c.PartitionQuery(ctx, pbr) + for { + cursor, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, fmt.Errorf("GetPartitions: %v", err) + } + cursorReferences = append(cursorReferences, cursor.GetValues()...) + } + + // From Proto documentation: + // To obtain a complete result set ordered with respect to the results of the + // query supplied to PartitionQuery, the results sets should be merged: + // cursor A, cursor B, cursor M, cursor Q, cursor U, cursor W + // Once we have exhausted the pages, the cursor values need to be sorted in + // lexicographical order by segment (areas between '/'). + sort.Sort(byFirestoreValue(cursorReferences)) + + queryPartitions := make([]queryPartition, 0, len(cursorReferences)) + previousCursor := "" + + for _, cursor := range cursorReferences { + cursorRef := cursor.GetReferenceValue() + + // remove the root path from the reference, as queries take cursors + // relative to a collection + cursorRef = cursorRef[len(orderedQuery.path)+1:] + + qp := queryPartition{ + CollectionGroupQuery: orderedQuery, + StartAt: previousCursor, + EndBefore: cursorRef, + } + queryPartitions = append(queryPartitions, qp) + previousCursor = cursorRef + } + + // In the case there were no partitions, we still add a single partition to + // the result, that covers the complete range. + lastPart := queryPartition{CollectionGroupQuery: orderedQuery} + if len(cursorReferences) > 0 { + cursorRef := cursorReferences[len(cursorReferences)-1].GetReferenceValue() + lastPart.StartAt = cursorRef[len(orderedQuery.path)+1:] + } + queryPartitions = append(queryPartitions, lastPart) + + return queryPartitions, nil +} + +// queryPartition provides a Collection Group Reference and start and end split +// points allowing for a section of a collection group to be queried. This is +// used by GetPartitions which, given a CollectionGroupReference returns smaller +// sub-queries or partitions +type queryPartition struct { + // CollectionGroupQuery is an ordered query on a CollectionGroupReference. + // This query must be ordered Asc on __name__. + // Example: client.CollectionGroup("collectionID").query().OrderBy(DocumentID, Asc) + CollectionGroupQuery Query + + // StartAt is a document reference value, relative to the collection, not + // a complete parent path. + // Example: "documents/collectionName/documentName" + StartAt string + + // EndBefore is a document reference value, relative to the collection, not + // a complete parent path. + // Example: "documents/collectionName/documentName" + EndBefore string +} + +// toQuery converts a queryPartition object to a Query object +func (qp queryPartition) toQuery() Query { + q := *qp.CollectionGroupQuery.query() + + // Remove the leading path before calling StartAt, EndBefore + if qp.StartAt != "" { + q = q.StartAt(qp.StartAt) + } + if qp.EndBefore != "" { + q = q.EndBefore(qp.EndBefore) + } + return q +} diff --git a/firestore/collgroupref_test.go b/firestore/collgroupref_test.go new file mode 100644 index 000000000000..bbcbfe4ebb62 --- /dev/null +++ b/firestore/collgroupref_test.go @@ -0,0 +1,73 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "context" + "testing" +) + +func TestCGR_TestQueryPartition_ToQuery(t *testing.T) { + cgr := newCollectionGroupRef(testClient, testClient.path(), "collectionID") + qp := queryPartition{ + CollectionGroupQuery: cgr.Query.OrderBy(DocumentID, Asc), + StartAt: "documents/start/at", + EndBefore: "documents/end/before", + } + + got := qp.toQuery() + + want := Query{ + c: testClient, + path: "projects/projectID/databases/(default)", + parentPath: "projects/projectID/databases/(default)/documents", + collectionID: "collectionID", + startVals: []interface{}{"documents/start/at"}, + endVals: []interface{}{"documents/end/before"}, + startBefore: true, + endBefore: true, + allDescendants: true, + orders: []order{{fieldPath: []string{"__name__"}, dir: 1}}, + } + + if !testEqual(got, want) { + t.Errorf("got %+v, want %+v", got, want) + } +} + +func TestCGR_TestGetPartitions(t *testing.T) { + cgr := newCollectionGroupRef(testClient, testClient.path(), "collectionID") + _, err := cgr.getPartitions(context.Background(), 0) + if err == nil { + t.Error("Expected an error when requested partition count is < 1") + } + + parts, err := cgr.getPartitions(context.Background(), 1) + if err != nil { + t.Error("Didn't expect an error when requested partition count is 1") + } + if len(parts) != 1 { + t.Fatal("Expected 1 queryPartition") + } + got := parts[0] + want := queryPartition{ + CollectionGroupQuery: cgr.Query.OrderBy(DocumentID, Asc), + StartAt: "", + EndBefore: "", + } + if !testEqual(got, want) { + t.Errorf("got %+v, want %+v", got, want) + } +} diff --git a/firestore/integration_test.go b/firestore/integration_test.go index e398150f84fa..ca841858f988 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -1601,3 +1601,88 @@ func TestDetectProjectID(t *testing.T) { t.Errorf("expected an error while using TokenSource that does not have a project ID") } } + +func TestIntegration_ColGroupRefPartitions(t *testing.T) { + h := testHelper{t} + coll := integrationColl(t) + ctx := context.Background() + + // Create a doc in the test collection so a collectionID is live for testing + doc := coll.NewDoc() + h.mustCreate(doc, integrationTestMap) + + for _, tc := range []struct { + collectionID string + expectedPartitionCount int + }{ + // Verify no failures if a collection doesn't exist + {collectionID: "does-not-exist", expectedPartitionCount: 1}, + // Verify a collectionID with a small number of results returns a partition + {collectionID: coll.collectionID, expectedPartitionCount: 1}, + } { + colGroup := iClient.CollectionGroup(tc.collectionID) + partitions, err := colGroup.getPartitions(ctx, 10) + if err != nil { + t.Fatalf("getPartitions: received unexpected error: %v", err) + } + if got, want := len(partitions), tc.expectedPartitionCount; got != want { + t.Errorf("Unexpected Partition Count: got %d, want %d", got, want) + } + } +} + +func TestIntegration_ColGroupRefPartitionsLarge(t *testing.T) { + // Create collection with enough documents to have multiple partitions. + coll := integrationColl(t) + collectionID := coll.collectionID + "largeCollection" + coll = iClient.Collection(collectionID) + + ctx := context.Background() + + documentCount := 2*128 + 127 // Minimum partition size is 128. + + // Create documents in a collection sufficient to trigger multiple partitions. + batch := iClient.Batch() + deleteBatch := iClient.Batch() + for i := 0; i < documentCount; i++ { + doc := coll.Doc(fmt.Sprintf("doc%d", i)) + batch.Create(doc, integrationTestMap) + deleteBatch.Delete(doc) + } + batch.Commit(ctx) + defer deleteBatch.Commit(ctx) + + // Verify that we retrieve 383 documents for the colGroup (128*2 + 127) + colGroup := iClient.CollectionGroup(collectionID) + docs, err := colGroup.Documents(ctx).GetAll() + if err != nil { + t.Fatalf("GetAll(): received unexpected error: %v", err) + } + if got, want := len(docs), documentCount; got != want { + t.Errorf("Unexpected number of documents in collection group: got %d, want %d", got, want) + } + + // Get partitions, allow up to 10 to come back, expect less will be returned. + partitions, err := colGroup.GetPartitionedQueries(ctx, 10) + if err != nil { + t.Fatalf("GetPartitionedQueries: received unexpected error: %v", err) + } + if len(partitions) < 2 { + t.Errorf("Unexpected Partition Count. Expected 2 or more: got %d, want 2+", len(partitions)) + } + + // Verify that we retrieve 383 documents across all partitions. (128*2 + 127) + totalCount := 0 + for _, query := range partitions { + + allDocs, err := query.Documents(ctx).GetAll() + if err != nil { + t.Fatalf("GetAll(): received unexpected error: %v", err) + } + totalCount += len(allDocs) + } + + if got, want := totalCount, documentCount; got != want { + t.Errorf("Unexpected number of documents across partitions: got %d, want %d", got, want) + } +} diff --git a/firestore/order.go b/firestore/order.go index e5ee1e09fb53..c495a141fd36 100644 --- a/firestore/order.go +++ b/firestore/order.go @@ -22,6 +22,7 @@ import ( "strings" tspb "github.com/golang/protobuf/ptypes/timestamp" + firestorepb "google.golang.org/genproto/googleapis/firestore/v1" pb "google.golang.org/genproto/googleapis/firestore/v1" ) @@ -214,3 +215,10 @@ func typeOrder(v *pb.Value) int { panic(fmt.Sprintf("bad value type: %v", v)) } } + +// byReferenceValue implements sort.Interface for []*firestorepb.Value +type byFirestoreValue []*firestorepb.Value + +func (a byFirestoreValue) Len() int { return len(a) } +func (a byFirestoreValue) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a byFirestoreValue) Less(i, j int) bool { return compareValues(a[i], a[j]) < 0 } diff --git a/pubsublite/pscompat/integration_test.go b/pubsublite/pscompat/integration_test.go index a22358ead8e7..ab14ad7b973b 100644 --- a/pubsublite/pscompat/integration_test.go +++ b/pubsublite/pscompat/integration_test.go @@ -30,7 +30,9 @@ import ( "cloud.google.com/go/pubsublite/internal/wire" "github.com/google/go-cmp/cmp/cmpopts" "golang.org/x/sync/errgroup" + "golang.org/x/xerrors" "google.golang.org/api/option" + "google.golang.org/grpc/codes" vkit "cloud.google.com/go/pubsublite/apiv1" pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" @@ -167,7 +169,7 @@ func partitionNumbers(partitionCount int) []int { func publishMessages(t *testing.T, settings PublishSettings, topic wire.TopicPath, msgs ...*pubsub.Message) { ctx := context.Background() - publisher := publisherClient(ctx, t, settings, topic) + publisher := publisherClient(context.Background(), t, settings, topic) defer publisher.Stop() var pubResults []*pubsub.PublishResult @@ -179,7 +181,7 @@ func publishMessages(t *testing.T, settings PublishSettings, topic wire.TopicPat func publishPrefixedMessages(t *testing.T, settings PublishSettings, topic wire.TopicPath, msgPrefix string, msgCount, msgSize int) []string { ctx := context.Background() - publisher := publisherClient(ctx, t, settings, topic) + publisher := publisherClient(context.Background(), t, settings, topic) defer publisher.Stop() orderingSender := test.NewOrderingSender() @@ -271,7 +273,7 @@ func receiveAllMessages(t *testing.T, msgTracker *test.MsgTracker, settings Rece } } - subscriber := subscriberClient(cctx, t, settings, subscription) + subscriber := subscriberClient(context.Background(), t, settings, subscription) if err := subscriber.Receive(cctx, messageReceiver); err != nil { t.Errorf("Receive() got err: %v", err) } @@ -298,7 +300,7 @@ func receiveAndVerifyMessage(t *testing.T, want *pubsub.Message, settings Receiv } } - subscriber := subscriberClient(cctx, t, settings, subscription) + subscriber := subscriberClient(context.Background(), t, settings, subscription) if err := subscriber.Receive(cctx, messageReceiver); err != nil { t.Errorf("Receive() got err: %v", err) } @@ -383,7 +385,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { } got.Nack() } - subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath) if gotErr := subscriber.Receive(cctx, messageReceiver1); !test.ErrorEqual(gotErr, errNackCalled) { t.Errorf("Receive() got err: (%v), want err: (%v)", gotErr, errNackCalled) } @@ -400,7 +402,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { } return fmt.Errorf("Received unexpected message: %q", truncateMsg(string(msg.Data))) } - subscriber = subscriberClient(cctx, t, customSettings, subscriptionPath) + subscriber = subscriberClient(context.Background(), t, customSettings, subscriptionPath) messageReceiver2 := func(ctx context.Context, got *pubsub.Message) { got.Nack() @@ -434,7 +436,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { got.Ack() stopSubscriber() } - subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath) // The message receiver stops the subscriber after receiving the first // message. However, the subscriber isn't guaranteed to immediately stop, so @@ -485,7 +487,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { // next test, which would receive an incorrect message. got.Ack() } - subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, recvSettings, subscriptionPath) if err := subscriber.Receive(cctx, messageReceiver); err != nil { t.Errorf("Receive() got err: %v", err) @@ -539,6 +541,44 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { receiveAllMessages(t, msgTracker, recvSettings, subscriptionPath) }) + // Verifies that cancelling the context passed to NewPublisherClient can shut + // down the publisher. + t.Run("CancelPublisherContext", func(t *testing.T) { + cctx, cancel := context.WithCancel(context.Background()) + publisher := publisherClient(cctx, t, DefaultPublishSettings, topicPath) + + cancel() + + wantCode := codes.Canceled + result := publisher.Publish(ctx, &pubsub.Message{Data: []byte("cancel_publisher_context")}) + if _, err := result.Get(ctx); !test.ErrorHasCode(err, wantCode) { + t.Errorf("Publish() got err: %v, want code: %v", err, wantCode) + } + if err := xerrors.Unwrap(publisher.Error()); !test.ErrorHasCode(err, wantCode) { + t.Errorf("Error() got err: %v, want code: %v", err, wantCode) + } + publisher.Stop() + }) + + // Verifies that cancelling the context passed to NewSubscriberClient can shut + // down the subscriber. + t.Run("CancelSubscriberContext", func(t *testing.T) { + msg := &pubsub.Message{Data: []byte("cancel_subscriber_context")} + publishMessages(t, DefaultPublishSettings, topicPath, msg) + + cctx, cancel := context.WithCancel(context.Background()) + subscriber := subscriberClient(cctx, t, recvSettings, subscriptionPath) + + subsErr := subscriber.Receive(context.Background(), func(ctx context.Context, got *pubsub.Message) { + got.Ack() + cancel() + }) + + if err, wantCode := xerrors.Unwrap(subsErr), codes.Canceled; !test.ErrorHasCode(err, wantCode) { + t.Errorf("Receive() got err: %v, want code: %v", err, wantCode) + } + }) + // NOTE: This should be the last test case. // Verifies that increasing the number of topic partitions is handled // correctly by publishers. @@ -547,7 +587,7 @@ func TestIntegration_PublishSubscribeSinglePartition(t *testing.T) { const pollPeriod = 5 * time.Second pubSettings := DefaultPublishSettings pubSettings.configPollPeriod = pollPeriod // Poll updates more frequently - publisher := publisherClient(ctx, t, pubSettings, topicPath) + publisher := publisherClient(context.Background(), t, pubSettings, topicPath) defer publisher.Stop() // Update the number of partitions. @@ -661,7 +701,7 @@ func TestIntegration_PublishSubscribeMultiPartition(t *testing.T) { for i := 0; i < subscriberCount; i++ { // Subscribers must be started in a goroutine as Receive() blocks. g.Go(func() error { - subscriber := subscriberClient(cctx, t, DefaultReceiveSettings, subscriptionPath) + subscriber := subscriberClient(context.Background(), t, DefaultReceiveSettings, subscriptionPath) err := subscriber.Receive(cctx, messageReceiver) if err != nil { t.Errorf("Receive() got err: %v", err) diff --git a/pubsublite/pscompat/publisher.go b/pubsublite/pscompat/publisher.go index 6f4e9415a822..b9f8c82e19ab 100644 --- a/pubsublite/pscompat/publisher.go +++ b/pubsublite/pscompat/publisher.go @@ -82,10 +82,7 @@ func NewPublisherClientWithSettings(ctx context.Context, topic string, settings return nil, err } - // Note: ctx is not used to create the wire publisher, because if it is - // cancelled, the publisher will not be able to perform graceful shutdown - // (e.g. flush pending messages). - wirePub, err := wire.NewPublisher(context.Background(), settings.toWireSettings(), region, topic, opts...) + wirePub, err := wire.NewPublisher(ctx, settings.toWireSettings(), region, topic, opts...) if err != nil { return nil, err } diff --git a/pubsublite/pscompat/subscriber.go b/pubsublite/pscompat/subscriber.go index dafb25ee526f..d76cc9d290bf 100644 --- a/pubsublite/pscompat/subscriber.go +++ b/pubsublite/pscompat/subscriber.go @@ -72,7 +72,7 @@ func (ah *pslAckHandler) OnNack() { // wireSubscriberFactory is a factory for creating wire subscribers, which can // be overridden with a mock in unit tests. type wireSubscriberFactory interface { - New(wire.MessageReceiverFunc) (wire.Subscriber, error) + New(context.Context, wire.MessageReceiverFunc) (wire.Subscriber, error) } type wireSubscriberFactoryImpl struct { @@ -82,8 +82,8 @@ type wireSubscriberFactoryImpl struct { options []option.ClientOption } -func (f *wireSubscriberFactoryImpl) New(receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { - return wire.NewSubscriber(context.Background(), f.settings, receiver, f.region, f.subscription.String(), f.options...) +func (f *wireSubscriberFactoryImpl) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { + return wire.NewSubscriber(ctx, f.settings, receiver, f.region, f.subscription.String(), f.options...) } type messageReceiverFunc = func(context.Context, *pubsub.Message) @@ -103,8 +103,8 @@ type subscriberInstance struct { err error } -func newSubscriberInstance(ctx context.Context, factory wireSubscriberFactory, settings ReceiveSettings, receiver messageReceiverFunc) (*subscriberInstance, error) { - recvCtx, recvCancel := context.WithCancel(ctx) +func newSubscriberInstance(recvCtx, clientCtx context.Context, factory wireSubscriberFactory, settings ReceiveSettings, receiver messageReceiverFunc) (*subscriberInstance, error) { + recvCtx, recvCancel := context.WithCancel(recvCtx) subInstance := &subscriberInstance{ settings: settings, recvCtx: recvCtx, @@ -112,10 +112,11 @@ func newSubscriberInstance(ctx context.Context, factory wireSubscriberFactory, s receiver: receiver, } - // Note: ctx is not used to create the wire subscriber, because if it is - // cancelled, the subscriber will not be able to perform graceful shutdown - // (e.g. process acks and commit the final cursor offset). - wireSub, err := factory.New(subInstance.onMessage) + // Note: The context from Receive (recvCtx) should not be used, as when it is + // cancelled, the gRPC streams will be disconnected and the subscriber will + // not be able to process acks and commit the final cursor offset. Use the + // context from NewSubscriberClient (clientCtx) instead. + wireSub, err := factory.New(clientCtx, subInstance.onMessage) if err != nil { return nil, err } @@ -229,6 +230,7 @@ func (si *subscriberInstance) Wait(ctx context.Context) error { // See https://cloud.google.com/pubsub/lite/docs/subscribing for more // information about receiving messages. type SubscriberClient struct { + clientCtx context.Context settings ReceiveSettings wireSubFactory wireSubscriberFactory @@ -265,6 +267,7 @@ func NewSubscriberClientWithSettings(ctx context.Context, subscription string, s options: opts, } subClient := &SubscriberClient{ + clientCtx: ctx, settings: settings, wireSubFactory: factory, } @@ -303,7 +306,7 @@ func (s *SubscriberClient) Receive(ctx context.Context, f func(context.Context, defer s.setReceiveActive(false) // Initialize a subscriber instance. - subInstance, err := newSubscriberInstance(ctx, s.wireSubFactory, s.settings, f) + subInstance, err := newSubscriberInstance(ctx, s.clientCtx, s.wireSubFactory, s.settings, f) if err != nil { return err } diff --git a/pubsublite/pscompat/subscriber_test.go b/pubsublite/pscompat/subscriber_test.go index 429c9f551416..5c737f8ff424 100644 --- a/pubsublite/pscompat/subscriber_test.go +++ b/pubsublite/pscompat/subscriber_test.go @@ -113,7 +113,7 @@ func (ms *mockWireSubscriber) WaitStopped() error { type mockWireSubscriberFactory struct{} -func (f *mockWireSubscriberFactory) New(receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { +func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc) (wire.Subscriber, error) { return &mockWireSubscriber{ receiver: receiver, msgsC: make(chan *wire.ReceivedMessage, 10), @@ -122,7 +122,7 @@ func (f *mockWireSubscriberFactory) New(receiver wire.MessageReceiverFunc) (wire } func newTestSubscriberInstance(ctx context.Context, settings ReceiveSettings, receiver messageReceiverFunc) *subscriberInstance { - sub, _ := newSubscriberInstance(ctx, new(mockWireSubscriberFactory), settings, receiver) + sub, _ := newSubscriberInstance(ctx, context.Background(), new(mockWireSubscriberFactory), settings, receiver) return sub }