diff --git a/debug_toolbar/management/commands/debugsqlshell.py b/debug_toolbar/management/commands/debugsqlshell.py index ea39f3e1c..1bff9220b 100644 --- a/debug_toolbar/management/commands/debugsqlshell.py +++ b/debug_toolbar/management/commands/debugsqlshell.py @@ -2,7 +2,9 @@ import sqlparse from django.core.management.commands.shell import Command # noqa +from django.db import connection from django.db.backends import utils as db_backends_utils +from django.db.backends.postgresql import base as psql_base # 'debugsqlshell' is the same as the 'shell'. @@ -20,4 +22,7 @@ def execute(self, sql, params=()): print("{} [{:.2f}ms]".format(formatted_sql, duration)) -db_backends_utils.CursorDebugWrapper = PrintQueryWrapper +if connection.vendor == "postgresql": + psql_base.CursorDebugWrapper = PrintQueryWrapper +else: + db_backends_utils.CursorDebugWrapper = PrintQueryWrapper diff --git a/tests/commands/test_debugsqlshell.py b/tests/commands/test_debugsqlshell.py index 54cd248e0..1c41369d1 100644 --- a/tests/commands/test_debugsqlshell.py +++ b/tests/commands/test_debugsqlshell.py @@ -4,6 +4,7 @@ from django.contrib.auth.models import User from django.core import management from django.db.backends import utils as db_backends_utils +from django.db.backends.postgresql import base as psql_base from django.test import TestCase from django.test.utils import override_settings @@ -11,7 +12,8 @@ @override_settings(DEBUG=True) class DebugSQLShellTestCase(TestCase): def setUp(self): - self.original_cursor_wrapper = db_backends_utils.CursorDebugWrapper + self.original_utils_wrapper = db_backends_utils.CursorDebugWrapper + self.original_psql_wrapper = psql_base.CursorDebugWrapper # Since debugsqlshell monkey-patches django.db.backends.utils, we can # test it simply by loading it, without executing it. But we have to # undo the monkey-patch on exit. @@ -20,7 +22,8 @@ def setUp(self): management.load_command_class(app_name, command_name) def tearDown(self): - db_backends_utils.CursorDebugWrapper = self.original_cursor_wrapper + db_backends_utils.CursorDebugWrapper = self.original_utils_wrapper + psql_base.CursorDebugWrapper = self.original_psql_wrapper def test_command(self): original_stdout, sys.stdout = sys.stdout, io.StringIO()