Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SQL Select and Explain actions for Postgres JSON fields. #1229

Merged
merged 1 commit into from Jan 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion debug_toolbar/panels/sql/tracking.py
Expand Up @@ -8,6 +8,11 @@
from debug_toolbar import settings as dt_settings
from debug_toolbar.utils import get_stack, get_template_info, tidy_stacktrace

try:
from psycopg2._json import Json as PostgresJson
except ImportError:
PostgresJson = None


class SQLQueryTriggered(Exception):
"""Thrown when template panel triggers a query"""
Expand Down Expand Up @@ -105,6 +110,8 @@ def _quote_params(self, params):
return [self._quote_expr(p) for p in params]

def _decode(self, param):
if PostgresJson and isinstance(param, PostgresJson):
return param.dumps(param.adapted)
# If a sequence type, decode each element separately
if isinstance(param, (tuple, list)):
return [self._decode(element) for element in param]
Expand Down Expand Up @@ -136,7 +143,6 @@ def _record(self, method, sql, params):
_params = json.dumps(self._decode(params))
except TypeError:
pass # object not JSON serializable

template_info = get_template_info()

alias = getattr(self.db, "alias", "default")
Expand Down
11 changes: 11 additions & 0 deletions tests/models.py
Expand Up @@ -8,3 +8,14 @@ def __repr__(self):

class Binary(models.Model):
field = models.BinaryField()


try:
from django.contrib.postgres.fields import JSONField

class PostgresJSON(models.Model):
field = JSONField()


except ImportError:
pass
30 changes: 30 additions & 0 deletions tests/panels/test_sql.py
Expand Up @@ -11,6 +11,16 @@

from ..base import BaseTestCase

try:
from psycopg2._json import Json as PostgresJson
except ImportError:
PostgresJson = None

if connection.vendor == "postgresql":
from ..models import PostgresJSON as PostgresJSONModel
else:
PostgresJSONModel = None


class SQLPanelTestCase(BaseTestCase):
panel_id = "SQLPanel"
Expand Down Expand Up @@ -120,6 +130,26 @@ def test_param_conversion(self):
('["Foo", true, false]', "[10, 1]", '["2017-12-22 16:07:01"]'),
)

@unittest.skipUnless(
connection.vendor == "postgresql", "Test valid only on PostgreSQL"
)
def test_json_param_conversion(self):
self.assertEqual(len(self.panel._queries), 0)

list(PostgresJSONModel.objects.filter(field__contains={"foo": "bar"}))

response = self.panel.process_request(self.request)
self.panel.generate_stats(self.request, response)

# ensure query was logged
self.assertEqual(len(self.panel._queries), 1)
self.assertEqual(
self.panel._queries[0][1]["params"], '["{\\"foo\\": \\"bar\\"}"]',
)
self.assertIsInstance(
self.panel._queries[0][1]["raw_params"][0], PostgresJson,
)

def test_binary_param_force_text(self):
self.assertEqual(len(self.panel._queries), 0)

Expand Down
30 changes: 30 additions & 0 deletions tests/test_integration.py
Expand Up @@ -6,6 +6,7 @@
from django.contrib.staticfiles.testing import StaticLiveServerTestCase
from django.core import signing
from django.core.checks import Warning, run_checks
from django.db import connection
from django.http import HttpResponse
from django.template.loader import get_template
from django.test import RequestFactory, SimpleTestCase, TestCase
Expand Down Expand Up @@ -206,6 +207,35 @@ def test_sql_explain_checks_show_toolbar(self):
)
self.assertEqual(response.status_code, 404)

@unittest.skipUnless(
connection.vendor == "postgresql", "Test valid only on PostgreSQL"
)
def test_sql_explain_postgres_json_field(self):
url = "/__debug__/sql_explain/"
base_query = (
'SELECT * FROM "tests_postgresjson" WHERE "tests_postgresjson"."field" @>'
)
query = base_query + """ '{"foo": "bar"}'"""
data = {
"sql": query,
"raw_sql": base_query + " %s",
"params": '["{\\"foo\\": \\"bar\\"}"]',
"alias": "default",
"duration": "0",
"hash": "2b7172eb2ac8e2a8d6f742f8a28342046e0d00ba",
}
response = self.client.post(url, data)
self.assertEqual(response.status_code, 200)
response = self.client.post(url, data, HTTP_X_REQUESTED_WITH="XMLHttpRequest")
self.assertEqual(response.status_code, 200)
with self.settings(INTERNAL_IPS=[]):
response = self.client.post(url, data)
self.assertEqual(response.status_code, 404)
response = self.client.post(
url, data, HTTP_X_REQUESTED_WITH="XMLHttpRequest"
)
self.assertEqual(response.status_code, 404)

def test_sql_profile_checks_show_toolbar(self):
url = "/__debug__/sql_profile/"
data = {
Expand Down