Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix serializer multiple inheritance bug #6980

Merged
merged 3 commits into from Dec 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 16 additions & 12 deletions rest_framework/serializers.py
Expand Up @@ -299,18 +299,22 @@ def _get_declared_fields(cls, bases, attrs):
if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1]._creation_counter)

# If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields.
for base in reversed(bases):
if hasattr(base, '_declared_fields'):
fields = [
(field_name, obj) for field_name, obj
in base._declared_fields.items()
if field_name not in attrs
] + fields

return OrderedDict(fields)
# Ensures a base class field doesn't override cls attrs, and maintains
# field precedence when inheriting multiple parents. e.g. if there is a
# class C(A, B), and A and B both define 'field', use 'field' from A.
known = set(attrs)

def visit(name):
known.add(name)
return name

base_fields = [
(visit(name), f)
for base in bases if hasattr(base, '_declared_fields')
for name, f in base._declared_fields.items() if name not in known
]

return OrderedDict(base_fields + fields)

def __new__(cls, name, bases, attrs):
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
Expand Down
50 changes: 50 additions & 0 deletions tests/test_serializer.py
Expand Up @@ -682,3 +682,53 @@ class Grandchild(Child):
assert len(Parent().get_fields()) == 2
assert len(Child().get_fields()) == 2
assert len(Grandchild().get_fields()) == 2

def test_multiple_inheritance(self):
class A(serializers.Serializer):
field = serializers.CharField()

class B(serializers.Serializer):
field = serializers.IntegerField()

class TestSerializer(A, B):
pass

fields = {
name: type(f) for name, f
in TestSerializer()._declared_fields.items()
}
assert fields == {
'field': serializers.CharField,
}

def test_field_ordering(self):
class Base(serializers.Serializer):
f1 = serializers.CharField()
f2 = serializers.CharField()

class A(Base):
f3 = serializers.IntegerField()

class B(serializers.Serializer):
f3 = serializers.CharField()
f4 = serializers.CharField()

class TestSerializer(A, B):
f2 = serializers.IntegerField()
f5 = serializers.CharField()

fields = {
name: type(f) for name, f
in TestSerializer()._declared_fields.items()
}

# `IntegerField`s should be the 'winners' in field name conflicts
# - `TestSerializer.f2` should override `Base.F2`
# - `A.f3` should override `B.f3`
assert fields == {
'f1': serializers.CharField,
'f2': serializers.IntegerField,
'f3': serializers.IntegerField,
'f4': serializers.CharField,
'f5': serializers.CharField,
}