diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 8c80d6bd5a..11a2915683 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -30,7 +30,7 @@ from django.utils.translation import gettext_lazy as _ from pytz.exceptions import InvalidTimeError -from rest_framework import ISO_8601 +from rest_framework import ISO_8601, RemovedInDRF313Warning from rest_framework.compat import ProhibitNullCharactersValidator from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.settings import api_settings @@ -263,10 +263,10 @@ def __call__(self, serializer_field): if hasattr(self.default, 'set_context'): warnings.warn( "Method `set_context` on defaults is deprecated and will " - "no longer be called starting with 3.12. Instead set " + "no longer be called starting with 3.13. Instead set " "`requires_context = True` on the class, and accept the " "context as an additional argument.", - DeprecationWarning, stacklevel=2 + RemovedInDRF313Warning, stacklevel=2 ) self.default.set_context(self) @@ -502,10 +502,10 @@ def get_default(self): if hasattr(self.default, 'set_context'): warnings.warn( "Method `set_context` on defaults is deprecated and will " - "no longer be called starting with 3.12. Instead set " + "no longer be called starting with 3.13. Instead set " "`requires_context = True` on the class, and accept the " "context as an additional argument.", - DeprecationWarning, stacklevel=2 + RemovedInDRF313Warning, stacklevel=2 ) self.default.set_context(self) @@ -576,10 +576,10 @@ def run_validators(self, value): if hasattr(validator, 'set_context'): warnings.warn( "Method `set_context` on validators is deprecated and will " - "no longer be called starting with 3.12. Instead set " + "no longer be called starting with 3.13. Instead set " "`requires_context = True` on the class, and accept the " "context as an additional argument.", - DeprecationWarning, stacklevel=2 + RemovedInDRF313Warning, stacklevel=2 ) validator.set_context(self) diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 2907312a9b..aa79377142 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -41,7 +41,6 @@ class UniqueValidator: def __init__(self, queryset, message=None, lookup='exact'): self.queryset = queryset - self.serializer_field = None self.message = message or self.message self.lookup = lookup @@ -94,15 +93,14 @@ class UniqueTogetherValidator: def __init__(self, queryset, fields, message=None): self.queryset = queryset self.fields = fields - self.serializer_field = None self.message = message or self.message - def enforce_required_fields(self, attrs, instance): + def enforce_required_fields(self, attrs, serializer): """ The `UniqueTogetherValidator` always forces an implied 'required' state on the fields it applies to. """ - if instance is not None: + if serializer.instance is not None: return missing_items = { @@ -113,16 +111,16 @@ def enforce_required_fields(self, attrs, instance): if missing_items: raise ValidationError(missing_items, code='required') - def filter_queryset(self, attrs, queryset, instance): + def filter_queryset(self, attrs, queryset, serializer): """ Filter the queryset to all instances matching the given attributes. """ # If this is an update, then any unprovided field should # have it's value set based on the existing instance attribute. - if instance is not None: + if serializer.instance is not None: for field_name in self.fields: if field_name not in attrs: - attrs[field_name] = getattr(instance, field_name) + attrs[field_name] = getattr(serializer.instance, field_name) # Determine the filter keyword arguments and filter the queryset. filter_kwargs = { @@ -141,13 +139,10 @@ def exclude_current_instance(self, attrs, queryset, instance): return queryset def __call__(self, attrs, serializer): - # Determine the existing instance, if this is an update operation. - instance = getattr(serializer, 'instance', None) - - self.enforce_required_fields(attrs, instance) + self.enforce_required_fields(attrs, serializer) queryset = self.queryset - queryset = self.filter_queryset(attrs, queryset, instance) - queryset = self.exclude_current_instance(attrs, queryset, instance) + queryset = self.filter_queryset(attrs, queryset, serializer) + queryset = self.exclude_current_instance(attrs, queryset, serializer.instance) # Ignore validation if any field is None checked_values = [ @@ -207,13 +202,11 @@ def __call__(self, attrs, serializer): # same as the serializer field names if `source=<>` is set. field_name = serializer.fields[self.field].source_attrs[-1] date_field_name = serializer.fields[self.date_field].source_attrs[-1] - # Determine the existing instance, if this is an update operation. - instance = getattr(serializer, 'instance', None) self.enforce_required_fields(attrs) queryset = self.queryset queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name) - queryset = self.exclude_current_instance(attrs, queryset, instance) + queryset = self.exclude_current_instance(attrs, queryset, serializer.instance) if qs_exists(queryset): message = self.message.format(date_field=self.date_field) raise ValidationError({ diff --git a/tests/test_fields.py b/tests/test_fields.py index 1d302b730e..0be1b1a7a0 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -565,11 +565,10 @@ def test_create_only_default_callable_sets_context(self): on the callable if possible """ class TestCallableDefault: - def set_context(self, serializer_field): - self.field = serializer_field + requires_context = True - def __call__(self): - return "success" if hasattr(self, 'field') else "failure" + def __call__(self, field=None): + return "success" if field is not None else "failure" class TestSerializer(serializers.Serializer): context_set = serializers.CharField(default=serializers.CreateOnlyDefault(TestCallableDefault())) diff --git a/tests/test_validators.py b/tests/test_validators.py index bb29a4305b..5c4a62b314 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -357,11 +357,16 @@ class MockQueryset: def filter(self, **kwargs): self.called_with = kwargs + class MockSerializer: + def __init__(self, instance): + self.instance = instance + data = {'race_name': 'bar'} queryset = MockQueryset() + serializer = MockSerializer(instance=self.instance) validator = UniqueTogetherValidator(queryset, fields=('race_name', 'position')) - validator.filter_queryset(attrs=data, queryset=queryset, instance=self.instance) + validator.filter_queryset(attrs=data, queryset=queryset, serializer=serializer) assert queryset.called_with == {'race_name': 'bar', 'position': 1}