Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix UniqueTogetherValidator with field sources #7086

Merged
merged 5 commits into from Dec 12, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion rest_framework/serializers.py
Expand Up @@ -448,7 +448,7 @@ def _read_only_defaults(self):
default = field.get_default()
except SkipField:
continue
defaults[field.field_name] = default
defaults[field.source] = default

return defaults

Expand Down
18 changes: 12 additions & 6 deletions rest_framework/validators.py
Expand Up @@ -106,7 +106,7 @@ def enforce_required_fields(self, attrs, serializer):
missing_items = {
field_name: self.missing_message
for field_name in self.fields
if field_name not in attrs
if serializer.fields[field_name].source not in attrs
}
if missing_items:
raise ValidationError(missing_items, code='required')
Expand All @@ -115,17 +115,23 @@ def filter_queryset(self, attrs, queryset, serializer):
"""
Filter the queryset to all instances matching the given attributes.
"""
# field names => field sources
sources = [
serializer.fields[field_name].source
for field_name in self.fields
]

# If this is an update, then any unprovided field should
# have it's value set based on the existing instance attribute.
if serializer.instance is not None:
for field_name in self.fields:
if field_name not in attrs:
attrs[field_name] = getattr(serializer.instance, field_name)
for source in sources:
if source not in attrs:
attrs[source] = getattr(serializer.instance, source)

# Determine the filter keyword arguments and filter the queryset.
filter_kwargs = {
field_name: attrs[field_name]
for field_name in self.fields
source: attrs[source]
for source in sources
}
return qs_filter(queryset, **filter_kwargs)

Expand Down
40 changes: 35 additions & 5 deletions tests/test_validators.py
Expand Up @@ -301,6 +301,40 @@ class Meta:
]
}

def test_read_only_fields_with_default_and_source(self):
class ReadOnlySerializer(serializers.ModelSerializer):
name = serializers.CharField(source='race_name', default='test', read_only=True)

class Meta:
model = UniquenessTogetherModel
fields = ['name', 'position']
validators = [
UniqueTogetherValidator(
queryset=UniquenessTogetherModel.objects.all(),
fields=['name', 'position']
)
]

serializer = ReadOnlySerializer(data={'position': 1})
assert serializer.is_valid(raise_exception=True)

def test_writeable_fields_with_source(self):
class WriteableSerializer(serializers.ModelSerializer):
name = serializers.CharField(source='race_name')

class Meta:
model = UniquenessTogetherModel
fields = ['name', 'position']
validators = [
UniqueTogetherValidator(
queryset=UniquenessTogetherModel.objects.all(),
fields=['name', 'position']
)
]

serializer = WriteableSerializer(data={'name': 'test', 'position': 1})
assert serializer.is_valid(raise_exception=True)

def test_allow_explict_override(self):
"""
Ensure validators can be explicitly removed..
Expand Down Expand Up @@ -357,13 +391,9 @@ class MockQueryset:
def filter(self, **kwargs):
self.called_with = kwargs

class MockSerializer:
def __init__(self, instance):
self.instance = instance

data = {'race_name': 'bar'}
queryset = MockQueryset()
serializer = MockSerializer(instance=self.instance)
serializer = UniquenessTogetherSerializer(instance=self.instance)
validator = UniqueTogetherValidator(queryset, fields=('race_name',
'position'))
validator.filter_queryset(attrs=data, queryset=queryset, serializer=serializer)
Expand Down