diff --git a/rest_framework/views.py b/rest_framework/views.py index 69db053d64..9120079d07 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,7 +3,7 @@ """ from django.conf import settings from django.core.exceptions import PermissionDenied -from django.db import connection, models, transaction +from django.db import connections, models, transaction from django.http import Http404 from django.http.response import HttpResponseBase from django.utils.cache import cc_delim_re, patch_vary_headers @@ -63,9 +63,13 @@ def get_view_description(view, html=False): def set_rollback(): - atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False) - if atomic_requests and connection.in_atomic_block: - transaction.set_rollback(True) + # Rollback all connections that have ATOMIC_REQUESTS set, if it looks their + # @atomic for the request was started + # Note this check in_atomic_block may be a false positive due to + # transactions started another way, e.g. through testing with TestCase + for db in connections.all(): + if db.settings_dict['ATOMIC_REQUESTS'] and db.in_atomic_block: + transaction.set_rollback(True, using=db.alias) def exception_handler(exc, context): @@ -223,9 +227,9 @@ def get_exception_handler_context(self): """ return { 'view': self, - 'args': getattr(self, 'args', ()), - 'kwargs': getattr(self, 'kwargs', {}), - 'request': getattr(self, 'request', None) + 'args': self.args, + 'kwargs': self.kwargs, + 'request': self.request, } def get_view_name(self): diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index de04d2c069..e6969bdfa7 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -39,12 +39,12 @@ def dispatch(self, *args, **kwargs): return super().dispatch(*args, **kwargs) def get(self, request, *args, **kwargs): - BasicModel.objects.all() + list(BasicModel.objects.all()) raise Http404 urlpatterns = ( - url(r'^$', NonAtomicAPIExceptionView.as_view()), + url(r'^non-atomic-exception$', NonAtomicAPIExceptionView.as_view()), ) @@ -94,8 +94,8 @@ def test_generic_exception_delegate_transaction_management(self): # 1 - begin savepoint # 2 - insert # 3 - release savepoint - with transaction.atomic(): - self.assertRaises(Exception, self.view, request) + with transaction.atomic(), self.assertRaises(Exception): + self.view(request) assert not transaction.get_rollback() assert BasicModel.objects.count() == 1 @@ -139,12 +139,15 @@ class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): def setUp(self): connections.databases['default']['ATOMIC_REQUESTS'] = True - def tearDown(self): - connections.databases['default']['ATOMIC_REQUESTS'] = False + @self.addCleanup + def restore_atomic_requests(): + connections.databases['default']['ATOMIC_REQUESTS'] = False def test_api_exception_rollback_transaction_non_atomic_view(self): - response = self.client.get('/') + response = self.client.get('/non-atomic-exception') - # without checking connection.in_atomic_block view raises 500 - # due attempt to rollback without transaction + # without check for db.in_atomic_block, would raise 500 due to attempt + # to rollback without transaction assert response.status_code == status.HTTP_404_NOT_FOUND + # Check we can still perform DB queries + list(BasicModel.objects.all())