Skip to content

Commit

Permalink
Tests: Improve test coverage 95.91% -> 98.30%. (#862)
Browse files Browse the repository at this point in the history
  • Loading branch information
onegreyonewhite committed Jul 19, 2023
1 parent 5e239a8 commit 2548298
Show file tree
Hide file tree
Showing 16 changed files with 412 additions and 60 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/review.yml
Expand Up @@ -29,6 +29,12 @@ jobs:
PYTHON_VERSION: ${{ matrix.python }}
run: tox -e $(tox -l | grep py${PYTHON_VERSION//.} | paste -sd "," -)

- name: Report coverage
if: ${{ matrix.python == 3.9 }}
run: |
pip install coverage
coverage report
- name: Check for incompatibilities with publishing to PyPi
if: ${{ matrix.python == 3.8 }}
run: |
Expand Down
30 changes: 16 additions & 14 deletions src/drf_yasg/codecs.py
Expand Up @@ -6,18 +6,25 @@
from django.utils.encoding import force_bytes
import yaml

try:
from swagger_spec_validator.common import SwaggerValidationError as SSVErr
from swagger_spec_validator.validator20 import validate_spec as validate_ssv
except ImportError: # pragma: no cover
validate_ssv = None

try:
from flex.core import parse as validate_flex
from flex.exceptions import ValidationError
except ImportError: # pragma: no cover
validate_flex = None

from . import openapi
from .errors import SwaggerValidationError

logger = logging.getLogger(__name__)


def _validate_flex(spec):
try:
from flex.core import parse as validate_flex
from flex.exceptions import ValidationError
except ImportError:
return

try:
validate_flex(spec)
Expand All @@ -26,8 +33,6 @@ def _validate_flex(spec):


def _validate_swagger_spec_validator(spec):
from swagger_spec_validator.common import SwaggerValidationError as SSVErr
from swagger_spec_validator.validator20 import validate_spec as validate_ssv
try:
validate_ssv(spec)
except SSVErr as ex:
Expand All @@ -36,8 +41,8 @@ def _validate_swagger_spec_validator(spec):

#:
VALIDATORS = {
'flex': _validate_flex,
'ssv': _validate_swagger_spec_validator,
"flex": _validate_flex if validate_flex else lambda s: None,
"ssv": _validate_swagger_spec_validator if validate_ssv else lambda s: None,
}


Expand Down Expand Up @@ -117,10 +122,7 @@ def _dump_dict(self, spec):
:rtype: str"""
if self.pretty:
out = json.dumps(spec, indent=4, separators=(',', ': '), ensure_ascii=False)
if out[-1] != '\n':
out += '\n'
return out
return f"{json.dumps(spec, indent=4, separators=(',', ': '), ensure_ascii=False)}\n"
else:
return json.dumps(spec, ensure_ascii=False)

Expand Down Expand Up @@ -219,7 +221,7 @@ def yaml_sane_load(stream):
:param stream: YAML stream (can be a string or a file-like object)
:rtype: OrderedDict
"""
return yaml.load(stream, Loader=YamlLoader)
return yaml.load(stream, Loader=SaneYamlLoader)


class OpenAPICodecYaml(_OpenAPICodec):
Expand Down
27 changes: 17 additions & 10 deletions src/drf_yasg/inspectors/field.py
Expand Up @@ -3,6 +3,7 @@
import logging
import operator
import uuid
from contextlib import suppress
from collections import OrderedDict
from decimal import Decimal
from inspect import signature as inspect_signature
Expand Down Expand Up @@ -130,7 +131,7 @@ def make_schema_definition(serializer=field):

actual_serializer = getattr(actual_schema, '_NP_serializer', None)
this_serializer = get_serializer_class(field)
if actual_serializer and actual_serializer != this_serializer: # pragma: no cover
if actual_serializer and actual_serializer != this_serializer:
explicit_refs = self._has_ref_name(actual_serializer) and self._has_ref_name(this_serializer)
if not explicit_refs:
raise SwaggerGenerationError(
Expand Down Expand Up @@ -209,21 +210,27 @@ def get_parent_serializer(field):
return None # pragma: no cover


def get_model_from_descriptor(descriptor):
with suppress(Exception):
try:
return descriptor.rel.related_model
except Exception:
return descriptor.field.remote_field.model


def get_related_model(model, source):
"""Try to find the other side of a model relationship given the name of a related field.
:param model: one side of the relationship
:param str source: related field name
:return: related model or ``None``
"""
try:
descriptor = getattr(model, source)
try:
return descriptor.rel.related_model
except Exception:
return descriptor.field.remote_field.model
except Exception: # pragma: no cover
return None

with suppress(Exception):
if '.' in source and source.index('.'):
attr, source = source.split('.', maxsplit=1)
return get_related_model(get_model_from_descriptor(getattr(model, attr)), source)
return get_model_from_descriptor(getattr(model, source))


class RelatedFieldInspector(FieldInspector):
Expand Down Expand Up @@ -281,7 +288,7 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
elif isinstance(field, serializers.HyperlinkedRelatedField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)

return SwaggerType(type=openapi.TYPE_STRING)
return NotHandled # pragma: no cover


def find_regex(regex_field):
Expand Down
49 changes: 29 additions & 20 deletions src/drf_yasg/inspectors/query.py
Expand Up @@ -5,7 +5,6 @@
import coreschema
except ImportError:
coreschema = None
from rest_framework.pagination import CursorPagination, LimitOffsetPagination, PageNumberPagination

from .. import openapi
from ..utils import force_real_str
Expand Down Expand Up @@ -100,23 +99,33 @@ class DjangoRestResponsePagination(PaginatorInspector):
PageNumberPagination and CursorPagination
"""

def fix_paginated_property(self, key: str, value: dict):
# Need to remove useless params from schema
value.pop('example', None)
if 'nullable' in value:
value['x-nullable'] = value.pop('nullable')
if key in {'next', 'previous'} and 'format' not in value:
value['format'] = 'uri'
return openapi.Schema(**value)

def get_paginated_response(self, paginator, response_schema):
assert response_schema.type == openapi.TYPE_ARRAY, "array return expected for paged response"
paged_schema = None
if isinstance(paginator, (LimitOffsetPagination, PageNumberPagination, CursorPagination)):
has_count = not isinstance(paginator, CursorPagination)
paged_schema = openapi.Schema(
type=openapi.TYPE_OBJECT,
properties=OrderedDict((
('count', openapi.Schema(type=openapi.TYPE_INTEGER) if has_count else None),
('next', openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI, x_nullable=True)),
('previous', openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI, x_nullable=True)),
('results', response_schema),
)),
required=['results']
)

if has_count:
paged_schema.required.insert(0, 'count')

return paged_schema
if hasattr(paginator, 'get_paginated_response_schema'):
paginator_schema = paginator.get_paginated_response_schema(response_schema)
if paginator_schema['type'] == openapi.TYPE_OBJECT:
properties = {
k: self.fix_paginated_property(k, v)
for k, v in paginator_schema.pop('properties').items()
}
if 'required' not in paginator_schema:
paginator_schema.setdefault('required', [])
for prop in ('count', 'results'):
if prop in properties:
paginator_schema['required'].append(prop)
return openapi.Schema(
**paginator_schema,
properties=properties
)
else:
return openapi.Schema(**paginator_schema)

return response_schema
4 changes: 4 additions & 0 deletions testproj/snippets/models.py
Expand Up @@ -16,6 +16,10 @@ class Snippet(models.Model):
class Meta:
ordering = ('created',)

@property
def owner_snippets(self):
return Snippet._default_manager.filter(owner=self.owner)

@property
def nullable_secondary_language(self):
return None
Expand Down
1 change: 1 addition & 0 deletions testproj/users/urls.py
Expand Up @@ -5,4 +5,5 @@
urlpatterns = [
path('', views.UserList.as_view()),
path('<int:pk>/', views.user_detail),
path('<int:pk>/test_dummy', views.test_view_with_dummy_schema),
]
15 changes: 15 additions & 0 deletions testproj/users/views.py
Expand Up @@ -60,3 +60,18 @@ def user_detail(request, pk):
user = get_object_or_404(User.objects, pk=pk)
serializer = UserSerializer(user)
return Response(serializer.data)


class DummyAutoSchema:
def __init__(self, *args, **kwargs):
pass

def get_operation(self, keys):
pass


@swagger_auto_schema(methods=['get'], auto_schema=DummyAutoSchema)
@swagger_auto_schema(methods=['PUT'], auto_schema=None)
@api_view(['GET', 'PUT'])
def test_view_with_dummy_schema(request, pk):
return Response({})
18 changes: 18 additions & 0 deletions tests/test_schema_generator.py
Expand Up @@ -90,6 +90,24 @@ def test_security_requirements(swagger_settings, mock_schema_request):
swagger = generator.get_schema(mock_schema_request, public=True)
assert swagger['security'] == []

swagger_settings['SECURITY_REQUIREMENTS'] = None
swagger_settings['SECURITY_DEFINITIONS'] = None

swagger = generator.get_schema(mock_schema_request, public=True)
assert 'security' not in swagger


def test_default_url(swagger_settings, mock_schema_request):
swagger_settings['DEFAULT_API_URL'] = 'http://api.example.com'
generator = OpenAPISchemaGenerator(
info=openapi.Info(title="Test generator", default_version="v1"),
version="v2",
)

swagger = generator.get_schema(public=True)
assert swagger['host'] == 'api.example.com'
assert swagger['basePath'] == '/'


def _basename_or_base_name(basename):
# freaking DRF... TODO: remove when dropping support for DRF 3.8
Expand Down
59 changes: 59 additions & 0 deletions tests/test_schema_views.py
Expand Up @@ -3,6 +3,11 @@

import pytest

try:
import coreschema
except ImportError:
coreschema = None

from drf_yasg.codecs import yaml_sane_load


Expand Down Expand Up @@ -70,3 +75,57 @@ def test_non_public(client):
response = client.get('/private/swagger.yaml')
swagger = yaml_sane_load(response.content.decode('utf-8'))
assert len(swagger['paths']) == 0


@pytest.mark.skipif(coreschema is None, reason="Do not test without coreschema.")
@pytest.mark.urls('urlconfs.coreschema')
def test_paginator_schema(client, swagger_settings):
swagger_settings['DEFAULT_FILTER_INSPECTORS'] = [
'drf_yasg.inspectors.CoreAPICompatInspector',
'drf_yasg.inspectors.DrfAPICompatInspector',
]
swagger_settings['DEFAULT_PAGINATOR_INSPECTORS'] = [
'drf_yasg.inspectors.CoreAPICompatInspector',
'drf_yasg.inspectors.DrfAPICompatInspector',
]

response = client.get('/versioned/url/v1.0/swagger.yaml')
swagger = yaml_sane_load(response.content.decode('utf-8'))

assert swagger['paths']['/snippets/']['get']['responses']['200']['schema']['type'] == 'object'
assert swagger['paths']['/snippets/']['get']['responses']['200']['schema']['required'] == ['results']
assert swagger['paths']['/snippets/']['get']['parameters'][0]['name'] == 'test_param'
assert swagger['paths']['/snippets/']['get']['parameters'][0]['type'] == 'string'
assert swagger['paths']['/snippets/']['get']['parameters'][1]['name'] == 'limit'
assert swagger['paths']['/snippets/']['get']['parameters'][1]['in'] == 'query'
assert swagger['paths']['/snippets/']['get']['parameters'][1]['type'] == 'integer'

assert swagger['paths']['/other_snippets/']['get']['responses']['200']['schema']['type'] == 'array'
assert swagger['paths']['/other_snippets/']['get']['parameters'][0]['name'] == 'limit'
assert swagger['paths']['/other_snippets/']['get']['parameters'][0]['in'] == 'query'
assert swagger['paths']['/other_snippets/']['get']['parameters'][0]['type'] == 'integer'


@pytest.mark.urls('urlconfs.additional_fields_checks')
def test_extra_field_inspections(client, swagger_settings):
# swagger_settings[]
response = client.get('/versioned/url/v1.0/swagger.json')
swagger = json.loads(response.content.decode('utf-8'))

assert swagger['definitions']['Snippets']['properties']['url']['type'] == 'string'
assert swagger['definitions']['Snippets']['properties']['url']['format'] == 'uri'
assert swagger['definitions']['Snippets']['properties']['ipv4']['type'] == 'string'
assert swagger['definitions']['Snippets']['properties']['uri']['type'] == 'string'
assert swagger['definitions']['Snippets']['properties']['uri']['format'] == 'uri'
assert swagger['definitions']['Snippets']['properties']['tracks']['type'] == 'array'
assert swagger['definitions']['Snippets']['properties']['tracks']['items']['type'] == 'string'

assert swagger['definitions']['SnippetsV2']['properties']['url']['type'] == 'string'
assert swagger['definitions']['SnippetsV2']['properties']['url']['format'] == 'uri'

assert swagger['definitions']['SnippetsV2']['properties']['other_owner_snippets']['type'] == 'array'
assert swagger['definitions']['SnippetsV2']['properties']['other_owner_snippets']['items']['type'] == 'integer'

# Cannt check type of queryset in property descriptor.
assert swagger['definitions']['SnippetsV2']['properties']['owner_snippets']['type'] == 'array'
assert swagger['definitions']['SnippetsV2']['properties']['owner_snippets']['items']['type'] == 'string'

0 comments on commit 2548298

Please sign in to comment.