Skip to content

Commit

Permalink
a large number of type changes
Browse files Browse the repository at this point in the history
  • Loading branch information
juliusgeo committed Nov 3, 2022
1 parent 032271e commit cf8f8b0
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 20 deletions.
23 changes: 23 additions & 0 deletions doc/examples/type_hints.rst
Expand Up @@ -114,6 +114,29 @@ These methods automatically add an "_id" field.
>>> # This will not be type checked, despite being present, because it is added by PyMongo.
>>> assert type(result["_id"]) == ObjectId

This same typing scheme works for all of the insert methods (`insert_one`, `insert_many`, and `bulk_write`). For `bulk_write`,
both `InsertOne/Many` and `ReplaceOne/Many` operators are generic.

.. doctest::

>>> from typing import TypedDict
>>> from pymongo import MongoClient
>>> from pymongo.operations import InsertOne
>>> from pymongo.collection import Collection
>>> class Movie(TypedDict):
... name: str
... year: int
...
>>> client: MongoClient = MongoClient()
>>> collection: Collection[Movie] = client.test.test
>>> inserted = collection.bulk_write([InsertOne(Movie(name="Jurassic Park", year=1993))])
>>> result = collection.find_one({"name": "Jurassic Park"})
>>> assert result is not None
>>> assert result["year"] == 1993
>>> # This will not be type checked, despite being present, because it is added by PyMongo.
>>> assert type(result["_id"]) == ObjectId


Typed Database
--------------

Expand Down
2 changes: 1 addition & 1 deletion test/__init__.py
Expand Up @@ -1090,7 +1090,7 @@ def print_thread_stacks(pid: int) -> None:
class IntegrationTest(PyMongoTestCase):
"""Base class for TestCases that need a connection to MongoDB to pass."""

client: MongoClient
client: MongoClient[dict]
db: Database
credentials: Dict[str, str]

Expand Down
2 changes: 1 addition & 1 deletion test/mockupdb/test_cluster_time.py
Expand Up @@ -60,7 +60,7 @@ def callback(client):
self.cluster_time_conversation(callback, [{"ok": 1}] * 2)

def test_bulk(self):
def callback(client):
def callback(client: MongoClient[dict]) -> None:
client.db.collection.bulk_write(
[InsertOne({}), InsertOne({}), UpdateOne({}, {"$inc": {"x": 1}}), DeleteMany({})]
)
Expand Down
6 changes: 3 additions & 3 deletions test/mockupdb/test_op_msg.py
Expand Up @@ -137,22 +137,22 @@
# Legacy methods
Operation(
"bulk_write_insert",
lambda coll: coll.bulk_write([InsertOne({}), InsertOne({})]),
lambda coll: coll.bulk_write([InsertOne[dict]({}), InsertOne[dict]({})]),
request=OpMsg({"insert": "coll"}, flags=0),
reply={"ok": 1, "n": 2},
),
Operation(
"bulk_write_insert-w0",
lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write(
[InsertOne({}), InsertOne({})]
[InsertOne[dict]({}), InsertOne[dict]({})]
),
request=OpMsg({"insert": "coll"}, flags=0),
reply={"ok": 1, "n": 2},
),
Operation(
"bulk_write_insert-w0-unordered",
lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write(
[InsertOne({}), InsertOne({})], ordered=False
[InsertOne[dict]({}), InsertOne[dict]({})], ordered=False
),
request=OpMsg({"insert": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]),
reply=None,
Expand Down
4 changes: 2 additions & 2 deletions test/test_bulk.py
Expand Up @@ -296,7 +296,7 @@ def test_upsert(self):
def test_numerous_inserts(self):
# Ensure we don't exceed server's maxWriteBatchSize size limit.
n_docs = client_context.max_write_batch_size + 100
requests = [InsertOne({}) for _ in range(n_docs)]
requests = [InsertOne[dict]({}) for _ in range(n_docs)]
result = self.coll.bulk_write(requests, ordered=False)
self.assertEqual(n_docs, result.inserted_count)
self.assertEqual(n_docs, self.coll.count_documents({}))
Expand Down Expand Up @@ -347,7 +347,7 @@ def test_bulk_write_no_results(self):

def test_bulk_write_invalid_arguments(self):
# The requests argument must be a list.
generator = (InsertOne({}) for _ in range(10))
generator = (InsertOne[dict]({}) for _ in range(10))
with self.assertRaises(TypeError):
self.coll.bulk_write(generator) # type: ignore[arg-type]

Expand Down
1 change: 1 addition & 0 deletions test/test_client.py
Expand Up @@ -1652,6 +1652,7 @@ def test_network_error_message(self):
with self.fail_point(
{"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}
):
assert client.address is not None
expected = "%s:%s: " % client.address
with self.assertRaisesRegex(AutoReconnect, expected):
client.pymongo_test.test.find_one({})
Expand Down
2 changes: 1 addition & 1 deletion test/test_database.py
Expand Up @@ -201,7 +201,7 @@ def test_list_collection_names_filter(self):
db.capped.insert_one({})
db.non_capped.insert_one({})
self.addCleanup(client.drop_database, db.name)

filter: dict
# Should not send nameOnly.
for filter in ({"options.capped": True}, {"options.capped": True, "name": "capped"}):
results.clear()
Expand Down
4 changes: 3 additions & 1 deletion test/test_server_selection.py
Expand Up @@ -85,7 +85,9 @@ def all_hosts_started():
)

wait_until(all_hosts_started, "receive heartbeat from all hosts")
expected_port = max([n.address[1] for n in client._topology._description.readable_servers])
expected_port = max(
[n.address[1] for n in client._topology._description.readable_servers]
) # type:ignore[type-var]

# Insert 1 record and access it 10 times.
coll.insert_one({"name": "John Doe"})
Expand Down
10 changes: 7 additions & 3 deletions test/test_session.py
Expand Up @@ -898,7 +898,9 @@ def _test_writes(self, op):

@client_context.require_no_standalone
def test_writes(self):
self._test_writes(lambda coll, session: coll.bulk_write([InsertOne({})], session=session))
self._test_writes(
lambda coll, session: coll.bulk_write([InsertOne[dict]({})], session=session)
)
self._test_writes(lambda coll, session: coll.insert_one({}, session=session))
self._test_writes(lambda coll, session: coll.insert_many([{}], session=session))
self._test_writes(
Expand Down Expand Up @@ -944,7 +946,7 @@ def _test_no_read_concern(self, op):
@client_context.require_no_standalone
def test_writes_do_not_include_read_concern(self):
self._test_no_read_concern(
lambda coll, session: coll.bulk_write([InsertOne({})], session=session)
lambda coll, session: coll.bulk_write([InsertOne[dict]({})], session=session)
)
self._test_no_read_concern(lambda coll, session: coll.insert_one({}, session=session))
self._test_no_read_concern(lambda coll, session: coll.insert_many([{}], session=session))
Expand Down Expand Up @@ -1077,7 +1079,9 @@ def setUp(self):
def test_cluster_time(self):
listener = SessionTestListener()
# Prevent heartbeats from updating $clusterTime between operations.
client = rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999)
client: MongoClient[dict] = rs_or_single_client(
event_listeners=[listener], heartbeatFrequencyMS=999999
)
self.addCleanup(client.close)
collection = client.pymongo_test.collection
# Prepare for tests of find() and aggregate().
Expand Down
4 changes: 2 additions & 2 deletions test/test_transactions.py
Expand Up @@ -363,7 +363,7 @@ def test_transaction_direct_connection(self):
coll.insert_one({})
self.assertEqual(client.topology_description.topology_type_name, "Single")
ops = [
(coll.bulk_write, [[InsertOne({})]]),
(coll.bulk_write, [[InsertOne[dict]({})]]),
(coll.insert_one, [{}]),
(coll.insert_many, [[{}, {}]]),
(coll.replace_one, [{}, {}]),
Expand All @@ -385,7 +385,7 @@ def test_transaction_direct_connection(self):
]
for f, args in ops:
with client.start_session() as s, s.start_transaction():
res = f(*args, session=s)
res = f(*args, session=s) # type:ignore[operator]
if isinstance(res, (CommandCursor, Cursor)):
list(res)

Expand Down
13 changes: 7 additions & 6 deletions test/utils.py
Expand Up @@ -29,6 +29,7 @@
from collections import abc, defaultdict
from functools import partial
from test import client_context, db_pwd, db_user
from typing import Any

from bson import json_util
from bson.objectid import ObjectId
Expand Down Expand Up @@ -557,35 +558,35 @@ def _mongo_client(host, port, authenticate=True, directConnection=None, **kwargs
return MongoClient(uri, port, **client_options)


def single_client_noauth(h=None, p=None, **kwargs):
def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Make a direct connection. Don't authenticate."""
return _mongo_client(h, p, authenticate=False, directConnection=True, **kwargs)


def single_client(h=None, p=None, **kwargs):
def single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Make a direct connection, and authenticate if necessary."""
return _mongo_client(h, p, directConnection=True, **kwargs)


def rs_client_noauth(h=None, p=None, **kwargs):
def rs_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set. Don't authenticate."""
return _mongo_client(h, p, authenticate=False, **kwargs)


def rs_client(h=None, p=None, **kwargs):
def rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set and authenticate if necessary."""
return _mongo_client(h, p, **kwargs)


def rs_or_single_client_noauth(h=None, p=None, **kwargs):
def rs_or_single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set if there is one, otherwise the standalone.
Like rs_or_single_client, but does not authenticate.
"""
return _mongo_client(h, p, authenticate=False, **kwargs)


def rs_or_single_client(h=None, p=None, **kwargs):
def rs_or_single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
"""Connect to the replica set if there is one, otherwise the standalone.
Authenticates if necessary.
Expand Down

0 comments on commit cf8f8b0

Please sign in to comment.