diff --git a/requirements/test.txt b/requirements/test.txt index ff46819b..271a360c 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -5,5 +5,7 @@ pytest-cov>=2.6.0 pytest-xdist>=1.25.0 pytest-django>=3.4.4 datadiff==2.0.0 +psycopg2-binary==2.8.3 +django-fake-model==0.1.4 -r testproj.txt diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index 6527db4e..c852e4e0 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -656,7 +656,11 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, ** serializer = get_parent_serializer(field) if isinstance(serializer, serializers.ModelSerializer): model = getattr(getattr(serializer, 'Meta'), 'model') - model_field = get_model_field(model, field.source) + # Use the parent source for nested fields + model_field = get_model_field(model, field.source or field.parent.source) + # If the field has a base_field its type must be used + if getattr(model_field, "base_field", None): + model_field = model_field.base_field if model_field: model_type = get_basic_type_info(model_field) if model_type: diff --git a/tests/test_schema_generator.py b/tests/test_schema_generator.py index 885e5e9c..0c305a6b 100644 --- a/tests/test_schema_generator.py +++ b/tests/test_schema_generator.py @@ -3,11 +3,14 @@ import pytest from django.conf.urls import url +from django.contrib.postgres import fields as postgres_fields +from django.db import models from django.utils.inspect import get_func_args from rest_framework import routers, serializers, viewsets from rest_framework.decorators import api_view from rest_framework.response import Response +from django_fake_model import models as fake_models from drf_yasg import codecs, openapi from drf_yasg.codecs import yaml_sane_load from drf_yasg.errors import SwaggerGenerationError @@ -230,3 +233,43 @@ def retrieve(self, request, pk=None): property_schema = swagger['definitions']['Detail']['properties']['detail'] assert property_schema == openapi.Schema(title='Detail', type=expected_type, enum=choices) + + +@pytest.mark.parametrize('choices, field, expected_type', [ + ([1, 2, 3], models.IntegerField, openapi.TYPE_INTEGER), + (["A", "B"], models.CharField, openapi.TYPE_STRING), +]) +def test_nested_choice_in_array_field(choices, field, expected_type): + + # Create a model class on the fly to avoid warnings about using the several + # model class name several times + model_class = type( + "%sModel" % field.__name__, + (fake_models.FakeModel,), + { + "array": postgres_fields.ArrayField( + field(choices=((i, "choice %s" % i) for i in choices)) + ), + "__module__": "test_models", + } + ) + + class ArraySerializer(serializers.ModelSerializer): + class Meta: + model = model_class + fields = ("array",) + + class ArrayViewSet(viewsets.ModelViewSet): + serializer_class = ArraySerializer + + router = routers.DefaultRouter() + router.register(r'arrays', ArrayViewSet, **_basename_or_base_name('arrays')) + + generator = OpenAPISchemaGenerator( + info=openapi.Info(title='Test array model generator', default_version='v1'), + patterns=router.urls + ) + + swagger = generator.get_schema(None, True) + property_schema = swagger['definitions']['Array']['properties']['array']['items'] + assert property_schema == openapi.Schema(title='Array', type=expected_type, enum=choices)