Skip to content

Commit

Permalink
Merge pull request #33 from cyx2000/add_object_permission
Browse files Browse the repository at this point in the history
Add object permission and update README.md
  • Loading branch information
em1208 committed Apr 2, 2024
2 parents 2bc387c + 8173530 commit 61feeaa
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 2 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,20 @@ class AsyncAuthentication(BaseAuthentication):
return user, None

class AsyncPermission:
def has_permission(self, request, view) -> bool:
async def has_permission(self, request, view) -> bool:
if random.random() < 0.7:
return False

return True

async def has_object_permission(self, request, view, obj):
if obj.user == request.user or request.user.is_superuser:
return True

return False

class AsyncThrottle(BaseThrottle):
def allow_request(self, request, view) -> bool:
async def allow_request(self, request, view) -> bool:
if random.random() < 0.7:
return False

Expand Down
64 changes: 64 additions & 0 deletions adrf/views.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,70 @@ def check_sync_permissions(
code=getattr(permission, "code", None),
)

def check_object_permissions(self, request: Request, obj) -> None:
permissions = self.get_permissions()

if not permissions:
return

sync_permissions, async_permissions = [], []

for permission in permissions:
if asyncio.iscoroutinefunction(permission.has_object_permission):
async_permissions.append(permission)
else:
sync_permissions.append(permission)

if async_permissions:
async_to_sync(self.check_async_object_permissions)(
request, async_permissions, obj
)

if sync_permissions:
self.check_sync_object_permissions(request, sync_permissions, obj)

async def check_async_object_permissions(
self, request: AsyncRequest, permissions: List[BasePermission], obj
) -> None:
"""
Check if the request should be permitted asynchronously.
Raises an appropriate exception if the request is not permitted.
"""

has_object_permissions = await asyncio.gather(
*[
permission.has_object_permission(request, self, obj)
for permission in permissions
],
return_exceptions=True,
)

for has_object_permission in has_object_permissions:
if isinstance(has_object_permission, Exception):
raise has_object_permission
elif not has_object_permission:
self.permission_denied(
request,
message=getattr(has_object_permission, "detail", None),
code=getattr(has_object_permission, "code", None),
)

def check_sync_object_permissions(
self, request: Request, permissions: List[BasePermission], obj
) -> None:
"""
Check if the request should be permitted synchronously.
Raises an appropriate exception if the request is not permitted.
"""

for permission in permissions:
if not permission.has_object_permission(request, self, obj):
self.permission_denied(
request,
message=getattr(permission, "detail", None),
code=getattr(permission, "code", None),
)

def check_throttles(self, request: Request) -> None:
"""
Check if the request should be throttled.
Expand Down
75 changes: 75 additions & 0 deletions tests/test_object_permissions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from asgiref.sync import sync_to_async
from django.http import HttpResponse
from django.test import TestCase, override_settings

from adrf.views import APIView
from rest_framework.permissions import BasePermission
from rest_framework.test import APIRequestFactory

factory = APIRequestFactory()


class AsyncObjectPermission(BasePermission):
async def has_permission(self, request, view):
return True

async def has_object_permission(self, request, view, obj):
if obj != "/async/allow":
return False
return True


class SyncObjectPermission(BasePermission):
def has_permission(self, request, view):
return True

def has_object_permission(self, request, view, obj):
if obj != "/sync/allow":
return False
return True


class ObjectPermissionTestView(APIView):
permission_classes = (AsyncObjectPermission,)

async def get(self, request):
await sync_to_async(self.check_object_permissions)(request, request.path)
return HttpResponse("ok")


@override_settings(ROOT_URLCONF=__name__)
class TestAsyncObjectPermission(TestCase):
async def test_async_object_permission(self):
request = factory.get("/async/allow")

response = await ObjectPermissionTestView.as_view()(request)

self.assertEqual(response.status_code, 200)

async def test_async_object_permission_reject(self):
request = factory.get("/async/reject")

response = await ObjectPermissionTestView.as_view()(request)

self.assertEqual(response.status_code, 403)


@override_settings(ROOT_URLCONF=__name__)
class TestSyncObjectPermission(TestCase):
async def test_sync_object_permission(self):
request = factory.get("/sync/allow")

response = await ObjectPermissionTestView.as_view(
permission_classes=(SyncObjectPermission,)
)(request)

self.assertEqual(response.status_code, 200)

async def test_sync_object_permission_reject(self):
request = factory.get("/sync/reject")

response = await ObjectPermissionTestView.as_view(
permission_classes=(SyncObjectPermission,)
)(request)

self.assertEqual(response.status_code, 403)

0 comments on commit 61feeaa

Please sign in to comment.