From 6ce1cf0af13b2a7f59aff29e1c4ff81e69b61a3f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Mar 2021 06:14:20 +0000 Subject: [PATCH] Changelog:All: Bump go.mongodb.org/mongo-driver from 1.4.6 to 1.5.0 Bumps [go.mongodb.org/mongo-driver](https://github.com/mongodb/mongo-go-driver) from 1.4.6 to 1.5.0. - [Release notes](https://github.com/mongodb/mongo-go-driver/releases) - [Commits](https://github.com/mongodb/mongo-go-driver/compare/1.4.6...v1.5.0) Signed-off-by: dependabot[bot] --- LIC_FILES_CHKSUM.sha256 | 1 + go.mod | 2 +- go.sum | 11 +- vendor/github.com/youmark/pkcs8/.gitignore | 23 + vendor/github.com/youmark/pkcs8/.travis.yml | 9 + vendor/github.com/youmark/pkcs8/LICENSE | 21 + vendor/github.com/youmark/pkcs8/README | 1 + vendor/github.com/youmark/pkcs8/README.md | 21 + vendor/github.com/youmark/pkcs8/pkcs8.go | 305 +++++ .../bson/bsoncodec/array_codec.go | 50 + .../mongo-driver/bson/bsoncodec/bsoncodec.go | 53 + .../bson/bsoncodec/byte_slice_codec.go | 55 +- .../bson/bsoncodec/default_value_decoders.go | 913 ++++++++++----- .../bson/bsoncodec/default_value_encoders.go | 1 + .../bson/bsoncodec/empty_interface_codec.go | 43 +- .../mongo-driver/bson/bsoncodec/map_codec.go | 4 +- .../bson/bsoncodec/string_codec.go | 53 +- .../bson/bsoncodec/struct_tag_parser.go | 20 + .../mongo-driver/bson/bsoncodec/time_codec.go | 52 +- .../mongo-driver/bson/bsoncodec/types.go | 1 + .../mongo-driver/bson/bsoncodec/uint_codec.go | 79 +- .../mongo-driver/bson/bsonrw/copier.go | 60 +- .../bson/bsonrw/extjson_parser.go | 68 ++ .../go.mongodb.org/mongo-driver/bson/doc.go | 22 +- .../mongo-driver/bson/primitive/decimal.go | 44 + .../mongo-driver/bson/primitive/objectid.go | 6 + .../mongo-driver/bson/raw_value.go | 16 + .../go.mongodb.org/mongo-driver/event/doc.go | 56 + .../mongo-driver/event/monitoring.go | 81 ++ .../internal/background_context.go | 34 + .../internal/cancellation_listener.go | 47 + .../mongo-driver/internal/string_util.go | 45 + .../{x/mongo/driver => mongo}/address/addr.go | 2 +- .../mongo-driver/mongo/bulk_write.go | 16 +- .../mongo-driver/mongo/bulk_write_models.go | 15 +- .../mongo-driver/mongo/change_stream.go | 8 +- .../mongo/change_stream_deployment.go | 8 +- .../mongo-driver/mongo/client.go | 181 ++- .../mongo-driver/mongo/client_encryption.go | 11 +- .../mongo-driver/mongo/collection.go | 116 +- .../mongo-driver/mongo/database.go | 65 +- .../description/description.go | 2 +- .../driver => mongo}/description/server.go | 229 ++-- .../description/server_kind.go | 4 +- .../description/server_selector.go | 29 +- .../mongo/description/topology.go | 142 +++ .../description/topology_kind.go | 0 .../description/topology_version.go | 21 +- .../description/version_range.go | 11 + .../go.mongodb.org/mongo-driver/mongo/doc.go | 40 +- .../mongo-driver/mongo/errors.go | 191 +++ .../mongo-driver/mongo/index_view.go | 37 +- .../mongo-driver/mongo/mongo.go | 29 +- .../mongo/options/aggregateoptions.go | 4 +- .../mongo/options/autoencryptionoptions.go | 18 +- .../mongo/options/clientoptions.go | 45 +- .../mongo/options/clientoptions_1_10.go | 4 + .../mongo/options/clientoptions_1_9.go | 11 +- .../mongo/options/countoptions.go | 3 +- .../mongo/options/datakeyoptions.go | 33 +- .../mongo/options/deleteoptions.go | 3 +- .../mongo-driver/mongo/options/findoptions.go | 58 +- .../mongo/options/gridfsoptions.go | 3 +- .../mongo/options/listcollectionsoptions.go | 12 + .../mongo/options/replaceoptions.go | 3 +- .../mongo/options/updateoptions.go | 3 +- .../mongo-driver/mongo/readpref/mode.go | 14 + .../mongo-driver/mongo/results.go | 102 ++ .../mongo-driver/mongo/session.go | 24 +- .../mongo-driver/version/version.go | 2 +- .../mongo-driver/x/bsonx/bsoncore/array.go | 164 +++ .../x/bsonx/bsoncore/bson_arraybuilder.go | 201 ++++ .../x/bsonx/bsoncore/bson_documentbuilder.go | 189 +++ .../mongo-driver/x/bsonx/bsoncore/document.go | 69 +- .../x/bsonx/reflectionfree_d_codec.go | 1026 +++++++++++++++++ .../mongo-driver/x/mongo/driver/auth/auth.go | 46 +- .../x/mongo/driver/auth/default.go | 21 +- .../mongo-driver/x/mongo/driver/auth/x509.go | 2 +- .../x/mongo/driver/batch_cursor.go | 2 +- .../x/mongo/driver/connstring/connstring.go | 13 +- .../mongo-driver/x/mongo/driver/crypt.go | 40 +- .../x/mongo/driver/description/feature.go | 36 - .../x/mongo/driver/description/version.go | 44 - .../mongo-driver/x/mongo/driver/driver.go | 42 +- .../mongo-driver/x/mongo/driver/errors.go | 4 +- .../x/mongo/driver/mongocrypt/mongocrypt.go | 88 +- .../mongocrypt/options/mongocrypt_options.go | 17 +- .../mongocrypt/options/provider_options.go | 46 - .../mongo-driver/x/mongo/driver/operation.go | 28 +- .../driver/operation/abort_transaction.go | 2 +- .../x/mongo/driver/operation/aggregate.go | 2 +- .../x/mongo/driver/operation/command.go | 2 +- .../driver/operation/commit_transaction.go | 2 +- .../x/mongo/driver/operation/count.go | 2 +- .../x/mongo/driver/operation/create.go | 2 +- .../x/mongo/driver/operation/createIndexes.go | 2 +- .../x/mongo/driver/operation/delete.go | 2 +- .../x/mongo/driver/operation/distinct.go | 2 +- .../mongo/driver/operation/drop_collection.go | 2 +- .../x/mongo/driver/operation/drop_database.go | 2 +- .../x/mongo/driver/operation/drop_indexes.go | 2 +- .../x/mongo/driver/operation/end_sessions.go | 2 +- .../x/mongo/driver/operation/find.go | 2 +- .../mongo/driver/operation/find_and_modify.go | 2 +- .../x/mongo/driver/operation/insert.go | 2 +- .../x/mongo/driver/operation/ismaster.go | 30 +- .../x/mongo/driver/operation/listDatabases.go | 2 +- .../driver/operation/list_collections.go | 19 +- .../x/mongo/driver/operation/list_indexes.go | 2 +- .../x/mongo/driver/operation/update.go | 2 +- .../x/mongo/driver/operation_exhaust.go | 2 +- .../x/mongo/driver/operation_legacy.go | 6 +- .../x/mongo/driver/session/client_session.go | 68 +- .../x/mongo/driver/session/session_pool.go | 2 +- .../driver/topology/cancellation_listener.go | 14 + .../x/mongo/driver/topology/connection.go | 183 ++- .../driver/topology/connection_options.go | 11 +- .../topology.go => topology/diff.go} | 80 +- .../x/mongo/driver/topology/errors.go | 36 +- .../x/mongo/driver/topology/fsm.go | 33 +- .../x/mongo/driver/topology/pool.go | 5 +- .../x/mongo/driver/topology/server.go | 118 +- .../x/mongo/driver/topology/server_options.go | 9 + .../driver/topology/tls_connection_source.go | 18 +- .../x/mongo/driver/topology/topology.go | 151 ++- .../mongo/driver/topology/topology_options.go | 10 + vendor/modules.txt | 8 +- 127 files changed, 5410 insertions(+), 1205 deletions(-) create mode 100644 vendor/github.com/youmark/pkcs8/.gitignore create mode 100644 vendor/github.com/youmark/pkcs8/.travis.yml create mode 100644 vendor/github.com/youmark/pkcs8/LICENSE create mode 100644 vendor/github.com/youmark/pkcs8/README create mode 100644 vendor/github.com/youmark/pkcs8/README.md create mode 100644 vendor/github.com/youmark/pkcs8/pkcs8.go create mode 100644 vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/array_codec.go create mode 100644 vendor/go.mongodb.org/mongo-driver/event/doc.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/background_context.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/cancellation_listener.go create mode 100644 vendor/go.mongodb.org/mongo-driver/internal/string_util.go rename vendor/go.mongodb.org/mongo-driver/{x/mongo/driver => mongo}/address/addr.go (93%) rename vendor/go.mongodb.org/mongo-driver/{x/mongo/driver => mongo}/description/description.go (79%) rename vendor/go.mongodb.org/mongo-driver/{x/mongo/driver => mongo}/description/server.go (67%) rename vendor/go.mongodb.org/mongo-driver/{x/mongo/driver => mongo}/description/server_kind.go (85%) rename vendor/go.mongodb.org/mongo-driver/{x/mongo/driver => mongo}/description/server_selector.go (87%) create mode 100644 vendor/go.mongodb.org/mongo-driver/mongo/description/topology.go rename vendor/go.mongodb.org/mongo-driver/{x/mongo/driver => mongo}/description/topology_kind.go (100%) rename vendor/go.mongodb.org/mongo-driver/{x/mongo/driver => mongo}/description/topology_version.go (62%) rename vendor/go.mongodb.org/mongo-driver/{x/mongo/driver => mongo}/description/version_range.go (75%) create mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/array.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/bson_arraybuilder.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/bson_documentbuilder.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/bsonx/reflectionfree_d_codec.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/feature.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/version.go delete mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options/provider_options.go create mode 100644 vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/cancellation_listener.go rename vendor/go.mongodb.org/mongo-driver/x/mongo/driver/{description/topology.go => topology/diff.go} (51%) diff --git a/LIC_FILES_CHKSUM.sha256 b/LIC_FILES_CHKSUM.sha256 index ec3f6a44..d6d86085 100644 --- a/LIC_FILES_CHKSUM.sha256 +++ b/LIC_FILES_CHKSUM.sha256 @@ -49,6 +49,7 @@ f566a9f97bacdaf00d9f21dd991e81dc11201c4e016c86b470799429a1c9a79c vendor/github. 03458b6d5828e1be1127ca2adf122572eb574fc47b56190c3b38203b8b2a98d0 vendor/github.com/gin-contrib/sse/LICENSE bafe84dafda45b69105faab4dcd00c4c66736441f0dbe8deb836c48434f1cadc vendor/github.com/gin-contrib/cors/LICENSE caa932df46551b53643952fe03b351d55db97be73b5393b986dcf6b05d3c416a vendor/github.com/go-ozzo/ozzo-validation/v4/LICENSE +6d13d89b6c33df37f92553cc95cec278d8770ddc78c9de84620cc16dfb48cef6 vendor/github.com/youmark/pkcs8/LICENSE # BSD-3-Clause 8407b13e462f755c06db3db3a034dc1fdc9157af19c6ea8986e7d5aecf4126b3 vendor/gopkg.in/tomb.v2/LICENSE diff --git a/go.mod b/go.mod index d3a04255..6e51ca70 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/stretchr/testify v1.7.0 github.com/urfave/cli v1.22.5 github.com/vmihailenco/msgpack/v5 v5.1.0 - go.mongodb.org/mongo-driver v1.4.6 + go.mongodb.org/mongo-driver v1.5.0 golang.org/x/sys v0.0.0-20201214095126-aec9a390925b google.golang.org/protobuf v1.25.0 // indirect ) diff --git a/go.sum b/go.sum index 7f3a4e17..dcc444da 100644 --- a/go.sum +++ b/go.sum @@ -213,10 +213,6 @@ github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2y github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mendersoftware/go-lib-micro v0.0.0-20210219095151-13466b5958fb h1:TxxeigPztPVCMiKstl20acFSxcL06C11ktNEW6xJFbU= -github.com/mendersoftware/go-lib-micro v0.0.0-20210219095151-13466b5958fb/go.mod h1:XeXkX/pfX7FOwG/y+NIWBcbtmhnN/sAzuuq0RZ6GcPQ= -github.com/mendersoftware/go-lib-micro v0.0.0-20210301092156-349d3ae1e399 h1:RlM224ofXg2Wh8jD2z93m3yI/eqRyrnwISWSOgxcGls= -github.com/mendersoftware/go-lib-micro v0.0.0-20210301092156-349d3ae1e399/go.mod h1:XeXkX/pfX7FOwG/y+NIWBcbtmhnN/sAzuuq0RZ6GcPQ= github.com/mendersoftware/go-lib-micro v0.0.0-20210311084510-06d0c5918746 h1:P6zxuOT3ATh8AnYZTI95iLnH0kEo7BWno/U6iytqIuI= github.com/mendersoftware/go-lib-micro v0.0.0-20210311084510-06d0c5918746/go.mod h1:AYMV1BNNHDcAt25V7KY6cd5XhAZXIlOVlCV9yZt51zM= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= @@ -287,8 +283,6 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM= -github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.8.0 h1:nfhvjKcUMhBMVqbKHJlk5RPrrfYr/NMo3692g0dwfWU= github.com/sirupsen/logrus v1.8.0/go.mod h1:4GuYW9TZmE769R5STWrRakJc4UqQ3+QQ95fyz7ENv1A= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= @@ -338,9 +332,13 @@ github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhe github.com/xdg/stringprep v0.0.0-20180714160509-73f8eece6fdc h1:n+nNi93yXLkJvKwXNP9d55HC7lGK4H/SRcwB5IaUZLo= github.com/xdg/stringprep v0.0.0-20180714160509-73f8eece6fdc/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.mongodb.org/mongo-driver v1.4.6 h1:rh7GdYmDrb8AQSkF8yteAus8qYOgOASWDOv1BWqBXkU= go.mongodb.org/mongo-driver v1.4.6/go.mod h1:WcMNYLx/IlOxLe6JRJiv2uXuCz6zBLndR4SoGjYphSc= +go.mongodb.org/mongo-driver v1.5.0 h1:REddm85e1Nl0JPXGGhgZkgJdG/yOe6xvpXUcYK5WLt0= +go.mongodb.org/mongo-driver v1.5.0/go.mod h1:boiGPFqyBs5R0R5qf2ErokGRekMfwn+MqKaUyHs7wy0= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -354,6 +352,7 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+vQXfpEPiMdCaZgmGVxjNHM= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= diff --git a/vendor/github.com/youmark/pkcs8/.gitignore b/vendor/github.com/youmark/pkcs8/.gitignore new file mode 100644 index 00000000..83656241 --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/.gitignore @@ -0,0 +1,23 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test diff --git a/vendor/github.com/youmark/pkcs8/.travis.yml b/vendor/github.com/youmark/pkcs8/.travis.yml new file mode 100644 index 00000000..0bceef6f --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/.travis.yml @@ -0,0 +1,9 @@ +language: go + +go: + - "1.9.x" + - "1.10.x" + - master + +script: + - go test -v ./... diff --git a/vendor/github.com/youmark/pkcs8/LICENSE b/vendor/github.com/youmark/pkcs8/LICENSE new file mode 100644 index 00000000..c939f448 --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2014 youmark + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/vendor/github.com/youmark/pkcs8/README b/vendor/github.com/youmark/pkcs8/README new file mode 100644 index 00000000..376fcaf6 --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/README @@ -0,0 +1 @@ +pkcs8 package: implement PKCS#8 private key parsing and conversion as defined in RFC5208 and RFC5958 diff --git a/vendor/github.com/youmark/pkcs8/README.md b/vendor/github.com/youmark/pkcs8/README.md new file mode 100644 index 00000000..f2167dbf --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/README.md @@ -0,0 +1,21 @@ +pkcs8 +=== +OpenSSL can generate private keys in both "traditional format" and PKCS#8 format. Newer applications are advised to use more secure PKCS#8 format. Go standard crypto package provides a [function](http://golang.org/pkg/crypto/x509/#ParsePKCS8PrivateKey) to parse private key in PKCS#8 format. There is a limitation to this function. It can only handle unencrypted PKCS#8 private keys. To use this function, the user has to save the private key in file without encryption, which is a bad practice to leave private keys unprotected on file systems. In addition, Go standard package lacks the functions to convert RSA/ECDSA private keys into PKCS#8 format. + +pkcs8 package fills the gap here. It implements functions to process private keys in PKCS#8 format, as defined in [RFC5208](https://tools.ietf.org/html/rfc5208) and [RFC5958](https://tools.ietf.org/html/rfc5958). It can handle both unencrypted PKCS#8 PrivateKeyInfo format and EncryptedPrivateKeyInfo format with PKCS#5 (v2.0) algorithms. + + +[**Godoc**](http://godoc.org/github.com/youmark/pkcs8) + +## Installation +Supports Go 1.9+ + +```text +go get github.com/youmark/pkcs8 +``` +## dependency +This package depends on golang.org/x/crypto/pbkdf2 package. Use the following command to retrive pbkdf2 package +```text +go get golang.org/x/crypto/pbkdf2 +``` + diff --git a/vendor/github.com/youmark/pkcs8/pkcs8.go b/vendor/github.com/youmark/pkcs8/pkcs8.go new file mode 100644 index 00000000..9270a797 --- /dev/null +++ b/vendor/github.com/youmark/pkcs8/pkcs8.go @@ -0,0 +1,305 @@ +// Package pkcs8 implements functions to parse and convert private keys in PKCS#8 format, as defined in RFC5208 and RFC5958 +package pkcs8 + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "encoding/asn1" + "errors" + + "golang.org/x/crypto/pbkdf2" +) + +// Copy from crypto/x509 +var ( + oidPublicKeyRSA = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} + oidPublicKeyDSA = asn1.ObjectIdentifier{1, 2, 840, 10040, 4, 1} + oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} +) + +// Copy from crypto/x509 +var ( + oidNamedCurveP224 = asn1.ObjectIdentifier{1, 3, 132, 0, 33} + oidNamedCurveP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7} + oidNamedCurveP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34} + oidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35} +) + +// Copy from crypto/x509 +func oidFromNamedCurve(curve elliptic.Curve) (asn1.ObjectIdentifier, bool) { + switch curve { + case elliptic.P224(): + return oidNamedCurveP224, true + case elliptic.P256(): + return oidNamedCurveP256, true + case elliptic.P384(): + return oidNamedCurveP384, true + case elliptic.P521(): + return oidNamedCurveP521, true + } + + return nil, false +} + +// Unecrypted PKCS8 +var ( + oidPKCS5PBKDF2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 12} + oidPBES2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13} + oidAES256CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 42} + oidAES128CBC = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 1, 2} + oidHMACWithSHA256 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 9} + oidDESEDE3CBC = asn1.ObjectIdentifier{1, 2, 840, 113549, 3, 7} +) + +type ecPrivateKey struct { + Version int + PrivateKey []byte + NamedCurveOID asn1.ObjectIdentifier `asn1:"optional,explicit,tag:0"` + PublicKey asn1.BitString `asn1:"optional,explicit,tag:1"` +} + +type privateKeyInfo struct { + Version int + PrivateKeyAlgorithm []asn1.ObjectIdentifier + PrivateKey []byte +} + +// Encrypted PKCS8 +type prfParam struct { + IdPRF asn1.ObjectIdentifier + NullParam asn1.RawValue +} + +type pbkdf2Params struct { + Salt []byte + IterationCount int + PrfParam prfParam `asn1:"optional"` +} + +type pbkdf2Algorithms struct { + IdPBKDF2 asn1.ObjectIdentifier + PBKDF2Params pbkdf2Params +} + +type pbkdf2Encs struct { + EncryAlgo asn1.ObjectIdentifier + IV []byte +} + +type pbes2Params struct { + KeyDerivationFunc pbkdf2Algorithms + EncryptionScheme pbkdf2Encs +} + +type pbes2Algorithms struct { + IdPBES2 asn1.ObjectIdentifier + PBES2Params pbes2Params +} + +type encryptedPrivateKeyInfo struct { + EncryptionAlgorithm pbes2Algorithms + EncryptedData []byte +} + +// ParsePKCS8PrivateKeyRSA parses encrypted/unencrypted private keys in PKCS#8 format. To parse encrypted private keys, a password of []byte type should be provided to the function as the second parameter. +// +// The function can decrypt the private key encrypted with AES-256-CBC mode, and stored in PKCS #5 v2.0 format. +func ParsePKCS8PrivateKeyRSA(der []byte, v ...[]byte) (*rsa.PrivateKey, error) { + key, err := ParsePKCS8PrivateKey(der, v...) + if err != nil { + return nil, err + } + typedKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("key block is not of type RSA") + } + return typedKey, nil +} + +// ParsePKCS8PrivateKeyECDSA parses encrypted/unencrypted private keys in PKCS#8 format. To parse encrypted private keys, a password of []byte type should be provided to the function as the second parameter. +// +// The function can decrypt the private key encrypted with AES-256-CBC mode, and stored in PKCS #5 v2.0 format. +func ParsePKCS8PrivateKeyECDSA(der []byte, v ...[]byte) (*ecdsa.PrivateKey, error) { + key, err := ParsePKCS8PrivateKey(der, v...) + if err != nil { + return nil, err + } + typedKey, ok := key.(*ecdsa.PrivateKey) + if !ok { + return nil, errors.New("key block is not of type ECDSA") + } + return typedKey, nil +} + +// ParsePKCS8PrivateKey parses encrypted/unencrypted private keys in PKCS#8 format. To parse encrypted private keys, a password of []byte type should be provided to the function as the second parameter. +// +// The function can decrypt the private key encrypted with AES-256-CBC mode, and stored in PKCS #5 v2.0 format. +func ParsePKCS8PrivateKey(der []byte, v ...[]byte) (interface{}, error) { + // No password provided, assume the private key is unencrypted + if v == nil { + return x509.ParsePKCS8PrivateKey(der) + } + + // Use the password provided to decrypt the private key + password := v[0] + var privKey encryptedPrivateKeyInfo + if _, err := asn1.Unmarshal(der, &privKey); err != nil { + return nil, errors.New("pkcs8: only PKCS #5 v2.0 supported") + } + + if !privKey.EncryptionAlgorithm.IdPBES2.Equal(oidPBES2) { + return nil, errors.New("pkcs8: only PBES2 supported") + } + + if !privKey.EncryptionAlgorithm.PBES2Params.KeyDerivationFunc.IdPBKDF2.Equal(oidPKCS5PBKDF2) { + return nil, errors.New("pkcs8: only PBKDF2 supported") + } + + encParam := privKey.EncryptionAlgorithm.PBES2Params.EncryptionScheme + kdfParam := privKey.EncryptionAlgorithm.PBES2Params.KeyDerivationFunc.PBKDF2Params + + iv := encParam.IV + salt := kdfParam.Salt + iter := kdfParam.IterationCount + keyHash := sha1.New + if kdfParam.PrfParam.IdPRF.Equal(oidHMACWithSHA256) { + keyHash = sha256.New + } + + encryptedKey := privKey.EncryptedData + var symkey []byte + var block cipher.Block + var err error + switch { + case encParam.EncryAlgo.Equal(oidAES128CBC): + symkey = pbkdf2.Key(password, salt, iter, 16, keyHash) + block, err = aes.NewCipher(symkey) + case encParam.EncryAlgo.Equal(oidAES256CBC): + symkey = pbkdf2.Key(password, salt, iter, 32, keyHash) + block, err = aes.NewCipher(symkey) + case encParam.EncryAlgo.Equal(oidDESEDE3CBC): + symkey = pbkdf2.Key(password, salt, iter, 24, keyHash) + block, err = des.NewTripleDESCipher(symkey) + default: + return nil, errors.New("pkcs8: only AES-256-CBC, AES-128-CBC and DES-EDE3-CBC are supported") + } + if err != nil { + return nil, err + } + mode := cipher.NewCBCDecrypter(block, iv) + mode.CryptBlocks(encryptedKey, encryptedKey) + + key, err := x509.ParsePKCS8PrivateKey(encryptedKey) + if err != nil { + return nil, errors.New("pkcs8: incorrect password") + } + return key, nil +} + +func convertPrivateKeyToPKCS8(priv interface{}) ([]byte, error) { + var pkey privateKeyInfo + + switch priv := priv.(type) { + case *ecdsa.PrivateKey: + eckey, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return nil, err + } + + oidNamedCurve, ok := oidFromNamedCurve(priv.Curve) + if !ok { + return nil, errors.New("pkcs8: unknown elliptic curve") + } + + // Per RFC5958, if publicKey is present, then version is set to v2(1) else version is set to v1(0). + // But openssl set to v1 even publicKey is present + pkey.Version = 1 + pkey.PrivateKeyAlgorithm = make([]asn1.ObjectIdentifier, 2) + pkey.PrivateKeyAlgorithm[0] = oidPublicKeyECDSA + pkey.PrivateKeyAlgorithm[1] = oidNamedCurve + pkey.PrivateKey = eckey + case *rsa.PrivateKey: + + // Per RFC5958, if publicKey is present, then version is set to v2(1) else version is set to v1(0). + // But openssl set to v1 even publicKey is present + pkey.Version = 0 + pkey.PrivateKeyAlgorithm = make([]asn1.ObjectIdentifier, 1) + pkey.PrivateKeyAlgorithm[0] = oidPublicKeyRSA + pkey.PrivateKey = x509.MarshalPKCS1PrivateKey(priv) + } + + return asn1.Marshal(pkey) +} + +func convertPrivateKeyToPKCS8Encrypted(priv interface{}, password []byte) ([]byte, error) { + // Convert private key into PKCS8 format + pkey, err := convertPrivateKeyToPKCS8(priv) + if err != nil { + return nil, err + } + + // Calculate key from password based on PKCS5 algorithm + // Use 8 byte salt, 16 byte IV, and 2048 iteration + iter := 2048 + salt := make([]byte, 8) + iv := make([]byte, 16) + _, err = rand.Read(salt) + if err != nil { + return nil, err + } + _, err = rand.Read(iv) + if err != nil { + return nil, err + } + + key := pbkdf2.Key(password, salt, iter, 32, sha256.New) + + // Use AES256-CBC mode, pad plaintext with PKCS5 padding scheme + padding := aes.BlockSize - len(pkey)%aes.BlockSize + if padding > 0 { + n := len(pkey) + pkey = append(pkey, make([]byte, padding)...) + for i := 0; i < padding; i++ { + pkey[n+i] = byte(padding) + } + } + + encryptedKey := make([]byte, len(pkey)) + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + mode := cipher.NewCBCEncrypter(block, iv) + mode.CryptBlocks(encryptedKey, pkey) + + // pbkdf2algo := pbkdf2Algorithms{oidPKCS5PBKDF2, pbkdf2Params{salt, iter, prfParam{oidHMACWithSHA256}}} + + pbkdf2algo := pbkdf2Algorithms{oidPKCS5PBKDF2, pbkdf2Params{salt, iter, prfParam{oidHMACWithSHA256, asn1.RawValue{Tag: asn1.TagNull}}}} + pbkdf2encs := pbkdf2Encs{oidAES256CBC, iv} + pbes2algo := pbes2Algorithms{oidPBES2, pbes2Params{pbkdf2algo, pbkdf2encs}} + + encryptedPkey := encryptedPrivateKeyInfo{pbes2algo, encryptedKey} + + return asn1.Marshal(encryptedPkey) +} + +// ConvertPrivateKeyToPKCS8 converts the private key into PKCS#8 format. +// To encrypt the private key, the password of []byte type should be provided as the second parameter. +// +// The only supported key types are RSA and ECDSA (*rsa.PublicKey or *ecdsa.PublicKey for priv) +func ConvertPrivateKeyToPKCS8(priv interface{}, v ...[]byte) ([]byte, error) { + if v == nil { + return convertPrivateKeyToPKCS8(priv) + } + + password := string(v[0]) + return convertPrivateKeyToPKCS8Encrypted(priv, []byte(password)) +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/array_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/array_codec.go new file mode 100644 index 00000000..4e24f9ee --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/array_codec.go @@ -0,0 +1,50 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncodec + +import ( + "reflect" + + "go.mongodb.org/mongo-driver/bson/bsonrw" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +// ArrayCodec is the Codec used for bsoncore.Array values. +type ArrayCodec struct{} + +var defaultArrayCodec = NewArrayCodec() + +// NewArrayCodec returns an ArrayCodec. +func NewArrayCodec() *ArrayCodec { + return &ArrayCodec{} +} + +// EncodeValue is the ValueEncoder for bsoncore.Array values. +func (ac *ArrayCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tCoreArray { + return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val} + } + + arr := val.Interface().(bsoncore.Array) + return bsonrw.Copier{}.CopyArrayFromBytes(vw, arr) +} + +// DecodeValue is the ValueDecoder for bsoncore.Array values. +func (ac *ArrayCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tCoreArray { + return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val} + } + + if val.IsNil() { + val.Set(reflect.MakeSlice(val.Type(), 0, 0)) + } + + val.SetLen(0) + arr, err := bsonrw.Copier{}.AppendArrayBytes(val.Interface().(bsoncore.Array), vr) + val.Set(reflect.ValueOf(arr)) + return err +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/bsoncodec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/bsoncodec.go index 0ebc9a15..2c861b5c 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/bsoncodec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/bsoncodec.go @@ -15,6 +15,10 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" ) +var ( + emptyValue = reflect.Value{} +) + // Marshaler is an interface implemented by types that can marshal themselves // into a BSON document represented as bytes. The bytes returned must be a valid // BSON document if the error is nil. @@ -156,6 +160,55 @@ func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, return fn(dc, vr, val) } +// typeDecoder is the interface implemented by types that can handle the decoding of a value given its type. +type typeDecoder interface { + decodeType(DecodeContext, bsonrw.ValueReader, reflect.Type) (reflect.Value, error) +} + +// typeDecoderFunc is an adapter function that allows a function with the correct signature to be used as a typeDecoder. +type typeDecoderFunc func(DecodeContext, bsonrw.ValueReader, reflect.Type) (reflect.Value, error) + +func (fn typeDecoderFunc) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + return fn(dc, vr, t) +} + +// decodeAdapter allows two functions with the correct signatures to be used as both a ValueDecoder and typeDecoder. +type decodeAdapter struct { + ValueDecoderFunc + typeDecoderFunc +} + +var _ ValueDecoder = decodeAdapter{} +var _ typeDecoder = decodeAdapter{} + +// decodeTypeOrValue calls decoder.decodeType is decoder is a typeDecoder. Otherwise, it allocates a new element of type +// t and calls decoder.DecodeValue on it. +func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + td, _ := decoder.(typeDecoder) + return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t, true) +} + +func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type, convert bool) (reflect.Value, error) { + if td != nil { + val, err := td.decodeType(dc, vr, t) + if err == nil && convert && val.Type() != t { + // This conversion step is necessary for slices and maps. If a user declares variables like: + // + // type myBool bool + // var m map[string]myBool + // + // and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present + // because we'll try to assign a value of type bool to one of type myBool. + val = val.Convert(t) + } + return val, err + } + + val := reflect.New(t).Elem() + err := vd.DecodeValue(dc, vr, val) + return val, err +} + // CodecZeroer is the interface implemented by Codecs that can also determine if // a value of the type that would be encoded is zero. type CodecZeroer interface { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/byte_slice_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/byte_slice_codec.go index 8219748d..5a916cc1 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/byte_slice_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/byte_slice_codec.go @@ -15,14 +15,17 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" ) -var defaultByteSliceCodec = NewByteSliceCodec() - // ByteSliceCodec is the Codec used for []byte values. type ByteSliceCodec struct { EncodeNilAsEmpty bool } -var _ ValueCodec = &ByteSliceCodec{} +var ( + defaultByteSliceCodec = NewByteSliceCodec() + + _ ValueCodec = defaultByteSliceCodec + _ typeDecoder = defaultByteSliceCodec +) // NewByteSliceCodec returns a StringCodec with options opts. func NewByteSliceCodec(opts ...*bsonoptions.ByteSliceCodecOptions) *ByteSliceCodec { @@ -45,10 +48,13 @@ func (bsc *ByteSliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, return vw.WriteBinary(val.Interface().([]byte)) } -// DecodeValue is the ValueDecoder for []byte. -func (bsc *ByteSliceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tByteSlice { - return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val} +func (bsc *ByteSliceCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tByteSlice { + return emptyValue, ValueDecoderError{ + Name: "ByteSliceDecodeValue", + Types: []reflect.Type{tByteSlice}, + Received: reflect.Zero(t), + } } var data []byte @@ -57,34 +63,49 @@ func (bsc *ByteSliceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, case bsontype.String: str, err := vr.ReadString() if err != nil { - return err + return emptyValue, err } data = []byte(str) case bsontype.Symbol: sym, err := vr.ReadSymbol() if err != nil { - return err + return emptyValue, err } data = []byte(sym) case bsontype.Binary: var subtype byte data, subtype, err = vr.ReadBinary() if err != nil { - return err + return emptyValue, err } if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld { - return fmt.Errorf("ByteSliceDecodeValue can only be used to decode subtype 0x00 or 0x02 for %s, got %v", bsontype.Binary, subtype) + return emptyValue, decodeBinaryError{subtype: subtype, typeName: "[]byte"} } case bsontype.Null: - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() + err = vr.ReadNull() case bsontype.Undefined: - val.Set(reflect.Zero(val.Type())) - return vr.ReadUndefined() + err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a []byte", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a []byte", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(data), nil +} + +// DecodeValue is the ValueDecoder for []byte. +func (bsc *ByteSliceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tByteSlice { + return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val} + } + + elem, err := bsc.decodeType(dc, vr, tByteSlice) + if err != nil { + return err } - val.Set(reflect.ValueOf(data)) + val.Set(elem) return nil } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_decoders.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_decoders.go index a2e2d425..0402265d 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_decoders.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_decoders.go @@ -22,7 +22,19 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -var defaultValueDecoders DefaultValueDecoders +var ( + defaultValueDecoders DefaultValueDecoders + errCannotTruncate = errors.New("float64 can only be truncated to an integer type when truncation is enabled") +) + +type decodeBinaryError struct { + subtype byte + typeName string +} + +func (d decodeBinaryError) Error() string { + return fmt.Sprintf("only binary values with subtype 0x00 or 0x02 can be decoded into %s, but got subtype %v", d.typeName, d.subtype) +} func newDefaultStructCodec() *StructCodec { codec, err := NewStructCodec(DefaultStructTagParser) @@ -49,40 +61,45 @@ func (dvd DefaultValueDecoders) RegisterDefaultDecoders(rb *RegistryBuilder) { panic(errors.New("argument to RegisterDefaultDecoders must not be nil")) } + intDecoder := decodeAdapter{dvd.IntDecodeValue, dvd.intDecodeType} + floatDecoder := decodeAdapter{dvd.FloatDecodeValue, dvd.floatDecodeType} + rb. - RegisterTypeDecoder(tBinary, ValueDecoderFunc(dvd.BinaryDecodeValue)). - RegisterTypeDecoder(tUndefined, ValueDecoderFunc(dvd.UndefinedDecodeValue)). - RegisterTypeDecoder(tDateTime, ValueDecoderFunc(dvd.DateTimeDecodeValue)). - RegisterTypeDecoder(tNull, ValueDecoderFunc(dvd.NullDecodeValue)). - RegisterTypeDecoder(tRegex, ValueDecoderFunc(dvd.RegexDecodeValue)). - RegisterTypeDecoder(tDBPointer, ValueDecoderFunc(dvd.DBPointerDecodeValue)). - RegisterTypeDecoder(tTimestamp, ValueDecoderFunc(dvd.TimestampDecodeValue)). - RegisterTypeDecoder(tMinKey, ValueDecoderFunc(dvd.MinKeyDecodeValue)). - RegisterTypeDecoder(tMaxKey, ValueDecoderFunc(dvd.MaxKeyDecodeValue)). - RegisterTypeDecoder(tJavaScript, ValueDecoderFunc(dvd.JavaScriptDecodeValue)). - RegisterTypeDecoder(tSymbol, ValueDecoderFunc(dvd.SymbolDecodeValue)). + RegisterTypeDecoder(tD, ValueDecoderFunc(dvd.DDecodeValue)). + RegisterTypeDecoder(tBinary, decodeAdapter{dvd.BinaryDecodeValue, dvd.binaryDecodeType}). + RegisterTypeDecoder(tUndefined, decodeAdapter{dvd.UndefinedDecodeValue, dvd.undefinedDecodeType}). + RegisterTypeDecoder(tDateTime, decodeAdapter{dvd.DateTimeDecodeValue, dvd.dateTimeDecodeType}). + RegisterTypeDecoder(tNull, decodeAdapter{dvd.NullDecodeValue, dvd.nullDecodeType}). + RegisterTypeDecoder(tRegex, decodeAdapter{dvd.RegexDecodeValue, dvd.regexDecodeType}). + RegisterTypeDecoder(tDBPointer, decodeAdapter{dvd.DBPointerDecodeValue, dvd.dBPointerDecodeType}). + RegisterTypeDecoder(tTimestamp, decodeAdapter{dvd.TimestampDecodeValue, dvd.timestampDecodeType}). + RegisterTypeDecoder(tMinKey, decodeAdapter{dvd.MinKeyDecodeValue, dvd.minKeyDecodeType}). + RegisterTypeDecoder(tMaxKey, decodeAdapter{dvd.MaxKeyDecodeValue, dvd.maxKeyDecodeType}). + RegisterTypeDecoder(tJavaScript, decodeAdapter{dvd.JavaScriptDecodeValue, dvd.javaScriptDecodeType}). + RegisterTypeDecoder(tSymbol, decodeAdapter{dvd.SymbolDecodeValue, dvd.symbolDecodeType}). RegisterTypeDecoder(tByteSlice, defaultByteSliceCodec). RegisterTypeDecoder(tTime, defaultTimeCodec). RegisterTypeDecoder(tEmpty, defaultEmptyInterfaceCodec). - RegisterTypeDecoder(tOID, ValueDecoderFunc(dvd.ObjectIDDecodeValue)). - RegisterTypeDecoder(tDecimal, ValueDecoderFunc(dvd.Decimal128DecodeValue)). - RegisterTypeDecoder(tJSONNumber, ValueDecoderFunc(dvd.JSONNumberDecodeValue)). - RegisterTypeDecoder(tURL, ValueDecoderFunc(dvd.URLDecodeValue)). + RegisterTypeDecoder(tCoreArray, defaultArrayCodec). + RegisterTypeDecoder(tOID, decodeAdapter{dvd.ObjectIDDecodeValue, dvd.objectIDDecodeType}). + RegisterTypeDecoder(tDecimal, decodeAdapter{dvd.Decimal128DecodeValue, dvd.decimal128DecodeType}). + RegisterTypeDecoder(tJSONNumber, decodeAdapter{dvd.JSONNumberDecodeValue, dvd.jsonNumberDecodeType}). + RegisterTypeDecoder(tURL, decodeAdapter{dvd.URLDecodeValue, dvd.urlDecodeType}). RegisterTypeDecoder(tCoreDocument, ValueDecoderFunc(dvd.CoreDocumentDecodeValue)). - RegisterTypeDecoder(tCodeWithScope, ValueDecoderFunc(dvd.CodeWithScopeDecodeValue)). - RegisterDefaultDecoder(reflect.Bool, ValueDecoderFunc(dvd.BooleanDecodeValue)). - RegisterDefaultDecoder(reflect.Int, ValueDecoderFunc(dvd.IntDecodeValue)). - RegisterDefaultDecoder(reflect.Int8, ValueDecoderFunc(dvd.IntDecodeValue)). - RegisterDefaultDecoder(reflect.Int16, ValueDecoderFunc(dvd.IntDecodeValue)). - RegisterDefaultDecoder(reflect.Int32, ValueDecoderFunc(dvd.IntDecodeValue)). - RegisterDefaultDecoder(reflect.Int64, ValueDecoderFunc(dvd.IntDecodeValue)). + RegisterTypeDecoder(tCodeWithScope, decodeAdapter{dvd.CodeWithScopeDecodeValue, dvd.codeWithScopeDecodeType}). + RegisterDefaultDecoder(reflect.Bool, decodeAdapter{dvd.BooleanDecodeValue, dvd.booleanDecodeType}). + RegisterDefaultDecoder(reflect.Int, intDecoder). + RegisterDefaultDecoder(reflect.Int8, intDecoder). + RegisterDefaultDecoder(reflect.Int16, intDecoder). + RegisterDefaultDecoder(reflect.Int32, intDecoder). + RegisterDefaultDecoder(reflect.Int64, intDecoder). RegisterDefaultDecoder(reflect.Uint, defaultUIntCodec). RegisterDefaultDecoder(reflect.Uint8, defaultUIntCodec). RegisterDefaultDecoder(reflect.Uint16, defaultUIntCodec). RegisterDefaultDecoder(reflect.Uint32, defaultUIntCodec). RegisterDefaultDecoder(reflect.Uint64, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Float32, ValueDecoderFunc(dvd.FloatDecodeValue)). - RegisterDefaultDecoder(reflect.Float64, ValueDecoderFunc(dvd.FloatDecodeValue)). + RegisterDefaultDecoder(reflect.Float32, floatDecoder). + RegisterDefaultDecoder(reflect.Float64, floatDecoder). RegisterDefaultDecoder(reflect.Array, ValueDecoderFunc(dvd.ArrayDecodeValue)). RegisterDefaultDecoder(reflect.Map, defaultMapCodec). RegisterDefaultDecoder(reflect.Slice, defaultSliceCodec). @@ -114,10 +131,70 @@ func (dvd DefaultValueDecoders) RegisterDefaultDecoders(rb *RegistryBuilder) { RegisterHookDecoder(tUnmarshaler, ValueDecoderFunc(dvd.UnmarshalerDecodeValue)) } -// BooleanDecodeValue is the ValueDecoderFunc for bool types. -func (dvd DefaultValueDecoders) BooleanDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.IsValid() || !val.CanSet() || val.Kind() != reflect.Bool { - return ValueDecoderError{Name: "BooleanDecodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} +// DDecodeValue is the ValueDecoderFunc for primitive.D instances. +func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.IsValid() || !val.CanSet() || val.Type() != tD { + return ValueDecoderError{Name: "DDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} + } + + switch vrType := vr.Type(); vrType { + case bsontype.Type(0), bsontype.EmbeddedDocument: + dc.Ancestor = tD + case bsontype.Null: + val.Set(reflect.Zero(val.Type())) + return vr.ReadNull() + default: + return fmt.Errorf("cannot decode %v into a primitive.D", vrType) + } + + dr, err := vr.ReadDocument() + if err != nil { + return err + } + + decoder, err := dc.LookupDecoder(tEmpty) + if err != nil { + return err + } + tEmptyTypeDecoder, _ := decoder.(typeDecoder) + + // Use the elements in the provided value if it's non nil. Otherwise, allocate a new D instance. + var elems primitive.D + if !val.IsNil() { + val.SetLen(0) + elems = val.Interface().(primitive.D) + } else { + elems = make(primitive.D, 0) + } + + for { + key, elemVr, err := dr.ReadElement() + if err == bsonrw.ErrEOD { + break + } else if err != nil { + return err + } + + // Pass false for convert because we don't need to call reflect.Value.Convert for tEmpty. + elem, err := decodeTypeOrValueWithInfo(decoder, tEmptyTypeDecoder, dc, elemVr, tEmpty, false) + if err != nil { + return err + } + + elems = append(elems, primitive.E{Key: key, Value: elem.Interface()}) + } + + val.Set(reflect.ValueOf(elems)) + return nil +} + +func (dvd DefaultValueDecoders) booleanDecodeType(dctx DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t.Kind() != reflect.Bool { + return emptyValue, ValueDecoderError{ + Name: "BooleanDecodeValue", + Kinds: []reflect.Kind{reflect.Bool}, + Received: reflect.Zero(t), + } } var b bool @@ -126,116 +203,138 @@ func (dvd DefaultValueDecoders) BooleanDecodeValue(dctx DecodeContext, vr bsonrw case bsontype.Int32: i32, err := vr.ReadInt32() if err != nil { - return err + return emptyValue, err } b = (i32 != 0) case bsontype.Int64: i64, err := vr.ReadInt64() if err != nil { - return err + return emptyValue, err } b = (i64 != 0) case bsontype.Double: f64, err := vr.ReadDouble() if err != nil { - return err + return emptyValue, err } b = (f64 != 0) case bsontype.Boolean: b, err = vr.ReadBoolean() - if err != nil { - return err - } case bsontype.Null: - if err = vr.ReadNull(); err != nil { - return err - } + err = vr.ReadNull() case bsontype.Undefined: - if err = vr.ReadUndefined(); err != nil { - return err - } + err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a boolean", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a boolean", vrType) } - val.SetBool(b) - return nil + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(b), nil } -// IntDecodeValue is the ValueDecoderFunc for int types. -func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() { - return ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, - Received: val, - } +// BooleanDecodeValue is the ValueDecoderFunc for bool types. +func (dvd DefaultValueDecoders) BooleanDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.IsValid() || !val.CanSet() || val.Kind() != reflect.Bool { + return ValueDecoderError{Name: "BooleanDecodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} } + elem, err := dvd.booleanDecodeType(dctx, vr, val.Type()) + if err != nil { + return err + } + + val.SetBool(elem.Bool()) + return nil +} + +func (DefaultValueDecoders) intDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { var i64 int64 var err error switch vrType := vr.Type(); vrType { case bsontype.Int32: i32, err := vr.ReadInt32() if err != nil { - return err + return emptyValue, err } i64 = int64(i32) case bsontype.Int64: i64, err = vr.ReadInt64() if err != nil { - return err + return emptyValue, err } case bsontype.Double: f64, err := vr.ReadDouble() if err != nil { - return err + return emptyValue, err } if !dc.Truncate && math.Floor(f64) != f64 { - return errors.New("IntDecodeValue can only truncate float64 to an integer type when truncation is enabled") + return emptyValue, errCannotTruncate } if f64 > float64(math.MaxInt64) { - return fmt.Errorf("%g overflows int64", f64) + return emptyValue, fmt.Errorf("%g overflows int64", f64) } i64 = int64(f64) case bsontype.Boolean: b, err := vr.ReadBoolean() if err != nil { - return err + return emptyValue, err } if b { i64 = 1 } case bsontype.Null: if err = vr.ReadNull(); err != nil { - return err + return emptyValue, err } case bsontype.Undefined: if err = vr.ReadUndefined(); err != nil { - return err + return emptyValue, err } default: - return fmt.Errorf("cannot decode %v into an integer type", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType) } - switch val.Kind() { + switch t.Kind() { case reflect.Int8: if i64 < math.MinInt8 || i64 > math.MaxInt8 { - return fmt.Errorf("%d overflows int8", i64) + return emptyValue, fmt.Errorf("%d overflows int8", i64) } + + return reflect.ValueOf(int8(i64)), nil case reflect.Int16: if i64 < math.MinInt16 || i64 > math.MaxInt16 { - return fmt.Errorf("%d overflows int16", i64) + return emptyValue, fmt.Errorf("%d overflows int16", i64) } + + return reflect.ValueOf(int16(i64)), nil case reflect.Int32: if i64 < math.MinInt32 || i64 > math.MaxInt32 { - return fmt.Errorf("%d overflows int32", i64) + return emptyValue, fmt.Errorf("%d overflows int32", i64) } + + return reflect.ValueOf(int32(i64)), nil case reflect.Int64: + return reflect.ValueOf(i64), nil case reflect.Int: if int64(int(i64)) != i64 { // Can we fit this inside of an int - return fmt.Errorf("%d overflows int", i64) + return emptyValue, fmt.Errorf("%d overflows int", i64) } + + return reflect.ValueOf(int(i64)), nil default: + return emptyValue, ValueDecoderError{ + Name: "IntDecodeValue", + Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Received: reflect.Zero(t), + } + } +} + +// IntDecodeValue is the ValueDecoderFunc for int types. +func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() { return ValueDecoderError{ Name: "IntDecodeValue", Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, @@ -243,7 +342,12 @@ func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr bsonrw.Value } } - val.SetInt(i64) + elem, err := dvd.intDecodeType(dc, vr, val.Type()) + if err != nil { + return err + } + + val.SetInt(elem.Int()) return nil } @@ -330,67 +434,81 @@ func (dvd DefaultValueDecoders) UintDecodeValue(dc DecodeContext, vr bsonrw.Valu return nil } -// FloatDecodeValue is the ValueDecoderFunc for float types. -func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() { - return ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, - Received: val, - } - } - +func (dvd DefaultValueDecoders) floatDecodeType(ec DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { var f float64 var err error switch vrType := vr.Type(); vrType { case bsontype.Int32: i32, err := vr.ReadInt32() if err != nil { - return err + return emptyValue, err } f = float64(i32) case bsontype.Int64: i64, err := vr.ReadInt64() if err != nil { - return err + return emptyValue, err } f = float64(i64) case bsontype.Double: f, err = vr.ReadDouble() if err != nil { - return err + return emptyValue, err } case bsontype.Boolean: b, err := vr.ReadBoolean() if err != nil { - return err + return emptyValue, err } if b { f = 1 } case bsontype.Null: if err = vr.ReadNull(); err != nil { - return err + return emptyValue, err } case bsontype.Undefined: if err = vr.ReadUndefined(); err != nil { - return err + return emptyValue, err } default: - return fmt.Errorf("cannot decode %v into a float32 or float64 type", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a float32 or float64 type", vrType) } - switch val.Kind() { + switch t.Kind() { case reflect.Float32: if !ec.Truncate && float64(float32(f)) != f { - return errors.New("FloatDecodeValue can only convert float64 to float32 when truncation is allowed") + return emptyValue, errCannotTruncate } + + return reflect.ValueOf(float32(f)), nil case reflect.Float64: + return reflect.ValueOf(f), nil default: - return ValueDecoderError{Name: "FloatDecodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val} + return emptyValue, ValueDecoderError{ + Name: "FloatDecodeValue", + Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Received: reflect.Zero(t), + } } +} - val.SetFloat(f) +// FloatDecodeValue is the ValueDecoderFunc for float types. +func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() { + return ValueDecoderError{ + Name: "FloatDecodeValue", + Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Received: val, + } + } + + elem, err := dvd.floatDecodeType(ec, vr, val.Type()) + if err != nil { + return err + } + + val.SetFloat(elem.Float()) return nil } @@ -418,10 +536,13 @@ func (dvd DefaultValueDecoders) StringDecodeValue(dctx DecodeContext, vr bsonrw. return nil } -// JavaScriptDecodeValue is the ValueDecoderFunc for the primitive.JavaScript type. -func (DefaultValueDecoders) JavaScriptDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tJavaScript { - return ValueDecoderError{Name: "JavaScriptDecodeValue", Types: []reflect.Type{tJavaScript}, Received: val} +func (DefaultValueDecoders) javaScriptDecodeType(dctx DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tJavaScript { + return emptyValue, ValueDecoderError{ + Name: "JavaScriptDecodeValue", + Types: []reflect.Type{tJavaScript}, + Received: reflect.Zero(t), + } } var js string @@ -434,20 +555,37 @@ func (DefaultValueDecoders) JavaScriptDecodeValue(dctx DecodeContext, vr bsonrw. case bsontype.Undefined: err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a primitive.JavaScript", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a primitive.JavaScript", vrType) } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.JavaScript(js)), nil +} +// JavaScriptDecodeValue is the ValueDecoderFunc for the primitive.JavaScript type. +func (dvd DefaultValueDecoders) JavaScriptDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tJavaScript { + return ValueDecoderError{Name: "JavaScriptDecodeValue", Types: []reflect.Type{tJavaScript}, Received: val} + } + + elem, err := dvd.javaScriptDecodeType(dctx, vr, tJavaScript) if err != nil { return err } - val.SetString(js) + + val.SetString(elem.String()) return nil } -// SymbolDecodeValue is the ValueDecoderFunc for the primitive.Symbol type. -func (DefaultValueDecoders) SymbolDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tSymbol { - return ValueDecoderError{Name: "SymbolDecodeValue", Types: []reflect.Type{tSymbol}, Received: val} +func (DefaultValueDecoders) symbolDecodeType(dctx DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tSymbol { + return emptyValue, ValueDecoderError{ + Name: "SymbolDecodeValue", + Types: []reflect.Type{tSymbol}, + Received: reflect.Zero(t), + } } var symbol string @@ -455,43 +593,54 @@ func (DefaultValueDecoders) SymbolDecodeValue(dctx DecodeContext, vr bsonrw.Valu switch vrType := vr.Type(); vrType { case bsontype.String: symbol, err = vr.ReadString() - if err != nil { - return err - } case bsontype.Symbol: symbol, err = vr.ReadSymbol() - if err != nil { - return err - } case bsontype.Binary: data, subtype, err := vr.ReadBinary() if err != nil { - return err + return emptyValue, err } + if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld { - return fmt.Errorf("SymbolDecodeValue can only be used to decode subtype 0x00 or 0x02 for %s, got %v", bsontype.Binary, subtype) + return emptyValue, decodeBinaryError{subtype: subtype, typeName: "primitive.Symbol"} } symbol = string(data) case bsontype.Null: - if err = vr.ReadNull(); err != nil { - return err - } + err = vr.ReadNull() case bsontype.Undefined: - if err = vr.ReadUndefined(); err != nil { - return err - } + err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a primitive.Symbol", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a primitive.Symbol", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.Symbol(symbol)), nil +} + +// SymbolDecodeValue is the ValueDecoderFunc for the primitive.Symbol type. +func (dvd DefaultValueDecoders) SymbolDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tSymbol { + return ValueDecoderError{Name: "SymbolDecodeValue", Types: []reflect.Type{tSymbol}, Received: val} + } + + elem, err := dvd.symbolDecodeType(dctx, vr, tSymbol) + if err != nil { + return err } - val.SetString(symbol) + val.SetString(elem.String()) return nil } -// BinaryDecodeValue is the ValueDecoderFunc for Binary. -func (DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tBinary { - return ValueDecoderError{Name: "BinaryDecodeValue", Types: []reflect.Type{tBinary}, Received: val} +func (DefaultValueDecoders) binaryDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tBinary { + return emptyValue, ValueDecoderError{ + Name: "BinaryDecodeValue", + Types: []reflect.Type{tBinary}, + Received: reflect.Zero(t), + } } var data []byte @@ -505,20 +654,37 @@ func (DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr bsonrw.ValueR case bsontype.Undefined: err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a Binary", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a Binary", vrType) + } + if err != nil { + return emptyValue, err } + return reflect.ValueOf(primitive.Binary{Subtype: subtype, Data: data}), nil +} + +// BinaryDecodeValue is the ValueDecoderFunc for Binary. +func (dvd DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tBinary { + return ValueDecoderError{Name: "BinaryDecodeValue", Types: []reflect.Type{tBinary}, Received: val} + } + + elem, err := dvd.binaryDecodeType(dc, vr, tBinary) if err != nil { return err } - val.Set(reflect.ValueOf(primitive.Binary{Subtype: subtype, Data: data})) + + val.Set(elem) return nil } -// UndefinedDecodeValue is the ValueDecoderFunc for Undefined. -func (DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tUndefined { - return ValueDecoderError{Name: "UndefinedDecodeValue", Types: []reflect.Type{tUndefined}, Received: val} +func (DefaultValueDecoders) undefinedDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tUndefined { + return emptyValue, ValueDecoderError{ + Name: "UndefinedDecodeValue", + Types: []reflect.Type{tUndefined}, + Received: reflect.Zero(t), + } } var err error @@ -528,20 +694,37 @@ func (DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr bsonrw.Val case bsontype.Null: err = vr.ReadNull() default: - return fmt.Errorf("cannot decode %v into an Undefined", vr.Type()) + return emptyValue, fmt.Errorf("cannot decode %v into an Undefined", vr.Type()) + } + if err != nil { + return emptyValue, err } + return reflect.ValueOf(primitive.Undefined{}), nil +} + +// UndefinedDecodeValue is the ValueDecoderFunc for Undefined. +func (dvd DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tUndefined { + return ValueDecoderError{Name: "UndefinedDecodeValue", Types: []reflect.Type{tUndefined}, Received: val} + } + + elem, err := dvd.undefinedDecodeType(dc, vr, tUndefined) if err != nil { return err } - val.Set(reflect.ValueOf(primitive.Undefined{})) + + val.Set(elem) return nil } -// ObjectIDDecodeValue is the ValueDecoderFunc for primitive.ObjectID. -func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tOID { - return ValueDecoderError{Name: "ObjectIDDecodeValue", Types: []reflect.Type{tOID}, Received: val} +func (dvd DefaultValueDecoders) objectIDDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tOID { + return emptyValue, ValueDecoderError{ + Name: "ObjectIDDecodeValue", + Types: []reflect.Type{tOID}, + Received: reflect.Zero(t), + } } var oid primitive.ObjectID @@ -550,38 +733,55 @@ func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr bsonrw. case bsontype.ObjectID: oid, err = vr.ReadObjectID() if err != nil { - return err + return emptyValue, err } case bsontype.String: str, err := vr.ReadString() if err != nil { - return err + return emptyValue, err } if len(str) != 12 { - return fmt.Errorf("an ObjectID string must be exactly 12 bytes long (got %v)", len(str)) + return emptyValue, fmt.Errorf("an ObjectID string must be exactly 12 bytes long (got %v)", len(str)) } byteArr := []byte(str) copy(oid[:], byteArr) case bsontype.Null: if err = vr.ReadNull(); err != nil { - return err + return emptyValue, err } case bsontype.Undefined: if err = vr.ReadUndefined(); err != nil { - return err + return emptyValue, err } default: - return fmt.Errorf("cannot decode %v into an ObjectID", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into an ObjectID", vrType) } - val.Set(reflect.ValueOf(oid)) + return reflect.ValueOf(oid), nil +} + +// ObjectIDDecodeValue is the ValueDecoderFunc for primitive.ObjectID. +func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tOID { + return ValueDecoderError{Name: "ObjectIDDecodeValue", Types: []reflect.Type{tOID}, Received: val} + } + + elem, err := dvd.objectIDDecodeType(dc, vr, tOID) + if err != nil { + return err + } + + val.Set(elem) return nil } -// DateTimeDecodeValue is the ValueDecoderFunc for DateTime. -func (DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tDateTime { - return ValueDecoderError{Name: "DateTimeDecodeValue", Types: []reflect.Type{tDateTime}, Received: val} +func (DefaultValueDecoders) dateTimeDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tDateTime { + return emptyValue, ValueDecoderError{ + Name: "DateTimeDecodeValue", + Types: []reflect.Type{tDateTime}, + Received: reflect.Zero(t), + } } var dt int64 @@ -594,20 +794,37 @@ func (DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr bsonrw.Valu case bsontype.Undefined: err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a DateTime", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a DateTime", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.DateTime(dt)), nil +} + +// DateTimeDecodeValue is the ValueDecoderFunc for DateTime. +func (dvd DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tDateTime { + return ValueDecoderError{Name: "DateTimeDecodeValue", Types: []reflect.Type{tDateTime}, Received: val} } + elem, err := dvd.dateTimeDecodeType(dc, vr, tDateTime) if err != nil { return err } - val.Set(reflect.ValueOf(primitive.DateTime(dt))) + + val.Set(elem) return nil } -// NullDecodeValue is the ValueDecoderFunc for Null. -func (DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tNull { - return ValueDecoderError{Name: "NullDecodeValue", Types: []reflect.Type{tNull}, Received: val} +func (DefaultValueDecoders) nullDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tNull { + return emptyValue, ValueDecoderError{ + Name: "NullDecodeValue", + Types: []reflect.Type{tNull}, + Received: reflect.Zero(t), + } } var err error @@ -617,20 +834,37 @@ func (DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr bsonrw.ValueRea case bsontype.Null: err = vr.ReadNull() default: - return fmt.Errorf("cannot decode %v into a Null", vr.Type()) + return emptyValue, fmt.Errorf("cannot decode %v into a Null", vr.Type()) + } + if err != nil { + return emptyValue, err } + return reflect.ValueOf(primitive.Null{}), nil +} + +// NullDecodeValue is the ValueDecoderFunc for Null. +func (dvd DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tNull { + return ValueDecoderError{Name: "NullDecodeValue", Types: []reflect.Type{tNull}, Received: val} + } + + elem, err := dvd.nullDecodeType(dc, vr, tNull) if err != nil { return err } - val.Set(reflect.ValueOf(primitive.Null{})) + + val.Set(elem) return nil } -// RegexDecodeValue is the ValueDecoderFunc for Regex. -func (DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tRegex { - return ValueDecoderError{Name: "RegexDecodeValue", Types: []reflect.Type{tRegex}, Received: val} +func (DefaultValueDecoders) regexDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tRegex { + return emptyValue, ValueDecoderError{ + Name: "RegexDecodeValue", + Types: []reflect.Type{tRegex}, + Received: reflect.Zero(t), + } } var pattern, options string @@ -643,20 +877,37 @@ func (DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr bsonrw.ValueRe case bsontype.Undefined: err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a Regex", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a Regex", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.Regex{Pattern: pattern, Options: options}), nil +} + +// RegexDecodeValue is the ValueDecoderFunc for Regex. +func (dvd DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tRegex { + return ValueDecoderError{Name: "RegexDecodeValue", Types: []reflect.Type{tRegex}, Received: val} } + elem, err := dvd.regexDecodeType(dc, vr, tRegex) if err != nil { return err } - val.Set(reflect.ValueOf(primitive.Regex{Pattern: pattern, Options: options})) + + val.Set(elem) return nil } -// DBPointerDecodeValue is the ValueDecoderFunc for DBPointer. -func (DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tDBPointer { - return ValueDecoderError{Name: "DBPointerDecodeValue", Types: []reflect.Type{tDBPointer}, Received: val} +func (DefaultValueDecoders) dBPointerDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tDBPointer { + return emptyValue, ValueDecoderError{ + Name: "DBPointerDecodeValue", + Types: []reflect.Type{tDBPointer}, + Received: reflect.Zero(t), + } } var ns string @@ -670,20 +921,37 @@ func (DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr bsonrw.Val case bsontype.Undefined: err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a DBPointer", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a DBPointer", vrType) + } + if err != nil { + return emptyValue, err } + return reflect.ValueOf(primitive.DBPointer{DB: ns, Pointer: pointer}), nil +} + +// DBPointerDecodeValue is the ValueDecoderFunc for DBPointer. +func (dvd DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tDBPointer { + return ValueDecoderError{Name: "DBPointerDecodeValue", Types: []reflect.Type{tDBPointer}, Received: val} + } + + elem, err := dvd.dBPointerDecodeType(dc, vr, tDBPointer) if err != nil { return err } - val.Set(reflect.ValueOf(primitive.DBPointer{DB: ns, Pointer: pointer})) + + val.Set(elem) return nil } -// TimestampDecodeValue is the ValueDecoderFunc for Timestamp. -func (DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tTimestamp { - return ValueDecoderError{Name: "TimestampDecodeValue", Types: []reflect.Type{tTimestamp}, Received: val} +func (DefaultValueDecoders) timestampDecodeType(dc DecodeContext, vr bsonrw.ValueReader, reflectType reflect.Type) (reflect.Value, error) { + if reflectType != tTimestamp { + return emptyValue, ValueDecoderError{ + Name: "TimestampDecodeValue", + Types: []reflect.Type{tTimestamp}, + Received: reflect.Zero(reflectType), + } } var t, incr uint32 @@ -696,20 +964,37 @@ func (DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr bsonrw.Val case bsontype.Undefined: err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a Timestamp", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a Timestamp", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.Timestamp{T: t, I: incr}), nil +} + +// TimestampDecodeValue is the ValueDecoderFunc for Timestamp. +func (dvd DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tTimestamp { + return ValueDecoderError{Name: "TimestampDecodeValue", Types: []reflect.Type{tTimestamp}, Received: val} } + elem, err := dvd.timestampDecodeType(dc, vr, tTimestamp) if err != nil { return err } - val.Set(reflect.ValueOf(primitive.Timestamp{T: t, I: incr})) + + val.Set(elem) return nil } -// MinKeyDecodeValue is the ValueDecoderFunc for MinKey. -func (DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tMinKey { - return ValueDecoderError{Name: "MinKeyDecodeValue", Types: []reflect.Type{tMinKey}, Received: val} +func (DefaultValueDecoders) minKeyDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tMinKey { + return emptyValue, ValueDecoderError{ + Name: "MinKeyDecodeValue", + Types: []reflect.Type{tMinKey}, + Received: reflect.Zero(t), + } } var err error @@ -721,20 +1006,37 @@ func (DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr bsonrw.ValueR case bsontype.Undefined: err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a MinKey", vr.Type()) + return emptyValue, fmt.Errorf("cannot decode %v into a MinKey", vr.Type()) + } + if err != nil { + return emptyValue, err } + return reflect.ValueOf(primitive.MinKey{}), nil +} + +// MinKeyDecodeValue is the ValueDecoderFunc for MinKey. +func (dvd DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tMinKey { + return ValueDecoderError{Name: "MinKeyDecodeValue", Types: []reflect.Type{tMinKey}, Received: val} + } + + elem, err := dvd.minKeyDecodeType(dc, vr, tMinKey) if err != nil { return err } - val.Set(reflect.ValueOf(primitive.MinKey{})) + + val.Set(elem) return nil } -// MaxKeyDecodeValue is the ValueDecoderFunc for MaxKey. -func (DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tMaxKey { - return ValueDecoderError{Name: "MaxKeyDecodeValue", Types: []reflect.Type{tMaxKey}, Received: val} +func (DefaultValueDecoders) maxKeyDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tMaxKey { + return emptyValue, ValueDecoderError{ + Name: "MaxKeyDecodeValue", + Types: []reflect.Type{tMaxKey}, + Received: reflect.Zero(t), + } } var err error @@ -746,20 +1048,37 @@ func (DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr bsonrw.ValueR case bsontype.Undefined: err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a MaxKey", vr.Type()) + return emptyValue, fmt.Errorf("cannot decode %v into a MaxKey", vr.Type()) + } + if err != nil { + return emptyValue, err } + return reflect.ValueOf(primitive.MaxKey{}), nil +} + +// MaxKeyDecodeValue is the ValueDecoderFunc for MaxKey. +func (dvd DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tMaxKey { + return ValueDecoderError{Name: "MaxKeyDecodeValue", Types: []reflect.Type{tMaxKey}, Received: val} + } + + elem, err := dvd.maxKeyDecodeType(dc, vr, tMaxKey) if err != nil { return err } - val.Set(reflect.ValueOf(primitive.MaxKey{})) + + val.Set(elem) return nil } -// Decimal128DecodeValue is the ValueDecoderFunc for primitive.Decimal128. -func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tDecimal { - return ValueDecoderError{Name: "Decimal128DecodeValue", Types: []reflect.Type{tDecimal}, Received: val} +func (dvd DefaultValueDecoders) decimal128DecodeType(dctx DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tDecimal { + return emptyValue, ValueDecoderError{ + Name: "Decimal128DecodeValue", + Types: []reflect.Type{tDecimal}, + Received: reflect.Zero(t), + } } var d128 primitive.Decimal128 @@ -772,92 +1091,136 @@ func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr bso case bsontype.Undefined: err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a primitive.Decimal128", vr.Type()) + return emptyValue, fmt.Errorf("cannot decode %v into a primitive.Decimal128", vr.Type()) + } + if err != nil { + return emptyValue, err } + return reflect.ValueOf(d128), nil +} + +// Decimal128DecodeValue is the ValueDecoderFunc for primitive.Decimal128. +func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tDecimal { + return ValueDecoderError{Name: "Decimal128DecodeValue", Types: []reflect.Type{tDecimal}, Received: val} + } + + elem, err := dvd.decimal128DecodeType(dctx, vr, tDecimal) if err != nil { return err } - val.Set(reflect.ValueOf(d128)) - return err + + val.Set(elem) + return nil } -// JSONNumberDecodeValue is the ValueDecoderFunc for json.Number. -func (dvd DefaultValueDecoders) JSONNumberDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tJSONNumber { - return ValueDecoderError{Name: "JSONNumberDecodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} +func (dvd DefaultValueDecoders) jsonNumberDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tJSONNumber { + return emptyValue, ValueDecoderError{ + Name: "JSONNumberDecodeValue", + Types: []reflect.Type{tJSONNumber}, + Received: reflect.Zero(t), + } } + var jsonNum json.Number + var err error switch vrType := vr.Type(); vrType { case bsontype.Double: f64, err := vr.ReadDouble() if err != nil { - return err + return emptyValue, err } - val.Set(reflect.ValueOf(json.Number(strconv.FormatFloat(f64, 'f', -1, 64)))) + jsonNum = json.Number(strconv.FormatFloat(f64, 'f', -1, 64)) case bsontype.Int32: i32, err := vr.ReadInt32() if err != nil { - return err + return emptyValue, err } - val.Set(reflect.ValueOf(json.Number(strconv.FormatInt(int64(i32), 10)))) + jsonNum = json.Number(strconv.FormatInt(int64(i32), 10)) case bsontype.Int64: i64, err := vr.ReadInt64() if err != nil { - return err + return emptyValue, err } - val.Set(reflect.ValueOf(json.Number(strconv.FormatInt(i64, 10)))) + jsonNum = json.Number(strconv.FormatInt(i64, 10)) case bsontype.Null: - if err := vr.ReadNull(); err != nil { - return err - } - val.SetString("") + err = vr.ReadNull() case bsontype.Undefined: - if err := vr.ReadUndefined(); err != nil { - return err - } - val.SetString("") + err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a json.Number", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a json.Number", vrType) } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(jsonNum), nil +} +// JSONNumberDecodeValue is the ValueDecoderFunc for json.Number. +func (dvd DefaultValueDecoders) JSONNumberDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tJSONNumber { + return ValueDecoderError{Name: "JSONNumberDecodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} + } + + elem, err := dvd.jsonNumberDecodeType(dc, vr, tJSONNumber) + if err != nil { + return err + } + + val.Set(elem) return nil } -// URLDecodeValue is the ValueDecoderFunc for url.URL. -func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tURL { - return ValueDecoderError{Name: "URLDecodeValue", Types: []reflect.Type{tURL}, Received: val} +func (dvd DefaultValueDecoders) urlDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tURL { + return emptyValue, ValueDecoderError{ + Name: "URLDecodeValue", + Types: []reflect.Type{tURL}, + Received: reflect.Zero(t), + } } + urlPtr := &url.URL{} + var err error switch vrType := vr.Type(); vrType { case bsontype.String: - str, err := vr.ReadString() + var str string // Declare str here to avoid shadowing err during the ReadString call. + str, err = vr.ReadString() if err != nil { - return err + return emptyValue, err } - parsedURL, err := url.Parse(str) - if err != nil { - return err - } - val.Set(reflect.ValueOf(parsedURL).Elem()) - return nil + urlPtr, err = url.Parse(str) case bsontype.Null: - if err := vr.ReadNull(); err != nil { - return err - } - val.Set(reflect.ValueOf(url.URL{})) - return nil + err = vr.ReadNull() case bsontype.Undefined: - if err := vr.ReadUndefined(); err != nil { - return err - } - val.Set(reflect.ValueOf(url.URL{})) - return nil + err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a *url.URL", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a *url.URL", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(urlPtr).Elem(), nil +} + +// URLDecodeValue is the ValueDecoderFunc for url.URL. +func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tURL { + return ValueDecoderError{Name: "URLDecodeValue", Types: []reflect.Type{tURL}, Received: val} + } + + elem, err := dvd.urlDecodeType(dc, vr, tURL) + if err != nil { + return err } + + val.Set(elem) + return nil } // TimeDecodeValue is the ValueDecoderFunc for time.Time. @@ -1216,6 +1579,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR if err != nil { return nil, err } + eTypeDecoder, _ := decoder.(typeDecoder) idx := 0 for { @@ -1227,9 +1591,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR return nil, err } - elem := reflect.New(eType).Elem() - - err = decoder.DecodeValue(dc, vr, elem) + elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true) if err != nil { return nil, newDecodeError(strconv.Itoa(idx), err) } @@ -1240,48 +1602,71 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR return elems, nil } -// CodeWithScopeDecodeValue is the ValueDecoderFunc for CodeWithScope. -func (dvd DefaultValueDecoders) CodeWithScopeDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tCodeWithScope { - return ValueDecoderError{Name: "CodeWithScopeDecodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} +func (dvd DefaultValueDecoders) readCodeWithScope(dc DecodeContext, vr bsonrw.ValueReader) (primitive.CodeWithScope, error) { + var cws primitive.CodeWithScope + + code, dr, err := vr.ReadCodeWithScope() + if err != nil { + return cws, err } - switch vrType := vr.Type(); vrType { - case bsontype.CodeWithScope: - code, dr, err := vr.ReadCodeWithScope() - if err != nil { - return err - } + scope := reflect.New(tD).Elem() + elems, err := dvd.decodeElemsFromDocumentReader(dc, dr) + if err != nil { + return cws, err + } - scope := reflect.New(tD).Elem() - elems, err := dvd.decodeElemsFromDocumentReader(dc, dr) - if err != nil { - return err - } + scope.Set(reflect.MakeSlice(tD, 0, len(elems))) + scope.Set(reflect.Append(scope, elems...)) - scope.Set(reflect.MakeSlice(tD, 0, len(elems))) - scope.Set(reflect.Append(scope, elems...)) + cws = primitive.CodeWithScope{ + Code: primitive.JavaScript(code), + Scope: scope.Interface().(primitive.D), + } + return cws, nil +} - val.Set(reflect.ValueOf(primitive.CodeWithScope{ - Code: primitive.JavaScript(code), - Scope: scope.Interface().(primitive.D), - })) - return nil - case bsontype.Null: - if err := vr.ReadNull(); err != nil { - return err +func (dvd DefaultValueDecoders) codeWithScopeDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tCodeWithScope { + return emptyValue, ValueDecoderError{ + Name: "CodeWithScopeDecodeValue", + Types: []reflect.Type{tCodeWithScope}, + Received: reflect.Zero(t), } - val.Set(reflect.ValueOf(primitive.CodeWithScope{})) - return nil + } + + var cws primitive.CodeWithScope + var err error + switch vrType := vr.Type(); vrType { + case bsontype.CodeWithScope: + cws, err = dvd.readCodeWithScope(dc, vr) + case bsontype.Null: + err = vr.ReadNull() case bsontype.Undefined: - if err := vr.ReadUndefined(); err != nil { - return err - } - val.Set(reflect.ValueOf(primitive.CodeWithScope{})) - return nil + err = vr.ReadUndefined() default: - return fmt.Errorf("cannot decode %v into a primitive.CodeWithScope", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a primitive.CodeWithScope", vrType) + } + if err != nil { + return emptyValue, err } + + return reflect.ValueOf(cws), nil +} + +// CodeWithScopeDecodeValue is the ValueDecoderFunc for CodeWithScope. +func (dvd DefaultValueDecoders) CodeWithScopeDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tCodeWithScope { + return ValueDecoderError{Name: "CodeWithScopeDecodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} + } + + elem, err := dvd.codeWithScopeDecodeType(dc, vr, tCodeWithScope) + if err != nil { + return err + } + + val.Set(elem) + return nil } func (dvd DefaultValueDecoders) decodeD(dc DecodeContext, vr bsonrw.ValueReader, _ reflect.Value) ([]reflect.Value, error) { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_encoders.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_encoders.go index 01ddbbb6..49a0c3f1 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_encoders.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_encoders.go @@ -70,6 +70,7 @@ func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) { RegisterTypeEncoder(tByteSlice, defaultByteSliceCodec). RegisterTypeEncoder(tTime, defaultTimeCodec). RegisterTypeEncoder(tEmpty, defaultEmptyInterfaceCodec). + RegisterTypeEncoder(tCoreArray, defaultArrayCodec). RegisterTypeEncoder(tOID, ValueEncoderFunc(dve.ObjectIDEncodeValue)). RegisterTypeEncoder(tDecimal, ValueEncoderFunc(dve.Decimal128EncodeValue)). RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(dve.JSONNumberEncodeValue)). diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/empty_interface_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/empty_interface_codec.go index c215ec38..a15636d0 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/empty_interface_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/empty_interface_codec.go @@ -15,14 +15,17 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" ) -var defaultEmptyInterfaceCodec = NewEmptyInterfaceCodec() - // EmptyInterfaceCodec is the Codec used for interface{} values. type EmptyInterfaceCodec struct { DecodeBinaryAsSlice bool } -var _ ValueCodec = &EmptyInterfaceCodec{} +var ( + defaultEmptyInterfaceCodec = NewEmptyInterfaceCodec() + + _ ValueCodec = defaultEmptyInterfaceCodec + _ typeDecoder = defaultEmptyInterfaceCodec +) // NewEmptyInterfaceCodec returns a EmptyInterfaceCodec with options opts. func NewEmptyInterfaceCodec(opts ...*bsonoptions.EmptyInterfaceCodecOptions) *EmptyInterfaceCodec { @@ -86,33 +89,31 @@ func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, val return nil, err } -// DecodeValue is the ValueDecoderFunc for interface{}. -func (eic EmptyInterfaceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tEmpty { - return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val} +func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tEmpty { + return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)} } rtype, err := eic.getEmptyInterfaceDecodeType(dc, vr.Type()) if err != nil { switch vr.Type() { case bsontype.Null: - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() + return reflect.Zero(t), vr.ReadNull() default: - return err + return emptyValue, err } } decoder, err := dc.LookupDecoder(rtype) if err != nil { - return err + return emptyValue, err } - elem := reflect.New(rtype).Elem() - err = decoder.DecodeValue(dc, vr, elem) + elem, err := decodeTypeOrValue(decoder, dc, vr, rtype) if err != nil { - return err + return emptyValue, err } + if eic.DecodeBinaryAsSlice && rtype == tBinary { binElem := elem.Interface().(primitive.Binary) if binElem.Subtype == bsontype.BinaryGeneric || binElem.Subtype == bsontype.BinaryBinaryOld { @@ -120,6 +121,20 @@ func (eic EmptyInterfaceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueRead } } + return elem, nil +} + +// DecodeValue is the ValueDecoderFunc for interface{}. +func (eic EmptyInterfaceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tEmpty { + return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val} + } + + elem, err := eic.decodeType(dc, vr, val.Type()) + if err != nil { + return err + } + val.Set(elem) return nil } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/map_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/map_codec.go index d641960c..fbb8ef42 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/map_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/map_codec.go @@ -178,6 +178,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref if err != nil { return err } + eTypeDecoder, _ := decoder.(typeDecoder) if eType == tEmpty { dc.Ancestor = val.Type() @@ -199,8 +200,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref return err } - elem := reflect.New(eType).Elem() - err = decoder.DecodeValue(dc, vr, elem) + elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true) if err != nil { return newDecodeError(key, err) } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/string_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/string_codec.go index 910f2049..5332b7c3 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/string_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/string_codec.go @@ -15,14 +15,17 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" ) -var defaultStringCodec = NewStringCodec() - // StringCodec is the Codec used for struct values. type StringCodec struct { DecodeObjectIDAsHex bool } -var _ ValueCodec = &StringCodec{} +var ( + defaultStringCodec = NewStringCodec() + + _ ValueCodec = defaultStringCodec + _ typeDecoder = defaultStringCodec +) // NewStringCodec returns a StringCodec with options opts. func NewStringCodec(opts ...*bsonoptions.StringCodecOptions) *StringCodec { @@ -43,23 +46,27 @@ func (sc *StringCodec) EncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, va return vw.WriteString(val.String()) } -// DecodeValue is the ValueDecoder for string types. -func (sc *StringCodec) DecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Kind() != reflect.String { - return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} +func (sc *StringCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t.Kind() != reflect.String { + return emptyValue, ValueDecoderError{ + Name: "StringDecodeValue", + Kinds: []reflect.Kind{reflect.String}, + Received: reflect.Zero(t), + } } + var str string var err error switch vr.Type() { case bsontype.String: str, err = vr.ReadString() if err != nil { - return err + return emptyValue, err } case bsontype.ObjectID: oid, err := vr.ReadObjectID() if err != nil { - return err + return emptyValue, err } if sc.DecodeObjectIDAsHex { str = oid.Hex() @@ -70,29 +77,43 @@ func (sc *StringCodec) DecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, va case bsontype.Symbol: str, err = vr.ReadSymbol() if err != nil { - return err + return emptyValue, err } case bsontype.Binary: data, subtype, err := vr.ReadBinary() if err != nil { - return err + return emptyValue, err } if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld { - return fmt.Errorf("SliceDecodeValue can only be used to decode subtype 0x00 or 0x02 for %s, got %v", bsontype.Binary, subtype) + return emptyValue, decodeBinaryError{subtype: subtype, typeName: "string"} } str = string(data) case bsontype.Null: if err = vr.ReadNull(); err != nil { - return err + return emptyValue, err } case bsontype.Undefined: if err = vr.ReadUndefined(); err != nil { - return err + return emptyValue, err } default: - return fmt.Errorf("cannot decode %v into a string type", vr.Type()) + return emptyValue, fmt.Errorf("cannot decode %v into a string type", vr.Type()) + } + + return reflect.ValueOf(str), nil +} + +// DecodeValue is the ValueDecoder for string types. +func (sc *StringCodec) DecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Kind() != reflect.String { + return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} + } + + elem, err := sc.decodeType(dctx, vr, val.Type()) + if err != nil { + return err } - val.SetString(str) + val.SetString(elem.String()) return nil } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_tag_parser.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_tag_parser.go index 69d0ae4d..6f406c16 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_tag_parser.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/struct_tag_parser.go @@ -91,6 +91,10 @@ var DefaultStructTagParser StructTagParserFunc = func(sf reflect.StructField) (S if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { tag = string(sf.Tag) } + return parseTags(key, tag) +} + +func parseTags(key string, tag string) (StructTags, error) { var st StructTags if tag == "-" { st.Skip = true @@ -117,3 +121,19 @@ var DefaultStructTagParser StructTagParserFunc = func(sf reflect.StructField) (S return st, nil } + +// JSONFallbackStructTagParser has the same behavior as DefaultStructTagParser +// but will also fallback to parsing the json tag instead on a field where the +// bson tag isn't available. +var JSONFallbackStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) { + key := strings.ToLower(sf.Name) + tag, ok := sf.Tag.Lookup("bson") + if !ok { + tag, ok = sf.Tag.Lookup("json") + } + if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { + tag = string(sf.Tag) + } + + return parseTags(key, tag) +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/time_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/time_codec.go index a7df44db..ec7e30f7 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/time_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/time_codec.go @@ -21,14 +21,17 @@ const ( timeFormatString = "2006-01-02T15:04:05.999Z07:00" ) -var defaultTimeCodec = NewTimeCodec() - // TimeCodec is the Codec used for time.Time values. type TimeCodec struct { UseLocalTimeZone bool } -var _ ValueCodec = &TimeCodec{} +var ( + defaultTimeCodec = NewTimeCodec() + + _ ValueCodec = defaultTimeCodec + _ typeDecoder = defaultTimeCodec +) // NewTimeCodec returns a TimeCodec with options opts. func NewTimeCodec(opts ...*bsonoptions.TimeCodecOptions) *TimeCodec { @@ -41,10 +44,13 @@ func NewTimeCodec(opts ...*bsonoptions.TimeCodecOptions) *TimeCodec { return &codec } -// DecodeValue is the ValueDecoderFunc for time.Time. -func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tTime { - return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val} +func (tc *TimeCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tTime { + return emptyValue, ValueDecoderError{ + Name: "TimeDecodeValue", + Types: []reflect.Type{tTime}, + Received: reflect.Zero(t), + } } var timeVal time.Time @@ -52,47 +58,61 @@ func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val re case bsontype.DateTime: dt, err := vr.ReadDateTime() if err != nil { - return err + return emptyValue, err } timeVal = time.Unix(dt/1000, dt%1000*1000000) case bsontype.String: // assume strings are in the isoTimeFormat timeStr, err := vr.ReadString() if err != nil { - return err + return emptyValue, err } timeVal, err = time.Parse(timeFormatString, timeStr) if err != nil { - return err + return emptyValue, err } case bsontype.Int64: i64, err := vr.ReadInt64() if err != nil { - return err + return emptyValue, err } timeVal = time.Unix(i64/1000, i64%1000*1000000) case bsontype.Timestamp: t, _, err := vr.ReadTimestamp() if err != nil { - return err + return emptyValue, err } timeVal = time.Unix(int64(t), 0) case bsontype.Null: if err := vr.ReadNull(); err != nil { - return err + return emptyValue, err } case bsontype.Undefined: if err := vr.ReadUndefined(); err != nil { - return err + return emptyValue, err } default: - return fmt.Errorf("cannot decode %v into a time.Time", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into a time.Time", vrType) } if !tc.UseLocalTimeZone { timeVal = timeVal.UTC() } - val.Set(reflect.ValueOf(timeVal)) + return reflect.ValueOf(timeVal), nil +} + +// DecodeValue is the ValueDecoderFunc for time.Time. +func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tTime { + return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val} + } + + elem, err := tc.decodeType(dc, vr, tTime) + if err != nil { + return err + } + + val.Set(elem) return nil } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/types.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/types.go index bbb6bb9c..fb5b5108 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/types.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/types.go @@ -79,3 +79,4 @@ var tA = reflect.TypeOf(primitive.A{}) var tE = reflect.TypeOf(primitive.E{}) var tCoreDocument = reflect.TypeOf(bsoncore.Document{}) +var tCoreArray = reflect.TypeOf(bsoncore.Array{}) diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/uint_codec.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/uint_codec.go index 3c991264..0b21ce99 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/uint_codec.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/uint_codec.go @@ -7,7 +7,6 @@ package bsoncodec import ( - "errors" "fmt" "math" "reflect" @@ -17,14 +16,17 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" ) -var defaultUIntCodec = NewUIntCodec() - // UIntCodec is the Codec used for uint values. type UIntCodec struct { EncodeToMinSize bool } -var _ ValueCodec = &UIntCodec{} +var ( + defaultUIntCodec = NewUIntCodec() + + _ ValueCodec = defaultUIntCodec + _ typeDecoder = defaultUIntCodec +) // NewUIntCodec returns a UIntCodec with options opts. func NewUIntCodec(opts ...*bsonoptions.UIntCodecOptions) *UIntCodec { @@ -64,84 +66,96 @@ func (uic *UIntCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val r } } -// DecodeValue is the ValueDecoder for uint types. -func (uic *UIntCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { - if !val.CanSet() { - return ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } - } - +func (uic *UIntCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { var i64 int64 var err error switch vrType := vr.Type(); vrType { case bsontype.Int32: i32, err := vr.ReadInt32() if err != nil { - return err + return emptyValue, err } i64 = int64(i32) case bsontype.Int64: i64, err = vr.ReadInt64() if err != nil { - return err + return emptyValue, err } case bsontype.Double: f64, err := vr.ReadDouble() if err != nil { - return err + return emptyValue, err } if !dc.Truncate && math.Floor(f64) != f64 { - return errors.New("UintDecodeValue can only truncate float64 to an integer type when truncation is enabled") + return emptyValue, errCannotTruncate } if f64 > float64(math.MaxInt64) { - return fmt.Errorf("%g overflows int64", f64) + return emptyValue, fmt.Errorf("%g overflows int64", f64) } i64 = int64(f64) case bsontype.Boolean: b, err := vr.ReadBoolean() if err != nil { - return err + return emptyValue, err } if b { i64 = 1 } case bsontype.Null: if err = vr.ReadNull(); err != nil { - return err + return emptyValue, err } case bsontype.Undefined: if err = vr.ReadUndefined(); err != nil { - return err + return emptyValue, err } default: - return fmt.Errorf("cannot decode %v into an integer type", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType) } - switch val.Kind() { + switch t.Kind() { case reflect.Uint8: if i64 < 0 || i64 > math.MaxUint8 { - return fmt.Errorf("%d overflows uint8", i64) + return emptyValue, fmt.Errorf("%d overflows uint8", i64) } + + return reflect.ValueOf(uint8(i64)), nil case reflect.Uint16: if i64 < 0 || i64 > math.MaxUint16 { - return fmt.Errorf("%d overflows uint16", i64) + return emptyValue, fmt.Errorf("%d overflows uint16", i64) } + + return reflect.ValueOf(uint16(i64)), nil case reflect.Uint32: if i64 < 0 || i64 > math.MaxUint32 { - return fmt.Errorf("%d overflows uint32", i64) + return emptyValue, fmt.Errorf("%d overflows uint32", i64) } + + return reflect.ValueOf(uint32(i64)), nil case reflect.Uint64: if i64 < 0 { - return fmt.Errorf("%d overflows uint64", i64) + return emptyValue, fmt.Errorf("%d overflows uint64", i64) } + + return reflect.ValueOf(uint64(i64)), nil case reflect.Uint: if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint - return fmt.Errorf("%d overflows uint", i64) + return emptyValue, fmt.Errorf("%d overflows uint", i64) } + + return reflect.ValueOf(uint(i64)), nil default: + return emptyValue, ValueDecoderError{ + Name: "UintDecodeValue", + Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Received: reflect.Zero(t), + } + } +} + +// DecodeValue is the ValueDecoder for uint types. +func (uic *UIntCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() { return ValueDecoderError{ Name: "UintDecodeValue", Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, @@ -149,6 +163,11 @@ func (uic *UIntCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val r } } - val.SetUint(uint64(i64)) + elem, err := uic.decodeType(dc, vr, val.Type()) + if err != nil { + return err + } + + val.SetUint(elem.Uint()) return nil } diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/copier.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/copier.go index 02e3a7e3..5cdf6460 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/copier.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/copier.go @@ -45,6 +45,22 @@ func (c Copier) CopyDocument(dst ValueWriter, src ValueReader) error { return c.copyDocumentCore(dw, dr) } +// CopyArrayFromBytes copies the values from a BSON array represented as a +// []byte to a ValueWriter. +func (c Copier) CopyArrayFromBytes(dst ValueWriter, src []byte) error { + aw, err := dst.WriteArray() + if err != nil { + return err + } + + err = c.CopyBytesToArrayWriter(aw, src) + if err != nil { + return err + } + + return aw.WriteArrayEnd() +} + // CopyDocumentFromBytes copies the values from a BSON document represented as a // []byte to a ValueWriter. func (c Copier) CopyDocumentFromBytes(dst ValueWriter, src []byte) error { @@ -61,9 +77,29 @@ func (c Copier) CopyDocumentFromBytes(dst ValueWriter, src []byte) error { return dw.WriteDocumentEnd() } +type writeElementFn func(key string) (ValueWriter, error) + +// CopyBytesToArrayWriter copies the values from a BSON Array represented as a []byte to an +// ArrayWriter. +func (c Copier) CopyBytesToArrayWriter(dst ArrayWriter, src []byte) error { + wef := func(_ string) (ValueWriter, error) { + return dst.WriteArrayElement() + } + + return c.copyBytesToValueWriter(src, wef) +} + // CopyBytesToDocumentWriter copies the values from a BSON document represented as a []byte to a // DocumentWriter. func (c Copier) CopyBytesToDocumentWriter(dst DocumentWriter, src []byte) error { + wef := func(key string) (ValueWriter, error) { + return dst.WriteDocumentElement(key) + } + + return c.copyBytesToValueWriter(src, wef) +} + +func (c Copier) copyBytesToValueWriter(src []byte, wef writeElementFn) error { // TODO(skriptble): Create errors types here. Anything thats a tag should be a property. length, rem, ok := bsoncore.ReadLength(src) if !ok { @@ -93,15 +129,18 @@ func (c Copier) CopyBytesToDocumentWriter(dst DocumentWriter, src []byte) error if !ok { return fmt.Errorf("invalid key found. remaining bytes=%v", rem) } - dvw, err := dst.WriteDocumentElement(key) + + // write as either array element or document element using writeElementFn + vw, err := wef(key) if err != nil { return err } + val, rem, ok = bsoncore.ReadValue(rem, t) if !ok { return fmt.Errorf("not enough bytes available to read type. bytes=%d type=%s", len(rem), t) } - err = c.CopyValueFromBytes(dvw, t, val.Data) + err = c.CopyValueFromBytes(vw, t, val.Data) if err != nil { return err } @@ -133,6 +172,23 @@ func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) return dst, err } +// AppendArrayBytes copies an array from the ValueReader to dst. +func (c Copier) AppendArrayBytes(dst []byte, src ValueReader) ([]byte, error) { + if br, ok := src.(BytesReader); ok { + _, dst, err := br.ReadValueBytes(dst) + return dst, err + } + + vw := vwPool.Get().(*valueWriter) + defer vwPool.Put(vw) + + vw.reset(dst) + + err := c.copyArray(vw, src) + dst = vw.buf + return dst, err +} + // CopyValueFromBytes will write the value represtend by t and src to dst. func (c Copier) CopyValueFromBytes(dst ValueWriter, t bsontype.Type, src []byte) error { if wvb, ok := dst.(BytesWriter); ok { diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_parser.go b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_parser.go index 3ff17c19..8a690e37 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_parser.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsonrw/extjson_parser.go @@ -7,9 +7,12 @@ package bsonrw import ( + "encoding/base64" + "encoding/hex" "errors" "fmt" "io" + "strings" "go.mongodb.org/mongo-driver/bson/bsontype" ) @@ -66,6 +69,7 @@ type extJSONParser struct { maxDepth int emptyObject bool + relaxedUUID bool } // newExtJSONParser returns a new extended JSON parser, ready to to begin @@ -119,6 +123,12 @@ func (ejp *extJSONParser) peekType() (bsontype.Type, error) { } t = wrapperKeyBSONType(ejp.k) + // if $uuid is encountered, parse as binary subtype 4 + if ejp.k == "$uuid" { + ejp.relaxedUUID = true + t = bsontype.Binary + } + switch t { case bsontype.JavaScript: // just saw $code, need to check for $scope at same level @@ -273,6 +283,64 @@ func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) { ejp.advanceState() if t == bsontype.Binary && ejp.s == jpsSawValue { + // convert relaxed $uuid format + if ejp.relaxedUUID { + defer func() { ejp.relaxedUUID = false }() + uuid, err := ejp.v.parseSymbol() + if err != nil { + return nil, err + } + + // RFC 4122 defines the length of a UUID as 36 and the hyphens in a UUID as appearing + // in the 8th, 13th, 18th, and 23rd characters. + // + // See https://tools.ietf.org/html/rfc4122#section-3 + valid := len(uuid) == 36 && + string(uuid[8]) == "-" && + string(uuid[13]) == "-" && + string(uuid[18]) == "-" && + string(uuid[23]) == "-" + if !valid { + return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens") + } + + // remove hyphens + uuidNoHyphens := strings.Replace(uuid, "-", "", -1) + if len(uuidNoHyphens) != 32 { + return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens") + } + + // convert hex to bytes + bytes, err := hex.DecodeString(uuidNoHyphens) + if err != nil { + return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %v", err) + } + + ejp.advanceState() + if ejp.s != jpsSawEndObject { + return nil, invalidJSONErrorForType("$uuid and value and then }", bsontype.Binary) + } + + base64 := &extJSONValue{ + t: bsontype.String, + v: base64.StdEncoding.EncodeToString(bytes), + } + subType := &extJSONValue{ + t: bsontype.String, + v: "04", + } + + v = &extJSONValue{ + t: bsontype.EmbeddedDocument, + v: &extJSONObject{ + keys: []string{"base64", "subType"}, + values: []*extJSONValue{base64, subType}, + }, + } + + break + } + // convert legacy $binary format base64 := ejp.v diff --git a/vendor/go.mongodb.org/mongo-driver/bson/doc.go b/vendor/go.mongodb.org/mongo-driver/bson/doc.go index 5f411b62..16341568 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/doc.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/doc.go @@ -43,7 +43,7 @@ // 6. BSON embedded document unmarshals to the parent type (i.e. D for a D, M for an M). // 7. BSON array unmarshals to a bson.A. // 8. BSON ObjectId unmarshals to a primitive.ObjectID. -// 9. BSON datetime unmarshals to a primitive.Datetime. +// 9. BSON datetime unmarshals to a primitive.DateTime. // 10. BSON binary unmarshals to a primitive.Binary. // 11. BSON regular expression unmarshals to a primitive.Regex. // 12. BSON JavaScript unmarshals to a primitive.JavaScript. @@ -90,14 +90,26 @@ // unmarshalled into an interface{} field will be unmarshalled as a D. // // The encoding of each struct field can be customized by the "bson" struct tag. -// The tag gives the name of the field, possibly followed by a comma-separated list of options. +// +// This tag behavior is configurable, and different struct tag behavior can be configured by initializing a new +// bsoncodec.StructCodec with the desired tag parser and registering that StructCodec onto the Registry. By default, JSON tags +// are not honored, but that can be enabled by creating a StructCodec with JSONFallbackStructTagParser, like below: +// +// Example: +// structcodec, _ := bsoncodec.NewStructCodec(bsoncodec.JSONFallbackStructTagParser) +// +// The bson tag gives the name of the field, possibly followed by a comma-separated list of options. // The name may be empty in order to specify options without overriding the default field name. The following options can be used // to configure behavior: // // 1. omitempty: If the omitempty struct tag is specified on a field, the field will not be marshalled if it is set to -// the zero value. By default, a struct field is only considered empty if the field's type implements the Zeroer -// interface and the IsZero method returns true. Struct fields of types that do not implement Zeroer are always -// marshalled as embedded documents. This tag should be used for all slice and map values. +// the zero value. Fields with language primitive types such as integers, booleans, and strings are considered empty if +// their value is equal to the zero value for the type (i.e. 0 for integers, false for booleans, and "" for strings). +// Slices, maps, and arrays are considered empty if they are of length zero. Interfaces and pointers are considered +// empty if their value is nil. By default, structs are only considered empty if the struct type implements the +// bsoncodec.Zeroer interface and the IsZero method returns true. Struct fields whose types do not implement Zeroer are +// never considered empty and will be marshalled as embedded documents. +// NOTE: It is recommended that this tag be used for all slice and map fields. // // 2. minsize: If the minsize struct tag is specified on a field of type int64, uint, uint32, or uint64 and the value of // the field can fit in a signed int32, the field will be serialized as a BSON int32 rather than a BSON int64. For other diff --git a/vendor/go.mongodb.org/mongo-driver/bson/primitive/decimal.go b/vendor/go.mongodb.org/mongo-driver/bson/primitive/decimal.go index fdd90d89..a57e1d69 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/primitive/decimal.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/primitive/decimal.go @@ -10,6 +10,7 @@ package primitive import ( + "encoding/json" "errors" "fmt" "math/big" @@ -211,6 +212,49 @@ func (d Decimal128) IsZero() bool { return d.h == 0 && d.l == 0 } +// MarshalJSON returns Decimal128 as a string. +func (d Decimal128) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +// UnmarshalJSON creates a primitive.Decimal128 from a JSON string, an extended JSON $numberDecimal value, or the string +// "null". If b is a JSON string or extended JSON value, d will have the value of that string, and if b is "null", d will +// be unchanged. +func (d *Decimal128) UnmarshalJSON(b []byte) error { + // Ignore "null" to keep parity with the standard library. Decoding a JSON null into a non-pointer Decimal128 field + // will leave the field unchanged. For pointer values, encoding/json will set the pointer to nil and will not + // enter the UnmarshalJSON hook. + if string(b) == "null" { + return nil + } + + var res interface{} + err := json.Unmarshal(b, &res) + if err != nil { + return err + } + str, ok := res.(string) + + // Extended JSON + if !ok { + m, ok := res.(map[string]interface{}) + if !ok { + return errors.New("not an extended JSON Decimal128: expected document") + } + d128, ok := m["$numberDecimal"] + if !ok { + return errors.New("not an extended JSON Decimal128: expected key $numberDecimal") + } + str, ok = d128.(string) + if !ok { + return errors.New("not an extended JSON Decimal128: expected decimal to be string") + } + } + + *d, err = ParseDecimal128(str) + return err +} + func divmod(h, l uint64, div uint32) (qh, ql uint64, rem uint32) { div64 := uint64(div) a := h >> 32 diff --git a/vendor/go.mongodb.org/mongo-driver/bson/primitive/objectid.go b/vendor/go.mongodb.org/mongo-driver/bson/primitive/objectid.go index a0eb5378..30aaafe6 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/primitive/objectid.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/primitive/objectid.go @@ -88,6 +88,12 @@ func ObjectIDFromHex(s string) (ObjectID, error) { return oid, nil } +// IsValidObjectID returns true if the provided hex string represents a valid ObjectID and false if not. +func IsValidObjectID(s string) bool { + _, err := ObjectIDFromHex(s) + return err == nil +} + // MarshalJSON returns the ObjectID as a string func (id ObjectID) MarshalJSON() ([]byte, error) { return json.Marshal(id.Hex()) diff --git a/vendor/go.mongodb.org/mongo-driver/bson/raw_value.go b/vendor/go.mongodb.org/mongo-driver/bson/raw_value.go index bd4c0503..75297f30 100644 --- a/vendor/go.mongodb.org/mongo-driver/bson/raw_value.go +++ b/vendor/go.mongodb.org/mongo-driver/bson/raw_value.go @@ -266,6 +266,14 @@ func (rv RawValue) Int32() int32 { return convertToCoreValue(rv).Int32() } // panicking. func (rv RawValue) Int32OK() (int32, bool) { return convertToCoreValue(rv).Int32OK() } +// AsInt32 returns a BSON number as an int32. If the BSON type is not a numeric one, this method +// will panic. +func (rv RawValue) AsInt32() int32 { return convertToCoreValue(rv).AsInt32() } + +// AsInt32OK is the same as AsInt32, except that it returns a boolean instead of +// panicking. +func (rv RawValue) AsInt32OK() (int32, bool) { return convertToCoreValue(rv).AsInt32OK() } + // Timestamp returns the BSON timestamp value the Value represents. It panics if the value is a // BSON type other than timestamp. func (rv RawValue) Timestamp() (t, i uint32) { return convertToCoreValue(rv).Timestamp() } @@ -282,6 +290,14 @@ func (rv RawValue) Int64() int64 { return convertToCoreValue(rv).Int64() } // panicking. func (rv RawValue) Int64OK() (int64, bool) { return convertToCoreValue(rv).Int64OK() } +// AsInt64 returns a BSON number as an int64. If the BSON type is not a numeric one, this method +// will panic. +func (rv RawValue) AsInt64() int64 { return convertToCoreValue(rv).AsInt64() } + +// AsInt64OK is the same as AsInt64, except that it returns a boolean instead of +// panicking. +func (rv RawValue) AsInt64OK() (int64, bool) { return convertToCoreValue(rv).AsInt64OK() } + // Decimal128 returns the decimal the Value represents. It panics if the value is a BSON type other than // decimal. func (rv RawValue) Decimal128() primitive.Decimal128 { return convertToCoreValue(rv).Decimal128() } diff --git a/vendor/go.mongodb.org/mongo-driver/event/doc.go b/vendor/go.mongodb.org/mongo-driver/event/doc.go new file mode 100644 index 00000000..93b5ede0 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/event/doc.go @@ -0,0 +1,56 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +// Package event is a library for monitoring events from the MongoDB Go +// driver. Monitors can be set for commands sent to the MongoDB cluster, +// connection pool changes, or changes on the MongoDB cluster. +// +// Monitoring commands requires specifying a CommandMonitor when constructing +// a mongo.Client. A CommandMonitor can be set to monitor started, succeeded, +// and/or failed events. A CommandStartedEvent can be correlated to its matching +// CommandSucceededEvent or CommandFailedEvent through the RequestID field. For +// example, the following code collects the names of started events: +// +// var commandStarted []string +// cmdMonitor := &event.CommandMonitor{ +// Started: func(_ context.Context, evt *event.CommandStartedEvent) { +// commandStarted = append(commandStarted, evt.CommandName) +// }, +// } +// clientOpts := options.Client().ApplyURI("mongodb://localhost:27017").SetMonitor(cmdMonitor) +// client, err := mongo.Connect(context.Background(), clientOpts) +// +// Monitoring the connection pool requires specifying a PoolMonitor when constructing +// a mongo.Client. The following code tracks the number of checked out connections: +// +// var int connsCheckedOut +// poolMonitor := &event.PoolMonitor{ +// Event: func(evt *event.PoolEvent) { +// switch evt.Type { +// case event.GetSucceeded: +// connsCheckedOut++ +// case event.ConnectionReturned: +// connsCheckedOut-- +// } +// }, +// } +// clientOpts := options.Client().ApplyURI("mongodb://localhost:27017").SetPoolMonitor(poolMonitor) +// client, err := mongo.Connect(context.Background(), clientOpts) +// +// Monitoring server changes specifying a ServerMonitor object when constructing +// a mongo.Client. Different functions can be set on the ServerMonitor to +// monitor different kinds of events. See ServerMonitor for more details. +// The following code appends ServerHeartbeatStartedEvents to a slice: +// +// var heartbeatStarted []*event.ServerHeartbeatStartedEvent +// svrMonitor := &event.ServerMonitor{ +// ServerHeartbeatStarted: func(e *event.ServerHeartbeatStartedEvent) { +// heartbeatStarted = append(heartbeatStarted, e) +// } +// } +// clientOpts := options.Client().ApplyURI("mongodb://localhost:27017").SetServerMonitor(svrMonitor) +// client, err := mongo.Connect(context.Background(), clientOpts) +package event diff --git a/vendor/go.mongodb.org/mongo-driver/event/monitoring.go b/vendor/go.mongodb.org/mongo-driver/event/monitoring.go index 240f2398..be891fbf 100644 --- a/vendor/go.mongodb.org/mongo-driver/event/monitoring.go +++ b/vendor/go.mongodb.org/mongo-driver/event/monitoring.go @@ -10,6 +10,9 @@ import ( "context" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/description" ) // CommandStartedEvent represents an event generated when a command is sent to a server. @@ -62,6 +65,7 @@ const ( ConnectionClosed = "ConnectionClosed" PoolCreated = "ConnectionPoolCreated" ConnectionCreated = "ConnectionCreated" + ConnectionReady = "ConnectionReady" GetFailed = "ConnectionCheckOutFailed" GetSucceeded = "ConnectionCheckedOut" ConnectionReturned = "ConnectionCheckedIn" @@ -89,3 +93,80 @@ type PoolEvent struct { type PoolMonitor struct { Event func(*PoolEvent) } + +// ServerDescriptionChangedEvent represents a server description change. +type ServerDescriptionChangedEvent struct { + Address address.Address + TopologyID primitive.ObjectID // A unique identifier for the topology this server is a part of + PreviousDescription description.Server + NewDescription description.Server +} + +// ServerOpeningEvent is an event generated when the server is initialized. +type ServerOpeningEvent struct { + Address address.Address + TopologyID primitive.ObjectID // A unique identifier for the topology this server is a part of +} + +// ServerClosedEvent is an event generated when the server is closed. +type ServerClosedEvent struct { + Address address.Address + TopologyID primitive.ObjectID // A unique identifier for the topology this server is a part of +} + +// TopologyDescriptionChangedEvent represents a topology description change. +type TopologyDescriptionChangedEvent struct { + TopologyID primitive.ObjectID // A unique identifier for the topology this server is a part of + PreviousDescription description.Topology + NewDescription description.Topology +} + +// TopologyOpeningEvent is an event generated when the topology is initialized. +type TopologyOpeningEvent struct { + TopologyID primitive.ObjectID // A unique identifier for the topology this server is a part of +} + +// TopologyClosedEvent is an event generated when the topology is closed. +type TopologyClosedEvent struct { + TopologyID primitive.ObjectID // A unique identifier for the topology this server is a part of +} + +// ServerHeartbeatStartedEvent is an event generated when the heartbeat is started. +type ServerHeartbeatStartedEvent struct { + ConnectionID string // The address this heartbeat was sent to with a unique identifier + Awaited bool // If this heartbeat was awaitable +} + +// ServerHeartbeatSucceededEvent is an event generated when the heartbeat succeeds. +type ServerHeartbeatSucceededEvent struct { + DurationNanos int64 + Reply description.Server + ConnectionID string // The address this heartbeat was sent to with a unique identifier + Awaited bool // If this heartbeat was awaitable +} + +// ServerHeartbeatFailedEvent is an event generated when the heartbeat fails. +type ServerHeartbeatFailedEvent struct { + DurationNanos int64 + Failure error + ConnectionID string // The address this heartbeat was sent to with a unique identifier + Awaited bool // If this heartbeat was awaitable +} + +// ServerMonitor represents a monitor that is triggered for different server events. The client +// will monitor changes on the MongoDB deployment it is connected to, and this monitor reports +// the changes in the client's representation of the deployment. The topology represents the +// overall deployment, and heartbeats are sent to individual servers to check their current status. +type ServerMonitor struct { + ServerDescriptionChanged func(*ServerDescriptionChangedEvent) + ServerOpening func(*ServerOpeningEvent) + ServerClosed func(*ServerClosedEvent) + // TopologyDescriptionChanged is called when the topology is locked, so the callback should + // not attempt any operation that requires server selection on the same client. + TopologyDescriptionChanged func(*TopologyDescriptionChangedEvent) + TopologyOpening func(*TopologyOpeningEvent) + TopologyClosed func(*TopologyClosedEvent) + ServerHeartbeatStarted func(*ServerHeartbeatStartedEvent) + ServerHeartbeatSucceeded func(*ServerHeartbeatSucceededEvent) + ServerHeartbeatFailed func(*ServerHeartbeatFailedEvent) +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/background_context.go b/vendor/go.mongodb.org/mongo-driver/internal/background_context.go new file mode 100644 index 00000000..6f190edb --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/background_context.go @@ -0,0 +1,34 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package internal + +import "context" + +// backgroundContext is an implementation of the context.Context interface that wraps a child Context. Value requests +// are forwarded to the child Context but the Done and Err functions are overridden to ensure the new context does not +// time out or get cancelled. +type backgroundContext struct { + context.Context + childValuesCtx context.Context +} + +// NewBackgroundContext creates a new Context whose behavior matches that of context.Background(), but Value calls are +// forwarded to the provided ctx parameter. If ctx is nil, context.Background() is returned. +func NewBackgroundContext(ctx context.Context) context.Context { + if ctx == nil { + return context.Background() + } + + return &backgroundContext{ + Context: context.Background(), + childValuesCtx: ctx, + } +} + +func (b *backgroundContext) Value(key interface{}) interface{} { + return b.childValuesCtx.Value(key) +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/cancellation_listener.go b/vendor/go.mongodb.org/mongo-driver/internal/cancellation_listener.go new file mode 100644 index 00000000..a7fa163b --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/cancellation_listener.go @@ -0,0 +1,47 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package internal + +import "context" + +// CancellationListener listens for context cancellation in a loop until the context expires or the listener is aborted. +type CancellationListener struct { + aborted bool + done chan struct{} +} + +// NewCancellationListener constructs a CancellationListener. +func NewCancellationListener() *CancellationListener { + return &CancellationListener{ + done: make(chan struct{}), + } +} + +// Listen blocks until the provided context is cancelled or listening is aborted via the StopListening function. If this +// detects that the context has been cancelled (i.e. ctx.Err() == context.Canceled), the provided callback is called to +// abort in-progress work. Even if the context expires, this function will block until StopListening is called. +func (c *CancellationListener) Listen(ctx context.Context, abortFn func()) { + c.aborted = false + + select { + case <-ctx.Done(): + if ctx.Err() == context.Canceled { + c.aborted = true + abortFn() + } + + <-c.done + case <-c.done: + } +} + +// StopListening stops the in-progress Listen call. This blocks if there is no in-progress Listen call. This function +// will return true if the provided abort callback was called when listening for cancellation on the previous context. +func (c *CancellationListener) StopListening() bool { + c.done <- struct{}{} + return c.aborted +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/string_util.go b/vendor/go.mongodb.org/mongo-driver/internal/string_util.go new file mode 100644 index 00000000..db1e1890 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/string_util.go @@ -0,0 +1,45 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package internal + +import ( + "fmt" + + "go.mongodb.org/mongo-driver/bson" +) + +// StringSliceFromRawElement decodes the provided BSON element into a []string. This internally calls +// StringSliceFromRawValue on the element's value. The error conditions outlined in that function's documentation +// apply for this function as well. +func StringSliceFromRawElement(element bson.RawElement) ([]string, error) { + return StringSliceFromRawValue(element.Key(), element.Value()) +} + +// StringSliceFromRawValue decodes the provided BSON value into a []string. This function returns an error if the value +// is not an array or any of the elements in the array are not strings. The name parameter is used to add context to +// error messages. +func StringSliceFromRawValue(name string, val bson.RawValue) ([]string, error) { + arr, ok := val.ArrayOK() + if !ok { + return nil, fmt.Errorf("expected '%s' to be an array but it's a BSON %s", name, val.Type) + } + + arrayValues, err := arr.Values() + if err != nil { + return nil, err + } + + var strs []string + for _, arrayVal := range arrayValues { + str, ok := arrayVal.StringValueOK() + if !ok { + return nil, fmt.Errorf("expected '%s' to be an array of strings, but found a BSON %s", name, arrayVal.Type) + } + strs = append(strs, str) + } + return strs, nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/address/addr.go b/vendor/go.mongodb.org/mongo-driver/mongo/address/addr.go similarity index 93% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/address/addr.go rename to vendor/go.mongodb.org/mongo-driver/mongo/address/addr.go index ac2c981c..5655b346 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/address/addr.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/address/addr.go @@ -4,7 +4,7 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -package address // import "go.mongodb.org/mongo-driver/x/mongo/driver/address" +package address // import "go.mongodb.org/mongo-driver/mongo/address" import ( "net" diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write.go b/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write.go index a2d17155..532f544e 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write.go @@ -10,11 +10,11 @@ import ( "context" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -181,7 +181,7 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor). ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). Database(bw.collection.db.name).Collection(bw.collection.name). - Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt) + Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE) if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation { op = op.BypassDocumentValidation(*bw.bypassDocumentValidation) } @@ -230,7 +230,7 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor). ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). Database(bw.collection.db.name).Collection(bw.collection.name). - Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt).Hint(hasHint) + Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint) if bw.ordered != nil { op = op.Ordered(*bw.ordered) } @@ -248,7 +248,7 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera func createDeleteDoc(filter interface{}, collation *options.Collation, hint interface{}, deleteOne bool, registry *bsoncodec.Registry) (bsoncore.Document, error) { - f, err := transformBsoncoreDocument(registry, filter) + f, err := transformBsoncoreDocument(registry, filter, true, "filter") if err != nil { return nil, err } @@ -264,7 +264,7 @@ func createDeleteDoc(filter interface{}, collation *options.Collation, hint inte doc = bsoncore.AppendDocumentElement(doc, "collation", collation.ToDocument()) } if hint != nil { - hintVal, err := transformValue(registry, hint) + hintVal, err := transformValue(registry, hint, false, "hint") if err != nil { return nil, err } @@ -310,7 +310,7 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor). ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock). Database(bw.collection.db.name).Collection(bw.collection.name). - Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt).Hint(hasHint). + Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint). ArrayFilters(hasArrayFilters) if bw.ordered != nil { op = op.Ordered(*bw.ordered) @@ -339,7 +339,7 @@ func createUpdateDoc( checkDollarKey bool, registry *bsoncodec.Registry, ) (bsoncore.Document, error) { - f, err := transformBsoncoreDocument(registry, filter) + f, err := transformBsoncoreDocument(registry, filter, true, "filter") if err != nil { return nil, err } @@ -375,7 +375,7 @@ func createUpdateDoc( } if hint != nil { - hintVal, err := transformValue(registry, hint) + hintVal, err := transformValue(registry, hint, false, "hint") if err != nil { return nil, err } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write_models.go b/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write_models.go index f11be830..b4b8e3ef 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write_models.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/bulk_write_models.go @@ -70,7 +70,8 @@ func (dom *DeleteOneModel) SetCollation(collation *options.Collation) *DeleteOne // specification as a document. This option is only valid for MongoDB versions >= 4.4. Server versions >= 3.4 will // return an error if this option is specified. For server versions < 3.4, the driver will return a client-side error if // this option is specified. The driver will return an error if this option is specified during an unacknowledged write -// operation. The default value is nil, which means that no hint will be sent. +// operation. The driver will return an error if the hint parameter is a multi-key map. The default value is nil, which +// means that no hint will be sent. func (dom *DeleteOneModel) SetHint(hint interface{}) *DeleteOneModel { dom.Hint = hint return dom @@ -108,7 +109,8 @@ func (dmm *DeleteManyModel) SetCollation(collation *options.Collation) *DeleteMa // specification as a document. This option is only valid for MongoDB versions >= 4.4. Server versions >= 3.4 will // return an error if this option is specified. For server versions < 3.4, the driver will return a client-side error if // this option is specified. The driver will return an error if this option is specified during an unacknowledged write -// operation. The default value is nil, which means that no hint will be sent. +// operation. The driver will return an error if the hint parameter is a multi-key map. The default value is nil, which +// means that no hint will be sent. func (dmm *DeleteManyModel) SetHint(hint interface{}) *DeleteManyModel { dmm.Hint = hint return dmm @@ -134,7 +136,8 @@ func NewReplaceOneModel() *ReplaceOneModel { // specification as a document. This option is only valid for MongoDB versions >= 4.2. Server versions >= 3.4 will // return an error if this option is specified. For server versions < 3.4, the driver will return a client-side error if // this option is specified. The driver will return an error if this option is specified during an unacknowledged write -// operation. The default value is nil, which means that no hint will be sent. +// operation. The driver will return an error if the hint parameter is a multi-key map. The default value is nil, which +// means that no hint will be sent. func (rom *ReplaceOneModel) SetHint(hint interface{}) *ReplaceOneModel { rom.Hint = hint return rom @@ -191,7 +194,8 @@ func NewUpdateOneModel() *UpdateOneModel { // specification as a document. This option is only valid for MongoDB versions >= 4.2. Server versions >= 3.4 will // return an error if this option is specified. For server versions < 3.4, the driver will return a client-side error if // this option is specified. The driver will return an error if this option is specified during an unacknowledged write -// operation. The default value is nil, which means that no hint will be sent. +// operation. The driver will return an error if the hint parameter is a multi-key map. The default value is nil, which +// means that no hint will be sent. func (uom *UpdateOneModel) SetHint(hint interface{}) *UpdateOneModel { uom.Hint = hint return uom @@ -255,7 +259,8 @@ func NewUpdateManyModel() *UpdateManyModel { // specification as a document. This option is only valid for MongoDB versions >= 4.2. Server versions >= 3.4 will // return an error if this option is specified. For server versions < 3.4, the driver will return a client-side error if // this option is specified. The driver will return an error if this option is specified during an unacknowledged write -// operation. The default value is nil, which means that no hint will be sent. +// operation. The driver will return an error if the hint parameter is a multi-key map. The default value is nil, which +// means that no hint will be sent. func (umm *UpdateManyModel) SetHint(hint interface{}) *UpdateManyModel { umm.Hint = hint return umm diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/change_stream.go b/vendor/go.mongodb.org/mongo-driver/mongo/change_stream.go index d3927ebc..0bfd9972 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/change_stream.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/change_stream.go @@ -17,12 +17,12 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -343,7 +343,7 @@ func (cs *ChangeStream) buildPipelineSlice(pipeline interface{}) error { for i := 0; i < val.Len(); i++ { var elem []byte - elem, cs.err = transformBsoncoreDocument(cs.registry, val.Index(i).Interface()) + elem, cs.err = transformBsoncoreDocument(cs.registry, val.Index(i).Interface(), true, fmt.Sprintf("pipeline stage :%v", i)) if cs.err != nil { return cs.err } @@ -367,7 +367,7 @@ func (cs *ChangeStream) createPipelineOptionsDoc() bsoncore.Document { if cs.options.ResumeAfter != nil { var raDoc bsoncore.Document - raDoc, cs.err = transformBsoncoreDocument(cs.registry, cs.options.ResumeAfter) + raDoc, cs.err = transformBsoncoreDocument(cs.registry, cs.options.ResumeAfter, true, "resumeAfter") if cs.err != nil { return nil } @@ -377,7 +377,7 @@ func (cs *ChangeStream) createPipelineOptionsDoc() bsoncore.Document { if cs.options.StartAfter != nil { var saDoc bsoncore.Document - saDoc, cs.err = transformBsoncoreDocument(cs.registry, cs.options.StartAfter) + saDoc, cs.err = transformBsoncoreDocument(cs.registry, cs.options.StartAfter, true, "startAfter") if cs.err != nil { return nil } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/change_stream_deployment.go b/vendor/go.mongodb.org/mongo-driver/mongo/change_stream_deployment.go index 706b3601..5ff82c5f 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/change_stream_deployment.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/change_stream_deployment.go @@ -9,8 +9,8 @@ package mongo import ( "context" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) type changeStreamDeployment struct { @@ -35,11 +35,11 @@ func (c *changeStreamDeployment) Connection(context.Context) (driver.Connection, return c.conn, nil } -func (c *changeStreamDeployment) ProcessError(err error, conn driver.Connection) { +func (c *changeStreamDeployment) ProcessError(err error, conn driver.Connection) driver.ProcessErrorResult { ep, ok := c.server.(driver.ErrorProcessor) if !ok { - return + return driver.NoChange } - ep.ProcessError(err, conn) + return ep.ProcessError(err, conn) } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/client.go b/vendor/go.mongodb.org/mongo-driver/mongo/client.go index 88607934..d266da46 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/client.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/client.go @@ -10,12 +10,14 @@ import ( "context" "crypto/tls" "errors" + "fmt" "strings" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" @@ -24,7 +26,6 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/ocsp" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" @@ -62,13 +63,16 @@ type Client struct { registry *bsoncodec.Registry marshaller BSONAppender monitor *event.CommandMonitor + serverMonitor *event.ServerMonitor sessionPool *session.Pool // client-side encryption fields - keyVaultClient *Client - keyVaultColl *Collection - mongocryptd *mcryptClient - crypt *driver.Crypt + keyVaultClientFLE *Client + keyVaultCollFLE *Collection + mongocryptdFLE *mcryptClient + cryptFLE *driver.Crypt + metadataClientFLE *Client + internalClientFLE *Client } // Connect creates a new Client and then initializes it using the Connect method. This is equivalent to calling @@ -154,13 +158,26 @@ func (c *Client) Connect(ctx context.Context) error { } } - if c.mongocryptd != nil { - if err := c.mongocryptd.connect(ctx); err != nil { + if c.mongocryptdFLE != nil { + if err := c.mongocryptdFLE.connect(ctx); err != nil { return err } } - if c.keyVaultClient != nil { - if err := c.keyVaultClient.Connect(ctx); err != nil { + + if c.internalClientFLE != nil { + if err := c.internalClientFLE.Connect(ctx); err != nil { + return err + } + } + + if c.keyVaultClientFLE != nil && c.keyVaultClientFLE != c.internalClientFLE && c.keyVaultClientFLE != c { + if err := c.keyVaultClientFLE.Connect(ctx); err != nil { + return err + } + } + + if c.metadataClientFLE != nil && c.metadataClientFLE != c.internalClientFLE && c.metadataClientFLE != c { + if err := c.metadataClientFLE.Connect(ctx); err != nil { return err } } @@ -191,18 +208,30 @@ func (c *Client) Disconnect(ctx context.Context) error { } c.endSessions(ctx) - if c.mongocryptd != nil { - if err := c.mongocryptd.disconnect(ctx); err != nil { + if c.mongocryptdFLE != nil { + if err := c.mongocryptdFLE.disconnect(ctx); err != nil { + return err + } + } + + if c.internalClientFLE != nil { + if err := c.internalClientFLE.Disconnect(ctx); err != nil { + return err + } + } + + if c.keyVaultClientFLE != nil && c.keyVaultClientFLE != c.internalClientFLE && c.keyVaultClientFLE != c { + if err := c.keyVaultClientFLE.Disconnect(ctx); err != nil { return err } } - if c.keyVaultClient != nil { - if err := c.keyVaultClient.Disconnect(ctx); err != nil { + if c.metadataClientFLE != nil && c.metadataClientFLE != c.internalClientFLE && c.metadataClientFLE != c { + if err := c.metadataClientFLE.Disconnect(ctx); err != nil { return err } } - if c.crypt != nil { - c.crypt.Close() + if c.cryptFLE != nil { + c.cryptFLE.Close() } if disconnector, ok := c.deployment.(driver.Disconnector); ok { @@ -294,7 +323,7 @@ func (c *Client) endSessions(ctx context.Context) { sessionIDs := c.sessionPool.IDSlice() op := operation.NewEndSessions(nil).ClusterClock(c.clock).Deployment(c.deployment). ServerSelector(description.ReadPrefSelector(readpref.PrimaryPreferred())).CommandMonitor(c.monitor). - Database("admin").Crypt(c.crypt) + Database("admin").Crypt(c.cryptFLE) totalNumIDs := len(sessionIDs) var currentBatch []bsoncore.Document @@ -495,6 +524,19 @@ func (c *Client) configure(opts *options.ClientOptions) error { func(*event.CommandMonitor) *event.CommandMonitor { return opts.Monitor }, )) } + // ServerMonitor + if opts.ServerMonitor != nil { + c.serverMonitor = opts.ServerMonitor + serverOpts = append( + serverOpts, + topology.WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return opts.ServerMonitor }), + ) + + topologyOpts = append( + topologyOpts, + topology.WithTopologyServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return opts.ServerMonitor }), + ) + } // ReadConcern c.readConcern = readconcern.New() if opts.ReadConcern != nil { @@ -553,7 +595,7 @@ func (c *Client) configure(opts *options.ClientOptions) error { } // AutoEncryptionOptions if opts.AutoEncryptionOptions != nil { - if err := c.configureAutoEncryption(opts.AutoEncryptionOptions); err != nil { + if err := c.configureAutoEncryption(opts); err != nil { return err } } @@ -595,70 +637,115 @@ func (c *Client) configure(opts *options.ClientOptions) error { return nil } -func (c *Client) configureAutoEncryption(opts *options.AutoEncryptionOptions) error { - if err := c.configureKeyVault(opts); err != nil { +func (c *Client) configureAutoEncryption(clientOpts *options.ClientOptions) error { + if err := c.configureKeyVaultClientFLE(clientOpts); err != nil { + return err + } + if err := c.configureMetadataClientFLE(clientOpts); err != nil { return err } - if err := c.configureMongocryptd(opts); err != nil { + if err := c.configureMongocryptdClientFLE(clientOpts.AutoEncryptionOptions); err != nil { return err } - return c.configureCrypt(opts) + return c.configureCryptFLE(clientOpts.AutoEncryptionOptions) } -func (c *Client) configureKeyVault(opts *options.AutoEncryptionOptions) error { - // parse key vault options and create new client if necessary - if opts.KeyVaultClientOptions != nil { - var err error - c.keyVaultClient, err = NewClient(opts.KeyVaultClientOptions) - if err != nil { - return err - } +func (c *Client) getOrCreateInternalClient(clientOpts *options.ClientOptions) (*Client, error) { + if c.internalClientFLE != nil { + return c.internalClientFLE, nil + } + + internalClientOpts := options.MergeClientOptions(clientOpts) + internalClientOpts.AutoEncryptionOptions = nil + internalClientOpts.SetMinPoolSize(0) + var err error + c.internalClientFLE, err = NewClient(internalClientOpts) + return c.internalClientFLE, err +} + +func (c *Client) configureKeyVaultClientFLE(clientOpts *options.ClientOptions) error { + // parse key vault options and create new key vault client + var err error + aeOpts := clientOpts.AutoEncryptionOptions + switch { + case aeOpts.KeyVaultClientOptions != nil: + c.keyVaultClientFLE, err = NewClient(aeOpts.KeyVaultClientOptions) + case clientOpts.MaxPoolSize != nil && *clientOpts.MaxPoolSize == 0: + c.keyVaultClientFLE = c + default: + c.keyVaultClientFLE, err = c.getOrCreateInternalClient(clientOpts) } - dbName, collName := splitNamespace(opts.KeyVaultNamespace) - client := c.keyVaultClient - if client == nil { - client = c + if err != nil { + return err } - c.keyVaultColl = client.Database(dbName).Collection(collName, keyVaultCollOpts) + + dbName, collName := splitNamespace(aeOpts.KeyVaultNamespace) + c.keyVaultCollFLE = c.keyVaultClientFLE.Database(dbName).Collection(collName, keyVaultCollOpts) return nil } -func (c *Client) configureMongocryptd(opts *options.AutoEncryptionOptions) error { +func (c *Client) configureMetadataClientFLE(clientOpts *options.ClientOptions) error { + // parse key vault options and create new key vault client + aeOpts := clientOpts.AutoEncryptionOptions + if aeOpts.BypassAutoEncryption != nil && *aeOpts.BypassAutoEncryption { + // no need for a metadata client. + return nil + } + if clientOpts.MaxPoolSize != nil && *clientOpts.MaxPoolSize == 0 { + c.metadataClientFLE = c + return nil + } + + var err error + c.metadataClientFLE, err = c.getOrCreateInternalClient(clientOpts) + return err +} + +func (c *Client) configureMongocryptdClientFLE(opts *options.AutoEncryptionOptions) error { var err error - c.mongocryptd, err = newMcryptClient(opts) + c.mongocryptdFLE, err = newMcryptClient(opts) return err } -func (c *Client) configureCrypt(opts *options.AutoEncryptionOptions) error { +func (c *Client) configureCryptFLE(opts *options.AutoEncryptionOptions) error { // convert schemas in SchemaMap to bsoncore documents cryptSchemaMap := make(map[string]bsoncore.Document) for k, v := range opts.SchemaMap { - schema, err := transformBsoncoreDocument(c.registry, v) + schema, err := transformBsoncoreDocument(c.registry, v, true, "schemaMap") if err != nil { return err } cryptSchemaMap[k] = schema } + kmsProviders, err := transformBsoncoreDocument(c.registry, opts.KmsProviders, true, "kmsProviders") + if err != nil { + return fmt.Errorf("error creating KMS providers document: %v", err) + } // configure options var bypass bool if opts.BypassAutoEncryption != nil { bypass = *opts.BypassAutoEncryption } - kr := keyRetriever{coll: c.keyVaultColl} - cir := collInfoRetriever{client: c} + kr := keyRetriever{coll: c.keyVaultCollFLE} + var cir collInfoRetriever + // If bypass is true, c.metadataClientFLE is nil and the collInfoRetriever + // will not be used. If bypass is false, to the parent client or the + // internal client. + if !bypass { + cir = collInfoRetriever{client: c.metadataClientFLE} + } cryptOpts := &driver.CryptOptions{ CollInfoFn: cir.cryptCollInfo, KeyFn: kr.cryptKeys, - MarkFn: c.mongocryptd.markCommand, - KmsProviders: opts.KmsProviders, + MarkFn: c.mongocryptdFLE.markCommand, + KmsProviders: kmsProviders, BypassAutoEncryption: bypass, SchemaMap: cryptSchemaMap, } - var err error - c.crypt, err = driver.NewCrypt(cryptOpts) + c.cryptFLE, err = driver.NewCrypt(cryptOpts) return err } @@ -705,7 +792,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... return ListDatabasesResult{}, err } - filterDoc, err := transformBsoncoreDocument(c.registry, filter) + filterDoc, err := transformBsoncoreDocument(c.registry, filter, true, "filter") if err != nil { return ListDatabasesResult{}, err } @@ -719,7 +806,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... ldo := options.MergeListDatabasesOptions(opts...) op := operation.NewListDatabases(filterDoc). Session(sess).ReadPreference(c.readPreference).CommandMonitor(c.monitor). - ServerSelector(selector).ClusterClock(c.clock).Database("admin").Deployment(c.deployment).Crypt(c.crypt) + ServerSelector(selector).ClusterClock(c.clock).Database("admin").Deployment(c.deployment).Crypt(c.cryptFLE) if ldo.NameOnly != nil { op = op.NameOnly(*ldo.NameOnly) @@ -828,7 +915,7 @@ func (c *Client) Watch(ctx context.Context, pipeline interface{}, client: c, registry: c.registry, streamType: ClientStream, - crypt: c.crypt, + crypt: c.cryptFLE, } return newChangeStream(ctx, csConfig, pipeline, opts...) diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/client_encryption.go b/vendor/go.mongodb.org/mongo-driver/mongo/client_encryption.go index 4ec0b177..4b1f12d3 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/client_encryption.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/client_encryption.go @@ -8,6 +8,7 @@ package mongo import ( "context" + "fmt" "strings" "github.com/pkg/errors" @@ -41,14 +42,18 @@ func NewClientEncryption(keyVaultClient *Client, opts ...*options.ClientEncrypti db, coll := splitNamespace(ceo.KeyVaultNamespace) ce.keyVaultColl = ce.keyVaultClient.Database(db).Collection(coll, keyVaultCollOpts) + kmsProviders, err := transformBsoncoreDocument(bson.DefaultRegistry, ceo.KmsProviders, true, "kmsProviders") + if err != nil { + return nil, fmt.Errorf("error creating KMS providers map: %v", err) + } + // create Crypt - var err error kr := keyRetriever{coll: ce.keyVaultColl} cir := collInfoRetriever{client: ce.keyVaultClient} ce.crypt, err = driver.NewCrypt(&driver.CryptOptions{ KeyFn: kr.cryptKeys, CollInfoFn: cir.cryptCollInfo, - KmsProviders: ceo.KmsProviders, + KmsProviders: kmsProviders, }) if err != nil { return nil, err @@ -64,7 +69,7 @@ func (ce *ClientEncryption) CreateDataKey(ctx context.Context, kmsProvider strin dko := options.MergeDataKeyOptions(opts...) co := cryptOpts.DataKey().SetKeyAltNames(dko.KeyAltNames) if dko.MasterKey != nil { - keyDoc, err := transformBsoncoreDocument(ce.keyVaultClient.registry, dko.MasterKey) + keyDoc, err := transformBsoncoreDocument(ce.keyVaultClient.registry, dko.MasterKey, true, "masterKey") if err != nil { return primitive.Binary{}, err } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/collection.go b/vendor/go.mongodb.org/mongo-driver/mongo/collection.go index dbbf4952..6921678a 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/collection.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/collection.go @@ -16,13 +16,13 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -279,7 +279,7 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, Session(sess).WriteConcern(wc).CommandMonitor(coll.client.monitor). ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). - Deployment(coll.client.deployment).Crypt(coll.client.crypt) + Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true) imo := options.MergeInsertManyOptions(opts...) if imo.BypassDocumentValidation != nil && *imo.BypassDocumentValidation { op = op.BypassDocumentValidation(*imo.BypassDocumentValidation) @@ -398,7 +398,7 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter) + f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") if err != nil { return nil, err } @@ -439,7 +439,7 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn doc = bsoncore.AppendDocumentElement(doc, "collation", do.Collation.ToDocument()) } if do.Hint != nil { - hint, err := transformValue(coll.registry, do.Hint) + hint, err := transformValue(coll.registry, do.Hint, false, "hint") if err != nil { return nil, err } @@ -452,7 +452,7 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn Session(sess).WriteConcern(wc).CommandMonitor(coll.client.monitor). ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). - Deployment(coll.client.deployment).Crypt(coll.client.crypt) + Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true) if do.Hint != nil { op = op.Hint(true) } @@ -548,8 +548,8 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc Session(sess).WriteConcern(wc).CommandMonitor(coll.client.monitor). ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). - Deployment(coll.client.deployment).Crypt(coll.client.crypt).Hint(uo.Hint != nil). - ArrayFilters(uo.ArrayFilters != nil) + Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Hint(uo.Hint != nil). + ArrayFilters(uo.ArrayFilters != nil).Ordered(true) if uo.BypassDocumentValidation != nil && *uo.BypassDocumentValidation { op = op.BypassDocumentValidation(*uo.BypassDocumentValidation) @@ -581,6 +581,27 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc return res, err } +// UpdateByID executes an update command to update the document whose _id value matches the provided ID in the collection. +// This is equivalent to running UpdateOne(ctx, bson.D{{"_id", id}}, update, opts...). +// +// The id parameter is the _id of the document to be updated. It cannot be nil. If the ID does not match any documents, +// the operation will succeed and an UpdateResult with a MatchedCount of 0 will be returned. +// +// The update parameter must be a document containing update operators +// (https://docs.mongodb.com/manual/reference/operator/update/) and can be used to specify the modifications to be +// made to the selected document. It cannot be nil or empty. +// +// The opts parameter can be used to specify options for the operation (see the options.UpdateOptions documentation). +// +// For more information about the command, see https://docs.mongodb.com/manual/reference/command/update/. +func (coll *Collection) UpdateByID(ctx context.Context, id interface{}, update interface{}, + opts ...*options.UpdateOptions) (*UpdateResult, error) { + if id == nil { + return nil, ErrNilValue + } + return coll.UpdateOne(ctx, bson.D{{"_id", id}}, update, opts...) +} + // UpdateOne executes an update command to update at most one document in the collection. // // The filter parameter must be a document containing query operators and can be used to select the document to be @@ -602,7 +623,7 @@ func (coll *Collection) UpdateOne(ctx context.Context, filter interface{}, updat ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter) + f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") if err != nil { return nil, err } @@ -630,7 +651,7 @@ func (coll *Collection) UpdateMany(ctx context.Context, filter interface{}, upda ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter) + f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") if err != nil { return nil, err } @@ -658,12 +679,12 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{}, ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter) + f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") if err != nil { return nil, err } - r, err := transformBsoncoreDocument(coll.registry, replacement) + r, err := transformBsoncoreDocument(coll.registry, replacement, true, "replacement") if err != nil { return nil, err } @@ -761,7 +782,7 @@ func aggregate(a aggregateParams) (*Cursor, error) { ao := options.MergeAggregateOptions(a.opts...) cursorOpts := driver.CursorOptions{ CommandMonitor: a.client.monitor, - Crypt: a.client.crypt, + Crypt: a.client.cryptFLE, } op := operation.NewAggregate(pipelineArr). @@ -774,7 +795,7 @@ func aggregate(a aggregateParams) (*Cursor, error) { Database(a.db). Collection(a.col). Deployment(a.client.deployment). - Crypt(a.client.crypt) + Crypt(a.client.cryptFLE) if !hasOutputStage { // Only pass the user-specified read preference if the aggregation doesn't have a $out or $merge stage. // Otherwise, the read preference could be forwarded to a mongos, which would error if the aggregation were @@ -806,7 +827,7 @@ func aggregate(a aggregateParams) (*Cursor, error) { op.Comment(*ao.Comment) } if ao.Hint != nil { - hintVal, err := transformValue(a.registry, ao.Hint) + hintVal, err := transformValue(a.registry, ao.Hint, false, "hint") if err != nil { closeImplicitSession(sess) return nil, err @@ -880,7 +901,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, selector := makeReadPrefSelector(sess, coll.readSelector, coll.client.localThreshold) op := operation.NewAggregate(pipelineArr).Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector).ClusterClock(coll.client.clock).Database(coll.db.name). - Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.crypt) + Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE) if countOpts.Collation != nil { op.Collation(bsoncore.Document(countOpts.Collation.ToDocument())) } @@ -888,7 +909,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, op.MaxTimeMS(int64(*countOpts.MaxTime / time.Millisecond)) } if countOpts.Hint != nil { - hintVal, err := transformValue(coll.registry, countOpts.Hint) + hintVal, err := transformValue(coll.registry, countOpts.Hint, false, "hint") if err != nil { return 0, err } @@ -962,7 +983,7 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, op := operation.NewCount().Session(sess).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). - ServerSelector(selector).Crypt(coll.client.crypt) + ServerSelector(selector).Crypt(coll.client.cryptFLE) co := options.MergeEstimatedDocumentCountOptions(opts...) if co.MaxTime != nil { @@ -996,7 +1017,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter) + f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") if err != nil { return nil, err } @@ -1028,7 +1049,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i Session(sess).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). - ServerSelector(selector).Crypt(coll.client.crypt) + ServerSelector(selector).Crypt(coll.client.cryptFLE) if option.Collation != nil { op.Collation(bsoncore.Document(option.Collation.ToDocument())) @@ -1085,7 +1106,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter) + f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") if err != nil { return nil, err } @@ -1115,12 +1136,12 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector). ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name). - Deployment(coll.client.deployment).Crypt(coll.client.crypt) + Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE) fo := options.MergeFindOptions(opts...) cursorOpts := driver.CursorOptions{ CommandMonitor: coll.client.monitor, - Crypt: coll.client.crypt, + Crypt: coll.client.cryptFLE, } if fo.AllowDiskUse != nil { @@ -1149,7 +1170,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, } } if fo.Hint != nil { - hint, err := transformValue(coll.registry, fo.Hint) + hint, err := transformValue(coll.registry, fo.Hint, false, "hint") if err != nil { closeImplicitSession(sess) return nil, err @@ -1166,7 +1187,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, op.Limit(limit) } if fo.Max != nil { - max, err := transformBsoncoreDocument(coll.registry, fo.Max) + max, err := transformBsoncoreDocument(coll.registry, fo.Max, true, "max") if err != nil { closeImplicitSession(sess) return nil, err @@ -1180,7 +1201,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, op.MaxTimeMS(int64(*fo.MaxTime / time.Millisecond)) } if fo.Min != nil { - min, err := transformBsoncoreDocument(coll.registry, fo.Min) + min, err := transformBsoncoreDocument(coll.registry, fo.Min, true, "min") if err != nil { closeImplicitSession(sess) return nil, err @@ -1194,7 +1215,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, op.OplogReplay(*fo.OplogReplay) } if fo.Projection != nil { - proj, err := transformBsoncoreDocument(coll.registry, fo.Projection) + proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection") if err != nil { closeImplicitSession(sess) return nil, err @@ -1214,7 +1235,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, op.Snapshot(*fo.Snapshot) } if fo.Sort != nil { - sort, err := transformBsoncoreDocument(coll.registry, fo.Sort) + sort, err := transformBsoncoreDocument(coll.registry, fo.Sort, false, "sort") if err != nil { closeImplicitSession(sess) return nil, err @@ -1331,7 +1352,7 @@ func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAnd Collection(coll.name). Deployment(coll.client.deployment). Retry(retry). - Crypt(coll.client.crypt) + Crypt(coll.client.cryptFLE) _, err = processWriteError(op.Execute(ctx)) if err != nil { @@ -1355,7 +1376,7 @@ func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAnd func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{}, opts ...*options.FindOneAndDeleteOptions) *SingleResult { - f, err := transformBsoncoreDocument(coll.registry, filter) + f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") if err != nil { return &SingleResult{err: err} } @@ -1368,21 +1389,21 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} op = op.MaxTimeMS(int64(*fod.MaxTime / time.Millisecond)) } if fod.Projection != nil { - proj, err := transformBsoncoreDocument(coll.registry, fod.Projection) + proj, err := transformBsoncoreDocument(coll.registry, fod.Projection, true, "projection") if err != nil { return &SingleResult{err: err} } op = op.Fields(proj) } if fod.Sort != nil { - sort, err := transformBsoncoreDocument(coll.registry, fod.Sort) + sort, err := transformBsoncoreDocument(coll.registry, fod.Sort, false, "sort") if err != nil { return &SingleResult{err: err} } op = op.Sort(sort) } if fod.Hint != nil { - hint, err := transformValue(coll.registry, fod.Hint) + hint, err := transformValue(coll.registry, fod.Hint, false, "hint") if err != nil { return &SingleResult{err: err} } @@ -1409,11 +1430,11 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{}, replacement interface{}, opts ...*options.FindOneAndReplaceOptions) *SingleResult { - f, err := transformBsoncoreDocument(coll.registry, filter) + f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") if err != nil { return &SingleResult{err: err} } - r, err := transformBsoncoreDocument(coll.registry, replacement) + r, err := transformBsoncoreDocument(coll.registry, replacement, true, "replacement") if err != nil { return &SingleResult{err: err} } @@ -1433,7 +1454,7 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ op = op.MaxTimeMS(int64(*fo.MaxTime / time.Millisecond)) } if fo.Projection != nil { - proj, err := transformBsoncoreDocument(coll.registry, fo.Projection) + proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection") if err != nil { return &SingleResult{err: err} } @@ -1443,7 +1464,7 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ op = op.NewDocument(*fo.ReturnDocument == options.After) } if fo.Sort != nil { - sort, err := transformBsoncoreDocument(coll.registry, fo.Sort) + sort, err := transformBsoncoreDocument(coll.registry, fo.Sort, false, "sort") if err != nil { return &SingleResult{err: err} } @@ -1453,7 +1474,7 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ op = op.Upsert(*fo.Upsert) } if fo.Hint != nil { - hint, err := transformValue(coll.registry, fo.Hint) + hint, err := transformValue(coll.registry, fo.Hint, false, "hint") if err != nil { return &SingleResult{err: err} } @@ -1485,7 +1506,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter) + f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") if err != nil { return &SingleResult{err: err} } @@ -1516,7 +1537,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} op = op.MaxTimeMS(int64(*fo.MaxTime / time.Millisecond)) } if fo.Projection != nil { - proj, err := transformBsoncoreDocument(coll.registry, fo.Projection) + proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection") if err != nil { return &SingleResult{err: err} } @@ -1526,7 +1547,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} op = op.NewDocument(*fo.ReturnDocument == options.After) } if fo.Sort != nil { - sort, err := transformBsoncoreDocument(coll.registry, fo.Sort) + sort, err := transformBsoncoreDocument(coll.registry, fo.Sort, false, "sort") if err != nil { return &SingleResult{err: err} } @@ -1536,7 +1557,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} op = op.Upsert(*fo.Upsert) } if fo.Hint != nil { - hint, err := transformValue(coll.registry, fo.Hint) + hint, err := transformValue(coll.registry, fo.Hint, false, "hint") if err != nil { return &SingleResult{err: err} } @@ -1570,7 +1591,7 @@ func (coll *Collection) Watch(ctx context.Context, pipeline interface{}, streamType: CollectionStream, collectionName: coll.Name(), databaseName: coll.db.Name(), - crypt: coll.client.crypt, + crypt: coll.client.cryptFLE, } return newChangeStream(ctx, csConfig, pipeline, opts...) } @@ -1616,7 +1637,7 @@ func (coll *Collection) Drop(ctx context.Context) error { Session(sess).WriteConcern(wc).CommandMonitor(coll.client.monitor). ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). - Deployment(coll.client.deployment).Crypt(coll.client.crypt) + Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE) err = op.Execute(ctx) // ignore namespace not found erorrs @@ -1632,7 +1653,14 @@ func (coll *Collection) Drop(ctx context.Context) error { func makePinnedSelector(sess *session.Client, defaultSelector description.ServerSelector) description.ServerSelectorFunc { return func(t description.Topology, svrs []description.Server) ([]description.Server, error) { if sess != nil && sess.PinnedServer != nil { - return sess.PinnedServer.SelectServer(t, svrs) + // If there is a pinned server, try to find it in the list of candidates. + for _, candidate := range svrs { + if candidate.Addr == sess.PinnedServer.Addr { + return []description.Server{candidate}, nil + } + } + + return nil, nil } return defaultSelector.SelectServer(t, svrs) diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/database.go b/vendor/go.mongodb.org/mongo-driver/mongo/database.go index b734063e..ebf70843 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/database.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/database.go @@ -13,6 +13,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" @@ -20,7 +21,6 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -153,7 +153,7 @@ func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, return nil, sess, errors.New("read preference in a transaction must be primary") } - runCmdDoc, err := transformBsoncoreDocument(db.registry, cmd) + runCmdDoc, err := transformBsoncoreDocument(db.registry, cmd, false, "cmd") if err != nil { return nil, sess, err } @@ -162,14 +162,14 @@ func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, description.LatencySelector(db.client.localThreshold), }) if sess != nil && sess.PinnedServer != nil { - readSelect = sess.PinnedServer + readSelect = makePinnedSelector(sess, readSelect) } return operation.NewCommand(runCmdDoc). Session(sess).CommandMonitor(db.client.monitor). ServerSelector(readSelect).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment).ReadConcern(db.readConcern). - Crypt(db.client.crypt).ReadPreference(ro.ReadPreference), sess, nil + Crypt(db.client.cryptFLE).ReadPreference(ro.ReadPreference), sess, nil } // RunCommand executes the given command against the database. This function does not obey the Database's read @@ -271,7 +271,7 @@ func (db *Database) Drop(ctx context.Context) error { op := operation.NewDropDatabase(). Session(sess).WriteConcern(wc).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). - Database(db.name).Deployment(db.client.deployment).Crypt(db.client.crypt) + Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE) err = op.Execute(ctx) @@ -282,6 +282,41 @@ func (db *Database) Drop(ctx context.Context) error { return nil } +// ListCollectionSpecifications executes a listCollections command and returns a slice of CollectionSpecification +// instances representing the collections in the database. +// +// The filter parameter must be a document containing query operators and can be used to select which collections +// are included in the result. It cannot be nil. An empty document (e.g. bson.D{}) should be used to include all +// collections. +// +// The opts parameter can be used to specify options for the operation (see the options.ListCollectionsOptions +// documentation). +// +// For more information about the command, see https://docs.mongodb.com/manual/reference/command/listCollections/. +func (db *Database) ListCollectionSpecifications(ctx context.Context, filter interface{}, + opts ...*options.ListCollectionsOptions) ([]*CollectionSpecification, error) { + + cursor, err := db.ListCollections(ctx, filter, opts...) + if err != nil { + return nil, err + } + + var specs []*CollectionSpecification + err = cursor.All(ctx, &specs) + if err != nil { + return nil, err + } + + for _, spec := range specs { + // Pre-4.4 servers report a namespace in their responses, so we only set Namespace manually if it was not in + // the response. + if spec.IDIndex != nil && spec.IDIndex.Namespace == "" { + spec.IDIndex.Namespace = db.name + "." + spec.Name + } + } + return specs, nil +} + // ListCollections executes a listCollections command and returns a cursor over the collections in the database. // // The filter parameter must be a document containing query operators and can be used to select which collections @@ -297,7 +332,7 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt ctx = context.Background() } - filterDoc, err := transformBsoncoreDocument(db.registry, filter) + filterDoc, err := transformBsoncoreDocument(db.registry, filter, true, "filter") if err != nil { return nil, err } @@ -326,10 +361,14 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt op := operation.NewListCollections(filterDoc). Session(sess).ReadPreference(db.readPreference).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). - Database(db.name).Deployment(db.client.deployment).Crypt(db.client.crypt) + Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE) if lco.NameOnly != nil { op = op.NameOnly(*lco.NameOnly) } + if lco.BatchSize != nil { + op = op.BatchSize(*lco.BatchSize) + } + retry := driver.RetryNone if db.client.retryReads { retry = driver.RetryOncePerCommand @@ -342,7 +381,7 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt return nil, replaceErrors(err) } - bc, err := op.Result(driver.CursorOptions{Crypt: db.client.crypt}) + bc, err := op.Result(driver.CursorOptions{Crypt: db.client.cryptFLE}) if err != nil { closeImplicitSession(sess) return nil, replaceErrors(err) @@ -435,7 +474,7 @@ func (db *Database) Watch(ctx context.Context, pipeline interface{}, registry: db.registry, streamType: DatabaseStream, databaseName: db.Name(), - crypt: db.client.crypt, + crypt: db.client.cryptFLE, } return newChangeStream(ctx, csConfig, pipeline, opts...) } @@ -459,7 +498,7 @@ func (db *Database) CreateCollection(ctx context.Context, name string, opts ...* if cco.DefaultIndexOptions != nil { idx, doc := bsoncore.AppendDocumentStart(nil) if cco.DefaultIndexOptions.StorageEngine != nil { - storageEngine, err := transformBsoncoreDocument(db.registry, cco.DefaultIndexOptions.StorageEngine) + storageEngine, err := transformBsoncoreDocument(db.registry, cco.DefaultIndexOptions.StorageEngine, true, "storageEngine") if err != nil { return err } @@ -480,7 +519,7 @@ func (db *Database) CreateCollection(ctx context.Context, name string, opts ...* op.Size(*cco.SizeInBytes) } if cco.StorageEngine != nil { - storageEngine, err := transformBsoncoreDocument(db.registry, cco.StorageEngine) + storageEngine, err := transformBsoncoreDocument(db.registry, cco.StorageEngine, true, "storageEngine") if err != nil { return err } @@ -493,7 +532,7 @@ func (db *Database) CreateCollection(ctx context.Context, name string, opts ...* op.ValidationLevel(*cco.ValidationLevel) } if cco.Validator != nil { - validator, err := transformBsoncoreDocument(db.registry, cco.Validator) + validator, err := transformBsoncoreDocument(db.registry, cco.Validator, true, "validator") if err != nil { return err } @@ -567,7 +606,7 @@ func (db *Database) executeCreateOperation(ctx context.Context, op *operation.Cr ClusterClock(db.client.clock). Database(db.name). Deployment(db.client.deployment). - Crypt(db.client.crypt) + Crypt(db.client.cryptFLE) return replaceErrors(op.Execute(ctx)) } diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/description.go b/vendor/go.mongodb.org/mongo-driver/mongo/description/description.go similarity index 79% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/description.go rename to vendor/go.mongodb.org/mongo-driver/mongo/description/description.go index 1f92953b..40b1af13 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/description.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/description/description.go @@ -4,7 +4,7 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -package description // import "go.mongodb.org/mongo-driver/x/mongo/driver/description" +package description // import "go.mongodb.org/mongo-driver/mongo/description" // Unknown is an unknown server or topology kind. const Unknown = 0 diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/server.go b/vendor/go.mongodb.org/mongo-driver/mongo/description/server.go similarity index 67% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/server.go rename to vendor/go.mongodb.org/mongo-driver/mongo/description/server.go index ac7afabd..5034efd1 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/server.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/description/server.go @@ -11,54 +11,53 @@ import ( "fmt" "time" + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/internal" + "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/tag" - "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" - "go.mongodb.org/mongo-driver/x/mongo/driver/address" ) -// UnsetRTT is the unset value for a round trip time. -const UnsetRTT = -1 * time.Millisecond - -// SelectedServer represents a selected server that is a member of a topology. +// SelectedServer augments the Server type by also including the TopologyKind of the topology that includes the server. +// This type should be used to track the state of a server that was selected to perform an operation. type SelectedServer struct { Server Kind TopologyKind } -// Server represents a description of a server. This is created from an isMaster -// command. +// Server contains information about a node in a cluster. This is created from isMaster command responses. type Server struct { Addr address.Address - AverageRTT time.Duration - AverageRTTSet bool - Compression []string // compression methods returned by server - CanonicalAddr address.Address - ElectionID primitive.ObjectID - HeartbeatInterval time.Duration - LastError error - LastUpdateTime time.Time - LastWriteTime time.Time - MaxBatchCount uint32 - MaxDocumentSize uint32 - MaxMessageSize uint32 - Members []address.Address - ReadOnly bool - SessionTimeoutMinutes uint32 - SetName string - SetVersion uint32 - SpeculativeAuthenticate bsoncore.Document - Tags tag.Set - TopologyVersion *TopologyVersion - Kind ServerKind - WireVersion *VersionRange - - SaslSupportedMechs []string // user-specific from server handshake + Arbiters []string + AverageRTT time.Duration + AverageRTTSet bool + Compression []string // compression methods returned by server + CanonicalAddr address.Address + ElectionID primitive.ObjectID + HeartbeatInterval time.Duration + Hosts []string + LastError error + LastUpdateTime time.Time + LastWriteTime time.Time + MaxBatchCount uint32 + MaxDocumentSize uint32 + MaxMessageSize uint32 + Members []address.Address + Passives []string + Primary address.Address + ReadOnly bool + SessionTimeoutMinutes uint32 + SetName string + SetVersion uint32 + Tags tag.Set + TopologyVersion *TopologyVersion + Kind ServerKind + WireVersion *VersionRange } -// NewServer creates a new server description from the given parameters. -func NewServer(addr address.Address, response bsoncore.Document) Server { +// NewServer creates a new server description from the given isMaster command response. +func NewServer(addr address.Address, response bson.Raw) Server { desc := Server{Addr: addr, CanonicalAddr: addr, LastUpdateTime: time.Now().UTC()} elements, err := response.Elements() if err != nil { @@ -69,12 +68,11 @@ func NewServer(addr address.Address, response bsoncore.Document) Server { var isReplicaSet, isMaster, hidden, secondary, arbiterOnly bool var msg string var version VersionRange - var hosts, passives, arbiters []string for _, element := range elements { switch element.Key() { case "arbiters": var err error - arbiters, err = decodeStringSlice(element, "arbiters") + desc.Arbiters, err = internal.StringSliceFromRawElement(element) if err != nil { desc.LastError = err return desc @@ -87,7 +85,7 @@ func NewServer(addr address.Address, response bsoncore.Document) Server { } case "compression": var err error - desc.Compression, err = decodeStringSlice(element, "compression") + desc.Compression, err = internal.StringSliceFromRawElement(element) if err != nil { desc.LastError = err return desc @@ -106,7 +104,7 @@ func NewServer(addr address.Address, response bsoncore.Document) Server { } case "hosts": var err error - hosts, err = decodeStringSlice(element, "hosts") + desc.Hosts, err = internal.StringSliceFromRawElement(element) if err != nil { desc.LastError = err return desc @@ -203,24 +201,24 @@ func NewServer(addr address.Address, response bsoncore.Document) Server { } case "passives": var err error - passives, err = decodeStringSlice(element, "passives") + desc.Passives, err = internal.StringSliceFromRawElement(element) if err != nil { desc.LastError = err return desc } + case "primary": + primary, ok := element.Value().StringValueOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'primary' to be a string but it's a BSON %s", element.Value().Type) + return desc + } + desc.Primary = address.Address(primary) case "readOnly": desc.ReadOnly, ok = element.Value().BooleanOK() if !ok { desc.LastError = fmt.Errorf("expected 'readOnly' to be a boolean but it's a BSON %s", element.Value().Type) return desc } - case "saslSupportedMechs": - var err error - desc.SaslSupportedMechs, err = decodeStringSlice(element, "saslSupportedMechs") - if err != nil { - desc.LastError = err - return desc - } case "secondary": secondary, ok = element.Value().BooleanOK() if !ok { @@ -240,13 +238,6 @@ func NewServer(addr address.Address, response bsoncore.Document) Server { return desc } desc.SetVersion = uint32(i64) - case "speculativeAuthenticate": - desc.SpeculativeAuthenticate, ok = element.Value().DocumentOK() - if !ok { - desc.LastError = fmt.Errorf("expected 'speculativeAuthenticate' to be a document but it's a BSON %s", - element.Value().Type) - return desc - } case "tags": m, err := decodeStringMap(element, "tags") if err != nil { @@ -269,15 +260,15 @@ func NewServer(addr address.Address, response bsoncore.Document) Server { } } - for _, host := range hosts { + for _, host := range desc.Hosts { desc.Members = append(desc.Members, address.Address(host).Canonicalize()) } - for _, passive := range passives { + for _, passive := range desc.Passives { desc.Members = append(desc.Members, address.Address(passive).Canonicalize()) } - for _, arbiter := range arbiters { + for _, arbiter := range desc.Arbiters { desc.Members = append(desc.Members, address.Address(arbiter).Canonicalize()) } @@ -324,12 +315,7 @@ func NewServerFromError(addr address.Address, err error, tv *TopologyVersion) Se // SetAverageRTT sets the average round trip time for this server description. func (s Server) SetAverageRTT(rtt time.Duration) Server { s.AverageRTT = rtt - if rtt == UnsetRTT { - s.AverageRTTSet = false - } else { - s.AverageRTTSet = true - } - + s.AverageRTTSet = true return s } @@ -341,37 +327,23 @@ func (s Server) DataBearing() bool { s.Kind == Standalone } -// SelectServer selects this server if it is in the list of given candidates. -func (s Server) SelectServer(_ Topology, candidates []Server) ([]Server, error) { - for _, candidate := range candidates { - if candidate.Addr == s.Addr { - return []Server{candidate}, nil - } +// String implements the Stringer interface +func (s Server) String() string { + str := fmt.Sprintf("Addr: %s, Type: %s", + s.Addr, s.Kind) + if len(s.Tags) != 0 { + str += fmt.Sprintf(", Tag sets: %s", s.Tags) } - return nil, nil -} -func decodeStringSlice(element bsoncore.Element, name string) ([]string, error) { - arr, ok := element.Value().ArrayOK() - if !ok { - return nil, fmt.Errorf("expected '%s' to be an array but it's a BSON %s", name, element.Value().Type) - } - vals, err := arr.Values() - if err != nil { - return nil, err - } - var strs []string - for _, val := range vals { - str, ok := val.StringValueOK() - if !ok { - return nil, fmt.Errorf("expected '%s' to be an array of strings, but found a BSON %s", name, val.Type) - } - strs = append(strs, str) + str += fmt.Sprintf(", Average RTT: %d", s.AverageRTT) + + if s.LastError != nil { + str += fmt.Sprintf(", Last error: %s", s.LastError) } - return strs, nil + return str } -func decodeStringMap(element bsoncore.Element, name string) (map[string]string, error) { +func decodeStringMap(element bson.RawElement, name string) (map[string]string, error) { doc, ok := element.Value().DocumentOK() if !ok { return nil, fmt.Errorf("expected '%s' to be a document but it's a BSON %s", name, element.Value().Type) @@ -392,7 +364,82 @@ func decodeStringMap(element bsoncore.Element, name string) (map[string]string, return m, nil } -// SupportsRetryWrites returns true if this description represents a server that supports retryable writes. -func (s Server) SupportsRetryWrites() bool { - return s.SessionTimeoutMinutes != 0 && s.Kind != Standalone +// Equal compares two server descriptions and returns true if they are equal +func (s Server) Equal(other Server) bool { + if s.CanonicalAddr.String() != other.CanonicalAddr.String() { + return false + } + + if !sliceStringEqual(s.Arbiters, other.Arbiters) { + return false + } + + if !sliceStringEqual(s.Hosts, other.Hosts) { + return false + } + + if !sliceStringEqual(s.Passives, other.Passives) { + return false + } + + if s.Primary != other.Primary { + return false + } + + if s.SetName != other.SetName { + return false + } + + if s.Kind != other.Kind { + return false + } + + if s.LastError != nil || other.LastError != nil { + if s.LastError == nil || other.LastError == nil { + return false + } + if s.LastError.Error() != other.LastError.Error() { + return false + } + } + + if !s.WireVersion.Equals(other.WireVersion) { + return false + } + + if len(s.Tags) != len(other.Tags) || !s.Tags.ContainsAll(other.Tags) { + return false + } + + if s.SetVersion != other.SetVersion { + return false + } + + if s.ElectionID != other.ElectionID { + return false + } + + if s.SessionTimeoutMinutes != other.SessionTimeoutMinutes { + return false + } + + // If TopologyVersion is nil for both servers, CompareToIncoming will return -1 because it assumes that the + // incoming response is newer. We want the descriptions to be considered equal in this case, though, so an + // explicit check is required. + if s.TopologyVersion == nil && other.TopologyVersion == nil { + return true + } + return s.TopologyVersion.CompareToIncoming(other.TopologyVersion) == 0 +} + +func sliceStringEqual(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true } diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/server_kind.go b/vendor/go.mongodb.org/mongo-driver/mongo/description/server_kind.go similarity index 85% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/server_kind.go rename to vendor/go.mongodb.org/mongo-driver/mongo/description/server_kind.go index 657791be..933b38f7 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/server_kind.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/description/server_kind.go @@ -6,7 +6,7 @@ package description -// ServerKind represents the type of a server. +// ServerKind represents the type of a single server in a topology. type ServerKind uint32 // These constants are the possible types of servers. @@ -20,7 +20,7 @@ const ( Mongos ServerKind = 256 ) -// String implements the fmt.Stringer interface. +// String returns a stringified version of the kind or "Unknown" if the kind is invalid. func (kind ServerKind) String() string { switch kind { case Standalone: diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/server_selector.go b/vendor/go.mongodb.org/mongo-driver/mongo/description/server_selector.go similarity index 87% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/server_selector.go rename to vendor/go.mongodb.org/mongo-driver/mongo/description/server_selector.go index 076bfe16..81273e70 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/server_selector.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/description/server_selector.go @@ -15,8 +15,9 @@ import ( "go.mongodb.org/mongo-driver/tag" ) -// ServerSelector is an interface implemented by types that can select a server given a -// topology description. +// ServerSelector is an interface implemented by types that can perform server selection given a topology description +// and list of candidate servers. The selector should filter the provided candidates list and return a subset that +// matches some criteria. type ServerSelector interface { SelectServer(Topology, []Server) ([]Server, error) } @@ -33,7 +34,16 @@ type compositeSelector struct { selectors []ServerSelector } -// CompositeSelector combines multiple selectors into a single selector. +// CompositeSelector combines multiple selectors into a single selector by applying them in order to the candidates +// list. +// +// For example, if the initial candidates list is [s0, s1, s2, s3] and two selectors are provided where the first +// matches s0 and s1 and the second matches s1 and s2, the following would occur during server selection: +// +// 1. firstSelector([s0, s1, s2, s3]) -> [s0, s1] +// 2. secondSelector([s0, s1]) -> [s1] +// +// The final list of candidates returned by the composite selector would be [s1]. func CompositeSelector(selectors []ServerSelector) ServerSelector { return &compositeSelector{selectors: selectors} } @@ -53,7 +63,7 @@ type latencySelector struct { latency time.Duration } -// LatencySelector creates a ServerSelector which selects servers based on their latency. +// LatencySelector creates a ServerSelector which selects servers based on their average RTT values. func LatencySelector(latency time.Duration) ServerSelector { return &latencySelector{latency: latency} } @@ -120,7 +130,7 @@ func ReadPrefSelector(rp *readpref.ReadPref) ServerSelector { if _, set := rp.MaxStaleness(); set { for _, s := range candidates { if s.Kind != Unknown { - if err := MaxStalenessSupported(s.WireVersion); err != nil { + if err := maxStalenessSupported(s.WireVersion); err != nil { return nil, err } } @@ -140,6 +150,15 @@ func ReadPrefSelector(rp *readpref.ReadPref) ServerSelector { }) } +// maxStalenessSupported returns an error if the given server version does not support max staleness. +func maxStalenessSupported(wireVersion *VersionRange) error { + if wireVersion != nil && wireVersion.Max < 5 { + return fmt.Errorf("max staleness is only supported for servers 3.4 or newer") + } + + return nil +} + func selectForReplicaSet(rp *readpref.ReadPref, t Topology, candidates []Server) ([]Server, error) { if err := verifyMaxStaleness(rp, t); err != nil { return nil, err diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/description/topology.go b/vendor/go.mongodb.org/mongo-driver/mongo/description/topology.go new file mode 100644 index 00000000..8544548c --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/mongo/description/topology.go @@ -0,0 +1,142 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package description + +import ( + "fmt" + + "go.mongodb.org/mongo-driver/mongo/readpref" +) + +// Topology contains information about a MongoDB cluster. +type Topology struct { + Servers []Server + SetName string + Kind TopologyKind + SessionTimeoutMinutes uint32 + CompatibilityErr error +} + +// String implements the Stringer interface. +func (t Topology) String() string { + var serversStr string + for _, s := range t.Servers { + serversStr += "{ " + s.String() + " }, " + } + return fmt.Sprintf("Type: %s, Servers: [%s]", t.Kind, serversStr) +} + +// Equal compares two topology descriptions and returns true if they are equal. +func (t Topology) Equal(other Topology) bool { + if t.Kind != other.Kind { + return false + } + + topoServers := make(map[string]Server) + for _, s := range t.Servers { + topoServers[s.Addr.String()] = s + } + + otherServers := make(map[string]Server) + for _, s := range other.Servers { + otherServers[s.Addr.String()] = s + } + + if len(topoServers) != len(otherServers) { + return false + } + + for _, server := range topoServers { + otherServer := otherServers[server.Addr.String()] + + if !server.Equal(otherServer) { + return false + } + } + + return true +} + +// HasReadableServer returns true if the topology contains a server suitable for reading. +// +// If the Topology's kind is Single or Sharded, the mode parameter is ignored and the function contains true if any of +// the servers in the Topology are of a known type. +// +// For replica sets, the function returns true if the cluster contains a server that matches the provided read +// preference mode. +func (t Topology) HasReadableServer(mode readpref.Mode) bool { + switch t.Kind { + case Single, Sharded: + return hasAvailableServer(t.Servers, 0) + case ReplicaSetWithPrimary: + return hasAvailableServer(t.Servers, mode) + case ReplicaSetNoPrimary, ReplicaSet: + if mode == readpref.PrimaryMode { + return false + } + // invalid read preference + if !mode.IsValid() { + return false + } + + return hasAvailableServer(t.Servers, mode) + } + return false +} + +// HasWritableServer returns true if a topology has a server available for writing. +// +// If the Topology's kind is Single or Sharded, this function returns true if any of the servers in the Topology are of +// a known type. +// +// For replica sets, the function returns true if the replica set contains a primary. +func (t Topology) HasWritableServer() bool { + return t.HasReadableServer(readpref.PrimaryMode) +} + +// hasAvailableServer returns true if any servers are available based on the read preference. +func hasAvailableServer(servers []Server, mode readpref.Mode) bool { + switch mode { + case readpref.PrimaryMode: + for _, s := range servers { + if s.Kind == RSPrimary { + return true + } + } + return false + case readpref.PrimaryPreferredMode, readpref.SecondaryPreferredMode, readpref.NearestMode: + for _, s := range servers { + if s.Kind == RSPrimary || s.Kind == RSSecondary { + return true + } + } + return false + case readpref.SecondaryMode: + for _, s := range servers { + if s.Kind == RSSecondary { + return true + } + } + return false + } + + // read preference is not specified + for _, s := range servers { + switch s.Kind { + case Standalone, + RSMember, + RSPrimary, + RSSecondary, + RSArbiter, + RSGhost, + Mongos: + return true + } + } + + return false +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/topology_kind.go b/vendor/go.mongodb.org/mongo-driver/mongo/description/topology_kind.go similarity index 100% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/topology_kind.go rename to vendor/go.mongodb.org/mongo-driver/mongo/description/topology_kind.go diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/topology_version.go b/vendor/go.mongodb.org/mongo-driver/mongo/description/topology_version.go similarity index 62% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/topology_version.go rename to vendor/go.mongodb.org/mongo-driver/mongo/description/topology_version.go index cbc5a606..e6674ea7 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/topology_version.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/description/topology_version.go @@ -9,8 +9,8 @@ package description import ( "fmt" + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) // TopologyVersion represents a software version. @@ -20,7 +20,7 @@ type TopologyVersion struct { } // NewTopologyVersion creates a TopologyVersion based on doc -func NewTopologyVersion(doc bsoncore.Document) (*TopologyVersion, error) { +func NewTopologyVersion(doc bson.Raw) (*TopologyVersion, error) { elements, err := doc.Elements() if err != nil { return nil, err @@ -44,19 +44,22 @@ func NewTopologyVersion(doc bsoncore.Document) (*TopologyVersion, error) { return &tv, nil } -// CompareTopologyVersion returns -1 if currentTVresponseTV. -// This comparsion is not commutative so the original TopologyVersion should be first. -func CompareTopologyVersion(currentTV, responseTV *TopologyVersion) int { - if currentTV == nil || responseTV == nil { +// CompareToIncoming compares the receiver, which represents the currently known TopologyVersion for a server, to an +// incoming TopologyVersion extracted from a server command response. +// +// This returns -1 if the receiver version is less than the response, 0 if the versions are equal, and 1 if the +// receiver version is greater than the response. This comparison is not commutative. +func (tv *TopologyVersion) CompareToIncoming(responseTV *TopologyVersion) int { + if tv == nil || responseTV == nil { return -1 } - if currentTV.ProcessID != responseTV.ProcessID { + if tv.ProcessID != responseTV.ProcessID { return -1 } - if currentTV.Counter == responseTV.Counter { + if tv.Counter == responseTV.Counter { return 0 } - if currentTV.Counter < responseTV.Counter { + if tv.Counter < responseTV.Counter { return -1 } return 1 diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/version_range.go b/vendor/go.mongodb.org/mongo-driver/mongo/description/version_range.go similarity index 75% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/version_range.go rename to vendor/go.mongodb.org/mongo-driver/mongo/description/version_range.go index 984dff89..5d6270c5 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/version_range.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/description/version_range.go @@ -25,6 +25,17 @@ func (vr VersionRange) Includes(v int32) bool { return v >= vr.Min && v <= vr.Max } +// Equals returns a bool indicating whether the supplied VersionRange is equal. +func (vr *VersionRange) Equals(other *VersionRange) bool { + if vr == nil && other == nil { + return true + } + if vr == nil || other == nil { + return false + } + return vr.Min == other.Min && vr.Max == other.Max +} + // String implements the fmt.Stringer interface. func (vr VersionRange) String() string { return fmt.Sprintf("[%d, %d]", vr.Min, vr.Max) diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/doc.go b/vendor/go.mongodb.org/mongo-driver/mongo/doc.go index b84927f7..d184a575 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/doc.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/doc.go @@ -9,14 +9,12 @@ // Package mongo provides a MongoDB Driver API for Go. // // Basic usage of the driver starts with creating a Client from a connection -// string. To do so, call the NewClient and Connect functions: +// string. To do so, call Connect: // -// client, err := NewClient(options.Client().ApplyURI("mongodb://foo:bar@localhost:27017")) -// if err != nil { return err } -// ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) -// defer cancel() -// err = client.Connect(ctx) -// if err != nil { return err } +// ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) +// defer cancel() +// client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://foo:bar@localhost:27017")) +// if err != nil { return err } // // This will create a new client and start monitoring the MongoDB server on localhost. // The Database and Collection types can be used to access the database: @@ -52,6 +50,17 @@ // return err // } // +// Cursor.All will decode all of the returned elements at once: +// +// var results []struct{ +// Foo string +// Bar int32 +// } +// if err = cur.All(context.Background(), &results); err != nil { +// log.Fatal(err) +// } +// // do something with results... +// // Methods that only return a single document will return a *SingleResult, which works // like a *sql.Row: // @@ -70,6 +79,18 @@ // Additional examples can be found under the examples directory in the driver's repository and // on the MongoDB website. // +// Error Handling +// +// Errors from the MongoDB server will implement the ServerError interface, which has functions to check for specific +// error codes, labels, and message substrings. These can be used to check for and handle specific errors. Some methods, +// like InsertMany and BulkWrite, can return an error representing multiple errors, and in those cases the ServerError +// functions will return true if any of the contained errors satisfy the check. +// +// There are also helper functions to check for certain specific types of errors: +// IsDuplicateKeyError(error) +// IsNetworkError(error) +// IsTimeout(error) +// // Potential DNS Issues // // Building with Go 1.11+ and using connection strings with the "mongodb+srv"[1] scheme is @@ -84,8 +105,9 @@ // // Note: Auto encryption is an enterprise-only feature. // -// The libmongocrypt C library is required when using client-side encryption. To install libmongocrypt, follow the -// instructions for your operating system: +// The libmongocrypt C library is required when using client-side encryption. libmongocrypt version 1.1.0 or higher is +// required when using driver version 1.5.0 or higher. To install libmongocrypt, follow the instructions for your +// operating system: // // 1. Linux: follow the instructions listed at // https://github.com/mongodb/libmongocrypt#installing-libmongocrypt-from-distribution-packages to install the correct diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/errors.go b/vendor/go.mongodb.org/mongo-driver/mongo/errors.go index ae38d200..329cadb0 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/errors.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/errors.go @@ -8,8 +8,11 @@ package mongo import ( "bytes" + "context" "errors" "fmt" + "net" + "strings" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -32,6 +35,16 @@ var ErrNilValue = errors.New("value is nil") // ErrEmptySlice is returned when an empty slice is passed to a CRUD method that requires a non-empty slice. var ErrEmptySlice = errors.New("must provide at least one element in input slice") +// ErrMapForOrderedArgument is returned when a map with multiple keys is passed to a CRUD method for an ordered parameter +type ErrMapForOrderedArgument struct { + ParamName string +} + +// Error implements the error interface. +func (e ErrMapForOrderedArgument) Error() string { + return fmt.Sprintf("multi-key map passed in for ordered parameter %v", e.ParamName) +} + func replaceErrors(err error) error { if err == topology.ErrTopologyClosed { return ErrClientDisconnected @@ -70,6 +83,60 @@ func replaceErrors(err error) error { return err } +// IsDuplicateKeyError returns true if err is a duplicate key error +func IsDuplicateKeyError(err error) bool { + // handles SERVER-7164 and SERVER-11493 + for ; err != nil; err = unwrap(err) { + if e, ok := err.(ServerError); ok { + return e.HasErrorCode(11000) || e.HasErrorCode(11001) || e.HasErrorCode(12582) || + e.HasErrorCodeWithMessage(16460, " E11000 ") + } + } + return false +} + +// IsTimeout returns true if err is from a timeout +func IsTimeout(err error) bool { + for ; err != nil; err = unwrap(err) { + // check unwrappable errors together + if err == context.DeadlineExceeded { + return true + } + if ne, ok := err.(net.Error); ok { + return ne.Timeout() + } + //timeout error labels + if se, ok := err.(ServerError); ok { + if se.HasErrorLabel("NetworkTimeoutError") || se.HasErrorLabel("ExceededTimeLimitError") { + return true + } + } + } + + return false +} + +// unwrap returns the inner error if err implements Unwrap(), otherwise it returns nil. +func unwrap(err error) error { + u, ok := err.(interface { + Unwrap() error + }) + if !ok { + return nil + } + return u.Unwrap() +} + +// IsNetworkError returns true if err is a network error +func IsNetworkError(err error) bool { + for ; err != nil; err = unwrap(err) { + if e, ok := err.(ServerError); ok { + return e.HasErrorLabel("NetworkError") + } + } + return false +} + // MongocryptError represents an libmongocrypt error during client-side encryption. type MongocryptError struct { Code int32 @@ -112,6 +179,26 @@ func (e MongocryptdError) Unwrap() error { return e.Wrapped } +// ServerError is the interface implemented by errors returned from the server. Custom implementations of this +// interface should not be used in production. +type ServerError interface { + error + // HasErrorCode returns true if the error has the specified code. + HasErrorCode(int) bool + // HasErrorLabel returns true if the error contains the specified label. + HasErrorLabel(string) bool + // HasErrorMessage returns true if the error contains the specified message. + HasErrorMessage(string) bool + // HasErrorCodeWithMessage returns true if any of the contained errors have the specified code and message. + HasErrorCodeWithMessage(int, string) bool + + serverError() +} + +var _ ServerError = CommandError{} +var _ ServerError = WriteException{} +var _ ServerError = BulkWriteException{} + // CommandError represents a server error during execution of a command. This can be returned by any operation. type CommandError struct { Code int32 @@ -134,6 +221,11 @@ func (e CommandError) Unwrap() error { return e.Wrapped } +// HasErrorCode returns true if the error has the specified code. +func (e CommandError) HasErrorCode(code int) bool { + return int(e.Code) == code +} + // HasErrorLabel returns true if the error contains the specified label. func (e CommandError) HasErrorLabel(label string) bool { if e.Labels != nil { @@ -146,11 +238,24 @@ func (e CommandError) HasErrorLabel(label string) bool { return false } +// HasErrorMessage returns true if the error contains the specified message. +func (e CommandError) HasErrorMessage(message string) bool { + return strings.Contains(e.Message, message) +} + +// HasErrorCodeWithMessage returns true if the error has the specified code and Message contains the specified message. +func (e CommandError) HasErrorCodeWithMessage(code int, message string) bool { + return int(e.Code) == code && strings.Contains(e.Message, message) +} + // IsMaxTimeMSExpiredError returns true if the error is a MaxTimeMSExpired error. func (e CommandError) IsMaxTimeMSExpiredError() bool { return e.Code == 50 || e.Name == "MaxTimeMSExpired" } +// serverError implements the ServerError interface. +func (e CommandError) serverError() {} + // WriteError is an error that occurred during execution of a write operation. This error type is only returned as part // of a WriteException or BulkWriteException. type WriteError struct { @@ -227,6 +332,19 @@ func (mwe WriteException) Error() string { return buf.String() } +// HasErrorCode returns true if the error has the specified code. +func (mwe WriteException) HasErrorCode(code int) bool { + if mwe.WriteConcernError != nil && mwe.WriteConcernError.Code == code { + return true + } + for _, we := range mwe.WriteErrors { + if we.Code == code { + return true + } + } + return false +} + // HasErrorLabel returns true if the error contains the specified label. func (mwe WriteException) HasErrorLabel(label string) bool { if mwe.Labels != nil { @@ -239,6 +357,36 @@ func (mwe WriteException) HasErrorLabel(label string) bool { return false } +// HasErrorMessage returns true if the error contains the specified message. +func (mwe WriteException) HasErrorMessage(message string) bool { + if mwe.WriteConcernError != nil && strings.Contains(mwe.WriteConcernError.Message, message) { + return true + } + for _, we := range mwe.WriteErrors { + if strings.Contains(we.Message, message) { + return true + } + } + return false +} + +// HasErrorCodeWithMessage returns true if any of the contained errors have the specified code and message. +func (mwe WriteException) HasErrorCodeWithMessage(code int, message string) bool { + if mwe.WriteConcernError != nil && + mwe.WriteConcernError.Code == code && strings.Contains(mwe.WriteConcernError.Message, message) { + return true + } + for _, we := range mwe.WriteErrors { + if we.Code == code && strings.Contains(we.Message, message) { + return true + } + } + return false +} + +// serverError implements the ServerError interface. +func (mwe WriteException) serverError() {} + func convertDriverWriteConcernError(wce *driver.WriteConcernError) *WriteConcernError { if wce == nil { return nil @@ -287,6 +435,19 @@ func (bwe BulkWriteException) Error() string { return buf.String() } +// HasErrorCode returns true if any of the errors have the specified code. +func (bwe BulkWriteException) HasErrorCode(code int) bool { + if bwe.WriteConcernError != nil && bwe.WriteConcernError.Code == code { + return true + } + for _, we := range bwe.WriteErrors { + if we.Code == code { + return true + } + } + return false +} + // HasErrorLabel returns true if the error contains the specified label. func (bwe BulkWriteException) HasErrorLabel(label string) bool { if bwe.Labels != nil { @@ -299,6 +460,36 @@ func (bwe BulkWriteException) HasErrorLabel(label string) bool { return false } +// HasErrorMessage returns true if the error contains the specified message. +func (bwe BulkWriteException) HasErrorMessage(message string) bool { + if bwe.WriteConcernError != nil && strings.Contains(bwe.WriteConcernError.Message, message) { + return true + } + for _, we := range bwe.WriteErrors { + if strings.Contains(we.Message, message) { + return true + } + } + return false +} + +// HasErrorCodeWithMessage returns true if any of the contained errors have the specified code and message. +func (bwe BulkWriteException) HasErrorCodeWithMessage(code int, message string) bool { + if bwe.WriteConcernError != nil && + bwe.WriteConcernError.Code == code && strings.Contains(bwe.WriteConcernError.Message, message) { + return true + } + for _, we := range bwe.WriteErrors { + if we.Code == code && strings.Contains(we.Message, message) { + return true + } + } + return false +} + +// serverError implements the ServerError interface. +func (bwe BulkWriteException) serverError() {} + // returnResult is used to determine if a function calling processWriteError should return // the result or return nil. Since the processWriteError function is used by many different // methods, both *One and *Many, we need a way to differentiate if the method should return diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/index_view.go b/vendor/go.mongodb.org/mongo-driver/mongo/index_view.go index 497790c3..4b0998bc 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/index_view.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/index_view.go @@ -16,12 +16,12 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -132,6 +132,29 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption return cursor, replaceErrors(err) } +// ListSpecifications executes a List command and returns a slice of returned IndexSpecifications +func (iv IndexView) ListSpecifications(ctx context.Context, opts ...*options.ListIndexesOptions) ([]*IndexSpecification, error) { + cursor, err := iv.List(ctx, opts...) + if err != nil { + return nil, err + } + + var results []*IndexSpecification + err = cursor.All(ctx, &results) + if err != nil { + return nil, err + } + + ns := iv.coll.db.Name() + "." + iv.coll.Name() + for _, res := range results { + // Pre-4.4 servers report a namespace in their responses, so we only set Namespace manually if it was not in + // the response. + res.Namespace = ns + } + + return results, nil +} + // CreateOne executes a createIndexes command to create an index on the collection and returns the name of the new // index. See the IndexView.CreateMany documentation for more information and an example. func (iv IndexView) CreateOne(ctx context.Context, model IndexModel, opts ...*options.CreateIndexesOptions) (string, error) { @@ -164,7 +187,7 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. return nil, fmt.Errorf("index model keys cannot be nil") } - keys, err := transformBsoncoreDocument(iv.coll.registry, model.Keys) + keys, err := transformBsoncoreDocument(iv.coll.registry, model.Keys, false, "keys") if err != nil { return nil, err } @@ -239,7 +262,7 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. op.MaxTimeMS(int64(*option.MaxTime / time.Millisecond)) } if option.CommitQuorum != nil { - commitQuorum, err := transformValue(iv.coll.registry, option.CommitQuorum) + commitQuorum, err := transformValue(iv.coll.registry, option.CommitQuorum, true, "commitQuorum") if err != nil { return nil, err } @@ -270,7 +293,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum optsDoc = bsoncore.AppendBooleanElement(optsDoc, "sparse", *opts.Sparse) } if opts.StorageEngine != nil { - doc, err := transformBsoncoreDocument(iv.coll.registry, opts.StorageEngine) + doc, err := transformBsoncoreDocument(iv.coll.registry, opts.StorageEngine, true, "storageEngine") if err != nil { return nil, err } @@ -293,7 +316,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum optsDoc = bsoncore.AppendInt32Element(optsDoc, "textIndexVersion", *opts.TextVersion) } if opts.Weights != nil { - doc, err := transformBsoncoreDocument(iv.coll.registry, opts.Weights) + doc, err := transformBsoncoreDocument(iv.coll.registry, opts.Weights, true, "weights") if err != nil { return nil, err } @@ -316,7 +339,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum optsDoc = bsoncore.AppendInt32Element(optsDoc, "bucketSize", *opts.BucketSize) } if opts.PartialFilterExpression != nil { - doc, err := transformBsoncoreDocument(iv.coll.registry, opts.PartialFilterExpression) + doc, err := transformBsoncoreDocument(iv.coll.registry, opts.PartialFilterExpression, true, "partialFilterExpression") if err != nil { return nil, err } @@ -327,7 +350,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum optsDoc = bsoncore.AppendDocumentElement(optsDoc, "collation", bsoncore.Document(opts.Collation.ToDocument())) } if opts.WildcardProjection != nil { - doc, err := transformBsoncoreDocument(iv.coll.registry, opts.WildcardProjection) + doc, err := transformBsoncoreDocument(iv.coll.registry, opts.WildcardProjection, true, "wildcardProjection") if err != nil { return nil, err } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/mongo.go b/vendor/go.mongodb.org/mongo-driver/mongo/mongo.go index 4dc8a951..09e2fa42 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/mongo.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/mongo.go @@ -188,14 +188,14 @@ func transformDocument(registry *bsoncodec.Registry, val interface{}) (bsonx.Doc if doc, ok := val.(bsonx.Doc); ok { return doc.Copy(), nil } - b, err := transformBsoncoreDocument(registry, val) + b, err := transformBsoncoreDocument(registry, val, true, "document") if err != nil { return nil, err } return bsonx.ReadDoc(b) } -func transformBsoncoreDocument(registry *bsoncodec.Registry, val interface{}) (bsoncore.Document, error) { +func transformBsoncoreDocument(registry *bsoncodec.Registry, val interface{}, mapAllowed bool, paramName string) (bsoncore.Document, error) { if registry == nil { registry = bson.DefaultRegistry } @@ -206,6 +206,12 @@ func transformBsoncoreDocument(registry *bsoncodec.Registry, val interface{}) (b // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. val = bson.Raw(bs) } + if !mapAllowed { + refValue := reflect.ValueOf(val) + if refValue.Kind() == reflect.Map && refValue.Len() > 1 { + return nil, ErrMapForOrderedArgument{paramName} + } + } // TODO(skriptble): Use a pool of these instead. buf := make([]byte, 0, 256) @@ -326,7 +332,7 @@ func transformAggregatePipelinev2(registry *bsoncodec.Registry, pipeline interfa var hasOutputStage bool valLen := val.Len() for idx := 0; idx < valLen; idx++ { - doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface()) + doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface(), true, fmt.Sprintf("pipeline stage :%v", idx)) if err != nil { return nil, false, err } @@ -356,7 +362,7 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, doll return u, ErrNilDocument case primitive.D, bsonx.Doc: u.Type = bsontype.EmbeddedDocument - u.Data, err = transformBsoncoreDocument(registry, update) + u.Data, err = transformBsoncoreDocument(registry, update, true, "update") if err != nil { return u, err } @@ -398,7 +404,7 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, doll } if val.Kind() != reflect.Slice && val.Kind() != reflect.Array { u.Type = bsontype.EmbeddedDocument - u.Data, err = transformBsoncoreDocument(registry, update) + u.Data, err = transformBsoncoreDocument(registry, update, true, "update") if err != nil { return u, err } @@ -410,7 +416,7 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, doll aidx, arr := bsoncore.AppendArrayStart(nil) valLen := val.Len() for idx := 0; idx < valLen; idx++ { - doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface()) + doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface(), true, "update") if err != nil { return u, err } @@ -426,7 +432,7 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, doll } } -func transformValue(registry *bsoncodec.Registry, val interface{}) (bsoncore.Value, error) { +func transformValue(registry *bsoncodec.Registry, val interface{}, mapAllowed bool, paramName string) (bsoncore.Value, error) { if registry == nil { registry = bson.DefaultRegistry } @@ -434,6 +440,13 @@ func transformValue(registry *bsoncodec.Registry, val interface{}) (bsoncore.Val return bsoncore.Value{}, ErrNilValue } + if !mapAllowed { + refValue := reflect.ValueOf(val) + if refValue.Kind() == reflect.Map && refValue.Len() > 1 { + return bsoncore.Value{}, ErrMapForOrderedArgument{paramName} + } + } + buf := make([]byte, 0, 256) bsonType, bsonValue, err := bson.MarshalValueAppendWithRegistry(registry, buf[:0], val) if err != nil { @@ -445,7 +458,7 @@ func transformValue(registry *bsoncodec.Registry, val interface{}) (bsoncore.Val // Build the aggregation pipeline for the CountDocument command. func countDocumentsAggregatePipeline(registry *bsoncodec.Registry, filter interface{}, opts *options.CountOptions) (bsoncore.Document, error) { - filterDoc, err := transformBsoncoreDocument(registry, filter) + filterDoc, err := transformBsoncoreDocument(registry, filter, true, "filter") if err != nil { return nil, err } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/aggregateoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/aggregateoptions.go index cbefe624..33752c30 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/aggregateoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/aggregateoptions.go @@ -41,8 +41,8 @@ type AggregateOptions struct { Comment *string // The index to use for the aggregation. This should either be the index name as a string or the index specification - // as a document. The hint does not apply to $lookup and $graphLookup aggregation stages. The default value is nil, - // which means that no hint will be sent. + // as a document. The hint does not apply to $lookup and $graphLookup aggregation stages. The driver will return an + // error if the hint parameter is a multi-key map. The default value is nil, which means that no hint will be sent. Hint interface{} } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/autoencryptionoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/autoencryptionoptions.go index 3cbb60cb..517b69e6 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/autoencryptionoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/autoencryptionoptions.go @@ -34,8 +34,15 @@ func AutoEncryption() *AutoEncryptionOptions { return &AutoEncryptionOptions{} } -// SetKeyVaultClientOptions specifies options for the client used to communicate with the key vault collection. If this is -// not set, the client used to do encryption will be re-used for key vault communication. +// SetKeyVaultClientOptions specifies options for the client used to communicate with the key vault collection. +// +// If this is set, it is used to create an internal mongo.Client. +// Otherwise, if the target mongo.Client being configured has an unlimited connection pool size (i.e. maxPoolSize=0), +// it is reused to interact with the key vault collection. +// Otherwise, if the target mongo.Client has a limited connection pool size, a separate internal mongo.Client is used +// (and created if necessary). The internal mongo.Client may be shared during automatic encryption (if +// BypassAutomaticEncryption is false). The internal mongo.Client is configured with the same options as the target +// mongo.Client except minPoolSize is set to 0 and AutoEncryptionOptions is omitted. func (a *AutoEncryptionOptions) SetKeyVaultClientOptions(opts *ClientOptions) *AutoEncryptionOptions { a.KeyVaultClientOptions = opts return a @@ -66,6 +73,13 @@ func (a *AutoEncryptionOptions) SetSchemaMap(schemaMap map[string]interface{}) * } // SetBypassAutoEncryption specifies whether or not auto encryption should be done. +// +// If this is unset or false and target mongo.Client being configured has an unlimited connection pool size +// (i.e. maxPoolSize=0), it is reused in the process of auto encryption. +// Otherwise, if the target mongo.Client has a limited connection pool size, a separate internal mongo.Client is used +// (and created if necessary). The internal mongo.Client may be shared for key vault operations (if KeyVaultClient is +// unset). The internal mongo.Client is configured with the same options as the target mongo.Client except minPoolSize +// is set to 0 and AutoEncryptionOptions is omitted. func (a *AutoEncryptionOptions) SetBypassAutoEncryption(bypass bool) *AutoEncryptionOptions { a.BypassAutoEncryption = &bypass return a diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions.go index 93e1f153..354a5dac 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions.go @@ -19,6 +19,7 @@ import ( "strings" "time" + "github.com/youmark/pkcs8" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/mongo/readconcern" @@ -108,6 +109,7 @@ type ClientOptions struct { MinPoolSize *uint64 PoolMonitor *event.PoolMonitor Monitor *event.CommandMonitor + ServerMonitor *event.ServerMonitor ReadConcern *readconcern.ReadConcern ReadPreference *readpref.ReadPref Registry *bsoncodec.Registry @@ -525,6 +527,12 @@ func (c *ClientOptions) SetMonitor(m *event.CommandMonitor) *ClientOptions { return c } +// SetServerMonitor specifies an SDAM monitor used to monitor SDAM events. +func (c *ClientOptions) SetServerMonitor(m *event.ServerMonitor) *ClientOptions { + c.ServerMonitor = m + return c +} + // SetReadConcern specifies the read concern to use for read operations. A read concern level can also be set through // the "readConcernLevel" URI option (e.g. "readConcernLevel=majority"). The default is nil, meaning the server will use // its configured default. @@ -705,7 +713,7 @@ func (c *ClientOptions) SetDisableOCSPEndpointCheck(disableCheck bool) *ClientOp } // MergeClientOptions combines the given *ClientOptions into a single *ClientOptions in a last one wins fashion. -// The specified options are merged with the existing options on the collection, with the specified options taking +// The specified options are merged with the existing options on the client, with the specified options taking // precedence. func MergeClientOptions(opts ...*ClientOptions) *ClientOptions { c := Client() @@ -757,6 +765,9 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.Monitor != nil { c.Monitor = opt.Monitor } + if opt.ServerMonitor != nil { + c.ServerMonitor = opt.ServerMonitor + } if opt.ReadConcern != nil { c.ReadConcern = opt.ReadConcern } @@ -875,14 +886,34 @@ func addClientCertFromBytes(cfg *tls.Config, data []byte, keyPasswd string) (str certDecodedBlock = currentBlock.Bytes start += len(certBlock) } else if strings.HasSuffix(currentBlock.Type, "PRIVATE KEY") { - if keyPasswd != "" && x509.IsEncryptedPEMBlock(currentBlock) { - var encoded bytes.Buffer - buf, err := x509.DecryptPEMBlock(currentBlock, []byte(keyPasswd)) - if err != nil { - return "", err + isEncrypted := x509.IsEncryptedPEMBlock(currentBlock) || strings.Contains(currentBlock.Type, "ENCRYPTED PRIVATE KEY") + if isEncrypted { + if keyPasswd == "" { + return "", fmt.Errorf("no password provided to decrypt private key") } - pem.Encode(&encoded, &pem.Block{Type: currentBlock.Type, Bytes: buf}) + var keyBytes []byte + var err error + // Process the X.509-encrypted or PKCS-encrypted PEM block. + if x509.IsEncryptedPEMBlock(currentBlock) { + // Only covers encrypted PEM data with a DEK-Info header. + keyBytes, err = x509.DecryptPEMBlock(currentBlock, []byte(keyPasswd)) + if err != nil { + return "", err + } + } else if strings.Contains(currentBlock.Type, "ENCRYPTED") { + // The pkcs8 package only handles the PKCS #5 v2.0 scheme. + decrypted, err := pkcs8.ParsePKCS8PrivateKey(currentBlock.Bytes, []byte(keyPasswd)) + if err != nil { + return "", err + } + keyBytes, err = x509MarshalPKCS8PrivateKey(decrypted) + if err != nil { + return "", err + } + } + var encoded bytes.Buffer + pem.Encode(&encoded, &pem.Block{Type: currentBlock.Type, Bytes: keyBytes}) keyBlock = encoded.Bytes() start = len(data) - len(remaining) } else { diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions_1_10.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions_1_10.go index 97c0045d..1943d9c5 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions_1_10.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions_1_10.go @@ -7,3 +7,7 @@ import "crypto/x509" func x509CertSubject(cert *x509.Certificate) string { return cert.Subject.String() } + +func x509MarshalPKCS8PrivateKey(pkcs8 interface{}) ([]byte, error) { + return x509.MarshalPKCS8PrivateKey(pkcs8) +} diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions_1_9.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions_1_9.go index 385d6d36..e1099229 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions_1_9.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/clientoptions_1_9.go @@ -4,10 +4,17 @@ package options import ( "crypto/x509" + "fmt" ) -// We don't support version less then 1.10, but Evergreen needs to be able to compile the driver -// using version 1.8. +// We don't support Go versions less than 1.10, but Evergreen needs to be able to compile the driver +// using version 1.9 and cert.Subject func x509CertSubject(cert *x509.Certificate) string { return "" } + +// We don't support Go versions less than 1.10, but Evergreen needs to be able to compile the driver +// using version 1.9 and x509.MarshalPKCS8PrivateKey() +func x509MarshalPKCS8PrivateKey(pkcs8 interface{}) ([]byte, error) { + return nil, fmt.Errorf("PKCS8-encrypted client private keys are only supported with go1.10+") +} diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/countoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/countoptions.go index 1d1cd828..094524c1 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/countoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/countoptions.go @@ -16,7 +16,8 @@ type CountOptions struct { Collation *Collation // The index to use for the aggregation. This should either be the index name as a string or the index specification - // as a document. The default value is nil, which means that no hint will be sent. + // as a document. The driver will return an error if the hint parameter is a multi-key map. The default value is nil, + // which means that no hint will be sent. Hint interface{} // The maximum number of documents to count. The default value is 0, which means that there is no limit and all diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/datakeyoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/datakeyoptions.go index 2165bf85..c6a17f9e 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/datakeyoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/datakeyoptions.go @@ -19,10 +19,37 @@ func DataKey() *DataKeyOptions { // SetMasterKey specifies a KMS-specific key used to encrypt the new data key. // -// If being used with the AWS KMS provider, this option is required and must be a document with the following format: -// {region: string, key: string}. -// // If being used with a local KMS provider, this option is not applicable and should not be specified. +// +// For the AWS, Azure, and GCP KMS providers, this option is required and must be a document. For each, the value of the +// "endpoint" or "keyVaultEndpoint" must be a host name with an optional port number (e.g. "foo.com" or "foo.com:443"). +// +// When using AWS, the document must have the format: +// { +// region: , +// key: , // The Amazon Resource Name (ARN) to the AWS customer master key (CMK). +// endpoint: Optional // An alternate host identifier to send KMS requests to. +// } +// If unset, the "endpoint" defaults to "kms..amazonaws.com". +// +// When using Azure, the document must have the format: +// { +// keyVaultEndpoint: , // A host identifier to send KMS requests to. +// keyName: , +// keyVersion: Optional // A specific version of the named key. +// } +// If unset, "keyVersion" defaults to the key's primary version. +// +// When using GCP, the document must have the format: +// { +// projectId: , +// location: , +// keyRing: , +// keyName: , +// keyVersion: Optional, // A specific version of the named key. +// endpoint: Optional // An alternate host identifier to send KMS requests to. +// } +// If unset, "keyVersion" defaults to the key's primary version and "endpoint" defaults to "cloudkms.googleapis.com". func (dk *DataKeyOptions) SetMasterKey(masterKey interface{}) *DataKeyOptions { dk.MasterKey = masterKey return dk diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/deleteoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/deleteoptions.go index a9780ca1..fcea40a5 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/deleteoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/deleteoptions.go @@ -17,7 +17,8 @@ type DeleteOptions struct { // as a document. This option is only valid for MongoDB versions >= 4.4. Server versions >= 3.4 will return an error // if this option is specified. For server versions < 3.4, the driver will return a client-side error if this option // is specified. The driver will return an error if this option is specified during an unacknowledged write - // operation. The default value is nil, which means that no hint will be sent. + // operation. The driver will return an error if the hint parameter is a multi-key map. The default value is nil, + // which means that no hint will be sent. Hint interface{} } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/findoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/findoptions.go index bb69cd0f..ad9409c9 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/findoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/findoptions.go @@ -39,7 +39,8 @@ type FindOptions struct { CursorType *CursorType // The index to use for the operation. This should either be the index name as a string or the index specification - // as a document. The default value is nil, which means that no hint will be sent. + // as a document. The driver will return an error if the hint parameter is a multi-key map. The default value is nil, + // which means that no hint will be sent. Hint interface{} // The maximum number of documents to return. The default value is 0, which means that all documents matching the @@ -95,7 +96,8 @@ type FindOptions struct { // Deprecated: This option has been deprecated in MongoDB version 3.6 and removed in MongoDB version 4.0. Snapshot *bool - // A document specifying the order in which documents should be returned. + // A document specifying the order in which documents should be returned. The driver will return an error if the + // sort parameter is a multi-key map. Sort interface{} } @@ -307,6 +309,8 @@ type FindOneOptions struct { AllowPartialResults *bool // The maximum number of documents to be included in each batch returned by the server. + // + // Deprecated: This option is not valid for a findOne operation, as no cursor is actually created. BatchSize *int32 // Specifies a collation to use for string comparisons during the operation. This option is only valid for MongoDB @@ -320,10 +324,13 @@ type FindOneOptions struct { // Specifies the type of cursor that should be created for the operation. The default is NonTailable, which means // that the cursor will be closed by the server when the last batch of documents is retrieved. + // + // Deprecated: This option is not valid for a findOne operation, as no cursor is actually created. CursorType *CursorType // The index to use for the aggregation. This should either be the index name as a string or the index specification - // as a document. The default value is nil, which means that no hint will be sent. + // as a document. The driver will return an error if the hint parameter is a multi-key map. The default value is nil, + // which means that no hint will be sent. Hint interface{} // A document specifying the exclusive upper bound for a specific index. The default value is nil, which means that @@ -333,6 +340,8 @@ type FindOneOptions struct { // The maximum amount of time that the server should wait for new documents to satisfy a tailable cursor query. // This option is only valid for tailable await cursors (see the CursorType option for more information) and // MongoDB versions >= 3.2. For other cursor types or previous server versions, this option is ignored. + // + // Deprecated: This option is not valid for a findOne operation, as no cursor is actually created. MaxAwaitTime *time.Duration // The maximum amount of time that the query can run on the server. The default value is nil, meaning that there @@ -345,9 +354,14 @@ type FindOneOptions struct { // If true, the cursor created by the operation will not timeout after a period of inactivity. The default value // is false. + // + // Deprecated: This option is not valid for a findOne operation, as no cursor is actually created. NoCursorTimeout *bool // This option is for internal replication use only and should not be set. + // + // Deprecated: This option has been deprecated in MongoDB version 4.4 and will be ignored by the server if it is + // set. OplogReplay *bool // A document describing which fields will be included in the document returned by the operation. The default value @@ -372,7 +386,7 @@ type FindOneOptions struct { Snapshot *bool // A document specifying the sort order to apply to the query. The first document in the sorted order will be - // returned. + // returned. The driver will return an error if the sort parameter is a multi-key map. Sort interface{} } @@ -388,6 +402,8 @@ func (f *FindOneOptions) SetAllowPartialResults(b bool) *FindOneOptions { } // SetBatchSize sets the value for the BatchSize field. +// +// Deprecated: This option is not valid for a findOne operation, as no cursor is actually created. func (f *FindOneOptions) SetBatchSize(i int32) *FindOneOptions { f.BatchSize = &i return f @@ -406,6 +422,8 @@ func (f *FindOneOptions) SetComment(comment string) *FindOneOptions { } // SetCursorType sets the value for the CursorType field. +// +// Deprecated: This option is not valid for a findOne operation, as no cursor is actually created. func (f *FindOneOptions) SetCursorType(ct CursorType) *FindOneOptions { f.CursorType = &ct return f @@ -424,6 +442,8 @@ func (f *FindOneOptions) SetMax(max interface{}) *FindOneOptions { } // SetMaxAwaitTime sets the value for the MaxAwaitTime field. +// +// Deprecated: This option is not valid for a findOne operation, as no cursor is actually created. func (f *FindOneOptions) SetMaxAwaitTime(d time.Duration) *FindOneOptions { f.MaxAwaitTime = &d return f @@ -442,12 +462,17 @@ func (f *FindOneOptions) SetMin(min interface{}) *FindOneOptions { } // SetNoCursorTimeout sets the value for the NoCursorTimeout field. +// +// Deprecated: This option is not valid for a findOne operation, as no cursor is actually created. func (f *FindOneOptions) SetNoCursorTimeout(b bool) *FindOneOptions { f.NoCursorTimeout = &b return f } // SetOplogReplay sets the value for the OplogReplay field. +// +// Deprecated: This option has been deprecated in MongoDB version 4.4 and will be ignored by the server if it is +// set. func (f *FindOneOptions) SetOplogReplay(b bool) *FindOneOptions { f.OplogReplay = &b return f @@ -584,8 +609,8 @@ type FindOneAndReplaceOptions struct { ReturnDocument *ReturnDocument // A document specifying which document should be replaced if the filter used by the operation matches multiple - // documents in the collection. If set, the first document in the sorted order will be replaced. - // The default value is nil. + // documents in the collection. If set, the first document in the sorted order will be replaced. The driver will + // return an error if the sort parameter is a multi-key map. The default value is nil. Sort interface{} // If true, a new document will be inserted if the filter does not match any documents in the collection. The @@ -595,8 +620,9 @@ type FindOneAndReplaceOptions struct { // The index to use for the operation. This should either be the index name as a string or the index specification // as a document. This option is only valid for MongoDB versions >= 4.4. MongoDB version 4.2 will report an error if // this option is specified. For server versions < 4.2, the driver will return an error if this option is specified. - // The driver will return an error if this option is used with during an unacknowledged write operation. The default - // value is nil, which means that no hint will be sent. + // The driver will return an error if this option is used with during an unacknowledged write operation. The driver + // will return an error if the hint parameter is a multi-key map. The default value is nil, which means that no hint + // will be sent. Hint interface{} } @@ -721,8 +747,8 @@ type FindOneAndUpdateOptions struct { ReturnDocument *ReturnDocument // A document specifying which document should be updated if the filter used by the operation matches multiple - // documents in the collection. If set, the first document in the sorted order will be updated. - // The default value is nil. + // documents in the collection. If set, the first document in the sorted order will be updated. The driver will + // return an error if the sort parameter is a multi-key map. The default value is nil. Sort interface{} // If true, a new document will be inserted if the filter does not match any documents in the collection. The @@ -732,8 +758,9 @@ type FindOneAndUpdateOptions struct { // The index to use for the operation. This should either be the index name as a string or the index specification // as a document. This option is only valid for MongoDB versions >= 4.4. MongoDB version 4.2 will report an error if // this option is specified. For server versions < 4.2, the driver will return an error if this option is specified. - // The driver will return an error if this option is used with during an unacknowledged write operation. The default - // value is nil, which means that no hint will be sent. + // The driver will return an error if this option is used with during an unacknowledged write operation. The driver + // will return an error if the hint parameter is a multi-key map. The default value is nil, which means that no hint + // will be sent. Hint interface{} } @@ -853,14 +880,15 @@ type FindOneAndDeleteOptions struct { // A document specifying which document should be replaced if the filter used by the operation matches multiple // documents in the collection. If set, the first document in the sorted order will be selected for replacement. - // The default value is nil. + // The driver will return an error if the sort parameter is a multi-key map. The default value is nil. Sort interface{} // The index to use for the operation. This should either be the index name as a string or the index specification // as a document. This option is only valid for MongoDB versions >= 4.4. MongoDB version 4.2 will report an error if // this option is specified. For server versions < 4.2, the driver will return an error if this option is specified. - // The driver will return an error if this option is used with during an unacknowledged write operation. The default - // value is nil, which means that no hint will be sent. + // The driver will return an error if this option is used with during an unacknowledged write operation. The driver + // will return an error if the hint parameter is a multi-key map. The default value is nil, which means that no hint + // will be sent. Hint interface{} } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/gridfsoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/gridfsoptions.go index a24ec7d9..493fe983 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/gridfsoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/gridfsoptions.go @@ -234,7 +234,8 @@ type GridFSFindOptions struct { // The number of documents to skip before adding documents to the result. The default value is 0. Skip *int32 - // A document specifying the order in which documents should be returned. + // A document specifying the order in which documents should be returned. The driver will return an error if the + // sort parameter is a multi-key map. Sort interface{} } diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/listcollectionsoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/listcollectionsoptions.go index aa603c1d..4c2ce3e6 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/listcollectionsoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/listcollectionsoptions.go @@ -10,6 +10,9 @@ package options type ListCollectionsOptions struct { // If true, each collection document will only contain a field for the collection name. The default value is false. NameOnly *bool + + // The maximum number of documents to be included in each batch returned by the server. + BatchSize *int32 } // ListCollections creates a new ListCollectionsOptions instance. @@ -23,6 +26,12 @@ func (lc *ListCollectionsOptions) SetNameOnly(b bool) *ListCollectionsOptions { return lc } +// SetBatchSize sets the value for the BatchSize field. +func (lc *ListCollectionsOptions) SetBatchSize(size int32) *ListCollectionsOptions { + lc.BatchSize = &size + return lc +} + // MergeListCollectionsOptions combines the given ListCollectionsOptions instances into a single *ListCollectionsOptions // in a last-one-wins fashion. func MergeListCollectionsOptions(opts ...*ListCollectionsOptions) *ListCollectionsOptions { @@ -34,6 +43,9 @@ func MergeListCollectionsOptions(opts ...*ListCollectionsOptions) *ListCollectio if opt.NameOnly != nil { lc.NameOnly = opt.NameOnly } + if opt.BatchSize != nil { + lc.BatchSize = opt.BatchSize + } } return lc diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/replaceoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/replaceoptions.go index 37bba0dd..543c81ce 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/replaceoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/replaceoptions.go @@ -23,7 +23,8 @@ type ReplaceOptions struct { // as a document. This option is only valid for MongoDB versions >= 4.2. Server versions >= 3.4 will return an error // if this option is specified. For server versions < 3.4, the driver will return a client-side error if this option // is specified. The driver will return an error if this option is specified during an unacknowledged write - // operation. The default value is nil, which means that no hint will be sent. + // operation. The driver will return an error if the hint parameter is a multi-key map. The default value is nil, + // which means that no hint will be sent. Hint interface{} // If true, a new document will be inserted if the filter does not match any documents in the collection. The diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/updateoptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/updateoptions.go index a89ef664..4d1e7e47 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/options/updateoptions.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/updateoptions.go @@ -28,7 +28,8 @@ type UpdateOptions struct { // as a document. This option is only valid for MongoDB versions >= 4.2. Server versions >= 3.4 will return an error // if this option is specified. For server versions < 3.4, the driver will return a client-side error if this option // is specified. The driver will return an error if this option is specified during an unacknowledged write - // operation. The default value is nil, which means that no hint will be sent. + // operation. The driver will return an error if the hint parameter is a multi-key map. The default value is nil, + // which means that no hint will be sent. Hint interface{} // If true, a new document will be inserted if the filter does not match any documents in the collection. The diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/readpref/mode.go b/vendor/go.mongodb.org/mongo-driver/mongo/readpref/mode.go index deacf9f3..ce036504 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/readpref/mode.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/readpref/mode.go @@ -72,3 +72,17 @@ func (mode Mode) String() string { return "unknown" } } + +// IsValid checks whether the mode is valid. +func (mode Mode) IsValid() bool { + switch mode { + case PrimaryMode, + PrimaryPreferredMode, + SecondaryMode, + SecondaryPreferredMode, + NearestMode: + return true + default: + return false + } +} diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/results.go b/vendor/go.mongodb.org/mongo-driver/mongo/results.go index 8275f623..52e2dedd 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/results.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/results.go @@ -10,6 +10,7 @@ import ( "fmt" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" ) @@ -142,3 +143,104 @@ func (result *UpdateResult) UnmarshalBSON(b []byte) error { return nil } + +// IndexSpecification represents an index in a database. This type is returned by the IndexView.ListSpecifications +// function and is also used in the CollectionSpecification type. +type IndexSpecification struct { + // The index name. + Name string + + // The namespace for the index. This is a string in the format "databaseName.collectionName". + Namespace string + + // The keys specification document for the index. + KeysDocument bson.Raw + + // The index version. + Version int32 +} + +var _ bson.Unmarshaler = (*IndexSpecification)(nil) + +type unmarshalIndexSpecification struct { + Name string `bson:"name"` + Namespace string `bson:"ns"` + KeysDocument bson.Raw `bson:"key"` + Version int32 `bson:"v"` +} + +// UnmarshalBSON implements the bson.Unmarshaler interface. +func (i *IndexSpecification) UnmarshalBSON(data []byte) error { + var temp unmarshalIndexSpecification + if err := bson.Unmarshal(data, &temp); err != nil { + return err + } + + i.Name = temp.Name + i.Namespace = temp.Namespace + i.KeysDocument = temp.KeysDocument + i.Version = temp.Version + return nil +} + +// CollectionSpecification represents a collection in a database. This type is returned by the +// Database.ListCollectionSpecifications function. +type CollectionSpecification struct { + // The collection name. + Name string + + // The type of the collection. This will either be "collection" or "view". + Type string + + // Whether or not the collection is readOnly. This will be false for MongoDB versions < 3.4. + ReadOnly bool + + // The collection UUID. This field will be nil for MongoDB versions < 3.6. For versions 3.6 and higher, this will + // be a primitive.Binary with Subtype 4. + UUID *primitive.Binary + + // A document containing the options used to construct the collection. + Options bson.Raw + + // An IndexSpecification instance with details about the collection's _id index. This will be nil if the NameOnly + // option is used and for MongoDB versions < 3.4. + IDIndex *IndexSpecification +} + +var _ bson.Unmarshaler = (*CollectionSpecification)(nil) + +// unmarshalCollectionSpecification is used to unmarshal BSON bytes from a listCollections command into a +// CollectionSpecification. +type unmarshalCollectionSpecification struct { + Name string `bson:"name"` + Type string `bson:"type"` + Info *struct { + ReadOnly bool `bson:"readOnly"` + UUID *primitive.Binary `bson:"uuid"` + } `bson:"info"` + Options bson.Raw `bson:"options"` + IDIndex *IndexSpecification `bson:"idIndex"` +} + +// UnmarshalBSON implements the bson.Unmarshaler interface. +func (cs *CollectionSpecification) UnmarshalBSON(data []byte) error { + var temp unmarshalCollectionSpecification + if err := bson.Unmarshal(data, &temp); err != nil { + return err + } + + cs.Name = temp.Name + cs.Type = temp.Type + if cs.Type == "" { + // The "type" field is only present on 3.4+ because views were introduced in 3.4, so we implicitly set the + // value to "collection" if it's empty. + cs.Type = "collection" + } + if temp.Info != nil { + cs.ReadOnly = temp.Info.ReadOnly + cs.UUID = temp.Info.UUID + } + cs.Options = temp.Options + cs.IDIndex = temp.IDIndex + return nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/session.go b/vendor/go.mongodb.org/mongo-driver/mongo/session.go index 40ae4a0b..2c1c38b8 100644 --- a/vendor/go.mongodb.org/mongo-driver/mongo/session.go +++ b/vendor/go.mongodb.org/mongo-driver/mongo/session.go @@ -13,10 +13,11 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/internal" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -94,8 +95,10 @@ func SessionFromContext(ctx context.Context) Session { // callback, sessCtx must be used as the Context parameter for any operations that should be part of the transaction. If // the ctx parameter already has a Session attached to it, it will be replaced by this session. The fn callback may be // run multiple times during WithTransaction due to retry attempts, so it must be idempotent. Non-retryable operation -// errors or any operation errors that occur after the timeout expires will be returned without retrying. For a usage -// example, see the Client.StartSession method documentation. +// errors or any operation errors that occur after the timeout expires will be returned without retrying. If the +// callback fails, the driver will call AbortTransaction. Because this method must succeed to ensure that server-side +// resources are properly cleaned up, context deadlines and cancellations will not be respected during this call. For a +// usage example, see the Client.StartSession method documentation. // // ClusterTime, OperationTime, Client, and ID return the session's current operation time, the session's current cluster // time, the Client associated with the session, and the ID document associated with the session, respectively. The ID @@ -179,7 +182,9 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi res, err := fn(NewSessionContext(ctx, s)) if err != nil { if s.clientSession.TransactionRunning() { - _ = s.AbortTransaction(ctx) + // Wrap the user-provided Context in a new one that behaves like context.Background() for deadlines and + // cancellations, but forwards Value requests to the original one. + _ = s.AbortTransaction(internal.NewBackgroundContext(ctx)) } select { @@ -204,8 +209,10 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi CommitLoop: for { err = s.CommitTransaction(ctx) - if err == nil { - return res, nil + // End when error is nil (transaction has been committed), or when context has timed out or been + // canceled, as retrying has no chance of success. + if err == nil || ctx.Err() != nil { + return res, err } select { @@ -302,6 +309,11 @@ func (s *sessionImpl) CommitTransaction(ctx context.Context) error { } err = op.Execute(ctx) + // Return error without updating transaction state if it is a timeout, as the transaction has not + // actually been committed. + if IsTimeout(err) { + return replaceErrors(err) + } s.clientSession.Committing = false commitErr := s.clientSession.CommitTransaction() diff --git a/vendor/go.mongodb.org/mongo-driver/version/version.go b/vendor/go.mongodb.org/mongo-driver/version/version.go index 459e5707..13d4b2e8 100644 --- a/vendor/go.mongodb.org/mongo-driver/version/version.go +++ b/vendor/go.mongodb.org/mongo-driver/version/version.go @@ -7,4 +7,4 @@ package version // import "go.mongodb.org/mongo-driver/version" // Driver is the current version of the driver. -var Driver = "v1.4.6" +var Driver = "v1.5.0" diff --git a/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/array.go b/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/array.go new file mode 100644 index 00000000..8ea60ba3 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/array.go @@ -0,0 +1,164 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncore + +import ( + "bytes" + "fmt" + "io" + "strconv" +) + +// NewArrayLengthError creates and returns an error for when the length of an array exceeds the +// bytes available. +func NewArrayLengthError(length, rem int) error { + return lengthError("array", length, rem) +} + +// Array is a raw bytes representation of a BSON array. +type Array []byte + +// NewArrayFromReader reads an array from r. This function will only validate the length is +// correct and that the array ends with a null byte. +func NewArrayFromReader(r io.Reader) (Array, error) { + return newBufferFromReader(r) +} + +// Index searches for and retrieves the value at the given index. This method will panic if +// the array is invalid or if the index is out of bounds. +func (a Array) Index(index uint) Value { + value, err := a.IndexErr(index) + if err != nil { + panic(err) + } + return value +} + +// IndexErr searches for and retrieves the value at the given index. +func (a Array) IndexErr(index uint) (Value, error) { + elem, err := indexErr(a, index) + if err != nil { + return Value{}, err + } + return elem.Value(), err +} + +// DebugString outputs a human readable version of Array. It will attempt to stringify the +// valid components of the array even if the entire array is not valid. +func (a Array) DebugString() string { + if len(a) < 5 { + return "" + } + var buf bytes.Buffer + buf.WriteString("Array") + length, rem, _ := ReadLength(a) // We know we have enough bytes to read the length + buf.WriteByte('(') + buf.WriteString(strconv.Itoa(int(length))) + length -= 4 + buf.WriteString(")[") + var elem Element + var ok bool + for length > 1 { + elem, rem, ok = ReadElement(rem) + length -= int32(len(elem)) + if !ok { + buf.WriteString(fmt.Sprintf("", length)) + break + } + fmt.Fprintf(&buf, "%s", elem.Value().DebugString()) + if length != 1 { + buf.WriteByte(',') + } + } + buf.WriteByte(']') + + return buf.String() +} + +// String outputs an ExtendedJSON version of Array. If the Array is not valid, this method +// returns an empty string. +func (a Array) String() string { + if len(a) < 5 { + return "" + } + var buf bytes.Buffer + buf.WriteByte('[') + + length, rem, _ := ReadLength(a) // We know we have enough bytes to read the length + + length -= 4 + + var elem Element + var ok bool + for length > 1 { + elem, rem, ok = ReadElement(rem) + length -= int32(len(elem)) + if !ok { + return "" + } + fmt.Fprintf(&buf, "%s", elem.Value().String()) + if length > 1 { + buf.WriteByte(',') + } + } + if length != 1 { // Missing final null byte or inaccurate length + return "" + } + + buf.WriteByte(']') + return buf.String() +} + +// Values returns this array as a slice of values. The returned slice will contain valid values. +// If the array is not valid, the values up to the invalid point will be returned along with an +// error. +func (a Array) Values() ([]Value, error) { + return values(a) +} + +// Validate validates the array and ensures the elements contained within are valid. +func (a Array) Validate() error { + length, rem, ok := ReadLength(a) + if !ok { + return NewInsufficientBytesError(a, rem) + } + if int(length) > len(a) { + return NewArrayLengthError(int(length), len(a)) + } + if a[length-1] != 0x00 { + return ErrMissingNull + } + + length -= 4 + var elem Element + + var keyNum int64 + for length > 1 { + elem, rem, ok = ReadElement(rem) + length -= int32(len(elem)) + if !ok { + return NewInsufficientBytesError(a, rem) + } + + // validate element + err := elem.Validate() + if err != nil { + return err + } + + // validate keys increase numerically + if fmt.Sprint(keyNum) != elem.Key() { + return fmt.Errorf("array key %q is out of order or invalid", elem.Key()) + } + keyNum++ + } + + if len(rem) < 1 || rem[0] != 0x00 { + return ErrMissingNull + } + return nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/bson_arraybuilder.go b/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/bson_arraybuilder.go new file mode 100644 index 00000000..7e6937d8 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/bson_arraybuilder.go @@ -0,0 +1,201 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncore + +import ( + "strconv" + + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +// ArrayBuilder builds a bson array +type ArrayBuilder struct { + arr []byte + indexes []int32 + keys []int +} + +// NewArrayBuilder creates a new ArrayBuilder +func NewArrayBuilder() *ArrayBuilder { + return (&ArrayBuilder{}).startArray() +} + +// startArray reserves the array's length and sets the index to where the length begins +func (a *ArrayBuilder) startArray() *ArrayBuilder { + var index int32 + index, a.arr = AppendArrayStart(a.arr) + a.indexes = append(a.indexes, index) + a.keys = append(a.keys, 0) + return a +} + +// Build updates the length of the array and index to the beginning of the documents length +// bytes, then returns the array (bson bytes) +func (a *ArrayBuilder) Build() Array { + lastIndex := len(a.indexes) - 1 + lastKey := len(a.keys) - 1 + a.arr, _ = AppendArrayEnd(a.arr, a.indexes[lastIndex]) + a.indexes = a.indexes[:lastIndex] + a.keys = a.keys[:lastKey] + return a.arr +} + +// incrementKey() increments the value keys and returns the key to be used to a.appendArray* functions +func (a *ArrayBuilder) incrementKey() string { + idx := len(a.keys) - 1 + key := strconv.Itoa(a.keys[idx]) + a.keys[idx]++ + return key +} + +// AppendInt32 will append i32 to ArrayBuilder.arr +func (a *ArrayBuilder) AppendInt32(i32 int32) *ArrayBuilder { + a.arr = AppendInt32Element(a.arr, a.incrementKey(), i32) + return a +} + +// AppendDocument will append doc to ArrayBuilder.arr +func (a *ArrayBuilder) AppendDocument(doc []byte) *ArrayBuilder { + a.arr = AppendDocumentElement(a.arr, a.incrementKey(), doc) + return a +} + +// AppendArray will append arr to ArrayBuilder.arr +func (a *ArrayBuilder) AppendArray(arr []byte) *ArrayBuilder { + a.arr = AppendArrayElement(a.arr, a.incrementKey(), arr) + return a +} + +// AppendDouble will append f to ArrayBuilder.doc +func (a *ArrayBuilder) AppendDouble(f float64) *ArrayBuilder { + a.arr = AppendDoubleElement(a.arr, a.incrementKey(), f) + return a +} + +// AppendString will append str to ArrayBuilder.doc +func (a *ArrayBuilder) AppendString(str string) *ArrayBuilder { + a.arr = AppendStringElement(a.arr, a.incrementKey(), str) + return a +} + +// AppendObjectID will append oid to ArrayBuilder.doc +func (a *ArrayBuilder) AppendObjectID(oid primitive.ObjectID) *ArrayBuilder { + a.arr = AppendObjectIDElement(a.arr, a.incrementKey(), oid) + return a +} + +// AppendBinary will append a BSON binary element using subtype, and +// b to a.arr +func (a *ArrayBuilder) AppendBinary(subtype byte, b []byte) *ArrayBuilder { + a.arr = AppendBinaryElement(a.arr, a.incrementKey(), subtype, b) + return a +} + +// AppendUndefined will append a BSON undefined element using key to a.arr +func (a *ArrayBuilder) AppendUndefined() *ArrayBuilder { + a.arr = AppendUndefinedElement(a.arr, a.incrementKey()) + return a +} + +// AppendBoolean will append a boolean element using b to a.arr +func (a *ArrayBuilder) AppendBoolean(b bool) *ArrayBuilder { + a.arr = AppendBooleanElement(a.arr, a.incrementKey(), b) + return a +} + +// AppendDateTime will append datetime element dt to a.arr +func (a *ArrayBuilder) AppendDateTime(dt int64) *ArrayBuilder { + a.arr = AppendDateTimeElement(a.arr, a.incrementKey(), dt) + return a +} + +// AppendNull will append a null element to a.arr +func (a *ArrayBuilder) AppendNull() *ArrayBuilder { + a.arr = AppendNullElement(a.arr, a.incrementKey()) + return a +} + +// AppendRegex will append pattern and options to a.arr +func (a *ArrayBuilder) AppendRegex(pattern, options string) *ArrayBuilder { + a.arr = AppendRegexElement(a.arr, a.incrementKey(), pattern, options) + return a +} + +// AppendDBPointer will append ns and oid to a.arr +func (a *ArrayBuilder) AppendDBPointer(ns string, oid primitive.ObjectID) *ArrayBuilder { + a.arr = AppendDBPointerElement(a.arr, a.incrementKey(), ns, oid) + return a +} + +// AppendJavaScript will append js to a.arr +func (a *ArrayBuilder) AppendJavaScript(js string) *ArrayBuilder { + a.arr = AppendJavaScriptElement(a.arr, a.incrementKey(), js) + return a +} + +// AppendSymbol will append symbol to a.arr +func (a *ArrayBuilder) AppendSymbol(symbol string) *ArrayBuilder { + a.arr = AppendSymbolElement(a.arr, a.incrementKey(), symbol) + return a +} + +// AppendCodeWithScope will append code and scope to a.arr +func (a *ArrayBuilder) AppendCodeWithScope(code string, scope Document) *ArrayBuilder { + a.arr = AppendCodeWithScopeElement(a.arr, a.incrementKey(), code, scope) + return a +} + +// AppendTimestamp will append t and i to a.arr +func (a *ArrayBuilder) AppendTimestamp(t, i uint32) *ArrayBuilder { + a.arr = AppendTimestampElement(a.arr, a.incrementKey(), t, i) + return a +} + +// AppendInt64 will append i64 to a.arr +func (a *ArrayBuilder) AppendInt64(i64 int64) *ArrayBuilder { + a.arr = AppendInt64Element(a.arr, a.incrementKey(), i64) + return a +} + +// AppendDecimal128 will append d128 to a.arr +func (a *ArrayBuilder) AppendDecimal128(d128 primitive.Decimal128) *ArrayBuilder { + a.arr = AppendDecimal128Element(a.arr, a.incrementKey(), d128) + return a +} + +// AppendMaxKey will append a max key element to a.arr +func (a *ArrayBuilder) AppendMaxKey() *ArrayBuilder { + a.arr = AppendMaxKeyElement(a.arr, a.incrementKey()) + return a +} + +// AppendMinKey will append a min key element to a.arr +func (a *ArrayBuilder) AppendMinKey() *ArrayBuilder { + a.arr = AppendMinKeyElement(a.arr, a.incrementKey()) + return a +} + +// AppendValue appends a BSON value to the array. +func (a *ArrayBuilder) AppendValue(val Value) *ArrayBuilder { + a.arr = AppendValueElement(a.arr, a.incrementKey(), val) + return a +} + +// StartArray starts building an inline Array. After this document is completed, +// the user must call a.FinishArray +func (a *ArrayBuilder) StartArray() *ArrayBuilder { + a.arr = AppendHeader(a.arr, bsontype.Array, a.incrementKey()) + a.startArray() + return a +} + +// FinishArray builds the most recent array created +func (a *ArrayBuilder) FinishArray() *ArrayBuilder { + a.arr = a.Build() + return a +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/bson_documentbuilder.go b/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/bson_documentbuilder.go new file mode 100644 index 00000000..b0d45212 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/bson_documentbuilder.go @@ -0,0 +1,189 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncore + +import ( + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +// DocumentBuilder builds a bson document +type DocumentBuilder struct { + doc []byte + indexes []int32 +} + +// startDocument reserves the document's length and set the index to where the length begins +func (db *DocumentBuilder) startDocument() *DocumentBuilder { + var index int32 + index, db.doc = AppendDocumentStart(db.doc) + db.indexes = append(db.indexes, index) + return db +} + +// NewDocumentBuilder creates a new DocumentBuilder +func NewDocumentBuilder() *DocumentBuilder { + return (&DocumentBuilder{}).startDocument() +} + +// Build updates the length of the document and index to the beginning of the documents length +// bytes, then returns the document (bson bytes) +func (db *DocumentBuilder) Build() Document { + last := len(db.indexes) - 1 + db.doc, _ = AppendDocumentEnd(db.doc, db.indexes[last]) + db.indexes = db.indexes[:last] + return db.doc +} + +// AppendInt32 will append an int32 element using key and i32 to DocumentBuilder.doc +func (db *DocumentBuilder) AppendInt32(key string, i32 int32) *DocumentBuilder { + db.doc = AppendInt32Element(db.doc, key, i32) + return db +} + +// AppendDocument will append a bson embeded document element using key +// and doc to DocumentBuilder.doc +func (db *DocumentBuilder) AppendDocument(key string, doc []byte) *DocumentBuilder { + db.doc = AppendDocumentElement(db.doc, key, doc) + return db +} + +// AppendArray will append a bson array using key and arr to DocumentBuilder.doc +func (db *DocumentBuilder) AppendArray(key string, arr []byte) *DocumentBuilder { + db.doc = AppendHeader(db.doc, bsontype.Array, key) + db.doc = AppendArray(db.doc, arr) + return db +} + +// AppendDouble will append a double element using key and f to DocumentBuilder.doc +func (db *DocumentBuilder) AppendDouble(key string, f float64) *DocumentBuilder { + db.doc = AppendDoubleElement(db.doc, key, f) + return db +} + +// AppendString will append str to DocumentBuilder.doc with the given key +func (db *DocumentBuilder) AppendString(key string, str string) *DocumentBuilder { + db.doc = AppendStringElement(db.doc, key, str) + return db +} + +// AppendObjectID will append oid to DocumentBuilder.doc with the given key +func (db *DocumentBuilder) AppendObjectID(key string, oid primitive.ObjectID) *DocumentBuilder { + db.doc = AppendObjectIDElement(db.doc, key, oid) + return db +} + +// AppendBinary will append a BSON binary element using key, subtype, and +// b to db.doc +func (db *DocumentBuilder) AppendBinary(key string, subtype byte, b []byte) *DocumentBuilder { + db.doc = AppendBinaryElement(db.doc, key, subtype, b) + return db +} + +// AppendUndefined will append a BSON undefined element using key to db.doc +func (db *DocumentBuilder) AppendUndefined(key string) *DocumentBuilder { + db.doc = AppendUndefinedElement(db.doc, key) + return db +} + +// AppendBoolean will append a boolean element using key and b to db.doc +func (db *DocumentBuilder) AppendBoolean(key string, b bool) *DocumentBuilder { + db.doc = AppendBooleanElement(db.doc, key, b) + return db +} + +// AppendDateTime will append a datetime element using key and dt to db.doc +func (db *DocumentBuilder) AppendDateTime(key string, dt int64) *DocumentBuilder { + db.doc = AppendDateTimeElement(db.doc, key, dt) + return db +} + +// AppendNull will append a null element using key to db.doc +func (db *DocumentBuilder) AppendNull(key string) *DocumentBuilder { + db.doc = AppendNullElement(db.doc, key) + return db +} + +// AppendRegex will append pattern and options using key to db.doc +func (db *DocumentBuilder) AppendRegex(key, pattern, options string) *DocumentBuilder { + db.doc = AppendRegexElement(db.doc, key, pattern, options) + return db +} + +// AppendDBPointer will append ns and oid to using key to db.doc +func (db *DocumentBuilder) AppendDBPointer(key string, ns string, oid primitive.ObjectID) *DocumentBuilder { + db.doc = AppendDBPointerElement(db.doc, key, ns, oid) + return db +} + +// AppendJavaScript will append js using the provided key to db.doc +func (db *DocumentBuilder) AppendJavaScript(key, js string) *DocumentBuilder { + db.doc = AppendJavaScriptElement(db.doc, key, js) + return db +} + +// AppendSymbol will append a BSON symbol element using key and symbol db.doc +func (db *DocumentBuilder) AppendSymbol(key, symbol string) *DocumentBuilder { + db.doc = AppendSymbolElement(db.doc, key, symbol) + return db +} + +// AppendCodeWithScope will append code and scope using key to db.doc +func (db *DocumentBuilder) AppendCodeWithScope(key string, code string, scope Document) *DocumentBuilder { + db.doc = AppendCodeWithScopeElement(db.doc, key, code, scope) + return db +} + +// AppendTimestamp will append t and i to db.doc using provided key +func (db *DocumentBuilder) AppendTimestamp(key string, t, i uint32) *DocumentBuilder { + db.doc = AppendTimestampElement(db.doc, key, t, i) + return db +} + +// AppendInt64 will append i64 to dst using key to db.doc +func (db *DocumentBuilder) AppendInt64(key string, i64 int64) *DocumentBuilder { + db.doc = AppendInt64Element(db.doc, key, i64) + return db +} + +// AppendDecimal128 will append d128 to db.doc using provided key +func (db *DocumentBuilder) AppendDecimal128(key string, d128 primitive.Decimal128) *DocumentBuilder { + db.doc = AppendDecimal128Element(db.doc, key, d128) + return db +} + +// AppendMaxKey will append a max key element using key to db.doc +func (db *DocumentBuilder) AppendMaxKey(key string) *DocumentBuilder { + db.doc = AppendMaxKeyElement(db.doc, key) + return db +} + +// AppendMinKey will append a min key element using key to db.doc +func (db *DocumentBuilder) AppendMinKey(key string) *DocumentBuilder { + db.doc = AppendMinKeyElement(db.doc, key) + return db +} + +// AppendValue will append a BSON element with the provided key and value to the document. +func (db *DocumentBuilder) AppendValue(key string, val Value) *DocumentBuilder { + db.doc = AppendValueElement(db.doc, key, val) + return db +} + +// StartDocument starts building an inline document element with the provided key +// After this document is completed, the user must call finishDocument +func (db *DocumentBuilder) StartDocument(key string) *DocumentBuilder { + db.doc = AppendHeader(db.doc, bsontype.EmbeddedDocument, key) + db = db.startDocument() + return db +} + +// FinishDocument builds the most recent document created +func (db *DocumentBuilder) FinishDocument() *DocumentBuilder { + db.doc = db.Build() + return db +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/document.go b/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/document.go index d397cde2..b77c593e 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/document.go +++ b/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/document.go @@ -17,17 +17,20 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" ) -// DocumentValidationError is an error type returned when attempting to validate a document. -type DocumentValidationError string +// ValidationError is an error type returned when attempting to validate a document or array. +type ValidationError string -func (dve DocumentValidationError) Error() string { return string(dve) } +func (ve ValidationError) Error() string { return string(ve) } // NewDocumentLengthError creates and returns an error for when the length of a document exceeds the // bytes available. func NewDocumentLengthError(length, rem int) error { - return DocumentValidationError( - fmt.Sprintf("document length exceeds available bytes. length=%d remainingBytes=%d", length, rem), - ) + return lengthError("document", length, rem) +} + +func lengthError(bufferType string, length, rem int) error { + return ValidationError(fmt.Sprintf("%v length exceeds available bytes. length=%d remainingBytes=%d", + bufferType, length, rem)) } // InsufficientBytesError indicates that there were not enough bytes to read the next component. @@ -94,15 +97,16 @@ func (idte InvalidDepthTraversalError) Error() string { ) } -// ErrMissingNull is returned when a document's last byte is not null. -const ErrMissingNull DocumentValidationError = "document end is missing null byte" +// ErrMissingNull is returned when a document or array's last byte is not null. +const ErrMissingNull ValidationError = "document or array end is missing null byte" + +// ErrInvalidLength indicates that a length in a binary representation of a BSON document or array +// is invalid. +const ErrInvalidLength ValidationError = "document or array length is invalid" // ErrNilReader indicates that an operation was attempted on a nil io.Reader. var ErrNilReader = errors.New("nil reader") -// ErrInvalidLength indicates that a length in a binary representation of a BSON document is invalid. -var ErrInvalidLength = errors.New("document length is invalid") - // ErrEmptyKey indicates that no key was provided to a Lookup method. var ErrEmptyKey = errors.New("empty key provided") @@ -115,12 +119,13 @@ var ErrOutOfBounds = errors.New("out of bounds") // Document is a raw bytes representation of a BSON document. type Document []byte -// Array is a raw bytes representation of a BSON array. -type Array = Document - // NewDocumentFromReader reads a document from r. This function will only validate the length is // correct and that the document ends with a null byte. func NewDocumentFromReader(r io.Reader) (Document, error) { + return newBufferFromReader(r) +} + +func newBufferFromReader(r io.Reader) ([]byte, error) { if r == nil { return nil, ErrNilReader } @@ -137,20 +142,20 @@ func NewDocumentFromReader(r io.Reader) (Document, error) { if length < 0 { return nil, ErrInvalidLength } - document := make([]byte, length) + buffer := make([]byte, length) - copy(document, lengthBytes[:]) + copy(buffer, lengthBytes[:]) - _, err = io.ReadFull(r, document[4:]) + _, err = io.ReadFull(r, buffer[4:]) if err != nil { return nil, err } - if document[length-1] != 0x00 { + if buffer[length-1] != 0x00 { return nil, ErrMissingNull } - return document, nil + return buffer, nil } // Lookup searches the document, potentially recursively, for the given key. If there are multiple @@ -221,9 +226,13 @@ func (d Document) Index(index uint) Element { // IndexErr searches for and retrieves the element at the given index. func (d Document) IndexErr(index uint) (Element, error) { - length, rem, ok := ReadLength(d) + return indexErr(d, index) +} + +func indexErr(b []byte, index uint) (Element, error) { + length, rem, ok := ReadLength(b) if !ok { - return nil, NewInsufficientBytesError(d, rem) + return nil, NewInsufficientBytesError(b, rem) } length -= 4 @@ -234,7 +243,7 @@ func (d Document) IndexErr(index uint) (Element, error) { elem, rem, ok = ReadElement(rem) length -= int32(len(elem)) if !ok { - return nil, NewInsufficientBytesError(d, rem) + return nil, NewInsufficientBytesError(b, rem) } if current != index { current++ @@ -338,9 +347,13 @@ func (d Document) Elements() ([]Element, error) { // If the document is not valid, the values up to the invalid point will be returned along with an // error. func (d Document) Values() ([]Value, error) { - length, rem, ok := ReadLength(d) + return values(d) +} + +func values(b []byte) ([]Value, error) { + length, rem, ok := ReadLength(b) if !ok { - return nil, NewInsufficientBytesError(d, rem) + return nil, NewInsufficientBytesError(b, rem) } length -= 4 @@ -351,7 +364,7 @@ func (d Document) Values() ([]Value, error) { elem, rem, ok = ReadElement(rem) length -= int32(len(elem)) if !ok { - return vals, NewInsufficientBytesError(d, rem) + return vals, NewInsufficientBytesError(b, rem) } if err := elem.Value().Validate(); err != nil { return vals, err @@ -368,7 +381,7 @@ func (d Document) Validate() error { return NewInsufficientBytesError(d, rem) } if int(length) > len(d) { - return d.lengtherror(int(length), len(d)) + return NewDocumentLengthError(int(length), len(d)) } if d[length-1] != 0x00 { return ErrMissingNull @@ -394,7 +407,3 @@ func (d Document) Validate() error { } return nil } - -func (Document) lengtherror(length, rem int) error { - return DocumentValidationError(fmt.Sprintf("document length exceeds available bytes. length=%d remainingBytes=%d", length, rem)) -} diff --git a/vendor/go.mongodb.org/mongo-driver/x/bsonx/reflectionfree_d_codec.go b/vendor/go.mongodb.org/mongo-driver/x/bsonx/reflectionfree_d_codec.go new file mode 100644 index 00000000..9df93ad4 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/x/bsonx/reflectionfree_d_codec.go @@ -0,0 +1,1026 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonx + +import ( + "fmt" + "math" + "reflect" + "time" + + "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +var ( + tPrimitiveD = reflect.TypeOf(primitive.D{}) + tPrimitiveA = reflect.TypeOf(primitive.A{}) + tPrimitiveCWS = reflect.TypeOf(primitive.CodeWithScope{}) + defaultValueEncoders = bsoncodec.DefaultValueEncoders{} + defaultValueDecoders = bsoncodec.DefaultValueDecoders{} +) + +type reflectionFreeDCodec struct{} + +// ReflectionFreeDCodec is a ValueEncoder for the primitive.D type that does not use reflection. +var ReflectionFreeDCodec bsoncodec.ValueCodec = &reflectionFreeDCodec{} + +func (r *reflectionFreeDCodec) EncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tPrimitiveD { + return bsoncodec.ValueEncoderError{Name: "DEncodeValue", Types: []reflect.Type{tPrimitiveD}, Received: val} + } + + if val.IsNil() { + return vw.WriteNull() + } + + doc := val.Interface().(primitive.D) + return r.encodeDocument(ec, vw, doc) +} + +func (r *reflectionFreeDCodec) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.IsValid() || !val.CanSet() || val.Type() != tPrimitiveD { + return bsoncodec.ValueDecoderError{Name: "DDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} + } + + switch vrType := vr.Type(); vrType { + case bsontype.Type(0), bsontype.EmbeddedDocument: + case bsontype.Null: + val.Set(reflect.Zero(val.Type())) + return vr.ReadNull() + default: + return fmt.Errorf("cannot decode %v into a primitive.D", vrType) + } + + doc, err := r.decodeDocument(dc, vr) + if err != nil { + return err + } + + val.Set(reflect.ValueOf(doc)) + return nil +} + +func (r *reflectionFreeDCodec) decodeDocument(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader) (primitive.D, error) { + dr, err := vr.ReadDocument() + if err != nil { + return nil, err + } + + doc := primitive.D{} + for { + key, elemVr, err := dr.ReadElement() + if err == bsonrw.ErrEOD { + break + } + if err != nil { + return nil, err + } + + val, err := r.decodeValue(dc, elemVr) + if err != nil { + return nil, err + } + doc = append(doc, primitive.E{Key: key, Value: val}) + } + + return doc, nil +} + +func (r *reflectionFreeDCodec) decodeArray(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader) (primitive.A, error) { + ar, err := vr.ReadArray() + if err != nil { + return nil, err + } + + array := primitive.A{} + for { + arrayValReader, err := ar.ReadValue() + if err == bsonrw.ErrEOA { + break + } + if err != nil { + return nil, err + } + + val, err := r.decodeValue(dc, arrayValReader) + if err != nil { + return nil, err + } + array = append(array, val) + } + + return array, nil +} + +func (r *reflectionFreeDCodec) decodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader) (interface{}, error) { + switch vrType := vr.Type(); vrType { + case bsontype.Null: + return nil, vr.ReadNull() + case bsontype.Double: + return vr.ReadDouble() + case bsontype.String: + return vr.ReadString() + case bsontype.Binary: + data, subtype, err := vr.ReadBinary() + if err != nil { + return nil, err + } + + return primitive.Binary{ + Data: data, + Subtype: subtype, + }, nil + case bsontype.Undefined: + return primitive.Undefined{}, vr.ReadUndefined() + case bsontype.ObjectID: + return vr.ReadObjectID() + case bsontype.Boolean: + return vr.ReadBoolean() + case bsontype.DateTime: + dt, err := vr.ReadDateTime() + if err != nil { + return nil, err + } + + return primitive.DateTime(dt), nil + case bsontype.Regex: + pattern, options, err := vr.ReadRegex() + if err != nil { + return nil, err + } + + return primitive.Regex{ + Pattern: pattern, + Options: options, + }, nil + case bsontype.DBPointer: + ns, oid, err := vr.ReadDBPointer() + if err != nil { + return nil, err + } + + return primitive.DBPointer{ + DB: ns, + Pointer: oid, + }, nil + case bsontype.JavaScript: + js, err := vr.ReadJavascript() + if err != nil { + return nil, err + } + + return primitive.JavaScript(js), nil + case bsontype.Symbol: + sym, err := vr.ReadSymbol() + if err != nil { + return nil, err + } + + return primitive.Symbol(sym), nil + case bsontype.CodeWithScope: + cws := reflect.New(tPrimitiveCWS).Elem() + err := defaultValueDecoders.CodeWithScopeDecodeValue(dc, vr, cws) + if err != nil { + return nil, err + } + + return cws.Interface().(primitive.CodeWithScope), nil + case bsontype.Int32: + return vr.ReadInt32() + case bsontype.Int64: + return vr.ReadInt64() + case bsontype.Timestamp: + t, i, err := vr.ReadTimestamp() + if err != nil { + return nil, err + } + + return primitive.Timestamp{ + T: t, + I: i, + }, nil + case bsontype.Decimal128: + return vr.ReadDecimal128() + case bsontype.MinKey: + return primitive.MinKey{}, vr.ReadMinKey() + case bsontype.MaxKey: + return primitive.MaxKey{}, vr.ReadMaxKey() + case bsontype.Type(0), bsontype.EmbeddedDocument: + return r.decodeDocument(dc, vr) + case bsontype.Array: + return r.decodeArray(dc, vr) + default: + return nil, fmt.Errorf("cannot decode invalid BSON type %s", vrType) + } +} + +func (r *reflectionFreeDCodec) encodeDocumentValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, v interface{}) error { + switch val := v.(type) { + case int: + return r.encodeInt(ec, vw, val) + case int8: + return vw.WriteInt32(int32(val)) + case int16: + return vw.WriteInt32(int32(val)) + case int32: + return vw.WriteInt32(int32(val)) + case int64: + return r.encodeInt64(ec, vw, val) + case uint: + return r.encodeUint64(ec, vw, uint64(val)) + case uint8: + return vw.WriteInt32(int32(val)) + case uint16: + return vw.WriteInt32(int32(val)) + case uint32: + return r.encodeUint64(ec, vw, uint64(val)) + case uint64: + return r.encodeUint64(ec, vw, val) + case float32: + return vw.WriteDouble(float64(val)) + case float64: + return vw.WriteDouble(val) + case []byte: + return vw.WriteBinary(val) + case primitive.Binary: + return vw.WriteBinaryWithSubtype(val.Data, val.Subtype) + case bool: + return vw.WriteBoolean(val) + case primitive.CodeWithScope: + return defaultValueEncoders.CodeWithScopeEncodeValue(ec, vw, reflect.ValueOf(val)) + case primitive.DBPointer: + return vw.WriteDBPointer(val.DB, val.Pointer) + case primitive.DateTime: + return vw.WriteDateTime(int64(val)) + case time.Time: + dt := primitive.NewDateTimeFromTime(val) + return vw.WriteDateTime(int64(dt)) + case primitive.Decimal128: + return vw.WriteDecimal128(val) + case primitive.JavaScript: + return vw.WriteJavascript(string(val)) + case primitive.MinKey: + return vw.WriteMinKey() + case primitive.MaxKey: + return vw.WriteMaxKey() + case primitive.Null, nil: + return vw.WriteNull() + case primitive.ObjectID: + return vw.WriteObjectID(val) + case primitive.Regex: + return vw.WriteRegex(val.Pattern, val.Options) + case string: + return vw.WriteString(val) + case primitive.Symbol: + return vw.WriteSymbol(string(val)) + case primitive.Timestamp: + return vw.WriteTimestamp(val.T, val.I) + case primitive.Undefined: + return vw.WriteUndefined() + case primitive.D: + return r.encodeDocument(ec, vw, val) + case primitive.A: + return r.encodePrimitiveA(ec, vw, val) + case []interface{}: + return r.encodePrimitiveA(ec, vw, val) + case []primitive.D: + return r.encodeSliceD(ec, vw, val) + case []int: + return r.encodeSliceInt(ec, vw, val) + case []int8: + return r.encodeSliceInt8(ec, vw, val) + case []int16: + return r.encodeSliceInt16(ec, vw, val) + case []int32: + return r.encodeSliceInt32(ec, vw, val) + case []int64: + return r.encodeSliceInt64(ec, vw, val) + case []uint: + return r.encodeSliceUint(ec, vw, val) + case []uint16: + return r.encodeSliceUint16(ec, vw, val) + case []uint32: + return r.encodeSliceUint32(ec, vw, val) + case []uint64: + return r.encodeSliceUint64(ec, vw, val) + case [][]byte: + return r.encodeSliceByteSlice(ec, vw, val) + case []primitive.Binary: + return r.encodeSliceBinary(ec, vw, val) + case []bool: + return r.encodeSliceBoolean(ec, vw, val) + case []primitive.CodeWithScope: + return r.encodeSliceCWS(ec, vw, val) + case []primitive.DBPointer: + return r.encodeSliceDBPointer(ec, vw, val) + case []primitive.DateTime: + return r.encodeSliceDateTime(ec, vw, val) + case []time.Time: + return r.encodeSliceTimeTime(ec, vw, val) + case []primitive.Decimal128: + return r.encodeSliceDecimal128(ec, vw, val) + case []float32: + return r.encodeSliceFloat32(ec, vw, val) + case []float64: + return r.encodeSliceFloat64(ec, vw, val) + case []primitive.JavaScript: + return r.encodeSliceJavaScript(ec, vw, val) + case []primitive.MinKey: + return r.encodeSliceMinKey(ec, vw, val) + case []primitive.MaxKey: + return r.encodeSliceMaxKey(ec, vw, val) + case []primitive.Null: + return r.encodeSliceNull(ec, vw, val) + case []primitive.ObjectID: + return r.encodeSliceObjectID(ec, vw, val) + case []primitive.Regex: + return r.encodeSliceRegex(ec, vw, val) + case []string: + return r.encodeSliceString(ec, vw, val) + case []primitive.Symbol: + return r.encodeSliceSymbol(ec, vw, val) + case []primitive.Timestamp: + return r.encodeSliceTimestamp(ec, vw, val) + case []primitive.Undefined: + return r.encodeSliceUndefined(ec, vw, val) + default: + return fmt.Errorf("value of type %T not supported", v) + } +} + +func (r *reflectionFreeDCodec) encodeInt(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val int) error { + if fitsIn32Bits(int64(val)) { + return vw.WriteInt32(int32(val)) + } + return vw.WriteInt64(int64(val)) +} + +func (r *reflectionFreeDCodec) encodeInt64(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val int64) error { + if ec.MinSize && fitsIn32Bits(val) { + return vw.WriteInt32(int32(val)) + } + return vw.WriteInt64(int64(val)) +} + +func (r *reflectionFreeDCodec) encodeUint64(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val uint64) error { + if ec.MinSize && val <= math.MaxInt32 { + return vw.WriteInt32(int32(val)) + } + if val > math.MaxInt64 { + return fmt.Errorf("%d overflows int64", val) + } + + return vw.WriteInt64(int64(val)) +} + +func (r *reflectionFreeDCodec) encodeDocument(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, doc primitive.D) error { + dw, err := vw.WriteDocument() + if err != nil { + return err + } + + for _, elem := range doc { + docValWriter, err := dw.WriteDocumentElement(elem.Key) + if err != nil { + return err + } + + if err := r.encodeDocumentValue(ec, docValWriter, elem.Value); err != nil { + return err + } + } + + return dw.WriteDocumentEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceByteSlice(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr [][]byte) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteBinary(val); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceBinary(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.Binary) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteBinaryWithSubtype(val.Data, val.Subtype); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceBoolean(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []bool) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteBoolean(val); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceCWS(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.CodeWithScope) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := defaultValueEncoders.CodeWithScopeEncodeValue(ec, arrayValWriter, reflect.ValueOf(val)); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceDBPointer(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.DBPointer) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteDBPointer(val.DB, val.Pointer); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceDateTime(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.DateTime) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteDateTime(int64(val)); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceTimeTime(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []time.Time) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + dt := primitive.NewDateTimeFromTime(val) + if err := arrayValWriter.WriteDateTime(int64(dt)); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceDecimal128(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.Decimal128) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteDecimal128(val); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceFloat32(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []float32) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteDouble(float64(val)); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceFloat64(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []float64) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteDouble(val); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceJavaScript(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.JavaScript) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteJavascript(string(val)); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceMinKey(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.MinKey) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteMinKey(); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceMaxKey(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.MaxKey) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteMaxKey(); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceNull(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.Null) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteNull(); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceObjectID(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.ObjectID) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteObjectID(val); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceRegex(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.Regex) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteRegex(val.Pattern, val.Options); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceString(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []string) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteString(val); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceSymbol(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.Symbol) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteSymbol(string(val)); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceTimestamp(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.Timestamp) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteTimestamp(val.T, val.I); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceUndefined(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.Undefined) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteUndefined(); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodePrimitiveA(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr primitive.A) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := r.encodeDocumentValue(ec, arrayValWriter, val); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceD(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []primitive.D) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := r.encodeDocument(ec, arrayValWriter, val); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceInt(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []int) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := r.encodeInt(ec, arrayValWriter, val); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceInt8(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []int8) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteInt32(int32(val)); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceInt16(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []int16) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteInt32(int32(val)); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceInt32(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []int32) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteInt32(val); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceInt64(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []int64) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := r.encodeInt64(ec, arrayValWriter, val); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceUint(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []uint) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := r.encodeUint64(ec, arrayValWriter, uint64(val)); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceUint16(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []uint16) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := arrayValWriter.WriteInt32(int32(val)); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceUint32(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []uint32) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := r.encodeUint64(ec, arrayValWriter, uint64(val)); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func (r *reflectionFreeDCodec) encodeSliceUint64(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, arr []uint64) error { + aw, err := vw.WriteArray() + if err != nil { + return err + } + + for _, val := range arr { + arrayValWriter, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if err := r.encodeUint64(ec, arrayValWriter, val); err != nil { + return err + } + } + + return aw.WriteArrayEnd() +} + +func fitsIn32Bits(i int64) bool { + return math.MinInt32 <= i && i <= math.MaxInt32 +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/auth.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/auth.go index a8989b7b..adf50a37 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/auth.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/auth.go @@ -11,9 +11,10 @@ import ( "errors" "fmt" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/address" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -64,13 +65,17 @@ type authHandshaker struct { wrapped driver.Handshaker options *HandshakeOptions - conversation SpeculativeConversation + handshakeInfo driver.HandshakeInformation + conversation SpeculativeConversation } -// GetDescription performs an isMaster to retrieve the initial description for conn. -func (ah *authHandshaker) GetDescription(ctx context.Context, addr address.Address, conn driver.Connection) (description.Server, error) { +var _ driver.Handshaker = (*authHandshaker)(nil) + +// GetHandshakeInformation performs the initial MongoDB handshake to retrieve the required information for the provided +// connection. +func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr address.Address, conn driver.Connection) (driver.HandshakeInformation, error) { if ah.wrapped != nil { - return ah.wrapped.GetDescription(ctx, addr, conn) + return ah.wrapped.GetHandshakeInformation(ctx, addr, conn) } op := operation.NewIsMaster(). @@ -84,23 +89,24 @@ func (ah *authHandshaker) GetDescription(ctx context.Context, addr address.Addre var err error ah.conversation, err = speculativeAuth.CreateSpeculativeConversation() if err != nil { - return description.Server{}, newAuthError("failed to create conversation", err) + return driver.HandshakeInformation{}, newAuthError("failed to create conversation", err) } firstMsg, err := ah.conversation.FirstMessage() if err != nil { - return description.Server{}, newAuthError("failed to create speculative authentication message", err) + return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err) } op = op.SpeculativeAuthenticate(firstMsg) } } - desc, err := op.GetDescription(ctx, addr, conn) + var err error + ah.handshakeInfo, err = op.GetHandshakeInformation(ctx, addr, conn) if err != nil { - return description.Server{}, newAuthError("handshake failure", err) + return driver.HandshakeInformation{}, newAuthError("handshake failure", err) } - return desc, nil + return ah.handshakeInfo, nil } // FinishHandshake performs authentication for conn if necessary. @@ -116,9 +122,10 @@ func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Conne desc := conn.Description() if performAuth(desc) && ah.options.Authenticator != nil { cfg := &Config{ - Description: desc, - Connection: conn, - ClusterClock: ah.options.ClusterClock, + Description: desc, + Connection: conn, + ClusterClock: ah.options.ClusterClock, + HandshakeInfo: ah.handshakeInfo, } if err := ah.authenticate(ctx, cfg); err != nil { @@ -135,12 +142,12 @@ func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Conne func (ah *authHandshaker) authenticate(ctx context.Context, cfg *Config) error { // If the initial isMaster reply included a response to the speculative authentication attempt, we only need to // conduct the remainder of the conversation. - if speculativeResponse := cfg.Description.SpeculativeAuthenticate; speculativeResponse != nil { + if speculativeResponse := ah.handshakeInfo.SpeculativeAuthenticate; speculativeResponse != nil { // Defensively ensure that the server did not include a response if speculative auth was not attempted. if ah.conversation == nil { return errors.New("speculative auth was not attempted but the server included a response") } - return ah.conversation.Finish(ctx, cfg, speculativeResponse) + return ah.conversation.Finish(ctx, cfg, bsoncore.Document(speculativeResponse)) } // If the server does not support speculative authentication or the first attempt was not successful, we need to @@ -158,9 +165,10 @@ func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshake // Config holds the information necessary to perform an authentication attempt. type Config struct { - Description description.Server - Connection driver.Connection - ClusterClock *session.ClusterClock + Description description.Server + Connection driver.Connection + ClusterClock *session.ClusterClock + HandshakeInfo driver.HandshakeInformation } // Authenticator handles authenticating a connection. diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/default.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/default.go index 018f4e5c..4da4032f 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/default.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/default.go @@ -10,7 +10,7 @@ import ( "context" "fmt" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" + "go.mongodb.org/mongo-driver/mongo/description" ) func newDefaultAuthenticator(cred *Cred) (Authenticator, error) { @@ -52,7 +52,7 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error { var actual Authenticator var err error - switch chooseAuthMechanism(cfg.Description) { + switch chooseAuthMechanism(cfg) { case SCRAMSHA256: actual, err = newScramSHA256Authenticator(a.Cred) case SCRAMSHA1: @@ -71,9 +71,9 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error { // If a server provides a list of supported mechanisms, we choose // SCRAM-SHA-256 if it exists or else MUST use SCRAM-SHA-1. // Otherwise, we decide based on what is supported. -func chooseAuthMechanism(desc description.Server) string { - if desc.SaslSupportedMechs != nil { - for _, v := range desc.SaslSupportedMechs { +func chooseAuthMechanism(cfg *Config) string { + if saslSupportedMechs := cfg.HandshakeInfo.SaslSupportedMechs; saslSupportedMechs != nil { + for _, v := range saslSupportedMechs { if v == SCRAMSHA256 { return v } @@ -81,9 +81,18 @@ func chooseAuthMechanism(desc description.Server) string { return SCRAMSHA1 } - if err := description.ScramSHA1Supported(desc.WireVersion); err == nil { + if err := scramSHA1Supported(cfg.HandshakeInfo.Description.WireVersion); err == nil { return SCRAMSHA1 } return MONGODBCR } + +// scramSHA1Supported returns an error if the given server version does not support scram-sha-1. +func scramSHA1Supported(wireVersion *description.VersionRange) error { + if wireVersion != nil && wireVersion.Max < 3 { + return fmt.Errorf("SCRAM-SHA-1 is only supported for servers 3.0 or newer") + } + + return nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/x509.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/x509.go index b955f421..eed517e0 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/x509.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/auth/x509.go @@ -9,9 +9,9 @@ package auth import ( "context" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/batch_cursor.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/batch_cursor.go index 2d7339d4..9966e148 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/batch_cursor.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/batch_cursor.go @@ -8,8 +8,8 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/connstring/connstring.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/connstring/connstring.go index 230aca8e..52f81e03 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/connstring/connstring.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/connstring/connstring.go @@ -115,6 +115,7 @@ type ConnString struct { WNumber int WNumberSet bool Username string + UsernameSet bool ZlibLevel int ZlibLevelSet bool ZstdLevel int @@ -137,7 +138,7 @@ func (u *ConnString) String() string { func (u *ConnString) HasAuthParameters() bool { // Check all auth parameters except for AuthSource because an auth source without other credentials is semantically // valid and must not be interpreted as a request for authentication. - return u.AuthMechanism != "" || u.AuthMechanismProperties != nil || u.Username != "" || u.PasswordSet + return u.AuthMechanism != "" || u.AuthMechanismProperties != nil || u.UsernameSet || u.PasswordSet } // Validate checks that the Auth and SSL parameters are valid values. @@ -224,6 +225,7 @@ func (p *parser) parse(original string) error { if err != nil { return internal.WrapErrorf(err, "invalid username") } + p.UsernameSet = true // Validate and process the password. if strings.Contains(password, ":") { @@ -342,6 +344,12 @@ func (p *parser) validate() error { } func (p *parser) setDefaultAuthParams(dbName string) error { + // We do this check here rather than in validateAuth because this function is called as part of parsing and sets + // the value of AuthSource if authentication is enabled. + if p.AuthSourceSet && p.AuthSource == "" { + return errors.New("authSource must be non-empty when supplied in a URI") + } + switch strings.ToLower(p.AuthMechanism) { case "plain": if p.AuthSource == "" { @@ -466,6 +474,9 @@ func (p *parser) validateAuth() error { return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties") } case "": + if p.UsernameSet && p.Username == "" { + return fmt.Errorf("username required if URI contains user info") + } default: return fmt.Errorf("invalid auth mechanism") } diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/crypt.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/crypt.go index e8c8c546..812db32f 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/crypt.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/crypt.go @@ -38,7 +38,7 @@ type CryptOptions struct { CollInfoFn CollectionInfoFn KeyFn KeyRetrieverFn MarkFn MarkCommandFn - KmsProviders map[string]map[string]interface{} + KmsProviders bsoncore.Document SchemaMap map[string]bsoncore.Document BypassAutoEncryption bool } @@ -62,7 +62,9 @@ func NewCrypt(opts *CryptOptions) (*Crypt, error) { markFn: opts.MarkFn, BypassAutoEncryption: opts.BypassAutoEncryption, } - mc, err := mongocrypt.NewMongoCrypt(createMongoCryptOptions(opts)) + + mongocryptOpts := options.MongoCrypt().SetKmsProviders(opts.KmsProviders).SetLocalSchemaMap(opts.SchemaMap) + mc, err := mongocrypt.NewMongoCrypt(mongocryptOpts) if err != nil { return nil, err } @@ -297,37 +299,3 @@ func (c *Crypt) decryptKey(ctx context.Context, kmsCtx *mongocrypt.KmsContext) e } } } - -func createMongoCryptOptions(opts *CryptOptions) *options.MongoCryptOptions { - mcOpts := options.MongoCrypt().SetLocalSchemaMap(opts.SchemaMap) - // KMS providers options - for provider, providerOpts := range opts.KmsProviders { - switch provider { - case "aws": - awsOpts := options.AwsKmsProvider() - - if accessKey, ok := providerOpts["accessKeyId"]; ok { - if keyStr, ok := accessKey.(string); ok { - awsOpts.SetAccessKeyID(keyStr) - } - } - if secretAccessKey, ok := providerOpts["secretAccessKey"]; ok { - if keyStr, ok := secretAccessKey.(string); ok { - awsOpts.SetSecretAccessKey(keyStr) - } - } - mcOpts.SetAwsProviderOptions(awsOpts) - case "local": - localOpts := options.LocalKmsProvider() - - if key, ok := providerOpts["key"]; ok { - if keyBytes, ok := key.([]byte); ok { - localOpts.SetMasterKey(keyBytes) - } - } - mcOpts.SetLocalProviderOptions(localOpts) - } - } - - return mcOpts -} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/feature.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/feature.go deleted file mode 100644 index f0236c01..00000000 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/feature.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -import ( - "fmt" -) - -// MaxStalenessSupported returns an error if the given server version -// does not support max staleness. -func MaxStalenessSupported(wireVersion *VersionRange) error { - if wireVersion != nil && wireVersion.Max < 5 { - return fmt.Errorf("max staleness is only supported for servers 3.4 or newer") - } - - return nil -} - -// ScramSHA1Supported returns an error if the given server version -// does not support scram-sha-1. -func ScramSHA1Supported(wireVersion *VersionRange) error { - if wireVersion != nil && wireVersion.Max < 3 { - return fmt.Errorf("SCRAM-SHA-1 is only supported for servers 3.0 or newer") - } - - return nil -} - -// SessionsSupported returns true of the given server version indicates that it supports sessions. -func SessionsSupported(wireVersion *VersionRange) bool { - return wireVersion != nil && wireVersion.Max >= 6 -} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/version.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/version.go deleted file mode 100644 index 60cda4eb..00000000 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/version.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -import "strconv" - -// Version represents a software version. -type Version struct { - Desc string - Parts []uint8 -} - -// AtLeast ensures that the version is at least as large as the "other" version. -func (v Version) AtLeast(other ...uint8) bool { - for i := range other { - if i == len(v.Parts) { - return false - } - if v.Parts[i] < other[i] { - return false - } - } - return true -} - -// String provides the string represtation of the Version. -func (v Version) String() string { - if v.Desc == "" { - var s string - for i, p := range v.Parts { - if i != 0 { - s += "." - } - s += strconv.Itoa(int(p)) - } - return s - } - - return v.Desc -} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/driver.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/driver.go index 8acbc1b4..ec01d51c 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/driver.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/driver.go @@ -3,8 +3,9 @@ package driver // import "go.mongodb.org/mongo-driver/x/mongo/driver" import ( "context" - "go.mongodb.org/mongo-driver/x/mongo/driver/address" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) // Deployment is implemented by types that can select a server from a deployment. @@ -87,18 +88,48 @@ type Compressor interface { CompressWireMessage(src, dst []byte) ([]byte, error) } +// ProcessErrorResult represents the result of a ErrorProcessor.ProcessError() call. Exact values for this type can be +// checked directly (e.g. res == ServerMarkedUnknown), but it is recommended that applications use the ServerChanged() +// function instead. +type ProcessErrorResult int + +const ( + // NoChange indicates that the error did not affect the state of the server. + NoChange ProcessErrorResult = iota + // ServerMarkedUnknown indicates that the error only resulted in the server being marked as Unknown. + ServerMarkedUnknown + // ConnectionPoolCleared indicates that the error resulted in the server being marked as Unknown and its connection + // pool being cleared. + ConnectionPoolCleared +) + +// ServerChanged returns true if the ProcessErrorResult indicates that the server changed from an SDAM perspective +// during a ProcessError() call. +func (p ProcessErrorResult) ServerChanged() bool { + return p != NoChange +} + // ErrorProcessor implementations can handle processing errors, which may modify their internal state. // If this type is implemented by a Server, then Operation.Execute will call it's ProcessError // method after it decodes a wire message. type ErrorProcessor interface { - ProcessError(err error, conn Connection) + ProcessError(err error, conn Connection) ProcessErrorResult +} + +// HandshakeInformation contains information extracted from a MongoDB connection handshake. This is a helper type that +// augments description.Server by also tracking authentication-related fields. We use this type rather than adding +// these fields to description.Server to avoid retaining sensitive information in a user-facing type. +type HandshakeInformation struct { + Description description.Server + SpeculativeAuthenticate bsoncore.Document + SaslSupportedMechs []string } // Handshaker is the interface implemented by types that can perform a MongoDB // handshake over a provided driver.Connection. This is used during connection // initialization. Implementations must be goroutine safe. type Handshaker interface { - GetDescription(context.Context, address.Address, Connection) (description.Server, error) + GetHandshakeInformation(context.Context, address.Address, Connection) (HandshakeInformation, error) FinishHandshake(context.Context, Connection) error } @@ -118,8 +149,7 @@ func (SingleServerDeployment) Kind() description.TopologyKind { return descripti // SingleConnectionDeployment is an implementation of Deployment that always returns the same Connection. This // implementation should only be used for connection handshakes and server heartbeats as it does not implement -// ErrorProcessor, which is necessary for application operations and wraps the connection in nopCloserConnection, -// which does not implement Compressor. +// ErrorProcessor, which is necessary for application operations. type SingleConnectionDeployment struct{ C Connection } var _ Deployment = SingleConnectionDeployment{} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/errors.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/errors.go index 096acde6..a9086f2c 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/errors.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/errors.go @@ -7,8 +7,8 @@ import ( "strings" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) var ( @@ -442,7 +442,7 @@ func extractError(rdr bsoncore.Document) error { if !ok { break } - version, err := description.NewTopologyVersion(doc) + version, err := description.NewTopologyVersion(bson.Raw(doc)) if err == nil { tv = version } diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/mongocrypt.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/mongocrypt.go index 4fc1daf1..edfedf1a 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/mongocrypt.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/mongocrypt.go @@ -15,23 +15,13 @@ package mongocrypt // #include import "C" import ( + "errors" "unsafe" - "github.com/pkg/errors" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options" ) -// These constants reprsent valid values for KmsProvider. -const ( - AwsProvider = "aws" - LocalProvider = "local" -) - -// ErrInvalidProvider is returned when an invalid KMS provider is given. -var ErrInvalidProvider = errors.New("invalid KMS provider") - -// MongoCrypt represents a mongocrypt_t handle. type MongoCrypt struct { wrapped *C.mongocrypt_t } @@ -48,10 +38,7 @@ func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) { } // set options in mongocrypt - if err := crypt.setLocalProviderOpts(opts.LocalProviderOpts); err != nil { - return nil, err - } - if err := crypt.setAwsProviderOpts(opts.AwsProviderOpts); err != nil { + if err := crypt.setProviderOptions(opts.KmsProviders); err != nil { return nil, err } if err := crypt.setLocalSchemaMap(opts.LocalSchemaMap); err != nil { @@ -128,33 +115,23 @@ func (m *MongoCrypt) CreateDataKeyContext(kmsProvider string, opts *options.Data return nil, m.createErrorFromStatus() } - var ok bool - switch kmsProvider { - case AwsProvider: - // set region and key (required fields) - region := C.CString(lookupString(opts.MasterKey, "region")) - key := C.CString(lookupString(opts.MasterKey, "key")) - defer C.free(unsafe.Pointer(region)) - defer C.free(unsafe.Pointer(key)) - ok = bool(C.mongocrypt_ctx_setopt_masterkey_aws(ctx.wrapped, region, -1, key, -1)) - if !ok { - break - } - - // set endpoint (not a required field) - endpoint := lookupString(opts.MasterKey, "endpoint") - if endpoint == "" { - break - } - endpointCStr := C.CString(endpoint) - defer C.free(unsafe.Pointer(endpointCStr)) - ok = bool(C.mongocrypt_ctx_setopt_masterkey_aws_endpoint(ctx.wrapped, endpointCStr, -1)) - case LocalProvider: - ok = bool(C.mongocrypt_ctx_setopt_masterkey_local(ctx.wrapped)) + // Create a masterKey document of the form { "provider": , other options... }. + var masterKey bsoncore.Document + switch { + case opts.MasterKey != nil: + // The original key passed into the top-level API was already transformed into a raw BSON document and passed + // down to here, so we can modify it without copying. Remove the terminating byte to add the "provider" field. + masterKey = opts.MasterKey[:len(opts.MasterKey)-1] + masterKey = bsoncore.AppendStringElement(masterKey, "provider", kmsProvider) + masterKey, _ = bsoncore.AppendDocumentEnd(masterKey, 0) default: - return nil, ErrInvalidProvider + masterKey = bsoncore.NewDocumentBuilder().AppendString("provider", kmsProvider).Build() } - if !ok { + + masterKeyBinary := newBinaryFromBytes(masterKey) + defer masterKeyBinary.close() + + if ok := C.mongocrypt_ctx_setopt_key_encryption_key(ctx.wrapped, masterKeyBinary.wrapped); !ok { return nil, ctx.createErrorFromStatus() } @@ -228,34 +205,11 @@ func (m *MongoCrypt) Close() { C.mongocrypt_destroy(m.wrapped) } -// setLocalProviderOpts sets options for the local KMS provider in mongocrypt. -func (m *MongoCrypt) setLocalProviderOpts(opts *options.LocalKmsProviderOptions) error { - if opts == nil { - return nil - } - - keyBinary := newBinaryFromBytes(opts.MasterKey) - defer keyBinary.close() - - if ok := C.mongocrypt_setopt_kms_provider_local(m.wrapped, keyBinary.wrapped); !ok { - return m.createErrorFromStatus() - } - return nil -} - -// setAwsProviderOpts sets options for the AWS KMS provider in mongocrypt. -func (m *MongoCrypt) setAwsProviderOpts(opts *options.AwsKmsProviderOptions) error { - if opts == nil { - return nil - } - - // create C strings for function params - accessKeyID := C.CString(opts.AccessKeyID) - secretAccessKey := C.CString(opts.SecretAccessKey) - defer C.free(unsafe.Pointer(accessKeyID)) - defer C.free(unsafe.Pointer(secretAccessKey)) +func (m *MongoCrypt) setProviderOptions(kmsProviders bsoncore.Document) error { + providersBinary := newBinaryFromBytes(kmsProviders) + defer providersBinary.close() - if ok := C.mongocrypt_setopt_kms_provider_aws(m.wrapped, accessKeyID, C.int32_t(-1), secretAccessKey, C.int32_t(-1)); !ok { + if ok := C.mongocrypt_setopt_kms_providers(m.wrapped, providersBinary.wrapped); !ok { return m.createErrorFromStatus() } return nil diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options/mongocrypt_options.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options/mongocrypt_options.go index 09cae44c..abaf260d 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options/mongocrypt_options.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options/mongocrypt_options.go @@ -12,9 +12,8 @@ import ( // MongoCryptOptions specifies options to configure a MongoCrypt instance. type MongoCryptOptions struct { - AwsProviderOpts *AwsKmsProviderOptions - LocalProviderOpts *LocalKmsProviderOptions - LocalSchemaMap map[string]bsoncore.Document + KmsProviders bsoncore.Document + LocalSchemaMap map[string]bsoncore.Document } // MongoCrypt creates a new MongoCryptOptions instance. @@ -22,15 +21,9 @@ func MongoCrypt() *MongoCryptOptions { return &MongoCryptOptions{} } -// SetAwsProviderOptions specifies AWS KMS provider options. -func (mo *MongoCryptOptions) SetAwsProviderOptions(awsOpts *AwsKmsProviderOptions) *MongoCryptOptions { - mo.AwsProviderOpts = awsOpts - return mo -} - -// SetLocalProviderOptions specifies local KMS provider options. -func (mo *MongoCryptOptions) SetLocalProviderOptions(localOpts *LocalKmsProviderOptions) *MongoCryptOptions { - mo.LocalProviderOpts = localOpts +// SetKmsProviders specifies the KMS providers map. +func (mo *MongoCryptOptions) SetKmsProviders(kmsProviders bsoncore.Document) *MongoCryptOptions { + mo.KmsProviders = kmsProviders return mo } diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options/provider_options.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options/provider_options.go deleted file mode 100644 index 63daa4c5..00000000 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options/provider_options.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package options - -// AwsKmsProviderOptions specifies options for configuring the AWS KMS provider. -type AwsKmsProviderOptions struct { - AccessKeyID string - SecretAccessKey string -} - -// AwsKmsProvider creates a new AwsKmsProviderOptions instance. -func AwsKmsProvider() *AwsKmsProviderOptions { - return &AwsKmsProviderOptions{} -} - -// SetAccessKeyID specifies the AWS access key ID. -func (akpo *AwsKmsProviderOptions) SetAccessKeyID(accessKeyID string) *AwsKmsProviderOptions { - akpo.AccessKeyID = accessKeyID - return akpo -} - -// SetSecretAccessKey specifies the AWS secret access key. -func (akpo *AwsKmsProviderOptions) SetSecretAccessKey(secretAccessKey string) *AwsKmsProviderOptions { - akpo.SecretAccessKey = secretAccessKey - return akpo -} - -// LocalKmsProviderOptions specifies options for configuring a local KMS provider. -type LocalKmsProviderOptions struct { - MasterKey []byte -} - -// LocalKmsProvider creates a new LocalKmsProviderOptions instance. -func LocalKmsProvider() *LocalKmsProviderOptions { - return &LocalKmsProviderOptions{} -} - -// SetMasterKey specifies the local master key. -func (lkpo *LocalKmsProviderOptions) SetMasterKey(key []byte) *LocalKmsProviderOptions { - lkpo.MasterKey = key - return lkpo -} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation.go index 74898500..f57565f5 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation.go @@ -13,11 +13,11 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) @@ -367,7 +367,7 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error { } res, err = roundTrip(ctx, conn, wm) if ep, ok := srvr.(ErrorProcessor); ok { - ep.ProcessError(err, conn) + _ = ep.ProcessError(err, conn) } finishedInfo.response = res @@ -384,7 +384,8 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error { connDesc := conn.Description() retryableErr := tt.Retryable(connDesc.WireVersion) preRetryWriteLabelVersion := connDesc.WireVersion != nil && connDesc.WireVersion.Max < 9 - inTransaction := !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() + inTransaction := op.Client != nil && + !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() // If retry is enabled and the operation isn't in a transaction, add a RetryableWriteError label for // retryable errors from pre-4.4 servers if retryableErr && preRetryWriteLabelVersion && retryEnabled && !inTransaction { @@ -465,7 +466,8 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error { if op.Type == Write { retryableErr = tt.RetryableWrite(connDesc.WireVersion) preRetryWriteLabelVersion := connDesc.WireVersion != nil && connDesc.WireVersion.Max < 9 - inTransaction := !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() + inTransaction := op.Client != nil && + !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() // If retryWrites is enabled and the operation isn't in a transaction, add a RetryableWriteError label // for network errors and retryable errors from pre-4.4 servers if retryEnabled && !inTransaction && @@ -556,7 +558,7 @@ func (op Operation) retryable(desc description.Server) bool { if op.Client != nil && (op.Client.Committing || op.Client.Aborting) { return true } - if desc.SupportsRetryWrites() && + if retryWritesSupported(desc) && desc.WireVersion != nil && desc.WireVersion.Max >= 6 && op.Client != nil && !(op.Client.TransactionInProgress() || op.Client.TransactionStarting()) && writeconcern.AckWrite(op.WriteConcern) { @@ -924,7 +926,7 @@ func (op Operation) addReadConcern(dst []byte, desc description.SelectedServer) return dst, err } - if description.SessionsSupported(desc.WireVersion) && client != nil && client.Consistent && client.OperationTime != nil { + if sessionsSupported(desc.WireVersion) && client != nil && client.Consistent && client.OperationTime != nil { data = data[:len(data)-1] // remove the null byte data = bsoncore.AppendTimestampElement(data, "afterClusterTime", client.OperationTime.T, client.OperationTime.I) data, _ = bsoncore.AppendDocumentEnd(data, 0) @@ -958,7 +960,7 @@ func (op Operation) addWriteConcern(dst []byte, desc description.SelectedServer) func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]byte, error) { client := op.Client - if client == nil || !description.SessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutes == 0 { + if client == nil || !sessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutes == 0 { return dst, nil } if err := client.UpdateUseTime(); err != nil { @@ -988,7 +990,7 @@ func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]b func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) []byte { client, clock := op.Client, op.Clock - if (clock == nil && client == nil) || !description.SessionsSupported(desc.WireVersion) { + if (clock == nil && client == nil) || !sessionsSupported(desc.WireVersion) { return dst } clusterTime := clock.GetClusterTime() @@ -1404,3 +1406,13 @@ func (op Operation) publishFinishedEvent(ctx context.Context, info finishedInfor } op.CommandMonitor.Failed(ctx, failedEvent) } + +// sessionsSupported returns true of the given server version indicates that it supports sessions. +func sessionsSupported(wireVersion *description.VersionRange) bool { + return wireVersion != nil && wireVersion.Max >= 6 +} + +// retryWritesSupported returns true if this description represents a server that supports retryable writes. +func retryWritesSupported(s description.Server) bool { + return s.SessionTimeoutMinutes != 0 && s.Kind != description.Standalone +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/abort_transaction.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/abort_transaction.go index 7c3d054e..885c9b88 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/abort_transaction.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/abort_transaction.go @@ -13,10 +13,10 @@ import ( "errors" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/aggregate.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/aggregate.go index d8e2fda5..2b329485 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/aggregate.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/aggregate.go @@ -14,12 +14,12 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/command.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/command.go index 11cbc4ea..3e04814f 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/command.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/command.go @@ -7,11 +7,11 @@ import ( "errors" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/commit_transaction.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/commit_transaction.go index bbca63d1..13307dc3 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/commit_transaction.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/commit_transaction.go @@ -13,10 +13,10 @@ import ( "errors" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/count.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/count.go index b98eb25f..c3c5dc75 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/count.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/count.go @@ -14,11 +14,11 @@ import ( "fmt" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/create.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/create.go index a7f15c8d..45e376a9 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/create.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/create.go @@ -13,10 +13,10 @@ import ( "errors" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/createIndexes.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/createIndexes.go index ef65b1e7..289a9c09 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/createIndexes.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/createIndexes.go @@ -15,10 +15,10 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/delete.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/delete.go index c697266f..f915e015 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/delete.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/delete.go @@ -14,10 +14,10 @@ import ( "fmt" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/distinct.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/distinct.go index 9733e5a8..28dd5c62 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/distinct.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/distinct.go @@ -13,11 +13,11 @@ import ( "errors" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_collection.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_collection.go index dcc8e7a2..5d1504b5 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_collection.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_collection.go @@ -14,10 +14,10 @@ import ( "fmt" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_database.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_database.go index 1b43d767..28068bd9 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_database.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_database.go @@ -14,10 +14,10 @@ import ( "fmt" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_indexes.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_indexes.go index 1881dd25..5d365697 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_indexes.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/drop_indexes.go @@ -14,10 +14,10 @@ import ( "fmt" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/end_sessions.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/end_sessions.go index ee9537d0..8f8a3baa 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/end_sessions.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/end_sessions.go @@ -13,9 +13,9 @@ import ( "errors" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/find.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/find.go index 0e49d3a0..0675a960 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/find.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/find.go @@ -14,11 +14,11 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/find_and_modify.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/find_and_modify.go index 2534243c..e2e1f8e1 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/find_and_modify.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/find_and_modify.go @@ -16,10 +16,10 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/insert.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/insert.go index 779f2d55..13494d0d 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/insert.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/insert.go @@ -14,10 +14,10 @@ import ( "fmt" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/ismaster.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/ismaster.go index 1ac16232..7d24dd82 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/ismaster.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/ismaster.go @@ -7,11 +7,13 @@ import ( "runtime" "strconv" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/internal" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/version" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/address" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -89,7 +91,7 @@ func (im *IsMaster) MaxAwaitTimeMS(awaitTime int64) *IsMaster { // Result returns the result of executing this operation. func (im *IsMaster) Result(addr address.Address) description.Server { - return description.NewServer(addr, im.res) + return description.NewServer(addr, bson.Raw(im.res)) } func (im *IsMaster) decodeStringSlice(element bsoncore.Element, name string) ([]string, error) { @@ -223,9 +225,9 @@ func (im *IsMaster) createOperation() driver.Operation { } } -// GetDescription retrieves the server description for the given connection. This function implements the Handshaker -// interface. -func (im *IsMaster) GetDescription(ctx context.Context, _ address.Address, c driver.Connection) (description.Server, error) { +// GetHandshakeInformation performs the MongoDB handshake for the provided connection and returns the relevant +// information about the server. This function implements the driver.Handshaker interface. +func (im *IsMaster) GetHandshakeInformation(ctx context.Context, _ address.Address, c driver.Connection) (driver.HandshakeInformation, error) { err := driver.Operation{ Clock: im.clock, CommandFn: im.handshakeCommand, @@ -237,9 +239,21 @@ func (im *IsMaster) GetDescription(ctx context.Context, _ address.Address, c dri }, }.Execute(ctx, nil) if err != nil { - return description.Server{}, err + return driver.HandshakeInformation{}, err } - return im.Result(c.Address()), nil + + info := driver.HandshakeInformation{ + Description: im.Result(c.Address()), + } + if speculativeAuthenticate, ok := im.res.Lookup("speculativeAuthenticate").DocumentOK(); ok { + info.SpeculativeAuthenticate = speculativeAuthenticate + } + // Cast to bson.Raw to lookup saslSupportedMechs to avoid converting from bsoncore.Value to bson.RawValue for the + // StringSliceFromRawValue call. + if saslSupportedMechs, lookupErr := bson.Raw(im.res).LookupErr("saslSupportedMechs"); lookupErr == nil { + info.SaslSupportedMechs, err = internal.StringSliceFromRawValue("saslSupportedMechs", saslSupportedMechs) + } + return info, err } // FinishHandshake implements the Handshaker interface. This is a no-op function because a non-authenticated connection diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/listDatabases.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/listDatabases.go index 77f91787..36c989b7 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/listDatabases.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/listDatabases.go @@ -15,10 +15,10 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/list_collections.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/list_collections.go index 579ec5f1..383e2f12 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/list_collections.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/list_collections.go @@ -13,10 +13,10 @@ import ( "errors" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -34,6 +34,7 @@ type ListCollections struct { selector description.ServerSelector retry *driver.RetryMode result driver.CursorResponse + batchSize *int32 } // NewListCollections constructs and returns a new ListCollections. @@ -95,6 +96,12 @@ func (lc *ListCollections) command(dst []byte, desc description.SelectedServer) if lc.nameOnly != nil { dst = bsoncore.AppendBooleanElement(dst, "nameOnly", *lc.nameOnly) } + cursorDoc := bsoncore.NewDocumentBuilder() + if lc.batchSize != nil { + cursorDoc.AppendInt32("batchSize", *lc.batchSize) + } + dst = bsoncore.AppendDocumentElement(dst, "cursor", cursorDoc.Build()) + return dst, nil } @@ -208,3 +215,13 @@ func (lc *ListCollections) Retry(retry driver.RetryMode) *ListCollections { lc.retry = &retry return lc } + +// BatchSize specifies the number of documents to return in every batch. +func (lc *ListCollections) BatchSize(batchSize int32) *ListCollections { + if lc == nil { + lc = new(ListCollections) + } + + lc.batchSize = &batchSize + return lc +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/list_indexes.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/list_indexes.go index 0de46f65..1fc7dbdd 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/list_indexes.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/list_indexes.go @@ -13,9 +13,9 @@ import ( "errors" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/update.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/update.go index 0aed7809..2ccbaee3 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/update.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation/update.go @@ -15,10 +15,10 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation_exhaust.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation_exhaust.go index fd01928c..caee435e 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation_exhaust.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation_exhaust.go @@ -10,7 +10,7 @@ import ( "context" "errors" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" + "go.mongodb.org/mongo-driver/mongo/description" ) // ExecuteExhaust reads a response from the provided StreamerConnection. This will error if the connection's diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation_legacy.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation_legacy.go index ffdec97c..8dd13d4d 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation_legacy.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation_legacy.go @@ -13,8 +13,8 @@ import ( "time" "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) @@ -297,7 +297,7 @@ func (op Operation) legacyKillCursors(ctx context.Context, dst []byte, srvr Serv if err != nil { err = Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}} if ep, ok := srvr.(ErrorProcessor); ok { - ep.ProcessError(err, conn) + _ = ep.ProcessError(err, conn) } finishedInfo.cmdErr = err @@ -635,7 +635,7 @@ func (op Operation) appendLegacyQueryDocument(dst []byte, filter bsoncore.Docume func (op Operation) roundTripLegacyCursor(ctx context.Context, wm []byte, srvr Server, conn Connection, collName, identifier string) (bsoncore.Document, error) { wm, err := op.roundTripLegacy(ctx, conn, wm) if ep, ok := srvr.(ErrorProcessor); ok { - ep.ProcessError(err, conn) + _ = ep.ProcessError(err, conn) } if err != nil { return nil, err diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/session/client_session.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/session/client_session.go index a0d0d3bf..92676d8a 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/session/client_session.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/session/client_session.go @@ -12,10 +12,10 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/uuid" ) @@ -49,18 +49,36 @@ const ( Implicit ) -// State indicates the state of the FSM. -type state uint8 +// TransactionState indicates the state of the transactions FSM. +type TransactionState uint8 // Client Session states const ( - None state = iota + None TransactionState = iota Starting InProgress Committed Aborted ) +// String implements the fmt.Stringer interface. +func (s TransactionState) String() string { + switch s { + case None: + return "none" + case Starting: + return "starting" + case InProgress: + return "in progress" + case Committed: + return "committed" + case Aborted: + return "aborted" + default: + return "unknown" + } +} + // Client is a session for clients to run commands. type Client struct { *Server @@ -89,10 +107,10 @@ type Client struct { transactionWc *writeconcern.WriteConcern transactionMaxCommitTime *time.Duration - pool *Pool - state state - PinnedServer *description.Server - RecoveryToken bson.Raw + pool *Pool + TransactionState TransactionState + PinnedServer *description.Server + RecoveryToken bson.Raw } func getClusterTime(clusterTime bson.Raw) (uint32, uint32) { @@ -242,29 +260,29 @@ func (c *Client) EndSession() { // TransactionInProgress returns true if the client session is in an active transaction. func (c *Client) TransactionInProgress() bool { - return c.state == InProgress + return c.TransactionState == InProgress } // TransactionStarting returns true if the client session is starting a transaction. func (c *Client) TransactionStarting() bool { - return c.state == Starting + return c.TransactionState == Starting } // TransactionRunning returns true if the client session has started the transaction // and it hasn't been committed or aborted func (c *Client) TransactionRunning() bool { - return c != nil && (c.state == Starting || c.state == InProgress) + return c != nil && (c.TransactionState == Starting || c.TransactionState == InProgress) } // TransactionCommitted returns true of the client session just committed a transaciton. func (c *Client) TransactionCommitted() bool { - return c.state == Committed + return c.TransactionState == Committed } // CheckStartTransaction checks to see if allowed to start transaction and returns // an error if not allowed func (c *Client) CheckStartTransaction() error { - if c.state == InProgress || c.state == Starting { + if c.TransactionState == InProgress || c.TransactionState == Starting { return ErrTransactInProgress } return nil @@ -309,7 +327,7 @@ func (c *Client) StartTransaction(opts *TransactionOptions) error { return ErrUnackWCUnsupported } - c.state = Starting + c.TransactionState = Starting c.PinnedServer = nil return nil } @@ -317,9 +335,9 @@ func (c *Client) StartTransaction(opts *TransactionOptions) error { // CheckCommitTransaction checks to see if allowed to commit transaction and returns // an error if not allowed. func (c *Client) CheckCommitTransaction() error { - if c.state == None { + if c.TransactionState == None { return ErrNoTransactStarted - } else if c.state == Aborted { + } else if c.TransactionState == Aborted { return ErrCommitAfterAbort } return nil @@ -332,7 +350,7 @@ func (c *Client) CommitTransaction() error { if err != nil { return err } - c.state = Committed + c.TransactionState = Committed return nil } @@ -351,11 +369,11 @@ func (c *Client) UpdateCommitTransactionWriteConcern() { // CheckAbortTransaction checks to see if allowed to abort transaction and returns // an error if not allowed. func (c *Client) CheckAbortTransaction() error { - if c.state == None { + if c.TransactionState == None { return ErrNoTransactStarted - } else if c.state == Committed { + } else if c.TransactionState == Committed { return ErrAbortAfterCommit - } else if c.state == Aborted { + } else if c.TransactionState == Aborted { return ErrAbortTwice } return nil @@ -368,7 +386,7 @@ func (c *Client) AbortTransaction() error { if err != nil { return err } - c.state = Aborted + c.TransactionState = Aborted c.clearTransactionOpts() return nil } @@ -379,15 +397,15 @@ func (c *Client) ApplyCommand(desc description.Server) { // Do not change state if committing after already committed return } - if c.state == Starting { - c.state = InProgress + if c.TransactionState == Starting { + c.TransactionState = InProgress // If this is in a transaction and the server is a mongos, pin it if desc.Kind == description.Mongos { c.PinnedServer = &desc } - } else if c.state == Committed || c.state == Aborted { + } else if c.TransactionState == Committed || c.TransactionState == Aborted { c.clearTransactionOpts() - c.state = None + c.TransactionState = None } } diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/session/session_pool.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/session/session_pool.go index 9fce3945..61616ac3 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/session/session_pool.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/session/session_pool.go @@ -9,8 +9,8 @@ package session import ( "sync" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) // Node represents a server session in a linked list diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/cancellation_listener.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/cancellation_listener.go new file mode 100644 index 00000000..caca9880 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/cancellation_listener.go @@ -0,0 +1,14 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package topology + +import "context" + +type cancellationListener interface { + Listen(context.Context, func()) + StopListening() bool +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go index 6c1a4cc1..a015a647 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go @@ -18,16 +18,23 @@ import ( "sync/atomic" "time" + "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/internal" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/address" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/ocsp" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) var globalConnectionID uint64 = 1 +var ( + defaultMaxMessageSize uint32 = 48000000 + errResponseTooLarge error = errors.New("length of read message too large") +) + func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) } type connection struct { @@ -52,12 +59,14 @@ type connection struct { canStream bool currentlyStreaming bool connectContextMutex sync.Mutex + cancellationListener cancellationListener // pool related fields pool *pool poolID uint64 generation uint64 expireReason string + poolMonitor *event.PoolMonitor } // newConnection handles the creation of a connection. It does not connect the connection. @@ -70,14 +79,16 @@ func newConnection(addr address.Address, opts ...ConnectionOption) (*connection, id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID()) c := &connection{ - id: id, - addr: addr, - idleTimeout: cfg.idleTimeout, - readTimeout: cfg.readTimeout, - writeTimeout: cfg.writeTimeout, - connectDone: make(chan struct{}), - config: cfg, - connectContextMade: make(chan struct{}), + id: id, + addr: addr, + idleTimeout: cfg.idleTimeout, + readTimeout: cfg.readTimeout, + writeTimeout: cfg.writeTimeout, + connectDone: make(chan struct{}), + config: cfg, + connectContextMade: make(chan struct{}), + cancellationListener: internal.NewCancellationListener(), + poolMonitor: cfg.poolMonitor, } atomic.StoreInt32(&c.connected, initialized) @@ -104,10 +115,28 @@ func (c *connection) connect(ctx context.Context) { } defer close(c.connectDone) + // Create separate contexts for dialing a connection and doing the MongoDB/auth handshakes. + // + // handshakeCtx is simply a cancellable version of ctx because there's no default timeout that needs to be applied + // to the full handshake. The cancellation allows consumers to bail out early when dialing a connection if it's no + // longer required. This is done in lock because it accesses the shared cancelConnectContext field. + // + // dialCtx is equal to handshakeCtx if connectTimeoutMS=0. Otherwise, it is derived from handshakeCtx so the + // cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket + // establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid + // holding the lock longer than necessary. c.connectContextMutex.Lock() - ctx, c.cancelConnectContext = context.WithCancel(ctx) + var handshakeCtx context.Context + handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx) c.connectContextMutex.Unlock() + dialCtx := handshakeCtx + var dialCancel context.CancelFunc + if c.config.connectTimeout != 0 { + dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout) + defer dialCancel() + } + defer func() { var cancelFn context.CancelFunc @@ -126,7 +155,7 @@ func (c *connection) connect(ctx context.Context) { // Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case. var err error var tempNc net.Conn - tempNc, err = c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String()) + tempNc, err = c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String()) if err != nil { c.processInitializationError(err) return @@ -142,7 +171,7 @@ func (c *connection) connect(ctx context.Context) { Cache: c.config.ocspCache, DisableEndpointChecking: c.config.disableOCSPEndpointCheck, } - tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts) + tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts) if err != nil { c.processInitializationError(err) return @@ -155,16 +184,29 @@ func (c *connection) connect(ctx context.Context) { // running isMaster and authentication is handled by a handshaker on the configuration instance. handshaker := c.config.handshaker if handshaker == nil { + if c.poolMonitor != nil { + c.poolMonitor.Event(&event.PoolEvent{ + Type: event.ConnectionReady, + Address: c.addr.String(), + ConnectionID: c.poolID, + }) + } return } + var handshakeInfo driver.HandshakeInformation handshakeStartTime := time.Now() handshakeConn := initConnection{c} - c.desc, err = handshaker.GetDescription(ctx, c.addr, handshakeConn) + handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn) if err == nil { + // We only need to retain the Description field as the connection's description. The authentication-related + // fields in handshakeInfo are tracked by the handshaker if necessary. + c.desc = handshakeInfo.Description c.isMasterRTT = time.Since(handshakeStartTime) - err = handshaker.FinishHandshake(ctx, handshakeConn) + err = handshaker.FinishHandshake(handshakeCtx, handshakeConn) } + + // We have a failed handshake here if err != nil { c.processInitializationError(err) return @@ -198,6 +240,13 @@ func (c *connection) connect(ctx context.Context) { } } } + if c.poolMonitor != nil { + c.poolMonitor.Event(&event.PoolEvent{ + Type: event.ConnectionReady, + Address: c.addr.String(), + ConnectionID: c.poolID, + }) + } } func (c *connection) wait() error { @@ -221,20 +270,32 @@ func (c *connection) closeConnectContext() { } } -func transformNetworkError(originalError error, contextDeadlineUsed bool) error { +func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error { if originalError == nil { return nil } + + // If there was an error and the context was cancelled, we assume it happened due to the cancellation. + if ctx.Err() == context.Canceled { + return context.Canceled + } + + // If there was a timeout error and the context deadline was used, we convert the error into + // context.DeadlineExceeded. if !contextDeadlineUsed { return originalError } - if netErr, ok := originalError.(net.Error); ok && netErr.Timeout() { return context.DeadlineExceeded } + return originalError } +func (c *connection) cancellationListenerCallback() { + _ = c.close() +} + func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error { var err error if atomic.LoadInt32(&c.connected) != connected { @@ -261,12 +322,12 @@ func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error { return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"} } - _, err = c.nc.Write(wm) + err = c.write(ctx, wm) if err != nil { c.close() return ConnectionError{ ConnectionID: c.id, - Wrapped: transformNetworkError(err, contextDeadlineUsed), + Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed), message: "unable to write wire message to network", } } @@ -275,6 +336,23 @@ func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error { return nil } +func (c *connection) write(ctx context.Context, wm []byte) (err error) { + go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback) + defer func() { + // There is a race condition between Write and StopListening. If the context is cancelled after c.nc.Write + // succeeds, the cancellation listener could fire and close the connection. In this case, the connection has + // been invalidated but the error is nil. To account for this, overwrite the error to context.Cancelled if + // the abortedForCancellation flag was set. + + if aborted := c.cancellationListener.StopListening(); aborted && err == nil { + err = context.Canceled + } + }() + + _, err = c.nc.Write(wm) + return err +} + // readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten. func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, error) { if atomic.LoadInt32(&c.connected) != connected { @@ -304,6 +382,38 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"} } + dst, errMsg, err := c.read(ctx, dst) + if err != nil { + // We closeConnection the connection because we don't know if there are other bytes left to read. + c.close() + message := errMsg + if err == io.EOF { + message = "socket was unexpectedly closed" + } + return nil, ConnectionError{ + ConnectionID: c.id, + Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed), + message: message, + } + } + + c.bumpIdleDeadline() + return dst, nil +} + +func (c *connection) read(ctx context.Context, dst []byte) (bytesRead []byte, errMsg string, err error) { + go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback) + defer func() { + // If the context is cancelled after we finish reading the server response, the cancellation listener could fire + // even though the socket reads succeed. To account for this, we overwrite err to be context.Canceled if the + // abortedForCancellation flag is set. + + if aborted := c.cancellationListener.StopListening(); aborted && err == nil { + errMsg = "unable to read server response" + err = context.Canceled + } + }() + // We use an array here because it only costs 4 bytes on the stack and means we'll only need to // reslice dst once instead of twice. var sizeBuf [4]byte @@ -311,20 +421,24 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e // We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst // because there might be more than one wire message waiting to be read, for example when // reading messages from an exhaust cursor. - _, err := io.ReadFull(c.nc, sizeBuf[:]) + _, err = io.ReadFull(c.nc, sizeBuf[:]) if err != nil { - // We closeConnection the connection because we don't know if there are other bytes left to read. - c.close() - return nil, ConnectionError{ - ConnectionID: c.id, - Wrapped: transformNetworkError(err, contextDeadlineUsed), - message: "incomplete read of message header", - } + return nil, "incomplete read of message header", err } // read the length as an int32 size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) + // In the case of an isMaster response where MaxMessageSize has not yet been set, use the hard-coded + // defaultMaxMessageSize instead. + maxMessageSize := c.desc.MaxMessageSize + if maxMessageSize == 0 { + maxMessageSize = defaultMaxMessageSize + } + if uint32(size) > maxMessageSize { + return nil, errResponseTooLarge.Error(), errResponseTooLarge + } + if int(size) > cap(dst) { // Since we can't grow this slice without allocating, just allocate an entirely new slice. dst = make([]byte, 0, size) @@ -336,17 +450,10 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e _, err = io.ReadFull(c.nc, dst[4:]) if err != nil { - // We closeConnection the connection because we don't know if there are other bytes left to read. - c.close() - return nil, ConnectionError{ - ConnectionID: c.id, - Wrapped: transformNetworkError(err, contextDeadlineUsed), - message: "incomplete read of full message", - } + return nil, "incomplete read of full message", err } - c.bumpIdleDeadline() - return dst, nil + return dst, "", nil } func (c *connection) close() error { @@ -406,6 +513,10 @@ func (c *connection) setSocketTimeout(timeout time.Duration) { c.writeTimeout = timeout } +func (c *connection) ID() string { + return c.id +} + // initConnection is an adapter used during connection initialization. It has the minimum // functionality necessary to implement the driver.Connection interface, which is required to pass a // *connection to a Handshaker. diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_options.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_options.go index 21d22af8..2a1f4e83 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_options.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_options.go @@ -42,6 +42,7 @@ type connectionConfig struct { handshaker Handshaker idleTimeout time.Duration cmdMonitor *event.CommandMonitor + poolMonitor *event.PoolMonitor readTimeout time.Duration writeTimeout time.Duration tlsConfig *tls.Config @@ -69,7 +70,7 @@ func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) { } if cfg.dialer == nil { - cfg.dialer = &net.Dialer{Timeout: cfg.connectTimeout} + cfg.dialer = &net.Dialer{} } return cfg, nil @@ -166,6 +167,14 @@ func WithMonitor(fn func(*event.CommandMonitor) *event.CommandMonitor) Connectio } } +// withPoolMonitor configures a event for connection monitoring. +func withPoolMonitor(fn func(*event.PoolMonitor) *event.PoolMonitor) ConnectionOption { + return func(c *connectionConfig) error { + c.poolMonitor = fn(c.poolMonitor) + return nil + } +} + // WithZlibLevel sets the zLib compression level. func WithZlibLevel(fn func(*int) *int) ConnectionOption { return func(c *connectionConfig) error { diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/topology.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/diff.go similarity index 51% rename from vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/topology.go rename to vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/diff.go index b6f4e3ea..b9bf2c14 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/description/topology.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/diff.go @@ -4,89 +4,69 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -package description +package topology -import ( - "go.mongodb.org/mongo-driver/x/mongo/driver/address" -) +import "go.mongodb.org/mongo-driver/mongo/description" -// Topology represents a description of a mongodb topology -type Topology struct { - Servers []Server - Kind TopologyKind - SessionTimeoutMinutes uint32 -} - -// Server returns the server for the given address. Returns false if the server -// could not be found. -func (t Topology) Server(addr address.Address) (Server, bool) { - for _, server := range t.Servers { - if server.Addr.String() == addr.String() { - return server, true - } - } - return Server{}, false -} - -// TopologyDiff is the difference between two different topology descriptions. -type TopologyDiff struct { - Added []Server - Removed []Server +// hostlistDiff is the difference between a topology and a host list. +type hostlistDiff struct { + Added []string + Removed []string } -// DiffTopology compares the two topology descriptions and returns the difference. -func DiffTopology(old, new Topology) TopologyDiff { - var diff TopologyDiff +// diffHostList compares the topology description and host list and returns the difference. +func diffHostList(t description.Topology, hostlist []string) hostlistDiff { + var diff hostlistDiff oldServers := make(map[string]bool) - for _, s := range old.Servers { + for _, s := range t.Servers { oldServers[s.Addr.String()] = true } - for _, s := range new.Servers { - addr := s.Addr.String() + for _, addr := range hostlist { if oldServers[addr] { delete(oldServers, addr) } else { - diff.Added = append(diff.Added, s) + diff.Added = append(diff.Added, addr) } } - for _, s := range old.Servers { - addr := s.Addr.String() - if oldServers[addr] { - diff.Removed = append(diff.Removed, s) - } + for addr := range oldServers { + diff.Removed = append(diff.Removed, addr) } return diff } -// HostlistDiff is the difference between a topology and a host list. -type HostlistDiff struct { - Added []string - Removed []string +// topologyDiff is the difference between two different topology descriptions. +type topologyDiff struct { + Added []description.Server + Removed []description.Server } -// DiffHostlist compares the topology description and host list and returns the difference. -func (t Topology) DiffHostlist(hostlist []string) HostlistDiff { - var diff HostlistDiff +// diffTopology compares the two topology descriptions and returns the difference. +func diffTopology(old, new description.Topology) topologyDiff { + var diff topologyDiff oldServers := make(map[string]bool) - for _, s := range t.Servers { + for _, s := range old.Servers { oldServers[s.Addr.String()] = true } - for _, addr := range hostlist { + for _, s := range new.Servers { + addr := s.Addr.String() if oldServers[addr] { delete(oldServers, addr) } else { - diff.Added = append(diff.Added, addr) + diff.Added = append(diff.Added, s) } } - for addr := range oldServers { - diff.Removed = append(diff.Removed, addr) + for _, s := range old.Servers { + addr := s.Addr.String() + if oldServers[addr] { + diff.Removed = append(diff.Removed, s) + } } return diff diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/errors.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/errors.go index c543bacb..30274ee9 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/errors.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/errors.go @@ -2,6 +2,8 @@ package topology import ( "fmt" + + "go.mongodb.org/mongo-driver/mongo/description" ) // ConnectionError represents a connection error. @@ -17,10 +19,21 @@ type ConnectionError struct { // Error implements the error interface. func (e ConnectionError) Error() string { + message := e.message + if e.init { + fullMsg := "error occured during connection handshake" + if message != "" { + fullMsg = fmt.Sprintf("%s: %s", fullMsg, message) + } + message = fullMsg + } + if e.Wrapped != nil && message != "" { + return fmt.Sprintf("connection(%s) %s: %s", e.ConnectionID, message, e.Wrapped.Error()) + } if e.Wrapped != nil { - return fmt.Sprintf("connection(%s) %s: %s", e.ConnectionID, e.message, e.Wrapped.Error()) + return fmt.Sprintf("connection(%s) %s", e.ConnectionID, e.Wrapped.Error()) } - return fmt.Sprintf("connection(%s) %s", e.ConnectionID, e.message) + return fmt.Sprintf("connection(%s) %s", e.ConnectionID, message) } // Unwrap returns the underlying error. @@ -28,6 +41,25 @@ func (e ConnectionError) Unwrap() error { return e.Wrapped } +// ServerSelectionError represents a Server Selection error. +type ServerSelectionError struct { + Desc description.Topology + Wrapped error +} + +// Error implements the error interface. +func (e ServerSelectionError) Error() string { + if e.Wrapped != nil { + return fmt.Sprintf("server selection error: %s, current topology: { %s }", e.Wrapped.Error(), e.Desc.String()) + } + return fmt.Sprintf("server selection error: current topology: { %s }", e.Desc.String()) +} + +// Unwrap returns the underlying error. +func (e ServerSelectionError) Unwrap() error { + return e.Wrapped +} + // WaitQueueTimeoutError represents a timeout when requesting a connection from the pool type WaitQueueTimeoutError struct { Wrapped error diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/fsm.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/fsm.go index 90f0f870..9f5d1f1b 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/fsm.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/fsm.go @@ -12,16 +12,22 @@ import ( "sync/atomic" "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/x/mongo/driver/address" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/description" ) -var supportedWireVersions = description.NewVersionRange(2, 9) -var minSupportedMongoDBVersion = "2.6" +var ( + // SupportedWireVersions is the range of wire versions supported by the driver. + SupportedWireVersions = description.NewVersionRange(2, 9) +) + +const ( + // MinSupportedMongoDBVersion is the version string for the lowest MongoDB version supported by the driver. + MinSupportedMongoDBVersion = "2.6" +) type fsm struct { description.Topology - SetName string maxElectionID primitive.ObjectID maxSetVersion uint32 compatible atomic.Value @@ -48,6 +54,7 @@ func (f *fsm) apply(s description.Server) (description.Topology, description.Ser f.Topology = description.Topology{ Kind: f.Kind, Servers: newServers, + SetName: f.SetName, } // For data bearing servers, set SessionTimeoutMinutes to the lowest among them @@ -89,28 +96,30 @@ func (f *fsm) apply(s description.Server) (description.Topology, description.Ser for _, server := range f.Servers { if server.WireVersion != nil { - if server.WireVersion.Max < supportedWireVersions.Min { + if server.WireVersion.Max < SupportedWireVersions.Min { f.compatible.Store(false) f.compatibilityErr = fmt.Errorf( "server at %s reports wire version %d, but this version of the Go driver requires "+ "at least %d (MongoDB %s)", server.Addr.String(), server.WireVersion.Max, - supportedWireVersions.Min, - minSupportedMongoDBVersion, + SupportedWireVersions.Min, + MinSupportedMongoDBVersion, ) - return description.Topology{}, s, f.compatibilityErr + f.Topology.CompatibilityErr = f.compatibilityErr + return f.Topology, s, nil } - if server.WireVersion.Min > supportedWireVersions.Max { + if server.WireVersion.Min > SupportedWireVersions.Max { f.compatible.Store(false) f.compatibilityErr = fmt.Errorf( "server at %s requires wire version %d, but this version of the Go driver only supports up to %d", server.Addr.String(), server.WireVersion.Min, - supportedWireVersions.Max, + SupportedWireVersions.Max, ) - return description.Topology{}, s, f.compatibilityErr + f.Topology.CompatibilityErr = f.compatibilityErr + return f.Topology, s, nil } } } diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/pool.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/pool.go index 0a4fe487..7ce544f4 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/pool.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/pool.go @@ -14,7 +14,7 @@ import ( "time" "go.mongodb.org/mongo-driver/event" - "go.mongodb.org/mongo-driver/x/mongo/driver/address" + "go.mongodb.org/mongo-driver/mongo/address" "golang.org/x/sync/semaphore" ) @@ -138,6 +138,9 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) { if config.MaxIdleTime != time.Duration(0) { opts = append(opts, WithIdleTimeout(func(_ time.Duration) time.Duration { return config.MaxIdleTime })) } + if config.PoolMonitor != nil { + opts = append(opts, withPoolMonitor(func(_ *event.PoolMonitor) *event.PoolMonitor { return config.PoolMonitor })) + } var maxConns = config.MaxPoolSize if maxConns == 0 { diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server.go index 83309c4d..857730f9 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server.go @@ -15,10 +15,11 @@ import ( "sync/atomic" "time" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/address" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" ) @@ -97,6 +98,7 @@ type Server struct { // description related fields desc atomic.Value // holds a description.Server updateTopologyCallback atomic.Value + topologyID primitive.ObjectID // subscriber related fields subLock sync.Mutex @@ -126,8 +128,8 @@ type updateTopologyCallback func(description.Server) description.Server // ConnectServer creates a new Server and then initializes it using the // Connect method. -func ConnectServer(addr address.Address, updateCallback updateTopologyCallback, opts ...ServerOption) (*Server, error) { - srvr, err := NewServer(addr, opts...) +func ConnectServer(addr address.Address, updateCallback updateTopologyCallback, topologyID primitive.ObjectID, opts ...ServerOption) (*Server, error) { + srvr, err := NewServer(addr, topologyID, opts...) if err != nil { return nil, err } @@ -140,7 +142,7 @@ func ConnectServer(addr address.Address, updateCallback updateTopologyCallback, // NewServer creates a new server. The mongodb server at the address will be monitored // on an internal monitoring goroutine. -func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) { +func NewServer(addr address.Address, topologyID primitive.ObjectID, opts ...ServerOption) (*Server, error) { cfg, err := newServerConfig(opts...) if err != nil { return nil, err @@ -155,6 +157,8 @@ func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) { checkNow: make(chan struct{}, 1), disconnecting: make(chan struct{}), + topologyID: topologyID, + subscribers: make(map[uint64]chan description.Server), globalCtx: globalCtx, globalCtxCancel: globalCtxCancel, @@ -180,6 +184,9 @@ func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) { if err != nil { return nil, err } + + s.publishServerOpeningEvent(s.address) + return s, nil } @@ -349,10 +356,10 @@ func getWriteConcernErrorForProcessing(err error) (*driver.WriteConcernError, bo } // ProcessError handles SDAM error handling and implements driver.ErrorProcessor. -func (s *Server) ProcessError(err error, conn driver.Connection) { +func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessErrorResult { // ignore nil error if err == nil { - return + return driver.NoChange } s.processErrorLock.Lock() @@ -360,55 +367,59 @@ func (s *Server) ProcessError(err error, conn driver.Connection) { // ignore stale error if conn.Stale() { - return + return driver.NoChange } // Invalidate server description if not master or node recovering error occurs. // These errors can be reported as a command error or a write concern error. desc := conn.Description() if cerr, ok := err.(driver.Error); ok && (cerr.NodeIsRecovering() || cerr.NotMaster()) { // ignore stale error - if description.CompareTopologyVersion(desc.TopologyVersion, cerr.TopologyVersion) >= 0 { - return + if desc.TopologyVersion.CompareToIncoming(cerr.TopologyVersion) >= 0 { + return driver.NoChange } // updates description to unknown s.updateDescription(description.NewServerFromError(s.address, err, cerr.TopologyVersion)) s.RequestImmediateCheck() + res := driver.ServerMarkedUnknown // If the node is shutting down or is older than 4.2, we synchronously clear the pool if cerr.NodeIsShuttingDown() || desc.WireVersion == nil || desc.WireVersion.Max < 8 { + res = driver.ConnectionPoolCleared s.pool.clear() } - return + return res } if wcerr, ok := getWriteConcernErrorForProcessing(err); ok { // ignore stale error - if description.CompareTopologyVersion(desc.TopologyVersion, wcerr.TopologyVersion) >= 0 { - return + if desc.TopologyVersion.CompareToIncoming(wcerr.TopologyVersion) >= 0 { + return driver.NoChange } // updates description to unknown s.updateDescription(description.NewServerFromError(s.address, err, wcerr.TopologyVersion)) s.RequestImmediateCheck() + res := driver.ServerMarkedUnknown // If the node is shutting down or is older than 4.2, we synchronously clear the pool if wcerr.NodeIsShuttingDown() || desc.WireVersion == nil || desc.WireVersion.Max < 8 { + res = driver.ConnectionPoolCleared s.pool.clear() } - return + return res } wrappedConnErr := unwrapConnectionError(err) if wrappedConnErr == nil { - return + return driver.NoChange } // Ignore transient timeout errors. if netErr, ok := wrappedConnErr.(net.Error); ok && netErr.Timeout() { - return + return driver.NoChange } if wrappedConnErr == context.Canceled || wrappedConnErr == context.DeadlineExceeded { - return + return driver.NoChange } // For a non-timeout network error, we clear the pool, set the description to Unknown, and cancel the in-progress @@ -417,6 +428,7 @@ func (s *Server) ProcessError(err error, conn driver.Connection) { s.updateDescription(description.NewServerFromError(s.address, err, nil)) s.pool.clear() s.cancelCheck() + return driver.ConnectionPoolCleared } // update handles performing heartbeats and updating any subscribers of the @@ -567,8 +579,9 @@ func (s *Server) createConnection() (*connection, error) { WithHandshaker(func(h Handshaker) Handshaker { return operation.NewIsMaster().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts) }), - // Override any command monitors specified in options with nil to avoid monitoring heartbeats. + // Override any monitors specified in options with nil to avoid monitoring heartbeats. WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return nil }), + withPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor { return nil }), } opts = append(s.cfg.connectionOpts, opts...) @@ -629,6 +642,7 @@ func (s *Server) createBaseOperation(conn driver.Connection) *operation.IsMaster func (s *Server) check() (description.Server, error) { var descPtr *description.Server var err error + var durationNanos int64 // Create a new connection if this is the first check, the connection was closed after an error during the previous // check, or the previous check was cancelled. @@ -639,6 +653,7 @@ func (s *Server) check() (description.Server, error) { // Use the description from the connection handshake as the value for this check. s.rttMonitor.addSample(s.conn.isMasterRTT) descPtr = &s.conn.desc + durationNanos = s.conn.isMasterRTT.Nanoseconds() } } @@ -649,12 +664,15 @@ func (s *Server) check() (description.Server, error) { heartbeatConn := initConnection{s.conn} baseOperation := s.createBaseOperation(heartbeatConn) previousDescription := s.Description() + streamable := previousDescription.TopologyVersion != nil + s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable) + start := time.Now() switch { case s.conn.getCurrentlyStreaming(): // The connection is already in a streaming state, so we stream the next response. err = baseOperation.StreamResponse(s.heartbeatCtx, heartbeatConn) - case previousDescription.TopologyVersion != nil: + case streamable: // The server supports the streamable protocol. Set the socket timeout to // connectTimeoutMS+heartbeatFrequencyMS and execute an awaitable isMaster request. Set conn.canStream so // the wire message will advertise streaming support to the server. @@ -680,15 +698,19 @@ func (s *Server) check() (description.Server, error) { s.conn.setSocketTimeout(s.cfg.heartbeatTimeout) err = baseOperation.Execute(s.heartbeatCtx) } + durationNanos = time.Since(start).Nanoseconds() + if err == nil { tempDesc := baseOperation.Result(s.address) descPtr = &tempDesc + s.publishServerHeartbeatSucceededEvent(s.conn.ID(), durationNanos, tempDesc, s.conn.getCurrentlyStreaming() || streamable) } else { // Close the connection here rather than below so we ensure we're not closing a connection that wasn't // successfully created. if s.conn != nil { _ = s.conn.close() } + s.publishServerHeartbeatFailedEvent(s.conn.ID(), durationNanos, err, s.conn.getCurrentlyStreaming() || streamable) } } @@ -777,6 +799,64 @@ func (ss *ServerSubscription) Unsubscribe() error { return nil } +// publishes a ServerOpeningEvent to indicate the server is being initialized +func (s *Server) publishServerOpeningEvent(addr address.Address) { + serverOpening := &event.ServerOpeningEvent{ + Address: addr, + TopologyID: s.topologyID, + } + + if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerOpening != nil { + s.cfg.serverMonitor.ServerOpening(serverOpening) + } +} + +// publishes a ServerHeartbeatStartedEvent to indicate an ismaster command has started +func (s *Server) publishServerHeartbeatStartedEvent(connectionID string, await bool) { + serverHeartbeatStarted := &event.ServerHeartbeatStartedEvent{ + ConnectionID: connectionID, + Awaited: await, + } + + if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatStarted != nil { + s.cfg.serverMonitor.ServerHeartbeatStarted(serverHeartbeatStarted) + } +} + +// publishes a ServerHeartbeatSucceededEvent to indicate ismaster has succeeded +func (s *Server) publishServerHeartbeatSucceededEvent(connectionID string, + durationNanos int64, + desc description.Server, + await bool) { + serverHeartbeatSucceeded := &event.ServerHeartbeatSucceededEvent{ + DurationNanos: durationNanos, + Reply: desc, + ConnectionID: connectionID, + Awaited: await, + } + + if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatSucceeded != nil { + s.cfg.serverMonitor.ServerHeartbeatSucceeded(serverHeartbeatSucceeded) + } +} + +// publishes a ServerHeartbeatFailedEvent to indicate ismaster has failed +func (s *Server) publishServerHeartbeatFailedEvent(connectionID string, + durationNanos int64, + err error, + await bool) { + serverHeartbeatFailed := &event.ServerHeartbeatFailedEvent{ + DurationNanos: durationNanos, + Failure: err, + ConnectionID: connectionID, + Awaited: await, + } + + if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatFailed != nil { + s.cfg.serverMonitor.ServerHeartbeatFailed(serverHeartbeatFailed) + } +} + // unwrapConnectionError returns the connection error wrapped by err, or nil if err does not wrap a connection error. func unwrapConnectionError(err error) error { // This is essentially an implementation of errors.As to unwrap this error until we get a ConnectionError and then diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server_options.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server_options.go index 6675e7a7..902bfbdd 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server_options.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server_options.go @@ -27,6 +27,7 @@ type serverConfig struct { maxConns uint64 minConns uint64 poolMonitor *event.PoolMonitor + serverMonitor *event.ServerMonitor connectionPoolMaxIdleTime time.Duration registry *bsoncodec.Registry monitoringDisabled bool @@ -138,6 +139,14 @@ func WithConnectionPoolMonitor(fn func(*event.PoolMonitor) *event.PoolMonitor) S } } +// WithServerMonitor configures the monitor for all SDAM events for a server +func WithServerMonitor(fn func(*event.ServerMonitor) *event.ServerMonitor) ServerOption { + return func(cfg *serverConfig) error { + cfg.serverMonitor = fn(cfg.serverMonitor) + return nil + } +} + // WithClock configures the ClusterClock for the server to use. func WithClock(fn func(clock *session.ClusterClock) *session.ClusterClock) ServerOption { return func(cfg *serverConfig) error { diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/tls_connection_source.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/tls_connection_source.go index e67a0493..718a9abb 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/tls_connection_source.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/tls_connection_source.go @@ -11,16 +11,26 @@ import ( "net" ) +type tlsConn interface { + net.Conn + Handshake() error + ConnectionState() tls.ConnectionState +} + +var _ tlsConn = (*tls.Conn)(nil) + type tlsConnectionSource interface { - Client(net.Conn, *tls.Config) *tls.Conn + Client(net.Conn, *tls.Config) tlsConn } -type tlsConnectionSourceFn func(net.Conn, *tls.Config) *tls.Conn +type tlsConnectionSourceFn func(net.Conn, *tls.Config) tlsConn + +var _ tlsConnectionSource = (tlsConnectionSourceFn)(nil) -func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) *tls.Conn { +func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) tlsConn { return t(nc, cfg) } -var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn { +var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn { return tls.Client(nc, cfg) } diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology.go index 9b4f869f..a5cde18c 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology.go @@ -21,9 +21,11 @@ import ( "fmt" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/address" - "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/dns" ) @@ -88,6 +90,8 @@ type Topology struct { serversLock sync.Mutex serversClosed bool servers map[address.Address]*Server + + id primitive.ObjectID } var _ driver.Deployment = &Topology{} @@ -121,28 +125,19 @@ func New(opts ...Option) (*Topology, error) { subscribers: make(map[uint64]chan description.Topology), servers: make(map[address.Address]*Server), dnsResolver: dns.DefaultResolver, + id: primitive.NewObjectID(), } t.desc.Store(description.Topology{}) t.updateCallback = func(desc description.Server) description.Server { return t.apply(context.TODO(), desc) } - // A replica set name sets the initial topology type to ReplicaSetNoPrimary unless a direct connection is also - // specified, in which case the initial type is Single. - if cfg.replicaSetName != "" { - t.fsm.SetName = cfg.replicaSetName - t.fsm.Kind = description.ReplicaSetNoPrimary - } - - // A direct connection unconditionally sets the topology type to Single. - if cfg.mode == SingleMode { - t.fsm.Kind = description.Single - } - if t.cfg.uri != "" { t.pollingRequired = strings.HasPrefix(t.cfg.uri, "mongodb+srv://") } + t.publishTopologyOpeningEvent() + return t, nil } @@ -156,9 +151,34 @@ func (t *Topology) Connect() error { t.desc.Store(description.Topology{}) var err error t.serversLock.Lock() + + // A replica set name sets the initial topology type to ReplicaSetNoPrimary unless a direct connection is also + // specified, in which case the initial type is Single. + if t.cfg.replicaSetName != "" { + t.fsm.SetName = t.cfg.replicaSetName + t.fsm.Kind = description.ReplicaSetNoPrimary + } + + // A direct connection unconditionally sets the topology type to Single. + if t.cfg.mode == SingleMode { + t.fsm.Kind = description.Single + } + + for _, a := range t.cfg.seedList { + addr := address.Address(a).Canonicalize() + t.fsm.Servers = append(t.fsm.Servers, description.NewDefaultServer(addr)) + } + + // store new description + newDesc := description.Topology{ + Kind: t.fsm.Kind, + Servers: t.fsm.Servers, + SessionTimeoutMinutes: t.fsm.SessionTimeoutMinutes, + } + t.desc.Store(newDesc) + t.publishTopologyDescriptionChangedEvent(description.Topology{}, t.fsm.Topology) for _, a := range t.cfg.seedList { addr := address.Address(a).Canonicalize() - t.fsm.Servers = append(t.fsm.Servers, description.Server{Addr: addr}) err = t.addServer(addr) if err != nil { return err @@ -194,6 +214,7 @@ func (t *Topology) Disconnect(ctx context.Context) error { for _, server := range servers { _ = server.Disconnect(ctx) + t.publishServerClosedEvent(server.address) } t.subLock.Lock() @@ -212,6 +233,7 @@ func (t *Topology) Disconnect(ctx context.Context) error { t.desc.Store(description.Topology{}) atomic.StoreInt32(&t.connectionstate, disconnected) + t.publishTopologyClosedEvent() return nil } @@ -420,23 +442,19 @@ func (t *Topology) FindServer(selected description.Server) (*SelectedServer, err }, nil } -func wrapServerSelectionError(err error, t *Topology) error { - return fmt.Errorf("server selection error: %v, current topology: { %s }", err, t.String()) -} - // selectServerFromSubscription loops until a topology description is available for server selection. It returns // when the given context expires, server selection timeout is reached, or a description containing a selectable // server is available. func (t *Topology) selectServerFromSubscription(ctx context.Context, subscriptionCh <-chan description.Topology, selectionState serverSelectionState) ([]description.Server, error) { - var current description.Topology + current := t.Description() for { select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, ServerSelectionError{Wrapped: ctx.Err(), Desc: current} case <-selectionState.timeoutChan: - return nil, wrapServerSelectionError(ErrServerSelectionTimeout, t) + return nil, ServerSelectionError{Wrapped: ErrServerSelectionTimeout, Desc: current} case current = <-subscriptionCh: } @@ -459,6 +477,10 @@ func (t *Topology) selectServerFromDescription(desc description.Topology, // Unlike selectServerFromSubscription, this code path does not check ctx.Done or selectionState.timeoutChan because // selecting a server from a description is not a blocking operation. + if desc.CompatibilityErr != nil { + return nil, desc.CompatibilityErr + } + var allowed []description.Server for _, s := range desc.Servers { if s.Kind != description.Unknown { @@ -468,7 +490,7 @@ func (t *Topology) selectServerFromDescription(desc description.Topology, suitable, err := selectionState.selector.SelectServer(desc, allowed) if err != nil { - return nil, wrapServerSelectionError(err, t) + return nil, ServerSelectionError{Wrapped: err, Desc: desc} } return suitable, nil } @@ -541,7 +563,8 @@ func (t *Topology) processSRVResults(parsedHosts []string) bool { if t.serversClosed { return false } - diff := t.fsm.Topology.DiffHostlist(parsedHosts) + prev := t.fsm.Topology + diff := diffHostList(t.fsm.Topology, parsedHosts) if len(diff.Added) == 0 && len(diff.Removed) == 0 { return true @@ -560,6 +583,7 @@ func (t *Topology) processSRVResults(parsedHosts []string) bool { }() delete(t.servers, addr) t.fsm.removeServerByAddr(addr) + t.publishServerClosedEvent(s.address) } for _, a := range diff.Added { addr := address.Address(a).Canonicalize() @@ -574,6 +598,10 @@ func (t *Topology) processSRVResults(parsedHosts []string) bool { } t.desc.Store(newDesc) + if !prev.Equal(newDesc) { + t.publishTopologyDescriptionChangedEvent(prev, newDesc) + } + t.subLock.Lock() for _, ch := range t.subscribers { // We drain the description if there's one in the channel @@ -602,7 +630,7 @@ func (t *Topology) apply(ctx context.Context, desc description.Server) descripti prev := t.fsm.Topology oldDesc := t.fsm.Servers[ind] - if description.CompareTopologyVersion(oldDesc.TopologyVersion, desc.TopologyVersion) > 0 { + if oldDesc.TopologyVersion.CompareToIncoming(desc.TopologyVersion) > 0 { return oldDesc } @@ -613,7 +641,11 @@ func (t *Topology) apply(ctx context.Context, desc description.Server) descripti return desc } - diff := description.DiffTopology(prev, current) + if !oldDesc.Equal(desc) { + t.publishServerDescriptionChangedEvent(oldDesc, desc) + } + + diff := diffTopology(prev, current) for _, removed := range diff.Removed { if s, ok := t.servers[removed.Addr]; ok { @@ -623,6 +655,7 @@ func (t *Topology) apply(ctx context.Context, desc description.Server) descripti _ = s.Disconnect(cancelCtx) }() delete(t.servers, removed.Addr) + t.publishServerClosedEvent(s.address) } } @@ -631,6 +664,9 @@ func (t *Topology) apply(ctx context.Context, desc description.Server) descripti } t.desc.Store(current) + if !prev.Equal(current) { + t.publishTopologyDescriptionChangedEvent(prev, current) + } t.subLock.Lock() for _, ch := range t.subscribers { @@ -651,7 +687,7 @@ func (t *Topology) addServer(addr address.Address) error { return nil } - svr, err := ConnectServer(addr, t.updateCallback, t.cfg.serverOpts...) + svr, err := ConnectServer(addr, t.updateCallback, t.id, t.cfg.serverOpts...) if err != nil { return err } @@ -673,3 +709,64 @@ func (t *Topology) String() string { } return fmt.Sprintf("Type: %s, Servers: [%s]", desc.Kind, serversStr) } + +// publishes a ServerDescriptionChangedEvent to indicate the server description has changed +func (t *Topology) publishServerDescriptionChangedEvent(prev description.Server, current description.Server) { + serverDescriptionChanged := &event.ServerDescriptionChangedEvent{ + Address: current.Addr, + TopologyID: t.id, + PreviousDescription: prev, + NewDescription: current, + } + + if t.cfg.serverMonitor != nil && t.cfg.serverMonitor.ServerDescriptionChanged != nil { + t.cfg.serverMonitor.ServerDescriptionChanged(serverDescriptionChanged) + } +} + +// publishes a ServerClosedEvent to indicate the server has closed +func (t *Topology) publishServerClosedEvent(addr address.Address) { + serverClosed := &event.ServerClosedEvent{ + Address: addr, + TopologyID: t.id, + } + + if t.cfg.serverMonitor != nil && t.cfg.serverMonitor.ServerClosed != nil { + t.cfg.serverMonitor.ServerClosed(serverClosed) + } +} + +// publishes a TopologyDescriptionChangedEvent to indicate the topology description has changed +func (t *Topology) publishTopologyDescriptionChangedEvent(prev description.Topology, current description.Topology) { + topologyDescriptionChanged := &event.TopologyDescriptionChangedEvent{ + TopologyID: t.id, + PreviousDescription: prev, + NewDescription: current, + } + + if t.cfg.serverMonitor != nil && t.cfg.serverMonitor.TopologyDescriptionChanged != nil { + t.cfg.serverMonitor.TopologyDescriptionChanged(topologyDescriptionChanged) + } +} + +// publishes a TopologyOpeningEvent to indicate the topology is being initialized +func (t *Topology) publishTopologyOpeningEvent() { + topologyOpening := &event.TopologyOpeningEvent{ + TopologyID: t.id, + } + + if t.cfg.serverMonitor != nil && t.cfg.serverMonitor.TopologyOpening != nil { + t.cfg.serverMonitor.TopologyOpening(topologyOpening) + } +} + +// publishes a TopologyClosedEvent to indicate the topology has been closed +func (t *Topology) publishTopologyClosedEvent() { + topologyClosed := &event.TopologyClosedEvent{ + TopologyID: t.id, + } + + if t.cfg.serverMonitor != nil && t.cfg.serverMonitor.TopologyClosed != nil { + t.cfg.serverMonitor.TopologyClosed(topologyClosed) + } +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_options.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_options.go index b58f7f18..c40f5510 100644 --- a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_options.go +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_options.go @@ -17,6 +17,7 @@ import ( "strings" "time" + "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" @@ -34,6 +35,7 @@ type config struct { cs connstring.ConnString // This must not be used for any logic in topology.Topology. uri string serverSelectionTimeout time.Duration + serverMonitor *event.ServerMonitor } func newConfig(opts ...Option) (*config, error) { @@ -274,6 +276,14 @@ func WithServerSelectionTimeout(fn func(time.Duration) time.Duration) Option { } } +// WithTopologyServerMonitor configures the monitor for all SDAM events +func WithTopologyServerMonitor(fn func(*event.ServerMonitor) *event.ServerMonitor) Option { + return func(cfg *config) error { + cfg.serverMonitor = fn(cfg.serverMonitor) + return nil + } +} + // WithURI specifies the URI that was used to create the topology. func WithURI(fn func(string) string) Option { return func(cfg *config) error { diff --git a/vendor/modules.txt b/vendor/modules.txt index 15bc94c4..e12b6533 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -186,7 +186,9 @@ github.com/vmihailenco/tagparser/internal/parser github.com/xdg/scram # github.com/xdg/stringprep v0.0.0-20180714160509-73f8eece6fdc github.com/xdg/stringprep -# go.mongodb.org/mongo-driver v1.4.6 +# github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d +github.com/youmark/pkcs8 +# go.mongodb.org/mongo-driver v1.5.0 ## explicit go.mongodb.org/mongo-driver/bson go.mongodb.org/mongo-driver/bson/bsoncodec @@ -197,6 +199,8 @@ go.mongodb.org/mongo-driver/bson/primitive go.mongodb.org/mongo-driver/event go.mongodb.org/mongo-driver/internal go.mongodb.org/mongo-driver/mongo +go.mongodb.org/mongo-driver/mongo/address +go.mongodb.org/mongo-driver/mongo/description go.mongodb.org/mongo-driver/mongo/options go.mongodb.org/mongo-driver/mongo/readconcern go.mongodb.org/mongo-driver/mongo/readpref @@ -206,11 +210,9 @@ go.mongodb.org/mongo-driver/version go.mongodb.org/mongo-driver/x/bsonx go.mongodb.org/mongo-driver/x/bsonx/bsoncore go.mongodb.org/mongo-driver/x/mongo/driver -go.mongodb.org/mongo-driver/x/mongo/driver/address go.mongodb.org/mongo-driver/x/mongo/driver/auth go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/gssapi go.mongodb.org/mongo-driver/x/mongo/driver/connstring -go.mongodb.org/mongo-driver/x/mongo/driver/description go.mongodb.org/mongo-driver/x/mongo/driver/dns go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options