diff --git a/src/MongoDB.Driver/CreateCollectionOptions.cs b/src/MongoDB.Driver/CreateCollectionOptions.cs index 1b3cbb5929f..7a79c9cd3b7 100644 --- a/src/MongoDB.Driver/CreateCollectionOptions.cs +++ b/src/MongoDB.Driver/CreateCollectionOptions.cs @@ -196,6 +196,27 @@ public TimeSeriesOptions TimeSeriesOptions get { return _validationLevel; } set { _validationLevel = value; } } + + internal virtual CreateCollectionOptions Clone() => + new CreateCollectionOptions + { + _autoIndexId = _autoIndexId, + _capped = _capped, + _changeStreamPreAndPostImagesOptions = _changeStreamPreAndPostImagesOptions, + _collation = _collation, + _encryptedFields = _encryptedFields, + _expireAfter = _expireAfter, + _indexOptionDefaults = _indexOptionDefaults, + _maxDocuments = _maxDocuments, + _maxSize = _maxSize, + _noPadding = _noPadding, + _serializerRegistry = _serializerRegistry, + _storageEngine = _storageEngine, + _timeSeriesOptions = _timeSeriesOptions, + _usePowerOf2Sizes = _usePowerOf2Sizes, + _validationAction = _validationAction, + _validationLevel = _validationLevel + }; } /// @@ -282,5 +303,32 @@ public FilterDefinition Validator get { return _validator; } set { _validator = value; } } + + internal override CreateCollectionOptions Clone() => + new CreateCollectionOptions + { + #pragma warning disable CS0618 // Type or member is obsolete + AutoIndexId = base.AutoIndexId, + #pragma warning restore CS0618 // Type or member is obsolete + Capped = base.Capped, + ChangeStreamPreAndPostImagesOptions = base.ChangeStreamPreAndPostImagesOptions, + Collation = base.Collation, + EncryptedFields = base.EncryptedFields, + ExpireAfter = base.ExpireAfter, + IndexOptionDefaults = base.IndexOptionDefaults, + MaxDocuments = base.MaxDocuments, + MaxSize = base.MaxSize, + NoPadding = base.NoPadding, + SerializerRegistry = base.SerializerRegistry, + StorageEngine = base.StorageEngine, + TimeSeriesOptions = base.TimeSeriesOptions, + UsePowerOf2Sizes = base.UsePowerOf2Sizes, + ValidationAction = base.ValidationAction, + ValidationLevel = base.ValidationLevel, + + _clusteredIndex = _clusteredIndex, + _documentSerializer = _documentSerializer, + _validator = _validator + }; } } diff --git a/src/MongoDB.Driver/Encryption/ClientEncryption.cs b/src/MongoDB.Driver/Encryption/ClientEncryption.cs index 7dfadcb36b5..ef436a282a7 100644 --- a/src/MongoDB.Driver/Encryption/ClientEncryption.cs +++ b/src/MongoDB.Driver/Encryption/ClientEncryption.cs @@ -82,59 +82,85 @@ public ClientEncryption(ClientEncryptionOptions clientEncryptionOptions) /// /// Create encrypted collection. /// - /// The collection namespace. + /// The database. + /// The collection name. /// The create collection options. /// The kms provider. /// The datakey options. /// The cancellation token. + /// The operation result. /// /// if EncryptionFields contains a keyId with a null value, a data key will be automatically generated and assigned to keyId value. /// - public void CreateEncryptedCollection(CollectionNamespace collectionNamespace, CreateCollectionOptions createCollectionOptions, string kmsProvider, DataKeyOptions dataKeyOptions, CancellationToken cancellationToken = default) + public CreateEncryptedCollectionResult CreateEncryptedCollection(IMongoDatabase database, string collectionName, CreateCollectionOptions createCollectionOptions, string kmsProvider, DataKeyOptions dataKeyOptions, CancellationToken cancellationToken = default) { - Ensure.IsNotNull(collectionNamespace, nameof(collectionNamespace)); + Ensure.IsNotNull(database, nameof(database)); + Ensure.IsNotNull(collectionName, nameof(collectionName)); Ensure.IsNotNull(createCollectionOptions, nameof(createCollectionOptions)); Ensure.IsNotNull(dataKeyOptions, nameof(dataKeyOptions)); Ensure.IsNotNull(kmsProvider, nameof(kmsProvider)); - foreach (var fieldDocument in EncryptedCollectionHelper.IterateEmptyKeyIds(collectionNamespace, createCollectionOptions.EncryptedFields)) + var encryptedFields = createCollectionOptions.EncryptedFields?.DeepClone()?.AsBsonDocument; + try { - var dataKey = CreateDataKey(kmsProvider, dataKeyOptions, cancellationToken); - EncryptedCollectionHelper.ModifyEncryptedFields(fieldDocument, dataKey); + foreach (var fieldDocument in EncryptedCollectionHelper.IterateEmptyKeyIds(new CollectionNamespace(database.DatabaseNamespace.DatabaseName, collectionName), encryptedFields)) + { + var dataKey = CreateDataKey(kmsProvider, dataKeyOptions, cancellationToken); + EncryptedCollectionHelper.ModifyEncryptedFields(fieldDocument, dataKey); + } + + var effectiveCreateEncryptionOptions = createCollectionOptions.Clone(); + effectiveCreateEncryptionOptions.EncryptedFields = encryptedFields; + database.CreateCollection(collectionName, effectiveCreateEncryptionOptions, cancellationToken); + } + catch (Exception ex) + { + throw new MongoEncryptionCreateCollectionException(ex, encryptedFields); } - var database = _libMongoCryptController.KeyVaultClient.GetDatabase(collectionNamespace.DatabaseNamespace.DatabaseName); - - database.CreateCollection(collectionNamespace.CollectionName, createCollectionOptions, cancellationToken); + return new CreateEncryptedCollectionResult(encryptedFields); } /// /// Create encrypted collection. /// - /// The collection namespace. + /// The database. + /// The collection name. /// The create collection options. /// The kms provider. /// The datakey options. /// The cancellation token. + /// The operation result. /// /// if EncryptionFields contains a keyId with a null value, a data key will be automatically generated and assigned to keyId value. /// - public async Task CreateEncryptedCollectionAsync(CollectionNamespace collectionNamespace, CreateCollectionOptions createCollectionOptions, string kmsProvider, DataKeyOptions dataKeyOptions, CancellationToken cancellationToken = default) + public async Task CreateEncryptedCollectionAsync(IMongoDatabase database, string collectionName, CreateCollectionOptions createCollectionOptions, string kmsProvider, DataKeyOptions dataKeyOptions, CancellationToken cancellationToken = default) { - Ensure.IsNotNull(collectionNamespace, nameof(collectionNamespace)); + Ensure.IsNotNull(database, nameof(database)); + Ensure.IsNotNull(collectionName, nameof(collectionName)); Ensure.IsNotNull(createCollectionOptions, nameof(createCollectionOptions)); Ensure.IsNotNull(dataKeyOptions, nameof(dataKeyOptions)); Ensure.IsNotNull(kmsProvider, nameof(kmsProvider)); - foreach (var fieldDocument in EncryptedCollectionHelper.IterateEmptyKeyIds(collectionNamespace, createCollectionOptions.EncryptedFields)) + var encryptedFields = createCollectionOptions.EncryptedFields?.DeepClone()?.AsBsonDocument; + try { - var dataKey = await CreateDataKeyAsync(kmsProvider, dataKeyOptions, cancellationToken).ConfigureAwait(false); - EncryptedCollectionHelper.ModifyEncryptedFields(fieldDocument, dataKey); + foreach (var fieldDocument in EncryptedCollectionHelper.IterateEmptyKeyIds(new CollectionNamespace(database.DatabaseNamespace.DatabaseName, collectionName), encryptedFields)) + { + var dataKey = await CreateDataKeyAsync(kmsProvider, dataKeyOptions, cancellationToken).ConfigureAwait(false); + EncryptedCollectionHelper.ModifyEncryptedFields(fieldDocument, dataKey); + } + + var effectiveCreateEncryptionOptions = createCollectionOptions.Clone(); + effectiveCreateEncryptionOptions.EncryptedFields = encryptedFields; + await database.CreateCollectionAsync(collectionName, effectiveCreateEncryptionOptions, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + throw new MongoEncryptionCreateCollectionException(ex, encryptedFields); } - var database = _libMongoCryptController.KeyVaultClient.GetDatabase(collectionNamespace.DatabaseNamespace.DatabaseName); - - await database.CreateCollectionAsync(collectionNamespace.CollectionName, createCollectionOptions, cancellationToken).ConfigureAwait(false); + return new CreateEncryptedCollectionResult(encryptedFields); } /// diff --git a/src/MongoDB.Driver/Encryption/CreateEncryptedCollectionResult.cs b/src/MongoDB.Driver/Encryption/CreateEncryptedCollectionResult.cs new file mode 100644 index 00000000000..30476cf8f0a --- /dev/null +++ b/src/MongoDB.Driver/Encryption/CreateEncryptedCollectionResult.cs @@ -0,0 +1,38 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using MongoDB.Bson; + +namespace MongoDB.Driver.Encryption +{ + /// + /// Represents the result of a create encrypted collection. + /// + public sealed class CreateEncryptedCollectionResult + { + private readonly BsonDocument _encryptedFields; + + /// + /// Initializes a new instance of the class. + /// + /// The encrypted fields document. + public CreateEncryptedCollectionResult(BsonDocument encryptedFields) => _encryptedFields = encryptedFields; + + /// + /// The encrypted fields document. + /// + public BsonDocument EncryptedFields => _encryptedFields; + } +} diff --git a/src/MongoDB.Driver/Encryption/MongoEncryptionCreateCollectionException.cs b/src/MongoDB.Driver/Encryption/MongoEncryptionCreateCollectionException.cs new file mode 100644 index 00000000000..9a581233538 --- /dev/null +++ b/src/MongoDB.Driver/Encryption/MongoEncryptionCreateCollectionException.cs @@ -0,0 +1,69 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Runtime.Serialization; +using MongoDB.Bson; + +namespace MongoDB.Driver.Encryption +{ + /// + /// Represents an encryption exception. + /// + [Serializable] + public class MongoEncryptionCreateCollectionException : MongoEncryptionException + { + private readonly BsonDocument _encryptedFields; + + /// + /// Initializes a new instance of the class. + /// + /// The inner exception. + /// The encrypted fields. + public MongoEncryptionCreateCollectionException(Exception innerException, BsonDocument encryptedFields) + : base(innerException) + { + _encryptedFields = encryptedFields; + } + + /// + /// Initializes a new instance of the class (this overload used by deserialization). + /// + /// The SerializationInfo. + /// The StreamingContext. + protected MongoEncryptionCreateCollectionException(SerializationInfo info, StreamingContext context) + : base(info, context) + { + _encryptedFields = (BsonDocument)info.GetValue(nameof(_encryptedFields), typeof(BsonDocument)); + } + + /// + /// The encrypted fields. + /// + public BsonDocument EncryptedFields => _encryptedFields; + + // public methods + /// + /// Gets the object data. + /// + /// The information. + /// The context. + public override void GetObjectData(SerializationInfo info, StreamingContext context) + { + base.GetObjectData(info, context); + info.AddValue(nameof(_encryptedFields), _encryptedFields); + } + } +} diff --git a/src/MongoDB.Driver/Encryption/MongoEncryptionException.cs b/src/MongoDB.Driver/Encryption/MongoEncryptionException.cs index 5e264539c07..78b6ca522db 100644 --- a/src/MongoDB.Driver/Encryption/MongoEncryptionException.cs +++ b/src/MongoDB.Driver/Encryption/MongoEncryptionException.cs @@ -15,7 +15,6 @@ using System; using System.Runtime.Serialization; -using MongoDB.Driver.Core.Misc; namespace MongoDB.Driver.Encryption { diff --git a/tests/MongoDB.Bson.TestHelpers/BsonValueEquivalencyComparer.cs b/tests/MongoDB.Bson.TestHelpers/BsonValueEquivalencyComparer.cs index 692a1e1e927..b15a42168d9 100644 --- a/tests/MongoDB.Bson.TestHelpers/BsonValueEquivalencyComparer.cs +++ b/tests/MongoDB.Bson.TestHelpers/BsonValueEquivalencyComparer.cs @@ -13,6 +13,7 @@ * limitations under the License. */ +using System; using System.Collections.Generic; namespace MongoDB.Bson.TestHelpers @@ -22,15 +23,17 @@ public class BsonValueEquivalencyComparer : IEqualityComparer #region static public static BsonValueEquivalencyComparer Instance { get; } = new BsonValueEquivalencyComparer(); - public static bool Compare(BsonValue a, BsonValue b) + public static bool Compare(BsonValue a, BsonValue b, Action massageAction = null) { + massageAction?.Invoke(a, b); + if (a.BsonType == BsonType.Document && b.BsonType == BsonType.Document) { - return CompareDocuments((BsonDocument)a, (BsonDocument)b); + return CompareDocuments((BsonDocument)a, (BsonDocument)b, massageAction); } else if (a.BsonType == BsonType.Array && b.BsonType == BsonType.Array) { - return CompareArrays((BsonArray)a, (BsonArray)b); + return CompareArrays((BsonArray)a, (BsonArray)b, massageAction); } else if (a.BsonType == b.BsonType) { @@ -50,7 +53,7 @@ public static bool Compare(BsonValue a, BsonValue b) } } - private static bool CompareArrays(BsonArray a, BsonArray b) + private static bool CompareArrays(BsonArray a, BsonArray b, Action massageAction = null) { if (a.Count != b.Count) { @@ -59,7 +62,7 @@ private static bool CompareArrays(BsonArray a, BsonArray b) for (var i = 0; i < a.Count; i++) { - if (!Compare(a[i], b[i])) + if (!Compare(a[i], b[i], massageAction)) { return false; } @@ -68,7 +71,7 @@ private static bool CompareArrays(BsonArray a, BsonArray b) return true; } - private static bool CompareDocuments(BsonDocument a, BsonDocument b) + private static bool CompareDocuments(BsonDocument a, BsonDocument b, Action massageAction = null) { if (a.ElementCount != b.ElementCount) { @@ -83,7 +86,7 @@ private static bool CompareDocuments(BsonDocument a, BsonDocument b) return false; } - if (!Compare(aElement.Value, bElement.Value)) + if (!Compare(aElement.Value, bElement.Value, massageAction)) { return false; } diff --git a/tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs b/tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs index 70b9607ba3b..fe3170485bf 100644 --- a/tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs +++ b/tests/MongoDB.Driver.Tests/Encryption/ClientEncryptionTests.cs @@ -25,6 +25,9 @@ using MongoDB.Driver.Tests.Specifications.client_side_encryption; using MongoDB.Libmongocrypt; using Xunit; +using Moq; +using System.Collections.Generic; +using System.Threading; namespace MongoDB.Driver.Tests.Encryption { @@ -64,6 +67,149 @@ public async Task CreateDataKey_should_correctly_handle_input_arguments() } } + [Fact] + public async Task CreateEncryptedCollection_should_handle_input_arguments() + { + const string kmsProvider = "local"; + const string collectionName = "collName"; + var createCollectionOptions = new CreateCollectionOptions(); + var database = Mock.Of(); + + var dataKeyOptions = new DataKeyOptions(); + + using (var subject = CreateSubject()) + { + ShouldBeArgumentException(Record.Exception(() => subject.CreateEncryptedCollection(database: null, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions)), expectedParamName: "database"); + ShouldBeArgumentException(await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database: null, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions)), expectedParamName: "database"); + + ShouldBeArgumentException(Record.Exception(() => subject.CreateEncryptedCollection(database, collectionName: null, createCollectionOptions, kmsProvider, dataKeyOptions)), expectedParamName: "collectionName"); + ShouldBeArgumentException(await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database, collectionName: null, createCollectionOptions, kmsProvider, dataKeyOptions)), expectedParamName: "collectionName"); + + ShouldBeArgumentException(Record.Exception(() => subject.CreateEncryptedCollection(database, collectionName: collectionName, createCollectionOptions: null, kmsProvider, dataKeyOptions)), expectedParamName: "createCollectionOptions"); + ShouldBeArgumentException(await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database, collectionName, createCollectionOptions: null, kmsProvider, dataKeyOptions)), expectedParamName: "createCollectionOptions"); + + ShouldBeArgumentException(Record.Exception(() => subject.CreateEncryptedCollection(database, collectionName: collectionName, createCollectionOptions, kmsProvider: null, dataKeyOptions)), expectedParamName: "kmsProvider"); + ShouldBeArgumentException(await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database, collectionName, createCollectionOptions, kmsProvider: null, dataKeyOptions)), expectedParamName: "kmsProvider"); + + ShouldBeArgumentException(Record.Exception(() => subject.CreateEncryptedCollection(database, collectionName: collectionName, createCollectionOptions, kmsProvider, dataKeyOptions: null)), expectedParamName: "dataKeyOptions"); + ShouldBeArgumentException(await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions: null)), expectedParamName: "dataKeyOptions"); + } + } + + [Fact] + public async Task CreateEncryptedCollection_should_handle_save_generated_key_when_second_key_failed() + { + const string kmsProvider = "local"; + const string collectionName = "collName"; + const string encryptedFieldsStr = "{ fields : [{ keyId : null }, { keyId : null }] }"; + var database = Mock.Of(d => d.DatabaseNamespace == new DatabaseNamespace("db")); + + var dataKeyOptions = new DataKeyOptions(); + + var mockCollection = new Mock>(); + mockCollection + .SetupSequence(c => c.InsertOne(It.IsAny(), It.IsAny(), It.IsAny())) + .Pass() + .Throws(new Exception("test")); + mockCollection + .SetupSequence(c => c.InsertOneAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask) + .Throws(new Exception("test")); + var mockDatabase = new Mock(); + mockDatabase.Setup(c => c.GetCollection(It.IsAny(), It.IsAny())).Returns(mockCollection.Object); + var client = new Mock(); + client.Setup(c => c.GetDatabase(It.IsAny(), It.IsAny())).Returns(mockDatabase.Object); + + using (var subject = CreateSubject(client.Object)) + { + var createCollectionOptions = new CreateCollectionOptions() { EncryptedFields = BsonDocument.Parse(encryptedFieldsStr) }; + var exception = Record.Exception(() => subject.CreateEncryptedCollection(database, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions)); + AssertResults(exception, createCollectionOptions); + + exception = await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions)); + AssertResults(exception, createCollectionOptions); + } + + void AssertResults(Exception ex, CreateCollectionOptions createCollectionOptions) + { + var createCollectionException = ex.Should().BeOfType().Subject; + createCollectionException + .InnerException + .Should().BeOfType().Subject.InnerException + .Should().BeOfType().Which.Message + .Should().Be("test"); + var fields = createCollectionException.EncryptedFields["fields"].AsBsonArray; + fields[0].AsBsonDocument["keyId"].Should().BeOfType(); // pass + /* + - If generating `D` resulted in an error `E`, the entire + `CreateEncryptedCollection` must now fail with error `E`. Return the + partially-formed `EF'` with the error so that the caller may know what + datakeys have already been created by the helper. + */ + fields[1].AsBsonDocument["keyId"].Should().BeOfType(); // throw + } + } + + [Theory] + [InlineData(null, "There are no encrypted fields defined for the collection.")] + [InlineData("{}", "{}")] + [InlineData("{ a : 1 }", "{ a : 1 }")] + [InlineData("{ fields : { } }", "{ fields: { } }")] + [InlineData("{ fields : [] }", "{ fields: [] }")] + [InlineData("{ fields : [{ a : 1 }] }", "{ fields: [{ a : 1 }] }")] + [InlineData("{ fields : [{ keyId : 1 }] }", "{ fields: [{ keyId : 1 }] }")] + [InlineData("{ fields : [{ keyId : null }] }", "{ fields: [{ keyId : '#binary_generated#' }] }")] + [InlineData("{ fields : [{ keyId : null }, { keyId : null }] }", "{ fields: [{ keyId : '#binary_generated#' }, { keyId : '#binary_generated#' }] }")] + [InlineData("{ fields : [{ keyId : 3 }, { keyId : null }] }", "{ fields: [{ keyId : 3 }, { keyId : '#binary_generated#' }] }")] + public async Task CreateEncryptedCollection_should_handle_various_encryptedFields(string encryptedFieldsStr, string expectedResult) + { + const string kmsProvider = "local"; + const string collectionName = "collName"; + var database = Mock.Of(d => d.DatabaseNamespace == new DatabaseNamespace("db")); + + var dataKeyOptions = new DataKeyOptions(); + + using (var subject = CreateSubject()) + { + var createCollectionOptions = new CreateCollectionOptions() { EncryptedFields = encryptedFieldsStr != null ? BsonDocument.Parse(encryptedFieldsStr) : null }; + + if (BsonDocument.TryParse(expectedResult, out var encryptedFields)) + { + var createCollectionResult = subject.CreateEncryptedCollection(database, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions); + createCollectionResult.EncryptedFields.WithComparer(new EncryptedFieldsComparer()).Should().Be(encryptedFields.DeepClone()); + + createCollectionResult = await subject.CreateEncryptedCollectionAsync(database, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions); + createCollectionResult.EncryptedFields.WithComparer(new EncryptedFieldsComparer()).Should().Be(encryptedFields.DeepClone()); + } + else + { + AssertInvalidOperationException(Record.Exception(() => subject.CreateEncryptedCollection(database, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions)), expectedResult); + AssertInvalidOperationException(await Record.ExceptionAsync(() => subject.CreateEncryptedCollectionAsync(database, collectionName, createCollectionOptions, kmsProvider, dataKeyOptions)), expectedResult); + } + } + + void AssertInvalidOperationException(Exception ex, string message) => + ex + .Should().BeOfType().Subject.InnerException + .Should().BeOfType().Which.Message.Should().Be(message); + } + + private sealed class EncryptedFieldsComparer : IEqualityComparer + { + public bool Equals(BsonDocument x, BsonDocument y) => + BsonValueEquivalencyComparer.Compare( + x, y, + massageAction: (a, b) => + { + if (a is BsonDocument aDocument && aDocument.TryGetValue("keyId", out var aKeyId) && aKeyId.IsBsonBinaryData && + b is BsonDocument bDocument && bDocument.TryGetValue("keyId", out var bKeyId) && bKeyId == "#binary_generated#") + { + bDocument["keyId"] = aDocument["keyId"]; + } + }); + + public int GetHashCode(BsonDocument obj) => obj.GetHashCode(); + } [Fact] public void CryptClient_should_be_initialized() @@ -167,10 +313,10 @@ public async Task RewrapManyDataKey_should_correctly_handle_input_arguments() } // private methods - private ClientEncryption CreateSubject() + private ClientEncryption CreateSubject(IMongoClient client = null) { var clientEncryptionOptions = new ClientEncryptionOptions( - DriverTestConfiguration.Client, + client ?? DriverTestConfiguration.Client, __keyVaultCollectionNamespace, kmsProviders: EncryptionTestHelper.GetKmsProviders(filter: "local")); diff --git a/tests/MongoDB.Driver.Tests/Encryption/MongoEncryptionCreateCollectionExceptionTests.cs b/tests/MongoDB.Driver.Tests/Encryption/MongoEncryptionCreateCollectionExceptionTests.cs new file mode 100644 index 00000000000..3223eb3a7a9 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Encryption/MongoEncryptionCreateCollectionExceptionTests.cs @@ -0,0 +1,47 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.IO; +using System.Runtime.Serialization.Formatters.Binary; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Driver.Encryption; +using Xunit; + +namespace MongoDB.Driver.Tests.Encryption +{ + public class MongoEncryptionCreateCollectionExceptionTests + { + [Fact] + public void Serialization_should_work() + { + var subject = new MongoEncryptionCreateCollectionException(new Exception("inner"), new BsonDocument("value", 1)); + + var formatter = new BinaryFormatter(); + using (var stream = new MemoryStream()) + { +#pragma warning disable SYSLIB0011 // BinaryFormatter serialization is obsolete + formatter.Serialize(stream, subject); + stream.Position = 0; + var rehydrated = (MongoEncryptionCreateCollectionException)formatter.Deserialize(stream); +#pragma warning restore SYSLIB0011 // BinaryFormatter serialization is obsolete + + rehydrated.InnerException.Message.Should().Be(subject.InnerException.Message); + rehydrated.EncryptedFields.Should().Be(subject.EncryptedFields).And.Should().NotBeNull(); + } + } + } +} diff --git a/tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/prose-tests/ClientEncryptionProseTests.cs b/tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/prose-tests/ClientEncryptionProseTests.cs index cc71ce2294a..20220ce6f87 100644 --- a/tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/prose-tests/ClientEncryptionProseTests.cs +++ b/tests/MongoDB.Driver.Tests/Specifications/client-side-encryption/prose-tests/ClientEncryptionProseTests.cs @@ -119,7 +119,7 @@ void RunTestCase(int testCase) { case 1: // Case 1: Simple Creation and Validation { - var collection = CreateEncryptedCollection(client, clientEncryption, __collCollectionNamespace, encryptedFields, kmsProvider, async); + var collection = CreateEncryptedCollection(client, clientEncryption, __collCollectionNamespace, encryptedFields, kmsProvider, async, out _); var exception = Record.Exception(() => Insert(collection, async, new BsonDocument("ssn", "123-45-6789"))); exception.Should().BeOfType>().Which.Message.Should().Contain("Document failed validation"); @@ -127,24 +127,28 @@ void RunTestCase(int testCase) break; case 2: // Case 2: Missing ``encryptedFields`` { - var exception = Record.Exception(() => CreateEncryptedCollection(client, clientEncryption, __collCollectionNamespace, encryptedFields: null, kmsProvider, async)); + var exception = Record.Exception(() => CreateEncryptedCollection(client, clientEncryption, __collCollectionNamespace, encryptedFields: null, kmsProvider, async, out _)); - exception.Should().BeOfType().Which.Message.Should().Contain("There are no encrypted fields defined for the collection.") ; + exception + .Should().BeOfType().Which.InnerException + .Should().BeOfType().Which.Message.Should().Contain("There are no encrypted fields defined for the collection.") ; } break; case 3: // Case 3: Invalid ``keyId`` { var effectiveEncryptedFields = encryptedFields.DeepClone(); effectiveEncryptedFields["fields"].AsBsonArray[0].AsBsonDocument["keyId"] = false; - var exception = Record.Exception(() => CreateEncryptedCollection(client, clientEncryption, __collCollectionNamespace, effectiveEncryptedFields.AsBsonDocument, kmsProvider, async)); - exception.Should().BeOfType().Which.Message.Should().Contain("BSON field 'create.encryptedFields.fields.keyId' is the wrong type 'bool', expected type 'binData'"); + var exception = Record.Exception(() => CreateEncryptedCollection(client, clientEncryption, __collCollectionNamespace, effectiveEncryptedFields.AsBsonDocument, kmsProvider, async, out _)); + exception + .Should().BeOfType().Which.InnerException + .Should().BeOfType().Which.Message.Should().Contain("BSON field 'create.encryptedFields.fields.keyId' is the wrong type 'bool', expected type 'binData'"); } break; case 4: // Case 4: Insert encrypted value { var createCollectionOptions = new CreateCollectionOptions { EncryptedFields = encryptedFields }; - var collection = CreateEncryptedCollection(client, clientEncryption, __collCollectionNamespace, createCollectionOptions, kmsProvider, async); - var dataKey = createCollectionOptions.EncryptedFields["fields"].AsBsonArray[0].AsBsonDocument["keyId"].AsGuid; // get generated datakey + var collection = CreateEncryptedCollection(client, clientEncryption, __collCollectionNamespace, createCollectionOptions, kmsProvider, async, out var effectiveEncryptedFields); + var dataKey = effectiveEncryptedFields["fields"].AsBsonArray[0].AsBsonDocument["keyId"].AsGuid; // get generated datakey var encryptedValue = ExplicitEncrypt(clientEncryption, new EncryptOptions(algorithm: EncryptionAlgorithm.Unindexed, keyId: dataKey), "123-45-6789", async); // use explicit encryption to encrypt data before inserting Insert(collection, async, new BsonDocument("ssn", encryptedValue)); } @@ -2340,24 +2344,23 @@ private void CreateCollection(IMongoClient client, CollectionNamespace collectio }); } - private IMongoCollection CreateEncryptedCollection(IMongoClient client, ClientEncryption clientEncryption, CollectionNamespace collectionNamespace, BsonDocument encryptedFields, string kmsProvider, bool async) + private IMongoCollection CreateEncryptedCollection(IMongoClient client, ClientEncryption clientEncryption, CollectionNamespace collectionNamespace, BsonDocument encryptedFields, string kmsProvider, bool async, out BsonDocument effectiveEncryptedFields) { var createCollectionOptions = new CreateCollectionOptions { EncryptedFields = encryptedFields }; - return CreateEncryptedCollection(client, clientEncryption, collectionNamespace, createCollectionOptions, kmsProvider, async); + return CreateEncryptedCollection(client, clientEncryption, collectionNamespace, createCollectionOptions, kmsProvider, async, out effectiveEncryptedFields); } - private IMongoCollection CreateEncryptedCollection(IMongoClient client, ClientEncryption clientEncryption, CollectionNamespace collectionNamespace, CreateCollectionOptions createCollectionOptions, string kmsProvider, bool async) + private IMongoCollection CreateEncryptedCollection(IMongoClient client, ClientEncryption clientEncryption, CollectionNamespace collectionNamespace, CreateCollectionOptions createCollectionOptions, string kmsProvider, bool async, out BsonDocument effectiveEncryptedFields) { var datakeyOptions = CreateDataKeyOptions(kmsProvider); + var database = client.GetDatabase(collectionNamespace.DatabaseNamespace.DatabaseName); - if (async) - { - clientEncryption.CreateEncryptedCollectionAsync(collectionNamespace, createCollectionOptions, kmsProvider, datakeyOptions, cancellationToken: default).GetAwaiter().GetResult(); - } - else - { - clientEncryption.CreateEncryptedCollection(collectionNamespace, createCollectionOptions, kmsProvider, datakeyOptions, cancellationToken: default); - } + + var result = async + ? clientEncryption.CreateEncryptedCollectionAsync(database, collectionNamespace.CollectionName, createCollectionOptions, kmsProvider, datakeyOptions, cancellationToken: default).GetAwaiter().GetResult() + : clientEncryption.CreateEncryptedCollection(database, collectionNamespace.CollectionName, createCollectionOptions, kmsProvider, datakeyOptions, cancellationToken: default); + + effectiveEncryptedFields = result.EncryptedFields; return client.GetDatabase(collectionNamespace.DatabaseNamespace.DatabaseName).GetCollection(collectionNamespace.CollectionName); }