From 92e6150d84f128463f6e8f5f6b9d0e2537fef64e Mon Sep 17 00:00:00 2001 From: Julius Park Date: Thu, 10 Nov 2022 14:19:55 -0800 Subject: [PATCH] PYTHON-3493 Bulk Write InsertOne Should Be Parameter Of Collection Type (#1106) --- doc/examples/type_hints.rst | 20 ++++++++ mypy.ini | 2 +- pymongo/collection.py | 11 ++++- pymongo/encryption.py | 5 +- pymongo/operations.py | 13 +++--- pymongo/typings.py | 7 +++ test/__init__.py | 2 +- test/mockupdb/test_cluster_time.py | 2 +- test/mockupdb/test_op_msg.py | 6 +-- test/test_bulk.py | 4 +- test/test_client.py | 1 + test/test_database.py | 5 +- test/test_mypy.py | 75 +++++++++++++++++++++++++++--- test/test_server_selection.py | 6 ++- test/test_session.py | 6 ++- test/test_transactions.py | 4 +- test/utils.py | 13 +++--- 17 files changed, 144 insertions(+), 38 deletions(-) diff --git a/doc/examples/type_hints.rst b/doc/examples/type_hints.rst index 38349038b1..b413ad7b24 100644 --- a/doc/examples/type_hints.rst +++ b/doc/examples/type_hints.rst @@ -113,6 +113,26 @@ These methods automatically add an "_id" field. >>> assert result is not None >>> assert result["year"] == 1993 >>> # This will raise a type-checking error, despite being present, because it is added by PyMongo. + >>> assert result["_id"] # type:ignore[typeddict-item] + +This same typing scheme works for all of the insert methods (:meth:`~pymongo.collection.Collection.insert_one`, +:meth:`~pymongo.collection.Collection.insert_many`, and :meth:`~pymongo.collection.Collection.bulk_write`). +For `bulk_write` both :class:`~pymongo.operations.InsertOne` and :class:`~pymongo.operations.ReplaceOne` operators are generic. + +.. doctest:: + :pyversion: >= 3.8 + + >>> from typing import TypedDict + >>> from pymongo import MongoClient + >>> from pymongo.operations import InsertOne + >>> from pymongo.collection import Collection + >>> client: MongoClient = MongoClient() + >>> collection: Collection[Movie] = client.test.test + >>> inserted = collection.bulk_write([InsertOne(Movie(name="Jurassic Park", year=1993))]) + >>> result = collection.find_one({"name": "Jurassic Park"}) + >>> assert result is not None + >>> assert result["year"] == 1993 + >>> # This will raise a type-checking error, despite being present, because it is added by PyMongo. >>> assert result["_id"] # type:ignore[typeddict-item] Modeling Document Types with TypedDict diff --git a/mypy.ini b/mypy.ini index 9b1348472c..2562177ab1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -33,7 +33,7 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-test.test_mypy] -warn_unused_ignores = false +warn_unused_ignores = True [mypy-winkerberos.*] ignore_missing_imports = True diff --git a/pymongo/collection.py b/pymongo/collection.py index 23efe8fd35..600d73c4bc 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -77,7 +77,14 @@ _FIND_AND_MODIFY_DOC_FIELDS = {"value": 1} -_WriteOp = Union[InsertOne, DeleteOne, DeleteMany, ReplaceOne, UpdateOne, UpdateMany] +_WriteOp = Union[ + InsertOne[_DocumentType], + DeleteOne, + DeleteMany, + ReplaceOne[_DocumentType], + UpdateOne, + UpdateMany, +] # Hint supports index name, "myIndex", or list of index pairs: [('x', 1), ('y', -1)] _IndexList = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]] _IndexKeyHint = Union[str, _IndexList] @@ -436,7 +443,7 @@ def with_options( @_csot.apply def bulk_write( self, - requests: Sequence[_WriteOp], + requests: Sequence[_WriteOp[_DocumentType]], ordered: bool = True, bypass_document_validation: bool = False, session: Optional["ClientSession"] = None, diff --git a/pymongo/encryption.py b/pymongo/encryption.py index 9fef5963a6..92a268f452 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -18,7 +18,7 @@ import enum import socket import weakref -from typing import Any, Mapping, Optional, Sequence +from typing import Any, Generic, Mapping, Optional, Sequence try: from pymongocrypt.auto_encrypter import AutoEncrypter @@ -55,6 +55,7 @@ from pymongo.read_concern import ReadConcern from pymongo.results import BulkWriteResult, DeleteResult from pymongo.ssl_support import get_ssl_context +from pymongo.typings import _DocumentType from pymongo.uri_parser import parse_host from pymongo.write_concern import WriteConcern @@ -430,7 +431,7 @@ class QueryType(str, enum.Enum): """Used to encrypt a value for an equality query.""" -class ClientEncryption(object): +class ClientEncryption(Generic[_DocumentType]): """Explicit client-side field level encryption.""" def __init__( diff --git a/pymongo/operations.py b/pymongo/operations.py index 84e8bf4d35..92a4dad0ac 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -13,21 +13,22 @@ # limitations under the License. """Operation class definitions.""" -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Generic, List, Mapping, Optional, Sequence, Tuple, Union +from bson.raw_bson import RawBSONDocument from pymongo import helpers from pymongo.collation import validate_collation_or_none from pymongo.common import validate_boolean, validate_is_mapping, validate_list from pymongo.helpers import _gen_index_name, _index_document, _index_list -from pymongo.typings import _CollationIn, _DocumentIn, _Pipeline +from pymongo.typings import _CollationIn, _DocumentType, _Pipeline -class InsertOne(object): +class InsertOne(Generic[_DocumentType]): """Represents an insert_one operation.""" __slots__ = ("_doc",) - def __init__(self, document: _DocumentIn) -> None: + def __init__(self, document: Union[_DocumentType, RawBSONDocument]) -> None: """Create an InsertOne instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. @@ -170,7 +171,7 @@ def __ne__(self, other: Any) -> bool: return not self == other -class ReplaceOne(object): +class ReplaceOne(Generic[_DocumentType]): """Represents a replace_one operation.""" __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint") @@ -178,7 +179,7 @@ class ReplaceOne(object): def __init__( self, filter: Mapping[str, Any], - replacement: Mapping[str, Any], + replacement: Union[_DocumentType, RawBSONDocument], upsert: bool = False, collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, diff --git a/pymongo/typings.py b/pymongo/typings.py index 14e059a8f0..fe0e8bd523 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -37,3 +37,10 @@ _Pipeline = Sequence[Mapping[str, Any]] _DocumentOut = _DocumentIn _DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any]) + + +def strip_optional(elem): + """This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T + while inside a list comprehension.""" + assert elem is not None + return elem diff --git a/test/__init__.py b/test/__init__.py index eb66e45667..20b1d00ca8 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1090,7 +1090,7 @@ def print_thread_stacks(pid: int) -> None: class IntegrationTest(PyMongoTestCase): """Base class for TestCases that need a connection to MongoDB to pass.""" - client: MongoClient + client: MongoClient[dict] db: Database credentials: Dict[str, str] diff --git a/test/mockupdb/test_cluster_time.py b/test/mockupdb/test_cluster_time.py index cb06a129d2..e4f3e12d07 100644 --- a/test/mockupdb/test_cluster_time.py +++ b/test/mockupdb/test_cluster_time.py @@ -60,7 +60,7 @@ def callback(client): self.cluster_time_conversation(callback, [{"ok": 1}] * 2) def test_bulk(self): - def callback(client): + def callback(client: MongoClient[dict]) -> None: client.db.collection.bulk_write( [InsertOne({}), InsertOne({}), UpdateOne({}, {"$inc": {"x": 1}}), DeleteMany({})] ) diff --git a/test/mockupdb/test_op_msg.py b/test/mockupdb/test_op_msg.py index da7ff3d33e..22fe38fd02 100755 --- a/test/mockupdb/test_op_msg.py +++ b/test/mockupdb/test_op_msg.py @@ -137,14 +137,14 @@ # Legacy methods Operation( "bulk_write_insert", - lambda coll: coll.bulk_write([InsertOne({}), InsertOne({})]), + lambda coll: coll.bulk_write([InsertOne[dict]({}), InsertOne[dict]({})]), request=OpMsg({"insert": "coll"}, flags=0), reply={"ok": 1, "n": 2}, ), Operation( "bulk_write_insert-w0", lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write( - [InsertOne({}), InsertOne({})] + [InsertOne[dict]({}), InsertOne[dict]({})] ), request=OpMsg({"insert": "coll"}, flags=0), reply={"ok": 1, "n": 2}, @@ -152,7 +152,7 @@ Operation( "bulk_write_insert-w0-unordered", lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write( - [InsertOne({}), InsertOne({})], ordered=False + [InsertOne[dict]({}), InsertOne[dict]({})], ordered=False ), request=OpMsg({"insert": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]), reply=None, diff --git a/test/test_bulk.py b/test/test_bulk.py index fae1c7e201..ac7073c0ef 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -296,7 +296,7 @@ def test_upsert(self): def test_numerous_inserts(self): # Ensure we don't exceed server's maxWriteBatchSize size limit. n_docs = client_context.max_write_batch_size + 100 - requests = [InsertOne({}) for _ in range(n_docs)] + requests = [InsertOne[dict]({}) for _ in range(n_docs)] result = self.coll.bulk_write(requests, ordered=False) self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, self.coll.count_documents({})) @@ -347,7 +347,7 @@ def test_bulk_write_no_results(self): def test_bulk_write_invalid_arguments(self): # The requests argument must be a list. - generator = (InsertOne({}) for _ in range(10)) + generator = (InsertOne[dict]({}) for _ in range(10)) with self.assertRaises(TypeError): self.coll.bulk_write(generator) # type: ignore[arg-type] diff --git a/test/test_client.py b/test/test_client.py index 5bb116dbda..a33881fded 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1652,6 +1652,7 @@ def test_network_error_message(self): with self.fail_point( {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} ): + assert client.address is not None expected = "%s:%s: " % client.address with self.assertRaisesRegex(AutoReconnect, expected): client.pymongo_test.test.find_one({}) diff --git a/test/test_database.py b/test/test_database.py index d49ac8324f..49387b8bb9 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -16,7 +16,7 @@ import re import sys -from typing import Any, Iterable, List, Mapping +from typing import Any, Iterable, List, Mapping, Union sys.path[0:0] = [""] @@ -201,7 +201,7 @@ def test_list_collection_names_filter(self): db.capped.insert_one({}) db.non_capped.insert_one({}) self.addCleanup(client.drop_database, db.name) - + filter: Union[None, dict] # Should not send nameOnly. for filter in ({"options.capped": True}, {"options.capped": True, "name": "capped"}): results.clear() @@ -210,7 +210,6 @@ def test_list_collection_names_filter(self): self.assertNotIn("nameOnly", results["started"][0].command) # Should send nameOnly (except on 2.6). - filter: Any for filter in (None, {}, {"name": {"$in": ["capped", "non_capped"]}}): results.clear() names = db.list_collection_names(filter=filter) diff --git a/test/test_mypy.py b/test/test_mypy.py index 807f0e8ef3..58e69853ca 100644 --- a/test/test_mypy.py +++ b/test/test_mypy.py @@ -17,7 +17,7 @@ import os import tempfile import unittest -from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Union try: from typing_extensions import NotRequired, TypedDict @@ -42,7 +42,7 @@ class ImplicitMovie(TypedDict): Movie = dict # type:ignore[misc,assignment] ImplicitMovie = dict # type: ignore[assignment,misc] MovieWithId = dict # type: ignore[assignment,misc] - TypedDict = None # type: ignore[assignment] + TypedDict = None NotRequired = None # type: ignore[assignment] @@ -59,7 +59,7 @@ class ImplicitMovie(TypedDict): from bson.son import SON from pymongo import ASCENDING, MongoClient from pymongo.collection import Collection -from pymongo.operations import InsertOne +from pymongo.operations import DeleteOne, InsertOne, ReplaceOne from pymongo.read_preferences import ReadPreference TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mypy_fails") @@ -124,11 +124,40 @@ def to_list(iterable: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]: docs = to_list(cursor) self.assertTrue(docs) + @only_type_check def test_bulk_write(self) -> None: self.coll.insert_one({}) - requests = [InsertOne({})] - result = self.coll.bulk_write(requests) - self.assertTrue(result.acknowledged) + coll: Collection[Movie] = self.coll + requests: List[InsertOne[Movie]] = [InsertOne(Movie(name="American Graffiti", year=1973))] + self.assertTrue(coll.bulk_write(requests).acknowledged) + new_requests: List[Union[InsertOne[Movie], ReplaceOne[Movie]]] = [] + input_list: List[Union[InsertOne[Movie], ReplaceOne[Movie]]] = [ + InsertOne(Movie(name="American Graffiti", year=1973)), + ReplaceOne({}, Movie(name="American Graffiti", year=1973)), + ] + for i in input_list: + new_requests.append(i) + self.assertTrue(coll.bulk_write(new_requests).acknowledged) + + # Because ReplaceOne is not generic, type checking is not enforced for ReplaceOne in the first example. + @only_type_check + def test_bulk_write_heterogeneous(self): + coll: Collection[Movie] = self.coll + requests: List[Union[InsertOne[Movie], ReplaceOne, DeleteOne]] = [ + InsertOne(Movie(name="American Graffiti", year=1973)), + ReplaceOne({}, {"name": "American Graffiti", "year": "WRONG_TYPE"}), + DeleteOne({}), + ] + self.assertTrue(coll.bulk_write(requests).acknowledged) + requests_two: List[Union[InsertOne[Movie], ReplaceOne[Movie], DeleteOne]] = [ + InsertOne(Movie(name="American Graffiti", year=1973)), + ReplaceOne( + {}, + {"name": "American Graffiti", "year": "WRONG_TYPE"}, # type:ignore[typeddict-item] + ), + DeleteOne({}), + ] + self.assertTrue(coll.bulk_write(requests_two).acknowledged) def test_command(self) -> None: result: Dict = self.client.admin.command("ping") @@ -340,6 +369,40 @@ def test_typeddict_document_type_insertion(self) -> None: ) coll.insert_many([bad_movie]) + @only_type_check + def test_bulk_write_document_type_insertion(self): + client: MongoClient[MovieWithId] = MongoClient() + coll: Collection[MovieWithId] = client.test.test + coll.bulk_write( + [InsertOne(Movie({"name": "THX-1138", "year": 1971}))] # type:ignore[arg-type] + ) + mov_dict = {"_id": ObjectId(), "name": "THX-1138", "year": 1971} + coll.bulk_write( + [InsertOne(mov_dict)] # type:ignore[arg-type] + ) + coll.bulk_write( + [ + InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971}) + ] # No error because it is in-line. + ) + + @only_type_check + def test_bulk_write_document_type_replacement(self): + client: MongoClient[MovieWithId] = MongoClient() + coll: Collection[MovieWithId] = client.test.test + coll.bulk_write( + [ReplaceOne({}, Movie({"name": "THX-1138", "year": 1971}))] # type:ignore[arg-type] + ) + mov_dict = {"_id": ObjectId(), "name": "THX-1138", "year": 1971} + coll.bulk_write( + [ReplaceOne({}, mov_dict)] # type:ignore[arg-type] + ) + coll.bulk_write( + [ + ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971}) + ] # No error because it is in-line. + ) + @only_type_check def test_typeddict_explicit_document_type(self) -> None: out = MovieWithId(_id=ObjectId(), name="THX-1138", year=1971) diff --git a/test/test_server_selection.py b/test/test_server_selection.py index a80d5f13d9..c3f3762f9a 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -23,6 +23,7 @@ from pymongo.server_selectors import writable_server_selector from pymongo.settings import TopologySettings from pymongo.topology import Topology +from pymongo.typings import strip_optional sys.path[0:0] = [""] @@ -85,7 +86,10 @@ def all_hosts_started(): ) wait_until(all_hosts_started, "receive heartbeat from all hosts") - expected_port = max([n.address[1] for n in client._topology._description.readable_servers]) + + expected_port = max( + [strip_optional(n.address[1]) for n in client._topology._description.readable_servers] + ) # Insert 1 record and access it 10 times. coll.insert_one({"name": "John Doe"}) diff --git a/test/test_session.py b/test/test_session.py index f22a2d5eab..386bab295c 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -898,7 +898,9 @@ def _test_writes(self, op): @client_context.require_no_standalone def test_writes(self): - self._test_writes(lambda coll, session: coll.bulk_write([InsertOne({})], session=session)) + self._test_writes( + lambda coll, session: coll.bulk_write([InsertOne[dict]({})], session=session) + ) self._test_writes(lambda coll, session: coll.insert_one({}, session=session)) self._test_writes(lambda coll, session: coll.insert_many([{}], session=session)) self._test_writes( @@ -944,7 +946,7 @@ def _test_no_read_concern(self, op): @client_context.require_no_standalone def test_writes_do_not_include_read_concern(self): self._test_no_read_concern( - lambda coll, session: coll.bulk_write([InsertOne({})], session=session) + lambda coll, session: coll.bulk_write([InsertOne[dict]({})], session=session) ) self._test_no_read_concern(lambda coll, session: coll.insert_one({}, session=session)) self._test_no_read_concern(lambda coll, session: coll.insert_many([{}], session=session)) diff --git a/test/test_transactions.py b/test/test_transactions.py index 4cee3fa236..02e691329e 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -363,7 +363,7 @@ def test_transaction_direct_connection(self): coll.insert_one({}) self.assertEqual(client.topology_description.topology_type_name, "Single") ops = [ - (coll.bulk_write, [[InsertOne({})]]), + (coll.bulk_write, [[InsertOne[dict]({})]]), (coll.insert_one, [{}]), (coll.insert_many, [[{}, {}]]), (coll.replace_one, [{}, {}]), @@ -385,7 +385,7 @@ def test_transaction_direct_connection(self): ] for f, args in ops: with client.start_session() as s, s.start_transaction(): - res = f(*args, session=s) + res = f(*args, session=s) # type:ignore[operator] if isinstance(res, (CommandCursor, Cursor)): list(res) diff --git a/test/utils.py b/test/utils.py index 59349f4fdc..6b0876a158 100644 --- a/test/utils.py +++ b/test/utils.py @@ -29,6 +29,7 @@ from collections import abc, defaultdict from functools import partial from test import client_context, db_pwd, db_user +from typing import Any from bson import json_util from bson.objectid import ObjectId @@ -557,27 +558,27 @@ def _mongo_client(host, port, authenticate=True, directConnection=None, **kwargs return MongoClient(uri, port, **client_options) -def single_client_noauth(h=None, p=None, **kwargs): +def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: """Make a direct connection. Don't authenticate.""" return _mongo_client(h, p, authenticate=False, directConnection=True, **kwargs) -def single_client(h=None, p=None, **kwargs): +def single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: """Make a direct connection, and authenticate if necessary.""" return _mongo_client(h, p, directConnection=True, **kwargs) -def rs_client_noauth(h=None, p=None, **kwargs): +def rs_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: """Connect to the replica set. Don't authenticate.""" return _mongo_client(h, p, authenticate=False, **kwargs) -def rs_client(h=None, p=None, **kwargs): +def rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: """Connect to the replica set and authenticate if necessary.""" return _mongo_client(h, p, **kwargs) -def rs_or_single_client_noauth(h=None, p=None, **kwargs): +def rs_or_single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: """Connect to the replica set if there is one, otherwise the standalone. Like rs_or_single_client, but does not authenticate. @@ -585,7 +586,7 @@ def rs_or_single_client_noauth(h=None, p=None, **kwargs): return _mongo_client(h, p, authenticate=False, **kwargs) -def rs_or_single_client(h=None, p=None, **kwargs): +def rs_or_single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[Any]: """Connect to the replica set if there is one, otherwise the standalone. Authenticates if necessary.