Skip to content

Commit

Permalink
Followup to set_context removal (encode#7076)
Browse files Browse the repository at this point in the history
* Raise framework-specific deprecation warnings

- Use `RemovedInDRF313Warning` instead of DeprecationWarning
- Update to follow deprecation policy

* Pass serializer instead of model to validator

The `UniqueTogetherValidator` may need to access attributes on the
serializer instead of just the model instance. For example, this is
useful for handling field sources.

* Fix framework deprecation warning in test

* Remove outdated validator attribute
  • Loading branch information
rpkilby authored and Pierre Chiquet committed Mar 24, 2020
1 parent 3b645ee commit 6a02417
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 28 deletions.
14 changes: 7 additions & 7 deletions rest_framework/fields.py
Expand Up @@ -30,7 +30,7 @@
from django.utils.translation import gettext_lazy as _
from pytz.exceptions import InvalidTimeError

from rest_framework import ISO_8601
from rest_framework import ISO_8601, RemovedInDRF313Warning
from rest_framework.compat import ProhibitNullCharactersValidator
from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.settings import api_settings
Expand Down Expand Up @@ -263,10 +263,10 @@ def __call__(self, serializer_field):
if hasattr(self.default, 'set_context'):
warnings.warn(
"Method `set_context` on defaults is deprecated and will "
"no longer be called starting with 3.12. Instead set "
"no longer be called starting with 3.13. Instead set "
"`requires_context = True` on the class, and accept the "
"context as an additional argument.",
DeprecationWarning, stacklevel=2
RemovedInDRF313Warning, stacklevel=2
)
self.default.set_context(self)

Expand Down Expand Up @@ -502,10 +502,10 @@ def get_default(self):
if hasattr(self.default, 'set_context'):
warnings.warn(
"Method `set_context` on defaults is deprecated and will "
"no longer be called starting with 3.12. Instead set "
"no longer be called starting with 3.13. Instead set "
"`requires_context = True` on the class, and accept the "
"context as an additional argument.",
DeprecationWarning, stacklevel=2
RemovedInDRF313Warning, stacklevel=2
)
self.default.set_context(self)

Expand Down Expand Up @@ -576,10 +576,10 @@ def run_validators(self, value):
if hasattr(validator, 'set_context'):
warnings.warn(
"Method `set_context` on validators is deprecated and will "
"no longer be called starting with 3.12. Instead set "
"no longer be called starting with 3.13. Instead set "
"`requires_context = True` on the class, and accept the "
"context as an additional argument.",
DeprecationWarning, stacklevel=2
RemovedInDRF313Warning, stacklevel=2
)
validator.set_context(self)

Expand Down
25 changes: 9 additions & 16 deletions rest_framework/validators.py
Expand Up @@ -41,7 +41,6 @@ class UniqueValidator:

def __init__(self, queryset, message=None, lookup='exact'):
self.queryset = queryset
self.serializer_field = None
self.message = message or self.message
self.lookup = lookup

Expand Down Expand Up @@ -94,15 +93,14 @@ class UniqueTogetherValidator:
def __init__(self, queryset, fields, message=None):
self.queryset = queryset
self.fields = fields
self.serializer_field = None
self.message = message or self.message

def enforce_required_fields(self, attrs, instance):
def enforce_required_fields(self, attrs, serializer):
"""
The `UniqueTogetherValidator` always forces an implied 'required'
state on the fields it applies to.
"""
if instance is not None:
if serializer.instance is not None:
return

missing_items = {
Expand All @@ -113,16 +111,16 @@ def enforce_required_fields(self, attrs, instance):
if missing_items:
raise ValidationError(missing_items, code='required')

def filter_queryset(self, attrs, queryset, instance):
def filter_queryset(self, attrs, queryset, serializer):
"""
Filter the queryset to all instances matching the given attributes.
"""
# If this is an update, then any unprovided field should
# have it's value set based on the existing instance attribute.
if instance is not None:
if serializer.instance is not None:
for field_name in self.fields:
if field_name not in attrs:
attrs[field_name] = getattr(instance, field_name)
attrs[field_name] = getattr(serializer.instance, field_name)

# Determine the filter keyword arguments and filter the queryset.
filter_kwargs = {
Expand All @@ -141,13 +139,10 @@ def exclude_current_instance(self, attrs, queryset, instance):
return queryset

def __call__(self, attrs, serializer):
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer, 'instance', None)

self.enforce_required_fields(attrs, instance)
self.enforce_required_fields(attrs, serializer)
queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset, instance)
queryset = self.exclude_current_instance(attrs, queryset, instance)
queryset = self.filter_queryset(attrs, queryset, serializer)
queryset = self.exclude_current_instance(attrs, queryset, serializer.instance)

# Ignore validation if any field is None
checked_values = [
Expand Down Expand Up @@ -207,13 +202,11 @@ def __call__(self, attrs, serializer):
# same as the serializer field names if `source=<>` is set.
field_name = serializer.fields[self.field].source_attrs[-1]
date_field_name = serializer.fields[self.date_field].source_attrs[-1]
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer, 'instance', None)

self.enforce_required_fields(attrs)
queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name)
queryset = self.exclude_current_instance(attrs, queryset, instance)
queryset = self.exclude_current_instance(attrs, queryset, serializer.instance)
if qs_exists(queryset):
message = self.message.format(date_field=self.date_field)
raise ValidationError({
Expand Down
7 changes: 3 additions & 4 deletions tests/test_fields.py
Expand Up @@ -565,11 +565,10 @@ def test_create_only_default_callable_sets_context(self):
on the callable if possible
"""
class TestCallableDefault:
def set_context(self, serializer_field):
self.field = serializer_field
requires_context = True

def __call__(self):
return "success" if hasattr(self, 'field') else "failure"
def __call__(self, field=None):
return "success" if field is not None else "failure"

class TestSerializer(serializers.Serializer):
context_set = serializers.CharField(default=serializers.CreateOnlyDefault(TestCallableDefault()))
Expand Down
7 changes: 6 additions & 1 deletion tests/test_validators.py
Expand Up @@ -357,11 +357,16 @@ class MockQueryset:
def filter(self, **kwargs):
self.called_with = kwargs

class MockSerializer:
def __init__(self, instance):
self.instance = instance

data = {'race_name': 'bar'}
queryset = MockQueryset()
serializer = MockSerializer(instance=self.instance)
validator = UniqueTogetherValidator(queryset, fields=('race_name',
'position'))
validator.filter_queryset(attrs=data, queryset=queryset, instance=self.instance)
validator.filter_queryset(attrs=data, queryset=queryset, serializer=serializer)
assert queryset.called_with == {'race_name': 'bar', 'position': 1}


Expand Down

0 comments on commit 6a02417

Please sign in to comment.