diff --git a/debug_toolbar/panels/sql/forms.py b/debug_toolbar/panels/sql/forms.py index 4131cb775..3e1370411 100644 --- a/debug_toolbar/panels/sql/forms.py +++ b/debug_toolbar/panels/sql/forms.py @@ -52,7 +52,6 @@ def clean_raw_sql(self): def clean_params(self): value = self.cleaned_data["params"] - try: return json.loads(value) except ValueError: diff --git a/debug_toolbar/panels/sql/tracking.py b/debug_toolbar/panels/sql/tracking.py index c16f2319f..75366802c 100644 --- a/debug_toolbar/panels/sql/tracking.py +++ b/debug_toolbar/panels/sql/tracking.py @@ -8,6 +8,11 @@ from debug_toolbar import settings as dt_settings from debug_toolbar.utils import get_stack, get_template_info, tidy_stacktrace +try: + from psycopg2._json import Json as PostgresJson +except ImportError: + PostgresJson = None + class SQLQueryTriggered(Exception): """Thrown when template panel triggers a query""" @@ -105,6 +110,8 @@ def _quote_params(self, params): return [self._quote_expr(p) for p in params] def _decode(self, param): + if PostgresJson and isinstance(param, PostgresJson): + return param.dumps(param.adapted) # If a sequence type, decode each element separately if isinstance(param, (tuple, list)): return [self._decode(element) for element in param] @@ -136,7 +143,6 @@ def _record(self, method, sql, params): _params = json.dumps(self._decode(params)) except TypeError: pass # object not JSON serializable - template_info = get_template_info() alias = getattr(self.db, "alias", "default") diff --git a/tests/models.py b/tests/models.py index 652bed98a..0d9e3ee55 100644 --- a/tests/models.py +++ b/tests/models.py @@ -8,3 +8,14 @@ def __repr__(self): class Binary(models.Model): field = models.BinaryField() + + +try: + from django.contrib.postgres.fields import JSONField + + class PostgresJSON(models.Model): + field = JSONField() + + +except ModuleNotFoundError: + pass diff --git a/tests/test_integration.py b/tests/test_integration.py index 0adbdb03c..94e5ac990 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -6,6 +6,7 @@ from django.contrib.staticfiles.testing import StaticLiveServerTestCase from django.core import signing from django.core.checks import Warning, run_checks +from django.db import connection from django.http import HttpResponse from django.template.loader import get_template from django.test import RequestFactory, SimpleTestCase, TestCase @@ -206,6 +207,35 @@ def test_sql_explain_checks_show_toolbar(self): ) self.assertEqual(response.status_code, 404) + @unittest.skipUnless( + connection.vendor == "postgresql", "Test valid only on PostgreSQL" + ) + def test_sql_explain_postgres_json_field(self): + url = "/__debug__/sql_explain/" + base_query = ( + 'SELECT * FROM "tests_postgresjson" WHERE "tests_postgresjson"."field" @>' + ) + query = base_query + """ '{"foo": "bar"}'""" + data = { + "sql": query, + "raw_sql": base_query + " %s", + "params": '["{\\"foo\\": \\"bar\\"}"]', + "alias": "default", + "duration": "0", + "hash": "2b7172eb2ac8e2a8d6f742f8a28342046e0d00ba", + } + response = self.client.post(url, data) + self.assertEqual(response.status_code, 200) + response = self.client.post(url, data, HTTP_X_REQUESTED_WITH="XMLHttpRequest") + self.assertEqual(response.status_code, 200) + with self.settings(INTERNAL_IPS=[]): + response = self.client.post(url, data) + self.assertEqual(response.status_code, 404) + response = self.client.post( + url, data, HTTP_X_REQUESTED_WITH="XMLHttpRequest" + ) + self.assertEqual(response.status_code, 404) + def test_sql_profile_checks_show_toolbar(self): url = "/__debug__/sql_profile/" data = {