Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Commit

Permalink
Merge #326
Browse files Browse the repository at this point in the history
326: ref: Various fixes and type annotations for mypy r=dcramer a=dcramer

bors r+

Co-authored-by: David Cramer <dcramer@gmail.com>
  • Loading branch information
bors[bot] and dcramer committed Mar 12, 2020
2 parents 3672305 + cf3e605 commit 5e7afa8
Show file tree
Hide file tree
Showing 16 changed files with 101 additions and 82 deletions.
6 changes: 6 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,9 @@ ignore_missing_imports = True

[mypy-watchdog.*]
ignore_missing_imports = True

[mypy-cached_property.*]
ignore_missing_imports = True

[mypy-asyncpg.*]
ignore_missing_imports = True
8 changes: 5 additions & 3 deletions zeus/api/schemas/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
from marshmallow import Schema, fields, pre_dump
from sqlalchemy import and_
from typing import List, Mapping
from typing import Dict, List, Optional
from uuid import UUID

from zeus.config import db
Expand All @@ -18,7 +18,9 @@
from .job import JobSchema


def find_failure_origins(build: Build, test_failures: List[str]) -> Mapping[str, UUID]:
def find_failure_origins(
build: Build, test_failures: List[str]
) -> Dict[str, Optional[UUID]]:
"""
Attempt to find originating causes of failures.
Expand Down Expand Up @@ -127,7 +129,7 @@ def find_failure_origins(build: Build, test_failures: List[str]) -> Mapping[str,
for test_hash, build_id in queryset:
previous_test_failures[build_id].add(test_hash)

failures_at_build = dict()
failures_at_build: Dict[str, Optional[UUID]] = {}
searching = set(t for t in test_failures)
# last_checked_run = build.id
last_checked_run = None
Expand Down
42 changes: 23 additions & 19 deletions zeus/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flask import current_app, g, request, session
from itsdangerous import BadSignature, JSONWebSignatureSerializer
from sqlalchemy.orm import joinedload
from typing import Mapping, Optional
from typing import Any, Dict, Mapping, Optional
from urllib.parse import urlparse, urljoin
from uuid import UUID

Expand All @@ -27,9 +27,9 @@


class Tenant(object):
access = {}
access: Dict[UUID, Optional[Permission]] = {}

def __init__(self, access: Optional[Mapping[UUID, Optional[Permission]]] = None):
def __init__(self, access: Optional[Dict[UUID, Optional[Permission]]] = None):
if access is not None:
self.access = access

Expand All @@ -54,7 +54,7 @@ def has_permission(self, repository_id: UUID, permission: Permission = None):
return permission in access

@classmethod
def from_user(cls, user: User):
def from_user(cls, user: Optional[User]):
if not user:
return cls()

Expand All @@ -70,7 +70,7 @@ def from_repository(
if not repository:
return cls()

return RepositoryTenant(access={repository.id: permission})
return RepositoryTenant(repository.id, permission)

@classmethod
def from_api_token(cls, token: ApiToken):
Expand All @@ -82,9 +82,7 @@ def from_api_token(cls, token: ApiToken):

class ApiTokenTenant(Tenant):
def __init__(
self,
token_id: str,
access: Optional[Mapping[UUID, Optional[Permission]]] = None,
self, token_id: str, access: Optional[Dict[UUID, Optional[Permission]]] = None
):
self.token_id = token_id
if access is not None:
Expand All @@ -94,7 +92,7 @@ def __repr__(self):
return "<{} token_id={}>".format(type(self).__name__, self.token_id)

@cached_property
def access(self) -> Mapping[UUID, Permission]:
def access(self) -> Dict[UUID, Permission]:
if not self.token_id:
return {}

Expand All @@ -108,9 +106,7 @@ def access(self) -> Mapping[UUID, Permission]:

class UserTenant(Tenant):
def __init__(
self,
user_id: UUID,
access: Optional[Mapping[UUID, Optional[Permission]]] = None,
self, user_id: UUID, access: Optional[Dict[UUID, Optional[Permission]]] = None
):
self.user_id = user_id
if access is not None:
Expand All @@ -120,7 +116,7 @@ def __repr__(self):
return "<{} user_id={}>".format(type(self).__name__, self.user_id)

@cached_property
def access(self) -> Mapping[UUID, Permission]:
def access(self) -> Dict[UUID, Permission]:
if not self.user_id:
return {}

Expand All @@ -142,7 +138,7 @@ def __repr__(self):
)

@cached_property
def access(self) -> Mapping[UUID, Permission]:
def access(self) -> Dict[UUID, Optional[Permission]]:
if not self.repository_id:
return {}

Expand All @@ -156,6 +152,7 @@ def get_tenant_from_headers(headers: Mapping) -> Optional[Tenant]:
header = headers.get("Authorization", "")
if header:
return get_tenant_from_bearer_header(header)
return None


def get_tenant_from_request() -> Tenant:
Expand All @@ -171,6 +168,8 @@ def get_tenant_from_bearer_header(header: str) -> Optional[Tenant]:
return None

match = _bearer_regexp.match(header)
if not match:
return None
token = match.group(2)
if not token.startswith("zeus-"):
# Assuming this is a legacy token
Expand Down Expand Up @@ -314,13 +313,15 @@ def get_current_tenant() -> Tenant:

def generate_token(tenant: Tenant) -> bytes:
s = JSONWebSignatureSerializer(current_app.secret_key, salt="auth")
payload = {"access": {str(k): v for k, v in tenant.access.items()}}
payload: Dict[str, Any] = {
"access": {str(k): int(v) if v else None for k, v in tenant.access.items()}
}
if getattr(tenant, "user_id", None):
payload["uid"] = str(tenant.user_id)
return s.dumps(payload)


def parse_token(token: str) -> Optional[str]:
def parse_token(token: str) -> Optional[Any]:
s = JSONWebSignatureSerializer(current_app.secret_key, salt="auth")
try:
return s.loads(token)
Expand All @@ -330,10 +331,12 @@ def parse_token(token: str) -> Optional[str]:


def get_tenant_from_signed_token(token: str) -> Tenant:
payload = parse_token(token)
payload: Optional[Dict[str, Any]] = parse_token(token)
if not payload:
return Tenant()
access = {UUID(k): v for k, v in payload["access"].items()}
access = {
UUID(k): Permission(v) if v else None for k, v in payload["access"].items()
}
if "uid" in payload:
return UserTenant(user_id=UUID(payload["uid"]), access=access)
return Tenant(access=access)
Expand All @@ -354,7 +357,7 @@ def is_safe_url(target: str) -> bool:
)


def get_redirect_target(clear=True, session=session) -> str:
def get_redirect_target(clear=True, session=session) -> Optional[str]:
if clear:
session_target = session.pop("next", None)
else:
Expand All @@ -366,6 +369,7 @@ def get_redirect_target(clear=True, session=session) -> str:

if is_safe_url(target):
return target
return None


def bind_redirect_target(target: str = None, session=session):
Expand Down
5 changes: 0 additions & 5 deletions zeus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@
metrics = Metrics()


from flask_sqlalchemy.model import DefaultMeta

db.Model: DefaultMeta = db.Model


def with_health_check(app):
def middleware(environ, start_response):
path_info = environ.get("PATH_INFO", "")
Expand Down
5 changes: 4 additions & 1 deletion zeus/db/types/enum.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
__all__ = ["Enum", "IntEnum", "StrEnum"]

from enum import Enum as EnumType
from typing import Optional, Type

from sqlalchemy.types import TypeDecorator, INT, STRINGTYPE


class Enum(TypeDecorator):
impl = INT

def __init__(self, enum=None, *args, **kwargs):
def __init__(self, enum: Optional[Type[EnumType]] = None, *args, **kwargs):
self.enum = enum
super(Enum, self).__init__(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion zeus/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def try_create(model, where: dict, defaults: dict = None) -> Optional[Any]:
db.session.add(instance)
except IntegrityError as exc:
if "duplicate" not in str(exc):
return
return None
raise
return instance

Expand Down
1 change: 1 addition & 0 deletions zeus/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, scope, identity):
def get_upgrade_url(self) -> Optional[str]:
if self.identity.provider == "github":
return "/auth/github"
return None


class UnknownRepositoryBackend(Exception):
Expand Down
4 changes: 2 additions & 2 deletions zeus/models/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class Hook(RepositoryBoundMixin, StandardAttributes, db.Model):
def generate_token(cls) -> bytes:
return token_bytes(64)

def get_signature(self) -> bytes:
def get_signature(self) -> str:
return hmac.new(
key=self.token, msg=self.repository_id.bytes, digestmod=sha256
).hexdigest()

def is_valid_signature(self, signature: bytes) -> bool:
def is_valid_signature(self, signature: str) -> bool:
return compare_digest(self.get_signature(), signature)

def get_provider(self):
Expand Down
7 changes: 5 additions & 2 deletions zeus/notifications/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,17 @@ def build_message(build: Build, force=False) -> Message:
current_app.logger.info("mail.missing-author", extra={"build_id": build.id})
return

emails = find_linked_emails(build)
emails: List[Tuple[UUID, str]] = find_linked_emails(build)
if not emails and not force:
current_app.logger.info("mail.no-linked-accounts", extra={"build_id": build.id})
return

elif not emails:
current_user = auth.get_current_user()
emails = [[current_user.id, current_user.email]]
if current_user:
emails = [(current_user.id, current_user.email)]
elif not force:
return

# filter it down to the users that have notifications enabled
user_options = dict(
Expand Down
4 changes: 2 additions & 2 deletions zeus/storage/mock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from io import BytesIO
from typing import Mapping
from typing import Dict

from .base import FileStorage

_cache: Mapping[str, bytes] = {}
_cache: Dict[str, bytes] = {}


class FileStorageCache(FileStorage):
Expand Down
38 changes: 19 additions & 19 deletions zeus/utils/builds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import reduce
from itertools import groupby
from operator import and_, or_
from typing import Any, List, Mapping, Set, Tuple
from typing import List, Optional, Set, Tuple
from sqlalchemy.orm import joinedload, subqueryload_all
from uuid import UUID

Expand All @@ -16,20 +16,20 @@
class MetaBuild:
original: List[Build] = dataclasses.field(default_factory=list)
ref: str = ""
revision_sha: str = None
revision_sha: Optional[str] = None
label: str = ""
stats: dict = dataclasses.field(default_factory=dict)
result: Result = Result.unknown
status: Status = Status.unknown
authors: List[Author] = dataclasses.field(default_factory=list)
date_created: datetime = None
date_started: datetime = None
date_finished: datetime = None
date_created: Optional[datetime] = None
date_started: Optional[datetime] = None
date_finished: Optional[datetime] = None

revision: Revision = None
revision: Optional[Revision] = None


def merge_builds(target: MetaBuild, build: Build, with_relations=True) -> Build:
def merge_builds(target: MetaBuild, build: Build, with_relations=True) -> MetaBuild:
# Store the original build so we can retrieve its ID or number later, or
# show a list of all builds in the UI
target.original.append(build)
Expand Down Expand Up @@ -96,10 +96,8 @@ def merge_builds(target: MetaBuild, build: Build, with_relations=True) -> Build:


def merge_build_group(
build_group: Tuple[Any, List[Build]],
required_hook_ids: List[UUID] = None,
with_relations=True,
) -> Build:
build_group: List[Build], required_hook_ids: Set[str] = None, with_relations=True
) -> MetaBuild:
# XXX(dcramer): required_hook_ids is still dirty here, but its our simplest way
# to get it into place
grouped_builds = groupby(
Expand All @@ -111,8 +109,8 @@ def merge_build_group(

build = MetaBuild()
build.original = []
if set(required_hook_ids or ()).difference(
set(str(b.hook_id) for b in build_group)
if frozenset(required_hook_ids or ()).difference(
frozenset(str(b.hook_id) for b in build_group)
):
build.result = Result.failed

Expand All @@ -125,12 +123,12 @@ def merge_build_group(

def fetch_builds_for_revisions(
revisions: List[Revision], with_relations=True
) -> Mapping[str, Build]:
) -> List[Tuple[Tuple[UUID, str], MetaBuild]]:
# we query extra builds here, but its a lot easier than trying to get
# sqlalchemy to do a ``select (subquery)`` clause and maintain tenant
# constraints
if not revisions:
return {}
return []

lookups = []
for revision in revisions:
Expand All @@ -156,21 +154,23 @@ def fetch_builds_for_revisions(
build_groups = groupby(
builds, lambda build: (build.repository_id, build.revision_sha)
)
required_hook_ids: Set[UUID] = set()
required_hook_ids: Set[str] = set()
for build in builds:
required_hook_ids.update(build.data.get("required_hook_ids") or ())
return [
(
ident,
merge_build_group(
list(group), required_hook_ids, with_relations=with_relations
list(build_group), required_hook_ids, with_relations=with_relations
),
)
for ident, group in build_groups
for ident, build_group in build_groups
]


def fetch_build_for_revision(revision: Revision, with_relations=True) -> Build:
def fetch_build_for_revision(
revision: Revision, with_relations=True
) -> Optional[MetaBuild]:
builds = fetch_builds_for_revisions([revision], with_relations=with_relations)
if len(builds) < 1:
return None
Expand Down

0 comments on commit 5e7afa8

Please sign in to comment.