Skip to content

Commit

Permalink
MOTOR-1209: Motor's DriverInfo should not be overwritten (#233)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jibola committed Nov 14, 2023
1 parent d58545a commit a1c57b7
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
29 changes: 25 additions & 4 deletions motor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,25 @@ def __init__(self, *args, **kwargs):
self._io_loop = io_loop

kwargs.setdefault("connect", False)
kwargs.setdefault(
"driver", DriverInfo("Motor", motor_version, self._framework.platform_info())
)

driver_info = DriverInfo("Motor", motor_version, self._framework.platform_info())

if kwargs.get("driver"):
provided_info = kwargs.get("driver")
if not isinstance(provided_info, DriverInfo):
raise TypeError(
f"Incorrect type for `driver` {type(provided_info)};"
" expected value of type pymongo.driver_info.DriverInfo"
)
added_version = f"|{provided_info.version}" if provided_info.version else ""
added_platform = f"|{provided_info.platform}" if provided_info.platform else ""
driver_info = DriverInfo(
f"{driver_info.name}|{provided_info.name}",
f"{driver_info.version}{added_version}",
f"{driver_info.platform}{added_platform}",
)

kwargs["driver"] = driver_info

delegate = self.__delegate_class__(*args, **kwargs)
super().__init__(delegate)
Expand Down Expand Up @@ -1650,7 +1666,12 @@ def to_list(self, length):
else:
the_list = []
self._framework.add_future(
self.get_io_loop(), self._get_more(), self._to_list, length, the_list, future
self.get_io_loop(),
self._get_more(),
self._to_list,
length,
the_list,
future,
)

return future
Expand Down
35 changes: 35 additions & 0 deletions test/tornado_tests/test_motor_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from bson import CodecOptions
from mockupdb import OpQuery
from pymongo import CursorType, ReadPreference, WriteConcern
from pymongo.driver_info import DriverInfo
from pymongo.errors import ConnectionFailure, OperationFailure
from tornado import gen
from tornado.testing import gen_test
Expand Down Expand Up @@ -305,6 +306,40 @@ async def test_handshake(self):
except Exception:
pass

@gen_test
async def test_driver_info(self):
server = self.server()
driver_info = DriverInfo(name="Foo", version="1.1.1", platform="FooPlat")
client = motor.MotorClient(server.uri, driver=driver_info)

# Trigger connection.
future = client.db.command("ping")
handshake = await self.run_thread(server.receives, "ismaster")
meta = handshake.doc["client"]
self.assertEqual(f"PyMongo|Motor|{driver_info.name}", meta["driver"]["name"])
self.assertIn("Tornado", meta["platform"])
self.assertIn(f"|{driver_info.platform}", meta["platform"])
self.assertTrue(
meta["driver"]["version"].endswith(f"{motor.version}|{driver_info.version}"),
"Version in handshake [%s] doesn't end with MotorVersion|Test version [%s]"
% (meta["driver"]["version"], f"{motor.version}|{driver_info.version}"),
)

handshake.ok()
server.stop()
client.close()
try:
await future
except Exception:
pass

def test_incorrect_driver_info(self):
with self.assertRaises(
TypeError,
msg="Allowed invalid type parameter str, driver should only be of DriverInfo",
):
motor.MotorClient(driver="string")


if __name__ == "__main__":
unittest.main()

0 comments on commit a1c57b7

Please sign in to comment.