Skip to content

Commit

Permalink
Handle enum type for nested ChoiceFields (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
etene authored and axnsan12 committed Jul 15, 2019
1 parent e9f2744 commit 6417bb3
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
2 changes: 2 additions & 0 deletions requirements/test.txt
Expand Up @@ -5,5 +5,7 @@ pytest-cov>=2.6.0
pytest-xdist>=1.25.0
pytest-django>=3.4.4
datadiff==2.0.0
psycopg2-binary==2.8.3
django-fake-model==0.1.4

-r testproj.txt
6 changes: 5 additions & 1 deletion src/drf_yasg/inspectors/field.py
Expand Up @@ -656,7 +656,11 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
serializer = get_parent_serializer(field)
if isinstance(serializer, serializers.ModelSerializer):
model = getattr(getattr(serializer, 'Meta'), 'model')
model_field = get_model_field(model, field.source)
# Use the parent source for nested fields
model_field = get_model_field(model, field.source or field.parent.source)
# If the field has a base_field its type must be used
if getattr(model_field, "base_field", None):
model_field = model_field.base_field
if model_field:
model_type = get_basic_type_info(model_field)
if model_type:
Expand Down
43 changes: 43 additions & 0 deletions tests/test_schema_generator.py
Expand Up @@ -3,11 +3,14 @@

import pytest
from django.conf.urls import url
from django.contrib.postgres import fields as postgres_fields
from django.db import models
from django.utils.inspect import get_func_args
from rest_framework import routers, serializers, viewsets
from rest_framework.decorators import api_view
from rest_framework.response import Response

from django_fake_model import models as fake_models
from drf_yasg import codecs, openapi
from drf_yasg.codecs import yaml_sane_load
from drf_yasg.errors import SwaggerGenerationError
Expand Down Expand Up @@ -230,3 +233,43 @@ def retrieve(self, request, pk=None):
property_schema = swagger['definitions']['Detail']['properties']['detail']

assert property_schema == openapi.Schema(title='Detail', type=expected_type, enum=choices)


@pytest.mark.parametrize('choices, field, expected_type', [
([1, 2, 3], models.IntegerField, openapi.TYPE_INTEGER),
(["A", "B"], models.CharField, openapi.TYPE_STRING),
])
def test_nested_choice_in_array_field(choices, field, expected_type):

# Create a model class on the fly to avoid warnings about using the several
# model class name several times
model_class = type(
"%sModel" % field.__name__,
(fake_models.FakeModel,),
{
"array": postgres_fields.ArrayField(
field(choices=((i, "choice %s" % i) for i in choices))
),
"__module__": "test_models",
}
)

class ArraySerializer(serializers.ModelSerializer):
class Meta:
model = model_class
fields = ("array",)

class ArrayViewSet(viewsets.ModelViewSet):
serializer_class = ArraySerializer

router = routers.DefaultRouter()
router.register(r'arrays', ArrayViewSet, **_basename_or_base_name('arrays'))

generator = OpenAPISchemaGenerator(
info=openapi.Info(title='Test array model generator', default_version='v1'),
patterns=router.urls
)

swagger = generator.get_schema(None, True)
property_schema = swagger['definitions']['Array']['properties']['array']['items']
assert property_schema == openapi.Schema(title='Array', type=expected_type, enum=choices)

0 comments on commit 6417bb3

Please sign in to comment.