diff --git a/test/test_cursor.py b/test/test_cursor.py index 021e4d7cb4..b2de09429b 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1483,10 +1483,13 @@ def test_find_raw_transaction(self): session=session).sort('_id')) cmd = listener.results['started'][0] self.assertEqual(cmd.command_name, 'find') - self.assertEqual(cmd.command['$clusterTime'], - decode_all(session.cluster_time.raw)[0]) + self.assertIn('$clusterTime', cmd.command) self.assertEqual(cmd.command['startTransaction'], True) self.assertEqual(cmd.command['txnNumber'], 1) + # Ensure we update $clusterTime from the command response. + last_cmd = listener.results['succeeded'][-1] + self.assertEqual(last_cmd.reply['$clusterTime']['clusterTime'], + session.cluster_time['clusterTime']) self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1677,9 +1680,13 @@ def test_aggregate_raw_transaction(self): [{'$sort': {'_id': 1}}], session=session)) cmd = listener.results['started'][0] self.assertEqual(cmd.command_name, 'aggregate') - self.assertEqual(cmd.command['$clusterTime'], session.cluster_time) + self.assertIn('$clusterTime', cmd.command) self.assertEqual(cmd.command['startTransaction'], True) self.assertEqual(cmd.command['txnNumber'], 1) + # Ensure we update $clusterTime from the command response. + last_cmd = listener.results['succeeded'][-1] + self.assertEqual(last_cmd.reply['$clusterTime']['clusterTime'], + session.cluster_time['clusterTime']) self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0]))