diff --git a/mongo/integration/unified/admin_helpers_test.go b/mongo/integration/unified/admin_helpers.go similarity index 85% rename from mongo/integration/unified/admin_helpers_test.go rename to mongo/integration/unified/admin_helpers.go index f76b95d7dc..355422dc63 100644 --- a/mongo/integration/unified/admin_helpers_test.go +++ b/mongo/integration/unified/admin_helpers.go @@ -21,9 +21,9 @@ const ( errorInterrupted int32 = 11601 ) -// TerminateOpenSessions executes a killAllSessions command to ensure that sesssions left open on the server by a test +// terminateOpenSessions executes a killAllSessions command to ensure that sesssions left open on the server by a test // do not cause future tests to hang. -func TerminateOpenSessions(ctx context.Context) error { +func terminateOpenSessions(ctx context.Context) error { if mtest.CompareServerVersions(mtest.ServerVersion(), "3.6") < 0 { return nil } @@ -48,10 +48,10 @@ func TerminateOpenSessions(ctx context.Context) error { return runAgainstAllMongoses(ctx, commandFn) } -// PerformDistinctWorkaround executes a non-transactional "distinct" command against each mongos in a sharded cluster. -func PerformDistinctWorkaround(ctx context.Context) error { +// performDistinctWorkaround executes a non-transactional "distinct" command against each mongos in a sharded cluster. +func performDistinctWorkaround(ctx context.Context) error { commandFn := func(ctx context.Context, client *mongo.Client) error { - for _, coll := range Entities(ctx).Collections() { + for _, coll := range entities(ctx).collections() { newColl := client.Database(coll.Database().Name()).Collection(coll.Name()) _, err := newColl.Distinct(ctx, "x", bson.D{}) if err != nil { @@ -66,7 +66,7 @@ func PerformDistinctWorkaround(ctx context.Context) error { return runAgainstAllMongoses(ctx, commandFn) } -func RunCommandOnHost(ctx context.Context, host string, commandFn func(context.Context, *mongo.Client) error) error { +func runCommandOnHost(ctx context.Context, host string, commandFn func(context.Context, *mongo.Client) error) error { clientOpts := options.Client(). ApplyURI(mtest.ClusterURI()). SetHosts([]string{host}) @@ -83,7 +83,7 @@ func RunCommandOnHost(ctx context.Context, host string, commandFn func(context.C func runAgainstAllMongoses(ctx context.Context, commandFn func(context.Context, *mongo.Client) error) error { for _, host := range mtest.ClusterConnString().Hosts { - if err := RunCommandOnHost(ctx, host, commandFn); err != nil { + if err := runCommandOnHost(ctx, host, commandFn); err != nil { return fmt.Errorf("error executing callback against host %q: %v", host, err) } } diff --git a/mongo/integration/unified/bsonutil_test.go b/mongo/integration/unified/bsonutil.go similarity index 86% rename from mongo/integration/unified/bsonutil_test.go rename to mongo/integration/unified/bsonutil.go index b64c3a67e6..375311be29 100644 --- a/mongo/integration/unified/bsonutil_test.go +++ b/mongo/integration/unified/bsonutil.go @@ -20,14 +20,14 @@ var ( emptyRawValue = bson.RawValue{} ) -func DocumentToRawValue(doc bson.Raw) bson.RawValue { +func documentToRawValue(doc bson.Raw) bson.RawValue { return bson.RawValue{ Type: bsontype.EmbeddedDocument, Value: doc, } } -func RemoveFieldsFromDocument(doc bson.Raw, keys ...string) bson.Raw { +func removeFieldsFromDocument(doc bson.Raw, keys ...string) bson.Raw { newDoc := bsoncore.NewDocumentBuilder() elems, _ := doc.Elements() @@ -47,7 +47,7 @@ func RemoveFieldsFromDocument(doc bson.Raw, keys ...string) bson.Raw { return bson.Raw(newDoc.Build()) } -func SortDocument(doc bson.Raw) bson.Raw { +func sortDocument(doc bson.Raw) bson.Raw { elems, _ := doc.Elements() keys := make([]string, 0, len(elems)) valuesMap := make(map[string]bson.RawValue) @@ -66,11 +66,11 @@ func SortDocument(doc bson.Raw) bson.Raw { return bson.Raw(sorted.Build()) } -func LookupString(doc bson.Raw, key string) string { +func lookupString(doc bson.Raw, key string) string { return doc.Lookup(key).StringValue() } -func MapKeys(m map[string]interface{}) []string { +func mapKeys(m map[string]interface{}) []string { keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) diff --git a/mongo/integration/unified/bucket_options_test.go b/mongo/integration/unified/bucket_options.go similarity index 81% rename from mongo/integration/unified/bucket_options_test.go rename to mongo/integration/unified/bucket_options.go index ef2bc548ad..742f25da87 100644 --- a/mongo/integration/unified/bucket_options_test.go +++ b/mongo/integration/unified/bucket_options.go @@ -13,15 +13,15 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -// GridFSBucketOptions is a wrapper for *options.BucketOptions. This type implements the bson.Unmarshaler interface to +// gridFSBucketOptions is a wrapper for *options.BucketOptions. This type implements the bson.Unmarshaler interface to // convert BSON documents to a BucketOptions instance. -type GridFSBucketOptions struct { +type gridFSBucketOptions struct { *options.BucketOptions } -var _ bson.Unmarshaler = (*GridFSBucketOptions)(nil) +var _ bson.Unmarshaler = (*gridFSBucketOptions)(nil) -func (bo GridFSBucketOptions) UnmarshalBSON(data []byte) error { +func (bo gridFSBucketOptions) UnmarshalBSON(data []byte) error { var temp struct { Name *string `bson:"name"` ChunkSize *int32 `bson:"chunkSizeBytes"` @@ -31,10 +31,10 @@ func (bo GridFSBucketOptions) UnmarshalBSON(data []byte) error { Extra map[string]interface{} `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { - return fmt.Errorf("error unmarshalling to temporary GridFSBucketOptions object: %v", err) + return fmt.Errorf("error unmarshalling to temporary gridFSBucketOptions object: %v", err) } if len(temp.Extra) > 0 { - return fmt.Errorf("unrecognized fields for GridFSBucketOptions: %v", MapKeys(temp.Extra)) + return fmt.Errorf("unrecognized fields for gridFSBucketOptions: %v", mapKeys(temp.Extra)) } bo.BucketOptions = options.GridFSBucket() diff --git a/mongo/integration/unified/bulkwrite_helpers_test.go b/mongo/integration/unified/bulkwrite_helpers.go similarity index 100% rename from mongo/integration/unified/bulkwrite_helpers_test.go rename to mongo/integration/unified/bulkwrite_helpers.go diff --git a/mongo/integration/unified/change_stream_operation_execution_test.go b/mongo/integration/unified/change_stream_operation_execution.go similarity index 70% rename from mongo/integration/unified/change_stream_operation_execution_test.go rename to mongo/integration/unified/change_stream_operation_execution.go index dbff9716ba..c8fefcde05 100644 --- a/mongo/integration/unified/change_stream_operation_execution_test.go +++ b/mongo/integration/unified/change_stream_operation_execution.go @@ -8,18 +8,18 @@ package unified import "context" -func executeIterateUntilDocumentOrError(ctx context.Context, operation *Operation) (*OperationResult, error) { - stream, err := Entities(ctx).ChangeStream(operation.Object) +func executeIterateUntilDocumentOrError(ctx context.Context, operation *operation) (*operationResult, error) { + stream, err := entities(ctx).changeStream(operation.Object) if err != nil { return nil, err } for { if stream.TryNext(ctx) { - return NewDocumentResult(stream.Current, nil), nil + return newDocumentResult(stream.Current, nil), nil } if stream.Err() != nil { - return NewErrorResult(stream.Err()), nil + return newErrorResult(stream.Err()), nil } } } diff --git a/mongo/integration/unified/client_entity_test.go b/mongo/integration/unified/client_entity.go similarity index 88% rename from mongo/integration/unified/client_entity_test.go rename to mongo/integration/unified/client_entity.go index b03628a07d..fcb9632344 100644 --- a/mongo/integration/unified/client_entity_test.go +++ b/mongo/integration/unified/client_entity.go @@ -21,9 +21,9 @@ import ( "go.mongodb.org/mongo-driver/mongo/readconcern" ) -// ClientEntity is a wrapper for a mongo.Client object that also holds additional information required during test +// clientEntity is a wrapper for a mongo.Client object that also holds additional information required during test // execution. -type ClientEntity struct { +type clientEntity struct { *mongo.Client recordEvents atomic.Value @@ -33,8 +33,8 @@ type ClientEntity struct { ignoredCommands map[string]struct{} } -func NewClientEntity(ctx context.Context, entityOptions *EntityOptions) (*ClientEntity, error) { - entity := &ClientEntity{ +func newClientEntity(ctx context.Context, entityOptions *entityOptions) (*clientEntity, error) { + entity := &clientEntity{ // The "configureFailPoint" command should always be ignored. ignoredCommands: map[string]struct{}{ "configureFailPoint": {}, @@ -96,11 +96,11 @@ func NewClientEntity(ctx context.Context, entityOptions *EntityOptions) (*Client return entity, nil } -func (c *ClientEntity) StopListeningForEvents() { +func (c *clientEntity) StopListeningForEvents() { c.setRecordEvents(false) } -func (c *ClientEntity) StartedEvents() []*event.CommandStartedEvent { +func (c *clientEntity) startedEvents() []*event.CommandStartedEvent { var events []*event.CommandStartedEvent for _, evt := range c.started { if _, ok := c.ignoredCommands[evt.CommandName]; !ok { @@ -111,7 +111,7 @@ func (c *ClientEntity) StartedEvents() []*event.CommandStartedEvent { return events } -func (c *ClientEntity) SucceededEvents() []*event.CommandSucceededEvent { +func (c *clientEntity) succeededEvents() []*event.CommandSucceededEvent { var events []*event.CommandSucceededEvent for _, evt := range c.succeeded { if _, ok := c.ignoredCommands[evt.CommandName]; !ok { @@ -122,7 +122,7 @@ func (c *ClientEntity) SucceededEvents() []*event.CommandSucceededEvent { return events } -func (c *ClientEntity) FailedEvents() []*event.CommandFailedEvent { +func (c *clientEntity) failedEvents() []*event.CommandFailedEvent { var events []*event.CommandFailedEvent for _, evt := range c.failed { if _, ok := c.ignoredCommands[evt.CommandName]; !ok { @@ -133,29 +133,29 @@ func (c *ClientEntity) FailedEvents() []*event.CommandFailedEvent { return events } -func (c *ClientEntity) processStartedEvent(_ context.Context, evt *event.CommandStartedEvent) { +func (c *clientEntity) processStartedEvent(_ context.Context, evt *event.CommandStartedEvent) { if c.getRecordEvents() { c.started = append(c.started, evt) } } -func (c *ClientEntity) processSucceededEvent(_ context.Context, evt *event.CommandSucceededEvent) { +func (c *clientEntity) processSucceededEvent(_ context.Context, evt *event.CommandSucceededEvent) { if c.getRecordEvents() { c.succeeded = append(c.succeeded, evt) } } -func (c *ClientEntity) processFailedEvent(_ context.Context, evt *event.CommandFailedEvent) { +func (c *clientEntity) processFailedEvent(_ context.Context, evt *event.CommandFailedEvent) { if c.getRecordEvents() { c.failed = append(c.failed, evt) } } -func (c *ClientEntity) setRecordEvents(record bool) { +func (c *clientEntity) setRecordEvents(record bool) { c.recordEvents.Store(record) } -func (c *ClientEntity) getRecordEvents() bool { +func (c *clientEntity) getRecordEvents() bool { return c.recordEvents.Load().(bool) } diff --git a/mongo/integration/unified/client_operation_execution_test.go b/mongo/integration/unified/client_operation_execution.go similarity index 85% rename from mongo/integration/unified/client_operation_execution_test.go rename to mongo/integration/unified/client_operation_execution.go index 4b664913d4..e4399dbe9a 100644 --- a/mongo/integration/unified/client_operation_execution_test.go +++ b/mongo/integration/unified/client_operation_execution.go @@ -20,18 +20,18 @@ import ( // This file contains helpers to execute client operations. -func executeCreateChangeStream(ctx context.Context, operation *Operation) (*OperationResult, error) { +func executeCreateChangeStream(ctx context.Context, operation *operation) (*operationResult, error) { var watcher interface { Watch(context.Context, interface{}, ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) } var err error - watcher, err = Entities(ctx).Client(operation.Object) + watcher, err = entities(ctx).client(operation.Object) if err != nil { - watcher, err = Entities(ctx).Database(operation.Object) + watcher, err = entities(ctx).database(operation.Object) } if err != nil { - watcher, err = Entities(ctx).Collection(operation.Object) + watcher, err = entities(ctx).collection(operation.Object) } if err != nil { return nil, fmt.Errorf("no client, database, or collection entity found with ID %q", operation.Object) @@ -84,20 +84,20 @@ func executeCreateChangeStream(ctx context.Context, operation *Operation) (*Oper stream, err := watcher.Watch(ctx, pipeline, opts) if err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } if operation.ResultEntityID == nil { return nil, fmt.Errorf("no entity name provided to store executeChangeStream result") } - if err := Entities(ctx).AddChangeStreamEntity(*operation.ResultEntityID, stream); err != nil { + if err := entities(ctx).addChangeStreamEntity(*operation.ResultEntityID, stream); err != nil { return nil, fmt.Errorf("error storing result as changeStream entity: %v", err) } - return NewEmptyResult(), nil + return newEmptyResult(), nil } -func executeListDatabases(ctx context.Context, operation *Operation) (*OperationResult, error) { - client, err := Entities(ctx).Client(operation.Object) +func executeListDatabases(ctx context.Context, operation *operation) (*operationResult, error) { + client, err := entities(ctx).client(operation.Object) if err != nil { return nil, err } @@ -126,7 +126,7 @@ func executeListDatabases(ctx context.Context, operation *Operation) (*Operation res, err := client.ListDatabases(ctx, filter, opts) if err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } specsArray := bsoncore.NewArrayBuilder() @@ -143,5 +143,5 @@ func executeListDatabases(ctx context.Context, operation *Operation) (*Operation AppendArray("databases", specsArray.Build()). AppendInt64("totalSize", res.TotalSize). Build() - return NewDocumentResult(raw, nil), nil + return newDocumentResult(raw, nil), nil } diff --git a/mongo/integration/unified/collection_data_test.go b/mongo/integration/unified/collection_data.go similarity index 86% rename from mongo/integration/unified/collection_data_test.go rename to mongo/integration/unified/collection_data.go index 933278d5cf..93dcb0e6e9 100644 --- a/mongo/integration/unified/collection_data_test.go +++ b/mongo/integration/unified/collection_data.go @@ -19,15 +19,15 @@ import ( "go.mongodb.org/mongo-driver/mongo/readpref" ) -type CollectionData struct { +type collectionData struct { DatabaseName string `bson:"databaseName"` CollectionName string `bson:"collectionName"` Documents []bson.Raw `bson:"documents"` } -// CreateCollection configures the collection represented by the receiver using the internal client. This function +// createCollection configures the collection represented by the receiver using the internal client. This function // first drops the collection and then creates it and inserts the seed data if needed. -func (c *CollectionData) CreateCollection(ctx context.Context) error { +func (c *collectionData) createCollection(ctx context.Context) error { db := mtest.GlobalClient().Database(c.DatabaseName) coll := db.Collection(c.CollectionName, options.Collection().SetWriteConcern(mtest.MajorityWc)) if err := coll.Drop(ctx); err != nil { @@ -57,9 +57,9 @@ func (c *CollectionData) CreateCollection(ctx context.Context) error { return nil } -// VerifyContents asserts that the collection on the server represented by this CollectionData instance contains the +// verifyContents asserts that the collection on the server represented by this collectionData instance contains the // expected documents. -func (c *CollectionData) VerifyContents(ctx context.Context) error { +func (c *collectionData) verifyContents(ctx context.Context) error { collOpts := options.Collection(). SetReadPreference(readpref.Primary()). SetReadConcern(readconcern.Local()) @@ -82,12 +82,12 @@ func (c *CollectionData) VerifyContents(ctx context.Context) error { return fmt.Errorf("expected %d documents but found %d: %v", len(c.Documents), len(docs), docs) } - // We can't use VerifyValuesMatch here because the rules for evaluating matches (e.g. flexible numeric comparisons + // We can't use verifyValuesMatch here because the rules for evaluating matches (e.g. flexible numeric comparisons // and special $$ operators) do not apply when verifying collection outcomes. We have to permit variations in key // order, though, so we sort documents before doing a byte-wise comparison. for idx, expected := range c.Documents { - expected = SortDocument(expected) - actual := SortDocument(docs[idx]) + expected = sortDocument(expected) + actual := sortDocument(docs[idx]) if !bytes.Equal(expected, actual) { return fmt.Errorf("document comparison error at index %d: expected %s, got %s", idx, expected, actual) @@ -96,6 +96,6 @@ func (c *CollectionData) VerifyContents(ctx context.Context) error { return nil } -func (c *CollectionData) Namespace() string { +func (c *collectionData) namespace() string { return fmt.Sprintf("%s.%s", c.DatabaseName, c.CollectionName) } diff --git a/mongo/integration/unified/collection_operation_execution_test.go b/mongo/integration/unified/collection_operation_execution.go similarity index 85% rename from mongo/integration/unified/collection_operation_execution_test.go rename to mongo/integration/unified/collection_operation_execution.go index 4ca96dd39e..657a8a5627 100644 --- a/mongo/integration/unified/collection_operation_execution_test.go +++ b/mongo/integration/unified/collection_operation_execution.go @@ -21,15 +21,15 @@ import ( // This file contains helpers to execute collection operations. -func executeAggregate(ctx context.Context, operation *Operation) (*OperationResult, error) { +func executeAggregate(ctx context.Context, operation *operation) (*operationResult, error) { var aggregator interface { Aggregate(context.Context, interface{}, ...*options.AggregateOptions) (*mongo.Cursor, error) } var err error - aggregator, err = Entities(ctx).Collection(operation.Object) + aggregator, err = entities(ctx).collection(operation.Object) if err != nil { - aggregator, err = Entities(ctx).Database(operation.Object) + aggregator, err = entities(ctx).database(operation.Object) } if err != nil { return nil, fmt.Errorf("no database or collection entity found with ID %q", operation.Object) @@ -80,19 +80,19 @@ func executeAggregate(ctx context.Context, operation *Operation) (*OperationResu cursor, err := aggregator.Aggregate(ctx, pipeline, opts) if err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } defer cursor.Close(ctx) var docs []bson.Raw if err := cursor.All(ctx, &docs); err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } - return NewCursorResult(docs), nil + return newCursorResult(docs), nil } -func executeBulkWrite(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeBulkWrite(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -141,11 +141,11 @@ func executeBulkWrite(ctx context.Context, operation *Operation) (*OperationResu AppendDocument("upsertedIds", rawUpsertedIDs). Build() } - return NewDocumentResult(raw, err), nil + return newDocumentResult(raw, err), nil } -func executeCountDocuments(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeCountDocuments(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -189,13 +189,13 @@ func executeCountDocuments(ctx context.Context, operation *Operation) (*Operatio count, err := coll.CountDocuments(ctx, filter, opts) if err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } - return NewValueResult(bsontype.Int64, bsoncore.AppendInt64(nil, count), nil), nil + return newValueResult(bsontype.Int64, bsoncore.AppendInt64(nil, count), nil), nil } -func executeCreateIndex(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeCreateIndex(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -268,11 +268,11 @@ func executeCreateIndex(ctx context.Context, operation *Operation) (*OperationRe Options: indexOpts, } name, err := coll.Indexes().CreateOne(ctx, model) - return NewValueResult(bsontype.String, bsoncore.AppendString(nil, name), nil), nil + return newValueResult(bsontype.String, bsoncore.AppendString(nil, name), nil), nil } -func executeDeleteOne(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeDeleteOne(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -315,11 +315,11 @@ func executeDeleteOne(ctx context.Context, operation *Operation) (*OperationResu AppendInt64("deletedCount", res.DeletedCount). Build() } - return NewDocumentResult(raw, err), nil + return newDocumentResult(raw, err), nil } -func executeDeleteMany(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeDeleteMany(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -362,11 +362,11 @@ func executeDeleteMany(ctx context.Context, operation *Operation) (*OperationRes AppendInt64("deletedCount", res.DeletedCount). Build() } - return NewDocumentResult(raw, err), nil + return newDocumentResult(raw, err), nil } -func executeDistinct(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeDistinct(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -406,17 +406,17 @@ func executeDistinct(ctx context.Context, operation *Operation) (*OperationResul res, err := coll.Distinct(ctx, fieldName, filter, opts) if err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } _, rawRes, err := bson.MarshalValue(res) if err != nil { return nil, fmt.Errorf("error converting Distinct result to raw BSON: %v", err) } - return NewValueResult(bsontype.Array, rawRes, nil), nil + return newValueResult(bsontype.Array, rawRes, nil), nil } -func executeEstimatedDocumentCount(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeEstimatedDocumentCount(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -437,13 +437,13 @@ func executeEstimatedDocumentCount(ctx context.Context, operation *Operation) (* count, err := coll.EstimatedDocumentCount(ctx, opts) if err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } - return NewValueResult(bsontype.Int64, bsoncore.AppendInt64(nil, count), nil), nil + return newValueResult(bsontype.Int64, bsoncore.AppendInt64(nil, count), nil), nil } -func executeFind(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeFind(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -513,19 +513,19 @@ func executeFind(ctx context.Context, operation *Operation) (*OperationResult, e cursor, err := coll.Find(ctx, filter, opts) if err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } defer cursor.Close(ctx) var docs []bson.Raw if err := cursor.All(ctx, &docs); err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } - return NewCursorResult(docs), nil + return newCursorResult(docs), nil } -func executeFindOneAndDelete(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeFindOneAndDelete(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -568,11 +568,11 @@ func executeFindOneAndDelete(ctx context.Context, operation *Operation) (*Operat } res, err := coll.FindOneAndDelete(ctx, filter, opts).DecodeBytes() - return NewDocumentResult(res, err), nil + return newDocumentResult(res, err), nil } -func executeFindOneAndReplace(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeFindOneAndReplace(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -634,11 +634,11 @@ func executeFindOneAndReplace(ctx context.Context, operation *Operation) (*Opera } res, err := coll.FindOneAndReplace(ctx, filter, replacement, opts).DecodeBytes() - return NewDocumentResult(res, err), nil + return newDocumentResult(res, err), nil } -func executeFindOneAndUpdate(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeFindOneAndUpdate(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -707,11 +707,11 @@ func executeFindOneAndUpdate(ctx context.Context, operation *Operation) (*Operat } res, err := coll.FindOneAndUpdate(ctx, filter, update, opts).DecodeBytes() - return NewDocumentResult(res, err), nil + return newDocumentResult(res, err), nil } -func executeInsertMany(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeInsertMany(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -752,11 +752,11 @@ func executeInsertMany(ctx context.Context, operation *Operation) (*OperationRes AppendDocument("upsertedIds", bsoncore.NewDocumentBuilder().Build()). Build() } - return NewDocumentResult(raw, err), nil + return newDocumentResult(raw, err), nil } -func executeInsertOne(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeInsertOne(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -793,11 +793,11 @@ func executeInsertOne(ctx context.Context, operation *Operation) (*OperationResu AppendValue("insertedId", bsoncore.Value{Type: t, Data: data}). Build() } - return NewDocumentResult(raw, err), nil + return newDocumentResult(raw, err), nil } -func executeReplaceOne(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeReplaceOne(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -842,11 +842,11 @@ func executeReplaceOne(ctx context.Context, operation *Operation) (*OperationRes if buildErr != nil { return nil, buildErr } - return NewDocumentResult(raw, err), nil + return newDocumentResult(raw, err), nil } -func executeUpdateOne(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeUpdateOne(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -861,11 +861,11 @@ func executeUpdateOne(ctx context.Context, operation *Operation) (*OperationResu if buildErr != nil { return nil, buildErr } - return NewDocumentResult(raw, err), nil + return newDocumentResult(raw, err), nil } -func executeUpdateMany(ctx context.Context, operation *Operation) (*OperationResult, error) { - coll, err := Entities(ctx).Collection(operation.Object) +func executeUpdateMany(ctx context.Context, operation *operation) (*operationResult, error) { + coll, err := entities(ctx).collection(operation.Object) if err != nil { return nil, err } @@ -880,7 +880,7 @@ func executeUpdateMany(ctx context.Context, operation *Operation) (*OperationRes if buildErr != nil { return nil, buildErr } - return NewDocumentResult(raw, err), nil + return newDocumentResult(raw, err), nil } func buildUpdateResultDocument(res *mongo.UpdateResult) (bsoncore.Document, error) { diff --git a/mongo/integration/unified/command_monitoring_test.go b/mongo/integration/unified/command_monitoring.go similarity index 81% rename from mongo/integration/unified/command_monitoring_test.go rename to mongo/integration/unified/command_monitoring.go index 7c95c7bba8..53a63847a0 100644 --- a/mongo/integration/unified/command_monitoring_test.go +++ b/mongo/integration/unified/command_monitoring.go @@ -14,7 +14,7 @@ import ( "go.mongodb.org/mongo-driver/bson" ) -type CommandMonitoringEvent struct { +type commandMonitoringEvent struct { CommandStartedEvent *struct { Command bson.Raw `bson:"command"` CommandName *string `bson:"commandName"` @@ -31,13 +31,13 @@ type CommandMonitoringEvent struct { } `bson:"commandFailedEvent"` } -type ExpectedEvents struct { +type expectedEvents struct { ClientID string `bson:"client"` - Events []CommandMonitoringEvent `bson:"events"` + Events []commandMonitoringEvent `bson:"events"` } -func VerifyEvents(ctx context.Context, expectedEvents *ExpectedEvents) error { - client, err := Entities(ctx).Client(expectedEvents.ClientID) +func verifyEvents(ctx context.Context, expectedEvents *expectedEvents) error { + client, err := entities(ctx).client(expectedEvents.ClientID) if err != nil { return err } @@ -46,9 +46,9 @@ func VerifyEvents(ctx context.Context, expectedEvents *ExpectedEvents) error { return nil } - started := client.StartedEvents() - succeeded := client.SucceededEvents() - failed := client.FailedEvents() + started := client.startedEvents() + succeeded := client.succeededEvents() + failed := client.failedEvents() // If the Events array is nil, verify that no events were sent. if len(expectedEvents.Events) == 0 && (len(started)+len(succeeded)+len(failed) != 0) { @@ -75,9 +75,9 @@ func VerifyEvents(ctx context.Context, expectedEvents *ExpectedEvents) error { actual.DatabaseName) } if expected.Command != nil { - expectedDoc := DocumentToRawValue(expected.Command) - actualDoc := DocumentToRawValue(actual.Command) - if err := VerifyValuesMatch(ctx, expectedDoc, actualDoc, true); err != nil { + expectedDoc := documentToRawValue(expected.Command) + actualDoc := documentToRawValue(actual.Command) + if err := verifyValuesMatch(ctx, expectedDoc, actualDoc, true); err != nil { return newEventVerificationError(idx, client, "error comparing command documents: %v", err) } } @@ -95,9 +95,9 @@ func VerifyEvents(ctx context.Context, expectedEvents *ExpectedEvents) error { actual.CommandName) } if expected.Reply != nil { - expectedDoc := DocumentToRawValue(expected.Reply) - actualDoc := DocumentToRawValue(actual.Reply) - if err := VerifyValuesMatch(ctx, expectedDoc, actualDoc, true); err != nil { + expectedDoc := documentToRawValue(expected.Reply) + actualDoc := documentToRawValue(actual.Reply) + if err := verifyValuesMatch(ctx, expectedDoc, actualDoc, true); err != nil { return newEventVerificationError(idx, client, "error comparing reply documents: %v", err) } } @@ -115,7 +115,7 @@ func VerifyEvents(ctx context.Context, expectedEvents *ExpectedEvents) error { actual.CommandName) } default: - return newEventVerificationError(idx, client, "no expected event set on CommandMonitoringEvent instance") + return newEventVerificationError(idx, client, "no expected event set on commandMonitoringEvent instance") } } @@ -126,27 +126,27 @@ func VerifyEvents(ctx context.Context, expectedEvents *ExpectedEvents) error { return nil } -func newEventVerificationError(idx int, client *ClientEntity, msg string, args ...interface{}) error { +func newEventVerificationError(idx int, client *clientEntity, msg string, args ...interface{}) error { fullMsg := fmt.Sprintf(msg, args...) return fmt.Errorf("event comparison failed at index %d: %s; all events found for client: %s", idx, fullMsg, stringifyEventsForClient(client)) } -func stringifyEventsForClient(client *ClientEntity) string { +func stringifyEventsForClient(client *clientEntity) string { str := bytes.NewBuffer(nil) str.WriteString("\n\nStarted Events\n\n") - for _, evt := range client.StartedEvents() { + for _, evt := range client.startedEvents() { str.WriteString(fmt.Sprintf("[%s] %s\n", evt.ConnectionID, evt.Command)) } str.WriteString("\nSucceeded Events\n\n") - for _, evt := range client.SucceededEvents() { + for _, evt := range client.succeededEvents() { str.WriteString(fmt.Sprintf("[%s] CommandName: %s, Reply: %s\n", evt.ConnectionID, evt.CommandName, evt.Reply)) } str.WriteString("\nFailed Events\n\n") - for _, evt := range client.FailedEvents() { + for _, evt := range client.failedEvents() { str.WriteString(fmt.Sprintf("[%s] CommandName: %s, Failure: %s\n", evt.ConnectionID, evt.CommandName, evt.Failure)) } diff --git a/mongo/integration/unified/common_options_test.go b/mongo/integration/unified/common_options.go similarity index 100% rename from mongo/integration/unified/common_options_test.go rename to mongo/integration/unified/common_options.go diff --git a/mongo/integration/unified/context_test.go b/mongo/integration/unified/context.go similarity index 74% rename from mongo/integration/unified/context_test.go rename to mongo/integration/unified/context.go index 492dfca050..4583b47ed4 100644 --- a/mongo/integration/unified/context_test.go +++ b/mongo/integration/unified/context.go @@ -17,7 +17,7 @@ import ( type ctxKey string const ( - // entitiesKey is used to store an EntityMap instance in a Context. + // entitiesKey is used to store an entityMap instance in a Context. entitiesKey ctxKey = "test-entities" // failPointsKey is used to store a map from a fail point name to the Client instance used to configure it. failPointsKey ctxKey = "test-failpoints" @@ -25,16 +25,16 @@ const ( targetedFailPointsKey ctxKey = "test-targeted-failpoints" ) -// NewTestContext creates a new Context derived from ctx with values initialized to store the state required for test +// newTestContext creates a new Context derived from ctx with values initialized to store the state required for test // execution. -func NewTestContext(ctx context.Context) context.Context { - ctx = context.WithValue(ctx, entitiesKey, NewEntityMap()) +func newTestContext(ctx context.Context) context.Context { + ctx = context.WithValue(ctx, entitiesKey, newEntityMap()) ctx = context.WithValue(ctx, failPointsKey, make(map[string]*mongo.Client)) ctx = context.WithValue(ctx, targetedFailPointsKey, make(map[string]string)) return ctx } -func AddFailPoint(ctx context.Context, failPoint string, client *mongo.Client) error { +func addFailPoint(ctx context.Context, failPoint string, client *mongo.Client) error { failPoints := ctx.Value(failPointsKey).(map[string]*mongo.Client) if _, ok := failPoints[failPoint]; ok { return fmt.Errorf("fail point %q already exists in tracked fail points map", failPoint) @@ -44,7 +44,7 @@ func AddFailPoint(ctx context.Context, failPoint string, client *mongo.Client) e return nil } -func AddTargetedFailPoint(ctx context.Context, failPoint string, host string) error { +func addTargetedFailPoint(ctx context.Context, failPoint string, host string) error { failPoints := ctx.Value(targetedFailPointsKey).(map[string]string) if _, ok := failPoints[failPoint]; ok { return fmt.Errorf("fail point %q already exists in tracked targeted fail points map", failPoint) @@ -54,14 +54,14 @@ func AddTargetedFailPoint(ctx context.Context, failPoint string, host string) er return nil } -func FailPoints(ctx context.Context) map[string]*mongo.Client { +func failPoints(ctx context.Context) map[string]*mongo.Client { return ctx.Value(failPointsKey).(map[string]*mongo.Client) } -func TargetedFailPoints(ctx context.Context) map[string]string { +func targetedFailPoints(ctx context.Context) map[string]string { return ctx.Value(targetedFailPointsKey).(map[string]string) } -func Entities(ctx context.Context) *EntityMap { - return ctx.Value(entitiesKey).(*EntityMap) +func entities(ctx context.Context) *entityMap { + return ctx.Value(entitiesKey).(*entityMap) } diff --git a/mongo/integration/unified/crud_helpers_test.go b/mongo/integration/unified/crud_helpers.go similarity index 100% rename from mongo/integration/unified/crud_helpers_test.go rename to mongo/integration/unified/crud_helpers.go diff --git a/mongo/integration/unified/database_operation_execution_test.go b/mongo/integration/unified/database_operation_execution.go similarity index 78% rename from mongo/integration/unified/database_operation_execution_test.go rename to mongo/integration/unified/database_operation_execution.go index 091e231968..208fb3e1e2 100644 --- a/mongo/integration/unified/database_operation_execution_test.go +++ b/mongo/integration/unified/database_operation_execution.go @@ -17,8 +17,8 @@ import ( // This file contains helpers to execute database operations. -func executeCreateCollection(ctx context.Context, operation *Operation) (*OperationResult, error) { - db, err := Entities(ctx).Database(operation.Object) +func executeCreateCollection(ctx context.Context, operation *operation) (*operationResult, error) { + db, err := entities(ctx).database(operation.Object) if err != nil { return nil, err } @@ -41,11 +41,11 @@ func executeCreateCollection(ctx context.Context, operation *Operation) (*Operat } err = db.CreateCollection(ctx, collName) - return NewErrorResult(err), nil + return newErrorResult(err), nil } -func executeDropCollection(ctx context.Context, operation *Operation) (*OperationResult, error) { - db, err := Entities(ctx).Database(operation.Object) +func executeDropCollection(ctx context.Context, operation *operation) (*operationResult, error) { + db, err := entities(ctx).database(operation.Object) if err != nil { return nil, err } @@ -68,11 +68,11 @@ func executeDropCollection(ctx context.Context, operation *Operation) (*Operatio } err = db.Collection(collName).Drop(ctx) - return NewErrorResult(err), nil + return newErrorResult(err), nil } -func executeListCollections(ctx context.Context, operation *Operation) (*OperationResult, error) { - db, err := Entities(ctx).Database(operation.Object) +func executeListCollections(ctx context.Context, operation *operation) (*operationResult, error) { + db, err := entities(ctx).database(operation.Object) if err != nil { return nil, err } @@ -84,19 +84,19 @@ func executeListCollections(ctx context.Context, operation *Operation) (*Operati cursor, err := db.ListCollections(ctx, listCollArgs.filter, listCollArgs.opts) if err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } defer cursor.Close(ctx) var docs []bson.Raw if err := cursor.All(ctx, &cursor); err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } - return NewCursorResult(docs), nil + return newCursorResult(docs), nil } -func executeListCollectionNames(ctx context.Context, operation *Operation) (*OperationResult, error) { - db, err := Entities(ctx).Database(operation.Object) +func executeListCollectionNames(ctx context.Context, operation *operation) (*operationResult, error) { + db, err := entities(ctx).database(operation.Object) if err != nil { return nil, err } @@ -108,17 +108,17 @@ func executeListCollectionNames(ctx context.Context, operation *Operation) (*Ope names, err := db.ListCollectionNames(ctx, listCollArgs.filter, listCollArgs.opts) if err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } _, data, err := bson.MarshalValue(names) if err != nil { return nil, fmt.Errorf("error converting collection names slice to BSON: %v", err) } - return NewValueResult(bsontype.Array, data, nil), nil + return newValueResult(bsontype.Array, data, nil), nil } -func executeRunCommand(ctx context.Context, operation *Operation) (*OperationResult, error) { - db, err := Entities(ctx).Database(operation.Object) +func executeRunCommand(ctx context.Context, operation *operation) (*operationResult, error) { + db, err := entities(ctx).database(operation.Object) if err != nil { return nil, err } @@ -163,5 +163,5 @@ func executeRunCommand(ctx context.Context, operation *Operation) (*OperationRes } res, err := db.RunCommand(ctx, command, opts).DecodeBytes() - return NewDocumentResult(res, err), nil + return newDocumentResult(res, err), nil } diff --git a/mongo/integration/unified/db_collection_options_test.go b/mongo/integration/unified/db_collection_options.go similarity index 86% rename from mongo/integration/unified/db_collection_options_test.go rename to mongo/integration/unified/db_collection_options.go index a98ed75c3c..a1831a6f97 100644 --- a/mongo/integration/unified/db_collection_options_test.go +++ b/mongo/integration/unified/db_collection_options.go @@ -13,14 +13,14 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -type DBOrCollectionOptions struct { +type dbOrCollectionOptions struct { DBOptions *options.DatabaseOptions CollectionOptions *options.CollectionOptions } // UnmarshalBSON specifies custom BSON unmarshalling behavior to convert db/collection options from BSON/JSON documents // to their corresponding Go objects. -func (d *DBOrCollectionOptions) UnmarshalBSON(data []byte) error { +func (d dbOrCollectionOptions) UnmarshalBSON(data []byte) error { var temp struct { RC *readConcern `bson:"readConcern"` RP *readPreference `bson:"readPreference"` @@ -28,10 +28,10 @@ func (d *DBOrCollectionOptions) UnmarshalBSON(data []byte) error { Extra map[string]interface{} `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { - return fmt.Errorf("error unmarshalling to temporary DBOrCollectionOptions object: %v", err) + return fmt.Errorf("error unmarshalling to temporary dbOrCollectionOptions object: %v", err) } if len(temp.Extra) > 0 { - return fmt.Errorf("unrecognized fields for DBOrCollectionOptions: %v", MapKeys(temp.Extra)) + return fmt.Errorf("unrecognized fields for dbOrCollectionOptions: %v", mapKeys(temp.Extra)) } d.DBOptions = options.Database() diff --git a/mongo/integration/unified/entity_test.go b/mongo/integration/unified/entity.go similarity index 53% rename from mongo/integration/unified/entity_test.go rename to mongo/integration/unified/entity.go index 3fb3323343..3a8ab46347 100644 --- a/mongo/integration/unified/entity_test.go +++ b/mongo/integration/unified/entity.go @@ -16,9 +16,9 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -// EntityOptions represents all options that can be used to configure an entity. Because there are multiple entity +// entityOptions represents all options that can be used to configure an entity. Because there are multiple entity // types, only a subset of the options that this type contains apply to any given entity. -type EntityOptions struct { +type entityOptions struct { // Options that apply to all entity types. ID string `bson:"id"` @@ -27,55 +27,55 @@ type EntityOptions struct { UseMultipleMongoses *bool `bson:"useMultipleMongoses"` ObserveEvents []string `bson:"observeEvents"` IgnoredCommands []string `bson:"ignoreCommandMonitoringEvents"` - ServerAPIOptions *ServerAPIOptions `bson:"serverApi"` + ServerAPIOptions *serverAPIOptions `bson:"serverApi"` // Options for database entities. DatabaseName string `bson:"databaseName"` - DatabaseOptions *DBOrCollectionOptions `bson:"databaseOptions"` + DatabaseOptions *dbOrCollectionOptions `bson:"databaseOptions"` // Options for collection entities. CollectionName string `bson:"collectionName"` - CollectionOptions *DBOrCollectionOptions `bson:"collectionOptions"` + CollectionOptions *dbOrCollectionOptions `bson:"collectionOptions"` // Options for session entities. - SessionOptions *SessionOptions `bson:"sessionOptions"` + SessionOptions *sessionOptions `bson:"sessionOptions"` // Options for GridFS bucket entities. - GridFSBucketOptions *GridFSBucketOptions `bson:"bucketOptions"` + GridFSBucketOptions *gridFSBucketOptions `bson:"bucketOptions"` // Options that reference other entities. ClientID string `bson:"client"` DatabaseID string `bson:"database"` } -// EntityMap is used to store entities during tests. This type enforces uniqueness so no two entities can have the same +// entityMap is used to store entities during tests. This type enforces uniqueness so no two entities can have the same // ID, even if they are of different types. It also enforces referential integrity so construction of an entity that // references another (e.g. a database entity references a client) will fail if the referenced entity does not exist. -type EntityMap struct { - allEntities map[string]struct{} - changeStreams map[string]*mongo.ChangeStream - clients map[string]*ClientEntity - dbs map[string]*mongo.Database - collections map[string]*mongo.Collection - sessions map[string]mongo.Session - gridfsBuckets map[string]*gridfs.Bucket - bsonValues map[string]bson.RawValue +type entityMap struct { + allEntities map[string]struct{} + changeStreams map[string]*mongo.ChangeStream + clientEntities map[string]*clientEntity + dbEntites map[string]*mongo.Database + collEntities map[string]*mongo.Collection + sessions map[string]mongo.Session + gridfsBuckets map[string]*gridfs.Bucket + bsonValues map[string]bson.RawValue } -func NewEntityMap() *EntityMap { - return &EntityMap{ - allEntities: make(map[string]struct{}), - gridfsBuckets: make(map[string]*gridfs.Bucket), - bsonValues: make(map[string]bson.RawValue), - changeStreams: make(map[string]*mongo.ChangeStream), - clients: make(map[string]*ClientEntity), - collections: make(map[string]*mongo.Collection), - dbs: make(map[string]*mongo.Database), - sessions: make(map[string]mongo.Session), +func newEntityMap() *entityMap { + return &entityMap{ + allEntities: make(map[string]struct{}), + gridfsBuckets: make(map[string]*gridfs.Bucket), + bsonValues: make(map[string]bson.RawValue), + changeStreams: make(map[string]*mongo.ChangeStream), + clientEntities: make(map[string]*clientEntity), + collEntities: make(map[string]*mongo.Collection), + dbEntites: make(map[string]*mongo.Database), + sessions: make(map[string]mongo.Session), } } -func (em *EntityMap) AddBSONEntity(id string, val bson.RawValue) error { +func (em *entityMap) addBSONEntity(id string, val bson.RawValue) error { if err := em.verifyEntityDoesNotExist(id); err != nil { return err } @@ -85,7 +85,7 @@ func (em *EntityMap) AddBSONEntity(id string, val bson.RawValue) error { return nil } -func (em *EntityMap) AddChangeStreamEntity(id string, stream *mongo.ChangeStream) error { +func (em *entityMap) addChangeStreamEntity(id string, stream *mongo.ChangeStream) error { if err := em.verifyEntityDoesNotExist(id); err != nil { return err } @@ -95,7 +95,7 @@ func (em *EntityMap) AddChangeStreamEntity(id string, stream *mongo.ChangeStream return nil } -func (em *EntityMap) AddEntity(ctx context.Context, entityType string, entityOptions *EntityOptions) error { +func (em *entityMap) addEntity(ctx context.Context, entityType string, entityOptions *entityOptions) error { if err := em.verifyEntityDoesNotExist(entityOptions.ID); err != nil { return err } @@ -123,7 +123,7 @@ func (em *EntityMap) AddEntity(ctx context.Context, entityType string, entityOpt return nil } -func (em *EntityMap) GridFSBucket(id string) (*gridfs.Bucket, error) { +func (em *entityMap) gridFSBucket(id string) (*gridfs.Bucket, error) { bucket, ok := em.gridfsBuckets[id] if !ok { return nil, newEntityNotFoundError("gridfs bucket", id) @@ -131,7 +131,7 @@ func (em *EntityMap) GridFSBucket(id string) (*gridfs.Bucket, error) { return bucket, nil } -func (em *EntityMap) BSONValue(id string) (bson.RawValue, error) { +func (em *entityMap) bsonValue(id string) (bson.RawValue, error) { val, ok := em.bsonValues[id] if !ok { return emptyRawValue, newEntityNotFoundError("BSON", id) @@ -139,7 +139,7 @@ func (em *EntityMap) BSONValue(id string) (bson.RawValue, error) { return val, nil } -func (em *EntityMap) ChangeStream(id string) (*mongo.ChangeStream, error) { +func (em *entityMap) changeStream(id string) (*mongo.ChangeStream, error) { client, ok := em.changeStreams[id] if !ok { return nil, newEntityNotFoundError("change stream", id) @@ -147,39 +147,39 @@ func (em *EntityMap) ChangeStream(id string) (*mongo.ChangeStream, error) { return client, nil } -func (em *EntityMap) Client(id string) (*ClientEntity, error) { - client, ok := em.clients[id] +func (em *entityMap) client(id string) (*clientEntity, error) { + client, ok := em.clientEntities[id] if !ok { return nil, newEntityNotFoundError("client", id) } return client, nil } -func (em *EntityMap) Clients() map[string]*ClientEntity { - return em.clients +func (em *entityMap) clients() map[string]*clientEntity { + return em.clientEntities } -func (em *EntityMap) Collections() map[string]*mongo.Collection { - return em.collections +func (em *entityMap) collections() map[string]*mongo.Collection { + return em.collEntities } -func (em *EntityMap) Collection(id string) (*mongo.Collection, error) { - coll, ok := em.collections[id] +func (em *entityMap) collection(id string) (*mongo.Collection, error) { + coll, ok := em.collEntities[id] if !ok { return nil, newEntityNotFoundError("collection", id) } return coll, nil } -func (em *EntityMap) Database(id string) (*mongo.Database, error) { - db, ok := em.dbs[id] +func (em *entityMap) database(id string) (*mongo.Database, error) { + db, ok := em.dbEntites[id] if !ok { return nil, newEntityNotFoundError("database", id) } return db, nil } -func (em *EntityMap) Session(id string) (mongo.Session, error) { +func (em *entityMap) session(id string) (mongo.Session, error) { sess, ok := em.sessions[id] if !ok { return nil, newEntityNotFoundError("session", id) @@ -187,14 +187,14 @@ func (em *EntityMap) Session(id string) (mongo.Session, error) { return sess, nil } -// Close disposes of the session and client entities associated with this map. -func (em *EntityMap) Close(ctx context.Context) []error { +// close disposes of the session and client entities associated with this map. +func (em *entityMap) close(ctx context.Context) []error { for _, sess := range em.sessions { sess.EndSession(ctx) } var errs []error - for id, client := range em.clients { + for id, client := range em.clientEntities { if err := client.Disconnect(ctx); err != nil { errs = append(errs, fmt.Errorf("error closing client with ID %q: %v", id, err)) } @@ -202,56 +202,56 @@ func (em *EntityMap) Close(ctx context.Context) []error { return errs } -func (em *EntityMap) addClientEntity(ctx context.Context, EntityOptions *EntityOptions) error { - var client *ClientEntity - client, err := NewClientEntity(ctx, EntityOptions) +func (em *entityMap) addClientEntity(ctx context.Context, entityOptions *entityOptions) error { + var client *clientEntity + client, err := newClientEntity(ctx, entityOptions) if err != nil { return fmt.Errorf("error creating client entity: %v", err) } - em.clients[EntityOptions.ID] = client + em.clientEntities[entityOptions.ID] = client return nil } -func (em *EntityMap) addDatabaseEntity(EntityOptions *EntityOptions) error { - client, ok := em.clients[EntityOptions.ClientID] +func (em *entityMap) addDatabaseEntity(entityOptions *entityOptions) error { + client, ok := em.clientEntities[entityOptions.ClientID] if !ok { - return newEntityNotFoundError("client", EntityOptions.ClientID) + return newEntityNotFoundError("client", entityOptions.ClientID) } dbOpts := options.Database() - if EntityOptions.DatabaseOptions != nil { - dbOpts = EntityOptions.DatabaseOptions.DBOptions + if entityOptions.DatabaseOptions != nil { + dbOpts = entityOptions.DatabaseOptions.DBOptions } - em.dbs[EntityOptions.ID] = client.Database(EntityOptions.DatabaseName, dbOpts) + em.dbEntites[entityOptions.ID] = client.Database(entityOptions.DatabaseName, dbOpts) return nil } -func (em *EntityMap) addCollectionEntity(EntityOptions *EntityOptions) error { - db, ok := em.dbs[EntityOptions.DatabaseID] +func (em *entityMap) addCollectionEntity(entityOptions *entityOptions) error { + db, ok := em.dbEntites[entityOptions.DatabaseID] if !ok { - return newEntityNotFoundError("database", EntityOptions.DatabaseID) + return newEntityNotFoundError("database", entityOptions.DatabaseID) } collOpts := options.Collection() - if EntityOptions.CollectionOptions != nil { - collOpts = EntityOptions.CollectionOptions.CollectionOptions + if entityOptions.CollectionOptions != nil { + collOpts = entityOptions.CollectionOptions.CollectionOptions } - em.collections[EntityOptions.ID] = db.Collection(EntityOptions.CollectionName, collOpts) + em.collEntities[entityOptions.ID] = db.Collection(entityOptions.CollectionName, collOpts) return nil } -func (em *EntityMap) addSessionEntity(EntityOptions *EntityOptions) error { - client, ok := em.clients[EntityOptions.ClientID] +func (em *entityMap) addSessionEntity(entityOptions *entityOptions) error { + client, ok := em.clientEntities[entityOptions.ClientID] if !ok { - return newEntityNotFoundError("client", EntityOptions.ClientID) + return newEntityNotFoundError("client", entityOptions.ClientID) } sessionOpts := options.Session() - if EntityOptions.SessionOptions != nil { - sessionOpts = EntityOptions.SessionOptions.SessionOptions + if entityOptions.SessionOptions != nil { + sessionOpts = entityOptions.SessionOptions.SessionOptions } sess, err := client.StartSession(sessionOpts) @@ -259,19 +259,19 @@ func (em *EntityMap) addSessionEntity(EntityOptions *EntityOptions) error { return fmt.Errorf("error starting session: %v", err) } - em.sessions[EntityOptions.ID] = sess + em.sessions[entityOptions.ID] = sess return nil } -func (em *EntityMap) addGridFSBucketEntity(EntityOptions *EntityOptions) error { - db, ok := em.dbs[EntityOptions.DatabaseID] +func (em *entityMap) addGridFSBucketEntity(entityOptions *entityOptions) error { + db, ok := em.dbEntites[entityOptions.DatabaseID] if !ok { - return newEntityNotFoundError("database", EntityOptions.DatabaseID) + return newEntityNotFoundError("database", entityOptions.DatabaseID) } bucketOpts := options.GridFSBucket() - if EntityOptions.GridFSBucketOptions != nil { - bucketOpts = EntityOptions.GridFSBucketOptions.BucketOptions + if entityOptions.GridFSBucketOptions != nil { + bucketOpts = entityOptions.GridFSBucketOptions.BucketOptions } bucket, err := gridfs.NewBucket(db, bucketOpts) @@ -279,11 +279,11 @@ func (em *EntityMap) addGridFSBucketEntity(EntityOptions *EntityOptions) error { return fmt.Errorf("error creating GridFS bucket: %v", err) } - em.gridfsBuckets[EntityOptions.ID] = bucket + em.gridfsBuckets[entityOptions.ID] = bucket return nil } -func (em *EntityMap) verifyEntityDoesNotExist(id string) error { +func (em *entityMap) verifyEntityDoesNotExist(id string) error { if _, ok := em.allEntities[id]; ok { return fmt.Errorf("entity with ID %q already exists", id) } diff --git a/mongo/integration/unified/error_test.go b/mongo/integration/unified/error.go similarity index 93% rename from mongo/integration/unified/error_test.go rename to mongo/integration/unified/error.go index 625914927f..c5a6c994a4 100644 --- a/mongo/integration/unified/error_test.go +++ b/mongo/integration/unified/error.go @@ -15,9 +15,9 @@ import ( "go.mongodb.org/mongo-driver/mongo" ) -// ExpectedError represents an error that is expected to occur during a test. This type ignores the "isError" field in +// expectedError represents an error that is expected to occur during a test. This type ignores the "isError" field in // test files because it is always true if it is specified, so the runner can simply assert that an error occurred. -type ExpectedError struct { +type expectedError struct { IsClientError *bool `bson:"isClientError"` ErrorSubstring *string `bson:"errorContains"` Code *int32 `bson:"errorCode"` @@ -27,10 +27,10 @@ type ExpectedError struct { ExpectedResult *bson.RawValue `bson:"expectResult"` } -// VerifyOperationError compares the expected error to the actual operation result. If the expected parameter is nil, +// verifyOperationError compares the expected error to the actual operation result. If the expected parameter is nil, // this function will only check that result.Err is also nil. Otherwise, it will check that result.Err is non-nil and -// will perform any other assertions required by the ExpectedError object. An error is returned if any checks fail. -func VerifyOperationError(ctx context.Context, expected *ExpectedError, result *OperationResult) error { +// will perform any other assertions required by the expectedError object. An error is returned if any checks fail. +func verifyOperationError(ctx context.Context, expected *expectedError, result *operationResult) error { if expected == nil { if result.Err != nil { return fmt.Errorf("expected no error, but got %v", result.Err) @@ -101,7 +101,7 @@ func VerifyOperationError(ctx context.Context, expected *ExpectedError, result * } if expected.ExpectedResult != nil { - if err := VerifyOperationResult(ctx, *expected.ExpectedResult, result); err != nil { + if err := verifyOperationResult(ctx, *expected.ExpectedResult, result); err != nil { return fmt.Errorf("result comparison error: %v", err) } } diff --git a/mongo/integration/unified/gridfs_bucket_operation_execution_test.go b/mongo/integration/unified/gridfs_bucket_operation_execution.go similarity index 76% rename from mongo/integration/unified/gridfs_bucket_operation_execution_test.go rename to mongo/integration/unified/gridfs_bucket_operation_execution.go index f4aea3ee37..05888307a4 100644 --- a/mongo/integration/unified/gridfs_bucket_operation_execution_test.go +++ b/mongo/integration/unified/gridfs_bucket_operation_execution.go @@ -19,8 +19,8 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -func executeBucketDelete(ctx context.Context, operation *Operation) (*OperationResult, error) { - bucket, err := Entities(ctx).GridFSBucket(operation.Object) +func executeBucketDelete(ctx context.Context, operation *operation) (*operationResult, error) { + bucket, err := entities(ctx).gridFSBucket(operation.Object) if err != nil { return nil, err } @@ -42,11 +42,11 @@ func executeBucketDelete(ctx context.Context, operation *Operation) (*OperationR return nil, newMissingArgumentError("id") } - return NewErrorResult(bucket.Delete(*id)), nil + return newErrorResult(bucket.Delete(*id)), nil } -func executeBucketDownload(ctx context.Context, operation *Operation) (*OperationResult, error) { - bucket, err := Entities(ctx).GridFSBucket(operation.Object) +func executeBucketDownload(ctx context.Context, operation *operation) (*operationResult, error) { + bucket, err := entities(ctx).gridFSBucket(operation.Object) if err != nil { return nil, err } @@ -70,19 +70,19 @@ func executeBucketDownload(ctx context.Context, operation *Operation) (*Operatio stream, err := bucket.OpenDownloadStream(*id) if err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } var buffer bytes.Buffer if _, err := io.Copy(&buffer, stream); err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } - return NewValueResult(bsontype.Binary, bsoncore.AppendBinary(nil, 0, buffer.Bytes()), nil), nil + return newValueResult(bsontype.Binary, bsoncore.AppendBinary(nil, 0, buffer.Bytes()), nil), nil } -func executeBucketUpload(ctx context.Context, operation *Operation) (*OperationResult, error) { - bucket, err := Entities(ctx).GridFSBucket(operation.Object) +func executeBucketUpload(ctx context.Context, operation *operation) (*operationResult, error) { + bucket, err := entities(ctx).gridFSBucket(operation.Object) if err != nil { return nil, err } @@ -121,7 +121,7 @@ func executeBucketUpload(ctx context.Context, operation *Operation) (*OperationR fileID, err := bucket.UploadFromStream(filename, bytes.NewReader(fileBytes), opts) if err != nil { - return NewErrorResult(err), nil + return newErrorResult(err), nil } if operation.ResultEntityID != nil { @@ -129,10 +129,10 @@ func executeBucketUpload(ctx context.Context, operation *Operation) (*OperationR Type: bsontype.ObjectID, Value: fileID[:], } - if err := Entities(ctx).AddBSONEntity(*operation.ResultEntityID, fileIDValue); err != nil { + if err := entities(ctx).addBSONEntity(*operation.ResultEntityID, fileIDValue); err != nil { return nil, fmt.Errorf("error storing result as BSON entity: %v", err) } } - return NewValueResult(bsontype.ObjectID, fileID[:], nil), nil + return newValueResult(bsontype.ObjectID, fileID[:], nil), nil } diff --git a/mongo/integration/unified/matches.go b/mongo/integration/unified/matches.go new file mode 100644 index 0000000000..b256d6e238 --- /dev/null +++ b/mongo/integration/unified/matches.go @@ -0,0 +1,335 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package unified + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "strings" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsontype" +) + +// keyPathCtxKey is used as a key for a Context object. The value conveys the BSON key path that is currently being +// compared. +type keyPathCtxKey struct{} + +// extraKeysAllowedCtxKey is used as a key for a Context object. The value conveys whether or not the document under +// test can contain extra keys. For example, if the expected document is {x: 1}, the document {x: 1, y: 1} would match +// if the value for this key is true. +type extraKeysAllowedCtxKey struct{} + +func makeMatchContext(ctx context.Context, keyPath string, extraKeysAllowed bool) context.Context { + ctx = context.WithValue(ctx, keyPathCtxKey{}, keyPath) + return context.WithValue(ctx, extraKeysAllowedCtxKey{}, extraKeysAllowed) +} + +// verifyValuesMatch compares the provided BSON values and returns an error if they do not match. If the values are +// documents and extraKeysAllowed is true, the actual value will be allowed to have additional keys at the top-level. +// For example, an expected document {x: 1} would match the actual document {x: 1, y: 1}. +func verifyValuesMatch(ctx context.Context, expected, actual bson.RawValue, extraKeysAllowed bool) error { + return verifyValuesMatchInner(makeMatchContext(ctx, "", extraKeysAllowed), expected, actual) +} + +func verifyValuesMatchInner(ctx context.Context, expected, actual bson.RawValue) error { + keyPath := ctx.Value(keyPathCtxKey{}).(string) + extraKeysAllowed := ctx.Value(extraKeysAllowedCtxKey{}).(bool) + + if expectedDoc, ok := expected.DocumentOK(); ok { + // If the root document only has one element and the key is a special matching operator, the actual value might + // not actually be a document. In this case, evaluate the special operator with the actual value rather than + // doing an element-wise document comparison. + if requiresSpecialMatching(expectedDoc) { + if err := evaluateSpecialComparison(ctx, expectedDoc, actual, ""); err != nil { + return newMatchingError(keyPath, "error doing special matching assertion: %v", err) + } + return nil + } + + actualDoc, ok := actual.DocumentOK() + if !ok { + return newMatchingError(keyPath, "expected value to be a document but got a %s", actual.Type) + } + + // Perform element-wise comparisons. + expectedElems, _ := expectedDoc.Elements() + for _, expectedElem := range expectedElems { + expectedKey := expectedElem.Key() + expectedValue := expectedElem.Value() + + fullKeyPath := expectedKey + if keyPath != "" { + fullKeyPath = keyPath + "." + expectedKey + } + + // Get the value from actualDoc here but don't check the error until later because some of the special + // matching operators can assert that the value isn't present in the document (e.g. $$exists). + actualValue, err := actualDoc.LookupErr(expectedKey) + if specialDoc, ok := expectedValue.DocumentOK(); ok && requiresSpecialMatching(specialDoc) { + // Reset the key path so any errors returned from the function will only have the key path for the + // target value. Also unconditionally set extraKeysAllowed to false because an assertion like + // $$unsetOrMatches could recurse back into this function. In that case, the target document is nested + // and should not have extra keys. + ctx = makeMatchContext(ctx, "", false) + if err := evaluateSpecialComparison(ctx, specialDoc, actualValue, expectedKey); err != nil { + return newMatchingError(fullKeyPath, "error doing special matching assertion: %v", err) + } + continue + } + + // This isn't a special comparison. Assert that the value exists in the actual document. + if err != nil { + return newMatchingError(fullKeyPath, "key not found in actual document") + } + + // Nested documents cannot have extra keys, so we unconditionally pass false for extraKeysAllowed. + comparisonCtx := makeMatchContext(ctx, fullKeyPath, false) + if err := verifyValuesMatchInner(comparisonCtx, expectedValue, actualValue); err != nil { + return err + } + } + // If required, verify that the actual document does not have extra elements. We do this by iterating over the + // actual and checking for each key in the expected rather than comparing element counts because the presence of + // special operators can cause incorrect counts. For example, the document {y: {$$exists: false}} has one + // element, but should match the document {}, which has none. + if !extraKeysAllowed { + actualElems, _ := actualDoc.Elements() + for _, actualElem := range actualElems { + if _, err := expectedDoc.LookupErr(actualElem.Key()); err != nil { + return newMatchingError(keyPath, "extra key %q found in actual document %s", actualElem.Key(), + actualDoc) + } + } + } + + return nil + } + if expectedArr, ok := expected.ArrayOK(); ok { + actualArr, ok := actual.ArrayOK() + if !ok { + return newMatchingError(keyPath, "expected value to be an array but got a %s", actual.Type) + } + + expectedValues, _ := expectedArr.Values() + actualValues, _ := actualArr.Values() + + // Arrays must always have the same number of elements. + if len(expectedValues) != len(actualValues) { + return newMatchingError(keyPath, "expected array length %d, got %d", len(expectedValues), + len(actualValues)) + } + + for idx, expectedValue := range expectedValues { + // Use the index as the key to augment the key path. + fullKeyPath := fmt.Sprintf("%d", idx) + if keyPath != "" { + fullKeyPath = keyPath + "." + fullKeyPath + } + + comparisonCtx := makeMatchContext(ctx, fullKeyPath, extraKeysAllowed) + err := verifyValuesMatchInner(comparisonCtx, expectedValue, actualValues[idx]) + if err != nil { + return err + } + } + + return nil + } + + // Numeric values must be considered equal even if their types are different (e.g. if expected is an int32 and + // actual is an int64). + if expected.IsNumber() { + if !actual.IsNumber() { + return newMatchingError(keyPath, "expected value to be a number but got a %s", actual.Type) + } + + expectedInt64 := expected.AsInt64() + actualInt64 := actual.AsInt64() + if expectedInt64 != actualInt64 { + return newMatchingError(keyPath, "expected numeric value %d, got %d", expectedInt64, actualInt64) + } + return nil + } + + // If expected is not a recursive or numeric type, we can directly call Equal to do the comparison. + if !expected.Equal(actual) { + return newMatchingError(keyPath, "expected value %s, got %s", expected, actual) + } + return nil +} + +func evaluateSpecialComparison(ctx context.Context, assertionDoc bson.Raw, actual bson.RawValue, fieldName string) error { + assertionElem := assertionDoc.Index(0) + assertion := assertionElem.Key() + assertionVal := assertionElem.Value() + + switch assertion { + case "$$exists": + shouldExist := assertionVal.Boolean() + exists := actual.Validate() == nil + if shouldExist != exists { + return fmt.Errorf("expected value to exist: %v; value actually exists: %v", shouldExist, exists) + } + case "$$type": + possibleTypes, err := getTypesArray(assertionVal) + if err != nil { + return fmt.Errorf("error getting possible types for a $$type assertion: %v", err) + } + + for _, possibleType := range possibleTypes { + if actual.Type == possibleType { + return nil + } + } + return fmt.Errorf("expected type to be one of %v but was %s", possibleTypes, actual.Type) + case "$$matchesEntity": + expected, err := entities(ctx).bsonValue(assertionVal.StringValue()) + if err != nil { + return err + } + + // $$matchesEntity doesn't modify the nesting level of the key path so we can propagate ctx without changes. + return verifyValuesMatchInner(ctx, expected, actual) + case "$$matchesHexBytes": + expectedBytes, err := hex.DecodeString(assertionVal.StringValue()) + if err != nil { + return fmt.Errorf("error converting $$matcesHexBytes value to bytes: %v", err) + } + + _, actualBytes, ok := actual.BinaryOK() + if !ok { + return fmt.Errorf("expected binary value for a $$matchesHexBytes assertion, but got a %s", actual.Type) + } + if !bytes.Equal(expectedBytes, actualBytes) { + return fmt.Errorf("expected bytes %v, got %v", expectedBytes, actualBytes) + } + case "$$unsetOrMatches": + if actual.Validate() != nil { + return nil + } + + // $$unsetOrMatches doesn't modify the nesting level or the key path so we can propagate the context to the + // comparison function without changing anything. + return verifyValuesMatchInner(ctx, assertionVal, actual) + case "$$sessionLsid": + sess, err := entities(ctx).session(assertionVal.StringValue()) + if err != nil { + return err + } + + expectedID := sess.ID() + actualID, ok := actual.DocumentOK() + if !ok { + return fmt.Errorf("expected document value for a $$sessionLsid assertion, but got a %s", actual.Type) + } + if !bytes.Equal(expectedID, actualID) { + return fmt.Errorf("expected lsid %v, got %v", expectedID, actualID) + } + default: + return fmt.Errorf("unrecognized special matching assertion %q", assertion) + } + + return nil +} + +func requiresSpecialMatching(doc bson.Raw) bool { + elems, _ := doc.Elements() + return len(elems) == 1 && strings.HasPrefix(elems[0].Key(), "$$") +} + +func getTypesArray(val bson.RawValue) ([]bsontype.Type, error) { + switch val.Type { + case bsontype.String: + convertedType, err := convertStringToBSONType(val.StringValue()) + if err != nil { + return nil, err + } + + return []bsontype.Type{convertedType}, nil + case bsontype.Array: + var typeStrings []string + if err := val.Unmarshal(&typeStrings); err != nil { + return nil, fmt.Errorf("error unmarshalling to slice of strings: %v", err) + } + + var types []bsontype.Type + for _, typeStr := range typeStrings { + convertedType, err := convertStringToBSONType(typeStr) + if err != nil { + return nil, err + } + + types = append(types, convertedType) + } + return types, nil + default: + return nil, fmt.Errorf("invalid type to convert to bsontype.Type slice: %s", val.Type) + } +} + +func convertStringToBSONType(typeStr string) (bsontype.Type, error) { + switch typeStr { + case "double": + return bsontype.Double, nil + case "string": + return bsontype.String, nil + case "object": + return bsontype.EmbeddedDocument, nil + case "array": + return bsontype.Array, nil + case "binData": + return bsontype.Binary, nil + case "undefined": + return bsontype.Undefined, nil + case "objectId": + return bsontype.ObjectID, nil + case "bool": + return bsontype.Boolean, nil + case "date": + return bsontype.DateTime, nil + case "null": + return bsontype.Null, nil + case "regex": + return bsontype.Regex, nil + case "dbPointer": + return bsontype.DBPointer, nil + case "javascript": + return bsontype.JavaScript, nil + case "symbol": + return bsontype.Symbol, nil + case "javascriptWithScope": + return bsontype.CodeWithScope, nil + case "int": + return bsontype.Int32, nil + case "timestamp": + return bsontype.Timestamp, nil + case "long": + return bsontype.Int64, nil + case "decimal": + return bsontype.Decimal128, nil + case "minKey": + return bsontype.MinKey, nil + case "maxKey": + return bsontype.MaxKey, nil + default: + return bsontype.Type(0), fmt.Errorf("unrecognized BSON type string %q", typeStr) + } +} + +// newMatchingError creates an error to convey that BSON value comparison failed at the provided key path. If the +// key path is empty (e.g. because the values being compared were not documents), the error message will contain the +// phrase "top-level" instead of the path. +func newMatchingError(keyPath, msg string, args ...interface{}) error { + fullMsg := fmt.Sprintf(msg, args...) + if keyPath == "" { + return fmt.Errorf("comparison error at top-level: %s", fullMsg) + } + return fmt.Errorf("comparison error at key %q: %s", keyPath, fullMsg) +} diff --git a/mongo/integration/unified/matches_test.go b/mongo/integration/unified/matches_test.go index 26070774a3..5c096a299f 100644 --- a/mongo/integration/unified/matches_test.go +++ b/mongo/integration/unified/matches_test.go @@ -7,11 +7,8 @@ package unified import ( - "bytes" "context" "encoding/hex" - "fmt" - "strings" "testing" "go.mongodb.org/mongo-driver/bson" @@ -20,323 +17,6 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -// keyPathCtxKey is used as a key for a Context object. The value conveys the BSON key path that is currently being -// compared. -type keyPathCtxKey struct{} - -// extraKeysAllowedCtxKey is used as a key for a Context object. The value conveys whether or not the document under -// test can contain extra keys. For example, if the expected document is {x: 1}, the document {x: 1, y: 1} would match -// if the value for this key is true. -type extraKeysAllowedCtxKey struct{} - -func makeMatchContext(ctx context.Context, keyPath string, extraKeysAllowed bool) context.Context { - ctx = context.WithValue(ctx, keyPathCtxKey{}, keyPath) - return context.WithValue(ctx, extraKeysAllowedCtxKey{}, extraKeysAllowed) -} - -// VerifyValuesMatch compares the provided BSON values and returns an error if they do not match. If the values are -// documents and extraKeysAllowed is true, the actual value will be allowed to have additional keys at the top-level. -// For example, an expected document {x: 1} would match the actual document {x: 1, y: 1}. -func VerifyValuesMatch(ctx context.Context, expected, actual bson.RawValue, extraKeysAllowed bool) error { - return verifyValuesMatch(makeMatchContext(ctx, "", extraKeysAllowed), expected, actual) -} - -func verifyValuesMatch(ctx context.Context, expected, actual bson.RawValue) error { - keyPath := ctx.Value(keyPathCtxKey{}).(string) - extraKeysAllowed := ctx.Value(extraKeysAllowedCtxKey{}).(bool) - - if expectedDoc, ok := expected.DocumentOK(); ok { - // If the root document only has one element and the key is a special matching operator, the actual value might - // not actually be a document. In this case, evaluate the special operator with the actual value rather than - // doing an element-wise document comparison. - if requiresSpecialMatching(expectedDoc) { - if err := evaluateSpecialComparison(ctx, expectedDoc, actual, ""); err != nil { - return newMatchingError(keyPath, "error doing special matching assertion: %v", err) - } - return nil - } - - actualDoc, ok := actual.DocumentOK() - if !ok { - return newMatchingError(keyPath, "expected value to be a document but got a %s", actual.Type) - } - - // Perform element-wise comparisons. - expectedElems, _ := expectedDoc.Elements() - for _, expectedElem := range expectedElems { - expectedKey := expectedElem.Key() - expectedValue := expectedElem.Value() - - fullKeyPath := expectedKey - if keyPath != "" { - fullKeyPath = keyPath + "." + expectedKey - } - - // Get the value from actualDoc here but don't check the error until later because some of the special - // matching operators can assert that the value isn't present in the document (e.g. $$exists). - actualValue, err := actualDoc.LookupErr(expectedKey) - if specialDoc, ok := expectedValue.DocumentOK(); ok && requiresSpecialMatching(specialDoc) { - // Reset the key path so any errors returned from the function will only have the key path for the - // target value. Also unconditionally set extraKeysAllowed to false because an assertion like - // $$unsetOrMatches could recurse back into this function. In that case, the target document is nested - // and should not have extra keys. - ctx = makeMatchContext(ctx, "", false) - if err := evaluateSpecialComparison(ctx, specialDoc, actualValue, expectedKey); err != nil { - return newMatchingError(fullKeyPath, "error doing special matching assertion: %v", err) - } - continue - } - - // This isn't a special comparison. Assert that the value exists in the actual document. - if err != nil { - return newMatchingError(fullKeyPath, "key not found in actual document") - } - - // Nested documents cannot have extra keys, so we unconditionally pass false for extraKeysAllowed. - comparisonCtx := makeMatchContext(ctx, fullKeyPath, false) - if err := verifyValuesMatch(comparisonCtx, expectedValue, actualValue); err != nil { - return err - } - } - // If required, verify that the actual document does not have extra elements. We do this by iterating over the - // actual and checking for each key in the expected rather than comparing element counts because the presence of - // special operators can cause incorrect counts. For example, the document {y: {$$exists: false}} has one - // element, but should match the document {}, which has none. - if !extraKeysAllowed { - actualElems, _ := actualDoc.Elements() - for _, actualElem := range actualElems { - if _, err := expectedDoc.LookupErr(actualElem.Key()); err != nil { - return newMatchingError(keyPath, "extra key %q found in actual document %s", actualElem.Key(), - actualDoc) - } - } - } - - return nil - } - if expectedArr, ok := expected.ArrayOK(); ok { - actualArr, ok := actual.ArrayOK() - if !ok { - return newMatchingError(keyPath, "expected value to be an array but got a %s", actual.Type) - } - - expectedValues, _ := expectedArr.Values() - actualValues, _ := actualArr.Values() - - // Arrays must always have the same number of elements. - if len(expectedValues) != len(actualValues) { - return newMatchingError(keyPath, "expected array length %d, got %d", len(expectedValues), - len(actualValues)) - } - - for idx, expectedValue := range expectedValues { - // Use the index as the key to augment the key path. - fullKeyPath := fmt.Sprintf("%d", idx) - if keyPath != "" { - fullKeyPath = keyPath + "." + fullKeyPath - } - - comparisonCtx := makeMatchContext(ctx, fullKeyPath, extraKeysAllowed) - err := verifyValuesMatch(comparisonCtx, expectedValue, actualValues[idx]) - if err != nil { - return err - } - } - - return nil - } - - // Numeric values must be considered equal even if their types are different (e.g. if expected is an int32 and - // actual is an int64). - if expected.IsNumber() { - if !actual.IsNumber() { - return newMatchingError(keyPath, "expected value to be a number but got a %s", actual.Type) - } - - expectedInt64 := expected.AsInt64() - actualInt64 := actual.AsInt64() - if expectedInt64 != actualInt64 { - return newMatchingError(keyPath, "expected numeric value %d, got %d", expectedInt64, actualInt64) - } - return nil - } - - // If expected is not a recursive or numeric type, we can directly call Equal to do the comparison. - if !expected.Equal(actual) { - return newMatchingError(keyPath, "expected value %s, got %s", expected, actual) - } - return nil -} - -func evaluateSpecialComparison(ctx context.Context, assertionDoc bson.Raw, actual bson.RawValue, fieldName string) error { - assertionElem := assertionDoc.Index(0) - assertion := assertionElem.Key() - assertionVal := assertionElem.Value() - - switch assertion { - case "$$exists": - shouldExist := assertionVal.Boolean() - exists := actual.Validate() == nil - if shouldExist != exists { - return fmt.Errorf("expected value to exist: %v; value actually exists: %v", shouldExist, exists) - } - case "$$type": - possibleTypes, err := getTypesArray(assertionVal) - if err != nil { - return fmt.Errorf("error getting possible types for a $$type assertion: %v", err) - } - - for _, possibleType := range possibleTypes { - if actual.Type == possibleType { - return nil - } - } - return fmt.Errorf("expected type to be one of %v but was %s", possibleTypes, actual.Type) - case "$$matchesEntity": - expected, err := Entities(ctx).BSONValue(assertionVal.StringValue()) - if err != nil { - return err - } - - // $$matchesEntity doesn't modify the nesting level of the key path so we can propagate ctx without changes. - return verifyValuesMatch(ctx, expected, actual) - case "$$matchesHexBytes": - expectedBytes, err := hex.DecodeString(assertionVal.StringValue()) - if err != nil { - return fmt.Errorf("error converting $$matcesHexBytes value to bytes: %v", err) - } - - _, actualBytes, ok := actual.BinaryOK() - if !ok { - return fmt.Errorf("expected binary value for a $$matchesHexBytes assertion, but got a %s", actual.Type) - } - if !bytes.Equal(expectedBytes, actualBytes) { - return fmt.Errorf("expected bytes %v, got %v", expectedBytes, actualBytes) - } - case "$$unsetOrMatches": - if actual.Validate() != nil { - return nil - } - - // $$unsetOrMatches doesn't modify the nesting level or the key path so we can propagate the context to the - // comparison function without changing anything. - return verifyValuesMatch(ctx, assertionVal, actual) - case "$$sessionLsid": - sess, err := Entities(ctx).Session(assertionVal.StringValue()) - if err != nil { - return err - } - - expectedID := sess.ID() - actualID, ok := actual.DocumentOK() - if !ok { - return fmt.Errorf("expected document value for a $$sessionLsid assertion, but got a %s", actual.Type) - } - if !bytes.Equal(expectedID, actualID) { - return fmt.Errorf("expected lsid %v, got %v", expectedID, actualID) - } - default: - return fmt.Errorf("unrecognized special matching assertion %q", assertion) - } - - return nil -} - -func requiresSpecialMatching(doc bson.Raw) bool { - elems, _ := doc.Elements() - return len(elems) == 1 && strings.HasPrefix(elems[0].Key(), "$$") -} - -func getTypesArray(val bson.RawValue) ([]bsontype.Type, error) { - switch val.Type { - case bsontype.String: - convertedType, err := convertStringToBSONType(val.StringValue()) - if err != nil { - return nil, err - } - - return []bsontype.Type{convertedType}, nil - case bsontype.Array: - var typeStrings []string - if err := val.Unmarshal(&typeStrings); err != nil { - return nil, fmt.Errorf("error unmarshalling to slice of strings: %v", err) - } - - var types []bsontype.Type - for _, typeStr := range typeStrings { - convertedType, err := convertStringToBSONType(typeStr) - if err != nil { - return nil, err - } - - types = append(types, convertedType) - } - return types, nil - default: - return nil, fmt.Errorf("invalid type to convert to bsontype.Type slice: %s", val.Type) - } -} - -func convertStringToBSONType(typeStr string) (bsontype.Type, error) { - switch typeStr { - case "double": - return bsontype.Double, nil - case "string": - return bsontype.String, nil - case "object": - return bsontype.EmbeddedDocument, nil - case "array": - return bsontype.Array, nil - case "binData": - return bsontype.Binary, nil - case "undefined": - return bsontype.Undefined, nil - case "objectId": - return bsontype.ObjectID, nil - case "bool": - return bsontype.Boolean, nil - case "date": - return bsontype.DateTime, nil - case "null": - return bsontype.Null, nil - case "regex": - return bsontype.Regex, nil - case "dbPointer": - return bsontype.DBPointer, nil - case "javascript": - return bsontype.JavaScript, nil - case "symbol": - return bsontype.Symbol, nil - case "javascriptWithScope": - return bsontype.CodeWithScope, nil - case "int": - return bsontype.Int32, nil - case "timestamp": - return bsontype.Timestamp, nil - case "long": - return bsontype.Int64, nil - case "decimal": - return bsontype.Decimal128, nil - case "minKey": - return bsontype.MinKey, nil - case "maxKey": - return bsontype.MaxKey, nil - default: - return bsontype.Type(0), fmt.Errorf("unrecognized BSON type string %q", typeStr) - } -} - -// newMatchingError creates an error to convey that BSON value comparison failed at the provided key path. If the -// key path is empty (e.g. because the values being compared were not documents), the error message will contain the -// phrase "top-level" instead of the path. -func newMatchingError(keyPath, msg string, args ...interface{}) error { - fullMsg := fmt.Sprintf(msg, args...) - if keyPath == "" { - return fmt.Errorf("comparison error at top-level: %s", fullMsg) - } - return fmt.Errorf("comparison error at key %q: %s", keyPath, fullMsg) -} - func TestMatches(t *testing.T) { ctx := context.Background() unmarshalExtJSONValue := func(t *testing.T, str string) bson.RawValue { @@ -365,7 +45,7 @@ func TestMatches(t *testing.T) { assertMatches := func(t *testing.T, expected, actual bson.RawValue, shouldMatch bool) { t.Helper() - err := VerifyValuesMatch(ctx, expected, actual, true) + err := verifyValuesMatch(ctx, expected, actual, true) if shouldMatch { assert.Nil(t, err, "expected values to match, but got comparison error %v", err) return @@ -399,7 +79,7 @@ func TestMatches(t *testing.T) { t.Run("documents with extra keys not allowed", func(t *testing.T) { expected := unmarshalExtJSONValue(t, `{"x": 1}`) actual := unmarshalExtJSONValue(t, `{"x": 1, "y": 1}`) - err := VerifyValuesMatch(ctx, expected, actual, false) + err := verifyValuesMatch(ctx, expected, actual, false) assert.NotNil(t, err, "expected values to not match, but got no error") }) t.Run("exists operator", func(t *testing.T) { diff --git a/mongo/integration/unified/operation_test.go b/mongo/integration/unified/operation.go similarity index 83% rename from mongo/integration/unified/operation_test.go rename to mongo/integration/unified/operation.go index 01a8efecde..64b6b56b91 100644 --- a/mongo/integration/unified/operation_test.go +++ b/mongo/integration/unified/operation.go @@ -14,44 +14,44 @@ import ( "go.mongodb.org/mongo-driver/mongo" ) -type Operation struct { +type operation struct { Name string `bson:"name"` Object string `bson:"object"` Arguments bson.Raw `bson:"arguments"` - ExpectedError *ExpectedError `bson:"expectError"` + ExpectedError *expectedError `bson:"expectError"` ExpectedResult *bson.RawValue `bson:"expectResult"` ResultEntityID *string `bson:"saveResultAsEntity"` } -// Execute runs the operation and verifies the returned result and/or error. If the result needs to be saved as -// an entity, it also updates the EntityMap associated with ctx to do so. -func (op *Operation) Execute(ctx context.Context) error { +// execute runs the operation and verifies the returned result and/or error. If the result needs to be saved as +// an entity, it also updates the entityMap associated with ctx to do so. +func (op *operation) execute(ctx context.Context) error { res, err := op.run(ctx) if err != nil { return fmt.Errorf("execution failed: %v", err) } - if err := VerifyOperationError(ctx, op.ExpectedError, res); err != nil { + if err := verifyOperationError(ctx, op.ExpectedError, res); err != nil { return fmt.Errorf("error verification failed: %v", err) } if op.ExpectedResult != nil { - if err := VerifyOperationResult(ctx, *op.ExpectedResult, res); err != nil { + if err := verifyOperationResult(ctx, *op.ExpectedResult, res); err != nil { return fmt.Errorf("result verification failed: %v", err) } } return nil } -func (op *Operation) run(ctx context.Context) (*OperationResult, error) { +func (op *operation) run(ctx context.Context) (*operationResult, error) { if op.Object == "testRunner" { - // testRunner operations don't have results or expected errors, so we use NewEmptyResult to fake a result. - return NewEmptyResult(), executeTestRunnerOperation(ctx, op) + // testRunner operations don't have results or expected errors, so we use newEmptyResult to fake a result. + return newEmptyResult(), executeTestRunnerOperation(ctx, op) } // Special handling for the "session" field because it applies to all operations. if id, ok := op.Arguments.Lookup("session").StringValueOK(); ok { - sess, err := Entities(ctx).Session(id) + sess, err := entities(ctx).session(id) if err != nil { return nil, err } @@ -59,7 +59,7 @@ func (op *Operation) run(ctx context.Context) (*OperationResult, error) { // Set op.Arguments to a new document that has the "session" field removed so individual operations do // not have to account for it. - op.Arguments = RemoveFieldsFromDocument(op.Arguments, "session") + op.Arguments = removeFieldsFromDocument(op.Arguments, "session") } switch op.Name { @@ -70,12 +70,12 @@ func (op *Operation) run(ctx context.Context) (*OperationResult, error) { return executeCommitTransaction(ctx, op) case "endSession": // The EndSession() method doesn't return a result, so we return a non-nil empty result. - return NewEmptyResult(), executeEndSession(ctx, op) + return newEmptyResult(), executeEndSession(ctx, op) case "startTransaction": return executeStartTransaction(ctx, op) case "withTransaction": // executeWithTransaction internally verifies results/errors for each operation, so it doesn't return a result. - return NewEmptyResult(), executeWithTransaction(ctx, op) + return newEmptyResult(), executeWithTransaction(ctx, op) // Client operations case "createChangeStream": diff --git a/mongo/integration/unified/result_test.go b/mongo/integration/unified/result.go similarity index 67% rename from mongo/integration/unified/result_test.go rename to mongo/integration/unified/result.go index 111c6458ad..6c7a71a565 100644 --- a/mongo/integration/unified/result_test.go +++ b/mongo/integration/unified/result.go @@ -14,8 +14,8 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" ) -// OperationResult holds the result and/or error returned by an op. -type OperationResult struct { +// operationResult holds the result and/or error returned by an op. +type operationResult struct { // For operations that return a single result, this field holds a BSON representation. Result bson.RawValue @@ -27,29 +27,29 @@ type OperationResult struct { Err error } -// NewEmptyResult returns an OperationResult with no fields set. This should be used if the operation does not check +// newEmptyResult returns an operationResult with no fields set. This should be used if the operation does not check // results or errors. -func NewEmptyResult() *OperationResult { - return &OperationResult{} +func newEmptyResult() *operationResult { + return &operationResult{} } -// NewDocumentResult is a helper to create a value result where the value is a BSON document. -func NewDocumentResult(result []byte, err error) *OperationResult { - return NewValueResult(bsontype.EmbeddedDocument, result, err) +// newDocumentResult is a helper to create a value result where the value is a BSON document. +func newDocumentResult(result []byte, err error) *operationResult { + return newValueResult(bsontype.EmbeddedDocument, result, err) } -// NewValueResult creates an OperationResult where the result is a BSON value of an arbitrary type. Because some +// newValueResult creates an operationResult where the result is a BSON value of an arbitrary type. Because some // operations can return both a result and an error (e.g. bulkWrite), the err parameter should be the error returned // by the op, if any. -func NewValueResult(valueType bsontype.Type, data []byte, err error) *OperationResult { - return &OperationResult{ +func newValueResult(valueType bsontype.Type, data []byte, err error) *operationResult { + return &operationResult{ Result: bson.RawValue{Type: valueType, Value: data}, Err: err, } } -// NewCursorResult creates an OperationResult that contains documents retrieved by fully iterating a cursor. -func NewCursorResult(arr []bson.Raw) *OperationResult { +// newCursorResult creates an operationResult that contains documents retrieved by fully iterating a cursor. +func newCursorResult(arr []bson.Raw) *operationResult { // If the operation returned no documents, the array might be nil. It isn't possible to distinguish between this // case and the case where there is no cursor result, so we overwrite the result with an non-nil empty slice. result := arr @@ -57,20 +57,21 @@ func NewCursorResult(arr []bson.Raw) *OperationResult { result = make([]bson.Raw, 0) } - return &OperationResult{ + return &operationResult{ CursorResult: result, } } -// NewErrorResult creates an OperationResult that only holds an error. This should only be used when executing an +// newErrorResult creates an operationResult that only holds an error. This should only be used when executing an // operation that can return a result or an error, but not both. -func NewErrorResult(err error) *OperationResult { - return &OperationResult{ +func newErrorResult(err error) *operationResult { + return &operationResult{ Err: err, } } -func VerifyOperationResult(ctx context.Context, expected bson.RawValue, actual *OperationResult) error { +// verifyOperationResult checks that the actual and expected results match +func verifyOperationResult(ctx context.Context, expected bson.RawValue, actual *operationResult) error { actualVal := actual.Result if actual.CursorResult != nil { _, data, err := bson.MarshalValue(actual.CursorResult) @@ -88,5 +89,5 @@ func VerifyOperationResult(ctx context.Context, expected bson.RawValue, actual * // top-level keys. Single-value array results (e.g. from distinct) must match exactly, so we set extraKeysAllowed to // false only for that case. extraKeysAllowed := actual.Result.Type != bsontype.Array - return VerifyValuesMatch(ctx, expected, actualVal, extraKeysAllowed) + return verifyValuesMatch(ctx, expected, actualVal, extraKeysAllowed) } diff --git a/mongo/integration/unified/schema_version_test.go b/mongo/integration/unified/schema_version.go similarity index 93% rename from mongo/integration/unified/schema_version_test.go rename to mongo/integration/unified/schema_version.go index 1a56ac61c6..1721f5f509 100644 --- a/mongo/integration/unified/schema_version_test.go +++ b/mongo/integration/unified/schema_version.go @@ -20,8 +20,8 @@ var ( } ) -// CheckSchemaVersion determines if the provided schema version is supported and returns an error if it is not. -func CheckSchemaVersion(version string) error { +// checkSchemaVersion determines if the provided schema version is supported and returns an error if it is not. +func checkSchemaVersion(version string) error { // First get the major version number from the schema. The schema version string should be in the format // "major.minor.patch", "major.minor", or "major". diff --git a/mongo/integration/unified/server_api_options_test.go b/mongo/integration/unified/server_api_options.go similarity index 67% rename from mongo/integration/unified/server_api_options_test.go rename to mongo/integration/unified/server_api_options.go index e986b78ca6..233203b9d4 100644 --- a/mongo/integration/unified/server_api_options_test.go +++ b/mongo/integration/unified/server_api_options.go @@ -13,28 +13,28 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -// ServerAPIOptions is a wrapper for *options.ServerAPIOptions. This type implements the bson.Unmarshaler interface -// to convert BSON documents to a ServerAPIOptions instance. -type ServerAPIOptions struct { +// serverAPIOptions is a wrapper for *options.ServerAPIOptions. This type implements the bson.Unmarshaler interface +// to convert BSON documents to a serverAPIOptions instance. +type serverAPIOptions struct { *options.ServerAPIOptions } -type ServerAPIVersion = options.ServerAPIVersion +type serverAPIVersion = options.ServerAPIVersion -var _ bson.Unmarshaler = (*ServerAPIOptions)(nil) +var _ bson.Unmarshaler = (*serverAPIOptions)(nil) -func (s *ServerAPIOptions) UnmarshalBSON(data []byte) error { +func (s *serverAPIOptions) UnmarshalBSON(data []byte) error { var temp struct { - ServerAPIVersion ServerAPIVersion `bson:"version"` + ServerAPIVersion serverAPIVersion `bson:"version"` DeprecationErrors *bool `bson:"deprecationErrors"` Strict *bool `bson:"strict"` Extra map[string]interface{} `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { - return fmt.Errorf("error unmarshalling to temporary ServerAPIOptions object: %v", err) + return fmt.Errorf("error unmarshalling to temporary serverAPIOptions object: %v", err) } if len(temp.Extra) > 0 { - return fmt.Errorf("unrecognized fields for ServerAPIOptions: %v", MapKeys(temp.Extra)) + return fmt.Errorf("unrecognized fields for serverAPIOptions: %v", mapKeys(temp.Extra)) } if err := temp.ServerAPIVersion.Validate(); err != nil { diff --git a/mongo/integration/unified/session_operation_execution_test.go b/mongo/integration/unified/session_operation_execution.go similarity index 60% rename from mongo/integration/unified/session_operation_execution_test.go rename to mongo/integration/unified/session_operation_execution.go index 8896cc8e61..7fe5f57b05 100644 --- a/mongo/integration/unified/session_operation_execution_test.go +++ b/mongo/integration/unified/session_operation_execution.go @@ -15,8 +15,8 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -func executeAbortTransaction(ctx context.Context, operation *Operation) (*OperationResult, error) { - sess, err := Entities(ctx).Session(operation.Object) +func executeAbortTransaction(ctx context.Context, operation *operation) (*operationResult, error) { + sess, err := entities(ctx).session(operation.Object) if err != nil { return nil, err } @@ -27,11 +27,11 @@ func executeAbortTransaction(ctx context.Context, operation *Operation) (*Operat return nil, fmt.Errorf("unrecognized abortTransaction options %v", operation.Arguments) } - return NewErrorResult(sess.AbortTransaction(ctx)), nil + return newErrorResult(sess.AbortTransaction(ctx)), nil } -func executeEndSession(ctx context.Context, operation *Operation) error { - sess, err := Entities(ctx).Session(operation.Object) +func executeEndSession(ctx context.Context, operation *operation) error { + sess, err := entities(ctx).session(operation.Object) if err != nil { return err } @@ -46,8 +46,8 @@ func executeEndSession(ctx context.Context, operation *Operation) error { return nil } -func executeCommitTransaction(ctx context.Context, operation *Operation) (*OperationResult, error) { - sess, err := Entities(ctx).Session(operation.Object) +func executeCommitTransaction(ctx context.Context, operation *operation) (*operationResult, error) { + sess, err := entities(ctx).session(operation.Object) if err != nil { return nil, err } @@ -58,55 +58,55 @@ func executeCommitTransaction(ctx context.Context, operation *Operation) (*Opera return nil, fmt.Errorf("unrecognized commitTransaction options %v", operation.Arguments) } - return NewErrorResult(sess.CommitTransaction(ctx)), nil + return newErrorResult(sess.CommitTransaction(ctx)), nil } -func executeStartTransaction(ctx context.Context, operation *Operation) (*OperationResult, error) { - sess, err := Entities(ctx).Session(operation.Object) +func executeStartTransaction(ctx context.Context, operation *operation) (*operationResult, error) { + sess, err := entities(ctx).session(operation.Object) if err != nil { return nil, err } opts := options.Transaction() if operation.Arguments != nil { - var temp TransactionOptions + var temp transactionOptions if err := bson.Unmarshal(operation.Arguments, &temp); err != nil { - return nil, fmt.Errorf("error unmarshalling arguments to TransactionOptions: %v", err) + return nil, fmt.Errorf("error unmarshalling arguments to transactionOptions: %v", err) } opts = temp.TransactionOptions } - return NewErrorResult(sess.StartTransaction(opts)), nil + return newErrorResult(sess.StartTransaction(opts)), nil } -func executeWithTransaction(ctx context.Context, operation *Operation) error { - sess, err := Entities(ctx).Session(operation.Object) +func executeWithTransaction(ctx context.Context, op *operation) error { + sess, err := entities(ctx).session(op.Object) if err != nil { return err } - // Process the "callback" argument. This is an array of Operation objects, each of which should be executed inside + // Process the "callback" argument. This is an array of operation objects, each of which should be executed inside // the transaction. - callback, err := operation.Arguments.LookupErr("callback") + callback, err := op.Arguments.LookupErr("callback") if err != nil { return newMissingArgumentError("callback") } - var operations []*Operation + var operations []*operation if err := callback.Unmarshal(&operations); err != nil { return fmt.Errorf("error transforming callback option to slice of operations: %v", err) } // Remove the "callback" field and process the other options. - var temp TransactionOptions - if err := bson.Unmarshal(RemoveFieldsFromDocument(operation.Arguments, "callback"), &temp); err != nil { - return fmt.Errorf("error unmarshalling arguments to TransactionOptions: %v", err) + var temp transactionOptions + if err := bson.Unmarshal(removeFieldsFromDocument(op.Arguments, "callback"), &temp); err != nil { + return fmt.Errorf("error unmarshalling arguments to transactionOptions: %v", err) } _, err = sess.WithTransaction(ctx, func(sessCtx mongo.SessionContext) (interface{}, error) { - for idx, op := range operations { - if err := op.Execute(ctx); err != nil { - return nil, fmt.Errorf("error executing operation %q at index %d: %v", op.Name, idx, err) + for idx, oper := range operations { + if err := oper.execute(ctx); err != nil { + return nil, fmt.Errorf("error executing operation %q at index %d: %v", oper.Name, idx, err) } } return nil, nil diff --git a/mongo/integration/unified/session_options_test.go b/mongo/integration/unified/session_options.go similarity index 73% rename from mongo/integration/unified/session_options_test.go rename to mongo/integration/unified/session_options.go index 6b9e230cef..1ee1daaefb 100644 --- a/mongo/integration/unified/session_options_test.go +++ b/mongo/integration/unified/session_options.go @@ -14,15 +14,15 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -// TransactionOptions is a wrapper for *options.TransactionOptions. This type implements the bson.Unmarshaler interface -// to convert BSON documents to a TransactionOptions instance. -type TransactionOptions struct { +// transactionOptions is a wrapper for *options.transactionOptions. This type implements the bson.Unmarshaler interface +// to convert BSON documents to a transactionOptions instance. +type transactionOptions struct { *options.TransactionOptions } -var _ bson.Unmarshaler = (*TransactionOptions)(nil) +var _ bson.Unmarshaler = (*transactionOptions)(nil) -func (to *TransactionOptions) UnmarshalBSON(data []byte) error { +func (to *transactionOptions) UnmarshalBSON(data []byte) error { var temp struct { RC *readConcern `bson:"readConcern"` RP *readPreference `bson:"readPreference"` @@ -31,10 +31,10 @@ func (to *TransactionOptions) UnmarshalBSON(data []byte) error { Extra map[string]interface{} `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { - return fmt.Errorf("error unmarshalling to temporary TransactionOptions object: %v", err) + return fmt.Errorf("error unmarshalling to temporary transactionOptions object: %v", err) } if len(temp.Extra) > 0 { - return fmt.Errorf("unrecognized fields for TransactionOptions: %v", MapKeys(temp.Extra)) + return fmt.Errorf("unrecognized fields for transactionOptions: %v", mapKeys(temp.Extra)) } to.TransactionOptions = options.Transaction() @@ -62,26 +62,26 @@ func (to *TransactionOptions) UnmarshalBSON(data []byte) error { return nil } -// SessionOptions is a wrapper for *options.SessionOptions. This type implements the bson.Unmarshaler interface to -// convert BSON documents to a SessionOptions instance. -type SessionOptions struct { +// sessionOptions is a wrapper for *options.sessionOptions. This type implements the bson.Unmarshaler interface to +// convert BSON documents to a sessionOptions instance. +type sessionOptions struct { *options.SessionOptions } -var _ bson.Unmarshaler = (*SessionOptions)(nil) +var _ bson.Unmarshaler = (*sessionOptions)(nil) -func (so *SessionOptions) UnmarshalBSON(data []byte) error { +func (so *sessionOptions) UnmarshalBSON(data []byte) error { var temp struct { Causal *bool `bson:"causalConsistency"` MaxCommitTimeMS *int64 `bson:"maxCommitTimeMS"` - TxnOptions *TransactionOptions `bson:"defaultTransactionOptions"` + TxnOptions *transactionOptions `bson:"defaultTransactionOptions"` Extra map[string]interface{} `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { - return fmt.Errorf("error unmarshalling to temporary SessionOptions object: %v", err) + return fmt.Errorf("error unmarshalling to temporary sessionOptions object: %v", err) } if len(temp.Extra) > 0 { - return fmt.Errorf("unrecognized fields for SessionOptions: %v", MapKeys(temp.Extra)) + return fmt.Errorf("unrecognized fields for sessionOptions: %v", mapKeys(temp.Extra)) } so.SessionOptions = options.Session() diff --git a/mongo/integration/unified/testrunner_operation_test.go b/mongo/integration/unified/testrunner_operation.go similarity index 78% rename from mongo/integration/unified/testrunner_operation_test.go rename to mongo/integration/unified/testrunner_operation.go index 5ed63dd906..b7785a08e8 100644 --- a/mongo/integration/unified/testrunner_operation_test.go +++ b/mongo/integration/unified/testrunner_operation.go @@ -16,13 +16,13 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) -func executeTestRunnerOperation(ctx context.Context, operation *Operation) error { +func executeTestRunnerOperation(ctx context.Context, operation *operation) error { args := operation.Arguments switch operation.Name { case "failPoint": - clientID := LookupString(args, "client") - client, err := Entities(ctx).Client(clientID) + clientID := lookupString(args, "client") + client, err := entities(ctx).client(clientID) if err != nil { return err } @@ -31,10 +31,10 @@ func executeTestRunnerOperation(ctx context.Context, operation *Operation) error if err := mtest.SetRawFailPoint(fpDoc, client.Client); err != nil { return err } - return AddFailPoint(ctx, fpDoc.Index(0).Value().StringValue(), client.Client) + return addFailPoint(ctx, fpDoc.Index(0).Value().StringValue(), client.Client) case "targetedFailPoint": - sessID := LookupString(args, "session") - sess, err := Entities(ctx).Session(sessID) + sessID := lookupString(args, "session") + sess, err := entities(ctx).session(sessID) if err != nil { return err } @@ -50,19 +50,19 @@ func executeTestRunnerOperation(ctx context.Context, operation *Operation) error return mtest.SetRawFailPoint(fpDoc, client) } - if err := RunCommandOnHost(ctx, targetHost, commandFn); err != nil { + if err := runCommandOnHost(ctx, targetHost, commandFn); err != nil { return err } - return AddTargetedFailPoint(ctx, fpDoc.Index(0).Value().StringValue(), targetHost) + return addTargetedFailPoint(ctx, fpDoc.Index(0).Value().StringValue(), targetHost) case "assertSessionTransactionState": - sessID := LookupString(args, "session") - sess, err := Entities(ctx).Session(sessID) + sessID := lookupString(args, "session") + sess, err := entities(ctx).session(sessID) if err != nil { return err } var expectedState session.TransactionState - switch stateStr := LookupString(args, "state"); stateStr { + switch stateStr := lookupString(args, "state"); stateStr { case "none": expectedState = session.None case "starting": @@ -82,34 +82,34 @@ func executeTestRunnerOperation(ctx context.Context, operation *Operation) error } return nil case "assertSessionPinned": - return verifySessionPinnedState(ctx, LookupString(args, "session"), true) + return verifySessionPinnedState(ctx, lookupString(args, "session"), true) case "assertSessionUnpinned": - return verifySessionPinnedState(ctx, LookupString(args, "session"), false) + return verifySessionPinnedState(ctx, lookupString(args, "session"), false) case "assertSameLsidOnLastTwoCommands": - return verifyLastTwoLsidsEqual(ctx, LookupString(args, "client"), true) + return verifyLastTwoLsidsEqual(ctx, lookupString(args, "client"), true) case "assertDifferentLsidOnLastTwoCommands": - return verifyLastTwoLsidsEqual(ctx, LookupString(args, "client"), false) + return verifyLastTwoLsidsEqual(ctx, lookupString(args, "client"), false) case "assertSessionDirty": - return verifySessionDirtyState(ctx, LookupString(args, "session"), true) + return verifySessionDirtyState(ctx, lookupString(args, "session"), true) case "assertSessionNotDirty": - return verifySessionDirtyState(ctx, LookupString(args, "session"), false) + return verifySessionDirtyState(ctx, lookupString(args, "session"), false) case "assertCollectionExists": - db := LookupString(args, "databaseName") - coll := LookupString(args, "collectionName") + db := lookupString(args, "databaseName") + coll := lookupString(args, "collectionName") return verifyCollectionExists(ctx, db, coll, true) case "assertCollectionNotExists": - db := LookupString(args, "databaseName") - coll := LookupString(args, "collectionName") + db := lookupString(args, "databaseName") + coll := lookupString(args, "collectionName") return verifyCollectionExists(ctx, db, coll, false) case "assertIndexExists": - db := LookupString(args, "databaseName") - coll := LookupString(args, "collectionName") - index := LookupString(args, "indexName") + db := lookupString(args, "databaseName") + coll := lookupString(args, "collectionName") + index := lookupString(args, "indexName") return verifyIndexExists(ctx, db, coll, index, true) case "assertIndexNotExists": - db := LookupString(args, "databaseName") - coll := LookupString(args, "collectionName") - index := LookupString(args, "indexName") + db := lookupString(args, "databaseName") + coll := lookupString(args, "collectionName") + index := lookupString(args, "indexName") return verifyIndexExists(ctx, db, coll, index, false) default: return fmt.Errorf("unrecognized testRunner operation %q", operation.Name) @@ -121,7 +121,7 @@ func extractClientSession(sess mongo.Session) *session.Client { } func verifySessionPinnedState(ctx context.Context, sessionID string, expectedPinned bool) error { - sess, err := Entities(ctx).Session(sessionID) + sess, err := entities(ctx).session(sessionID) if err != nil { return err } @@ -133,12 +133,12 @@ func verifySessionPinnedState(ctx context.Context, sessionID string, expectedPin } func verifyLastTwoLsidsEqual(ctx context.Context, clientID string, expectedEqual bool) error { - client, err := Entities(ctx).Client(clientID) + client, err := entities(ctx).client(clientID) if err != nil { return err } - allEvents := client.StartedEvents() + allEvents := client.startedEvents() if len(allEvents) < 2 { return fmt.Errorf("client has recorded fewer than two command started events") } @@ -164,7 +164,7 @@ func verifyLastTwoLsidsEqual(ctx context.Context, clientID string, expectedEqual } func verifySessionDirtyState(ctx context.Context, sessionID string, expectedDirty bool) error { - sess, err := Entities(ctx).Session(sessionID) + sess, err := entities(ctx).session(sessionID) if err != nil { return err } @@ -200,7 +200,7 @@ func verifyIndexExists(ctx context.Context, dbName, collName, indexName string, var exists bool for cursor.Next(ctx) { - if LookupString(cursor.Current, "name") == indexName { + if lookupString(cursor.Current, "name") == indexName { exists = true break } diff --git a/mongo/integration/unified/unified_spec_runner.go b/mongo/integration/unified/unified_spec_runner.go new file mode 100644 index 0000000000..a545810aff --- /dev/null +++ b/mongo/integration/unified/unified_spec_runner.go @@ -0,0 +1,241 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package unified + +import ( + "context" + "fmt" + "io/ioutil" + "path" + "testing" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/internal/testutil/assert" + testhelpers "go.mongodb.org/mongo-driver/internal/testutil/helpers" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/integration/mtest" +) + +var ( + skippedTestDescriptions = map[string]struct{}{ + // GODRIVER-1773: This test runs a "find" with limit=4 and batchSize=3. It expects batchSize values of three for + // the "find" and one for the "getMore", but we send three for both. + "A successful find event with a getmore and the server kills the cursor": {}, + + // This test expects the driver to raise a client-side error when inserting a document with a key that contains + // a "." or "$". We don't do this validation and the server is moving towards supporting this, so we don't have + // any plans to add it. This test will need to be changed once the server support lands anyway. + "Client side error in command starting transaction": {}, + } +) + +const ( + lowHeartbeatFrequency int32 = 50 +) + +type testCase struct { + Description string `bson:"description"` + RunOnRequirements []mtest.RunOnBlock `bson:"runOnRequirements"` + SkipReason *string `bson:"skipReason"` + Operations []*operation `bson:"operations"` + expectedEvents []*expectedEvents `bson:"expectEvents"` + Outcome []*collectionData `bson:"outcome"` +} + +func (t *testCase) performsDistinct() bool { + return t.performsOperation("distinct") +} + +func (t *testCase) setsFailPoint() bool { + return t.performsOperation("failPoint") +} + +func (t *testCase) startsTransaction() bool { + return t.performsOperation("startTransaction") +} + +func (t *testCase) performsOperation(name string) bool { + for _, op := range t.Operations { + if op.Name == name { + return true + } + } + return false +} + +// TestFile holds the contents of a unified spec test file +type TestFile struct { + Description string `bson:"description"` + SchemaVersion string `bson:"schemaVersion"` + RunOnRequirements []mtest.RunOnBlock `bson:"runOnRequirements"` + CreateEntities []map[string]*entityOptions `bson:"createEntities"` + InitialData []*collectionData `bson:"initialData"` + TestCases []*testCase `bson:"tests"` +} + +// runTestDirectory runs the files in the given directory, which must be in the unifed spec format +func runTestDirectory(t *testing.T, directoryPath string) { + for _, filename := range testhelpers.FindJSONFilesInDir(t, directoryPath) { + t.Run(filename, func(t *testing.T) { + RunTestFile(t, path.Join(directoryPath, filename)) + }) + } +} + +// RunTestFile runs the tests in the given file, which must be in the unifed spec format +func RunTestFile(t *testing.T, filepath string) { + content, err := ioutil.ReadFile(filepath) + assert.Nil(t, err, "ReadFile error for file %q: %v", filepath, err) + + var testFile TestFile + err = bson.UnmarshalExtJSON(content, false, &testFile) + assert.Nil(t, err, "UnmarshalExtJSON error for file %q: %v", filepath, err) + + // Validate that we support the schema declared by the test file before attempting to use its contents. + err = checkSchemaVersion(testFile.SchemaVersion) + assert.Nil(t, err, "schema version %q not supported: %v", testFile.SchemaVersion, err) + + // Create mtest wrapper, which will skip the test if needed. + mtOpts := mtest.NewOptions(). + RunOn(testFile.RunOnRequirements...). + CreateClient(false) + mt := mtest.New(t, mtOpts) + defer mt.Close() + + for _, testCase := range testFile.TestCases { + mtOpts := mtest.NewOptions(). + RunOn(testCase.RunOnRequirements...). + CreateClient(false) + mt.RunOpts(testCase.Description, mtOpts, func(mt *mtest.T) { + runTestCase(mt, testFile, testCase) + }) + } +} + +func runTestCase(mt *mtest.T, testFile TestFile, testCase *testCase) { + if testCase.SkipReason != nil { + mt.Skipf("skipping for reason: %q", *testCase.SkipReason) + } + if _, ok := skippedTestDescriptions[testCase.Description]; ok { + mt.Skip("skipping due to known failure") + } + + testCtx := newTestContext(mtest.Background) + + defer func() { + // If anything fails while doing test cleanup, we only log the error because the actual test may have already + // failed and that failure should be preserved. + + for _, err := range disableUntargetedFailPoints(testCtx) { + mt.Log(err) + } + for _, err := range disableTargetedFailPoints(testCtx) { + mt.Log(err) + } + for _, err := range entities(testCtx).close(testCtx) { + mt.Log(err) + } + // Tests that started a transaction should terminate any sessions left open on the server. This is required even + // if the test attempted to commit/abort the transaction because an abortTransaction command can fail if it's + // sent to a mongos that isn't aware of the transaction. + if testCase.startsTransaction() { + if err := terminateOpenSessions(mtest.Background); err != nil { + mt.Logf("error terminating open transactions after failed test: %v", err) + } + } + }() + + // Set up collections based on the file-level initialData field. + for _, collData := range testFile.InitialData { + if err := collData.createCollection(testCtx); err != nil { + mt.Fatalf("error setting up collection %q: %v", collData.namespace(), err) + } + } + + // Set up entities based on the file-level createEntities field. For client entities, if the test will configure + // a fail point, set a low heartbeatFrequencyMS value into the URI options map if one is not already present. + // This speeds up recovery time for the client if the fail point forces the server to return a state change + // error. + shouldSetHeartbeatFrequency := testCase.setsFailPoint() + for idx, entity := range testFile.CreateEntities { + for entityType, entityOptions := range entity { + if shouldSetHeartbeatFrequency && entityType == "client" { + if entityOptions.URIOptions == nil { + entityOptions.URIOptions = make(bson.M) + } + if _, ok := entityOptions.URIOptions["heartbeatFrequencyMS"]; !ok { + entityOptions.URIOptions["heartbeatFrequencyMS"] = lowHeartbeatFrequency + } + } + + if err := entities(testCtx).addEntity(testCtx, entityType, entityOptions); err != nil { + mt.Fatalf("error creating entity at index %d: %v", idx, err) + } + } + } + + // Work around SERVER-39704. + if mtest.ClusterTopologyKind() == mtest.Sharded && testCase.performsDistinct() { + if err := performDistinctWorkaround(testCtx); err != nil { + mt.Fatalf("error performing \"distinct\" workaround: %v", err) + } + } + + for idx, operation := range testCase.Operations { + err := operation.execute(testCtx) + assert.Nil(mt, err, "error running operation %q at index %d: %v", operation.Name, idx, err) + } + + for _, client := range entities(testCtx).clients() { + client.StopListeningForEvents() + } + + for idx, expectedEvents := range testCase.expectedEvents { + err := verifyEvents(testCtx, expectedEvents) + assert.Nil(mt, err, "events verification failed at index %d: %v", idx, err) + } + + for idx, collData := range testCase.Outcome { + err := collData.verifyContents(testCtx) + assert.Nil(mt, err, "error verifying outcome for collection %q at index %d: %v", + collData.namespace(), idx, err) + } +} + +func disableUntargetedFailPoints(ctx context.Context) []error { + var errs []error + for fpName, client := range failPoints(ctx) { + if err := disableFailPointWithClient(ctx, fpName, client); err != nil { + errs = append(errs, fmt.Errorf("error disabling fail point %q: %v", fpName, err)) + } + } + return errs +} + +func disableTargetedFailPoints(ctx context.Context) []error { + var errs []error + for fpName, host := range targetedFailPoints(ctx) { + commandFn := func(ctx context.Context, client *mongo.Client) error { + return disableFailPointWithClient(ctx, fpName, client) + } + if err := runCommandOnHost(ctx, host, commandFn); err != nil { + errs = append(errs, fmt.Errorf("error disabling targeted fail point %q on host %q: %v", fpName, host, err)) + } + } + return errs +} + +func disableFailPointWithClient(ctx context.Context, fpName string, client *mongo.Client) error { + cmd := bson.D{ + {"configureFailPoint", fpName}, + {"mode", "off"}, + } + if err := client.Database("admin").RunCommand(ctx, cmd).Err(); err != nil { + return err + } + return nil +} diff --git a/mongo/integration/unified/unified_spec_test.go b/mongo/integration/unified/unified_spec_test.go index cd92e21da5..33792fe726 100644 --- a/mongo/integration/unified/unified_spec_test.go +++ b/mongo/integration/unified/unified_spec_test.go @@ -7,16 +7,9 @@ package unified import ( - "context" - "fmt" - "io/ioutil" "path" "testing" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/internal/testutil/assert" - testhelpers "go.mongodb.org/mongo-driver/internal/testutil/helpers" - "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/integration/mtest" ) @@ -27,229 +20,21 @@ var ( "crud/unified", "change-streams/unified", } - - skippedTestDescriptions = map[string]struct{}{ - // GODRIVER-1773: This test runs a "find" with limit=4 and batchSize=3. It expects batchSize values of three for - // the "find" and one for the "getMore", but we send three for both. - "A successful find event with a getmore and the server kills the cursor": {}, - - // This test expects the driver to raise a client-side error when inserting a document with a key that contains - // a "." or "$". We don't do this validation and the server is moving towards supporting this, so we don't have - // any plans to add it. This test will need to be changed once the server support lands anyway. - "Client side error in command starting transaction": {}, - } ) const ( - dataDirectory = "../../../data" - lowHeartbeatFrequency int32 = 50 + dataDirectory = "../../../data" ) -type TestCase struct { - Description string `bson:"description"` - RunOnRequirements []mtest.RunOnBlock `bson:"runOnRequirements"` - SkipReason *string `bson:"skipReason"` - Operations []*Operation `bson:"operations"` - ExpectedEvents []*ExpectedEvents `bson:"expectEvents"` - Outcome []*CollectionData `bson:"outcome"` -} - -func (t *TestCase) PerformsDistinct() bool { - return t.performsOperation("distinct") -} - -func (t *TestCase) SetsFailPoint() bool { - return t.performsOperation("failPoint") -} - -func (t *TestCase) StartsTransaction() bool { - return t.performsOperation("startTransaction") -} - -func (t *TestCase) performsOperation(name string) bool { - for _, op := range t.Operations { - if op.Name == name { - return true - } - } - return false -} - -type TestFile struct { - Description string `bson:"description"` - SchemaVersion string `bson:"schemaVersion"` - RunOnRequirements []mtest.RunOnBlock `bson:"runOnRequirements"` - CreateEntities []map[string]*EntityOptions `bson:"createEntities"` - InitialData []*CollectionData `bson:"initialData"` - TestCases []*TestCase `bson:"tests"` -} - func TestUnifiedSpec(t *testing.T) { // Ensure the cluster is in a clean state before test execution begins. - if err := TerminateOpenSessions(mtest.Background); err != nil { + if err := terminateOpenSessions(mtest.Background); err != nil { t.Fatalf("error terminating open transactions: %v", err) } for _, testDir := range directories { t.Run(testDir, func(t *testing.T) { - for _, filename := range testhelpers.FindJSONFilesInDir(t, path.Join(dataDirectory, testDir)) { - t.Run(filename, func(t *testing.T) { - runTestFile(t, path.Join(dataDirectory, testDir, filename)) - }) - } - }) - } -} - -func runTestFile(t *testing.T, filepath string) { - content, err := ioutil.ReadFile(filepath) - assert.Nil(t, err, "ReadFile error for file %q: %v", filepath, err) - - var testFile TestFile - err = bson.UnmarshalExtJSON(content, false, &testFile) - assert.Nil(t, err, "UnmarshalExtJSON error for file %q: %v", filepath, err) - - // Validate that we support the schema declared by the test file before attempting to use its contents. - err = CheckSchemaVersion(testFile.SchemaVersion) - assert.Nil(t, err, "schema version %q not supported: %v", testFile.SchemaVersion, err) - - // Create mtest wrapper, which will skip the test if needed. - mtOpts := mtest.NewOptions(). - RunOn(testFile.RunOnRequirements...). - CreateClient(false) - mt := mtest.New(t, mtOpts) - defer mt.Close() - - for _, testCase := range testFile.TestCases { - mtOpts := mtest.NewOptions(). - RunOn(testCase.RunOnRequirements...). - CreateClient(false) - mt.RunOpts(testCase.Description, mtOpts, func(mt *mtest.T) { - runTestCase(mt, testFile, testCase) + runTestDirectory(t, path.Join(dataDirectory, testDir)) }) } } - -func runTestCase(mt *mtest.T, testFile TestFile, testCase *TestCase) { - if testCase.SkipReason != nil { - mt.Skipf("skipping for reason: %q", *testCase.SkipReason) - } - if _, ok := skippedTestDescriptions[testCase.Description]; ok { - mt.Skip("skipping due to known failure") - } - - testCtx := NewTestContext(mtest.Background) - - defer func() { - // If anything fails while doing test cleanup, we only log the error because the actual test may have already - // failed and that failure should be preserved. - - for _, err := range DisableUntargetedFailPoints(testCtx) { - mt.Log(err) - } - for _, err := range DisableTargetedFailPoints(testCtx) { - mt.Log(err) - } - for _, err := range Entities(testCtx).Close(testCtx) { - mt.Log(err) - } - // Tests that started a transaction should terminate any sessions left open on the server. This is required even - // if the test attempted to commit/abort the transaction because an abortTransaction command can fail if it's - // sent to a mongos that isn't aware of the transaction. - if testCase.StartsTransaction() { - if err := TerminateOpenSessions(mtest.Background); err != nil { - mt.Logf("error terminating open transactions after failed test: %v", err) - } - } - }() - - // Set up collections based on the file-level initialData field. - for _, collData := range testFile.InitialData { - if err := collData.CreateCollection(testCtx); err != nil { - mt.Fatalf("error setting up collection %q: %v", collData.Namespace(), err) - } - } - - // Set up entities based on the file-level createEntities field. For client entities, if the test will configure - // a fail point, set a low heartbeatFrequencyMS value into the URI options map if one is not already present. - // This speeds up recovery time for the client if the fail point forces the server to return a state change - // error. - shouldSetHeartbeatFrequency := testCase.SetsFailPoint() - for idx, entity := range testFile.CreateEntities { - for entityType, entityOptions := range entity { - if shouldSetHeartbeatFrequency && entityType == "client" { - if entityOptions.URIOptions == nil { - entityOptions.URIOptions = make(bson.M) - } - if _, ok := entityOptions.URIOptions["heartbeatFrequencyMS"]; !ok { - entityOptions.URIOptions["heartbeatFrequencyMS"] = lowHeartbeatFrequency - } - } - - if err := Entities(testCtx).AddEntity(testCtx, entityType, entityOptions); err != nil { - mt.Fatalf("error creating entity at index %d: %v", idx, err) - } - } - } - - // Work around SERVER-39704. - if mtest.ClusterTopologyKind() == mtest.Sharded && testCase.PerformsDistinct() { - if err := PerformDistinctWorkaround(testCtx); err != nil { - mt.Fatalf("error performing \"distinct\" workaround: %v", err) - } - } - - for idx, operation := range testCase.Operations { - err := operation.Execute(testCtx) - assert.Nil(mt, err, "error running operation %q at index %d: %v", operation.Name, idx, err) - } - - for _, client := range Entities(testCtx).Clients() { - client.StopListeningForEvents() - } - - for idx, expectedEvents := range testCase.ExpectedEvents { - err := VerifyEvents(testCtx, expectedEvents) - assert.Nil(mt, err, "events verification failed at index %d: %v", idx, err) - } - - for idx, collData := range testCase.Outcome { - err := collData.VerifyContents(testCtx) - assert.Nil(mt, err, "error verifying outcome for collection %q at index %d: %v", - collData.Namespace(), idx, err) - } -} - -func DisableUntargetedFailPoints(ctx context.Context) []error { - var errs []error - for fpName, client := range FailPoints(ctx) { - if err := disableFailPointWithClient(ctx, fpName, client); err != nil { - errs = append(errs, fmt.Errorf("error disabling fail point %q: %v", fpName, err)) - } - } - return errs -} - -func DisableTargetedFailPoints(ctx context.Context) []error { - var errs []error - for fpName, host := range TargetedFailPoints(ctx) { - commandFn := func(ctx context.Context, client *mongo.Client) error { - return disableFailPointWithClient(ctx, fpName, client) - } - if err := RunCommandOnHost(ctx, host, commandFn); err != nil { - errs = append(errs, fmt.Errorf("error disabling targeted fail point %q on host %q: %v", fpName, host, err)) - } - } - return errs -} - -func disableFailPointWithClient(ctx context.Context, fpName string, client *mongo.Client) error { - cmd := bson.D{ - {"configureFailPoint", fpName}, - {"mode", "off"}, - } - if err := client.Database("admin").RunCommand(ctx, cmd).Err(); err != nil { - return err - } - return nil -}