Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrades mypy to version 0.971 #3952

Merged
merged 13 commits into from Aug 16, 2022
6 changes: 5 additions & 1 deletion mypy.ini
Expand Up @@ -4,6 +4,9 @@

[mypy]
warn_redundant_casts = True
exclude = (?x)(
wandb/vendor/.
)

[mypy-wandb.*]
ignore_errors = True
Expand Down Expand Up @@ -396,7 +399,8 @@ ignore_errors = True
# -----------------------------

[mypy-wandb.proto.*]
ignore_errors = False
;ignore_errors = False
ignore_missing_imports = True

# -----------------------------
# wandb/docker/
Expand Down
2 changes: 1 addition & 1 deletion requirements_sweeps.txt
@@ -1 +1 @@
sweeps>=0.1.0
sweeps>=0.2.0
7 changes: 4 additions & 3 deletions tox.ini
Expand Up @@ -202,21 +202,22 @@ commands=
basepython=python3
skip_install = true
deps=
mypy==0.812
types-click==7.1.8
mypy==0.971
lxml
grpcio==1.40.0
setenv =
MYPYPATH = {toxinidir}
commands=
mypy --show-error-codes --config-file {toxinidir}/mypy.ini -p wandb --html-report mypy-results/ --cobertura-xml-report mypy-results/ --exclude wandb/vendor/
mypy --install-types --non-interactive --show-error-codes --config-file {toxinidir}/mypy.ini -p wandb --html-report mypy-results/ --cobertura-xml-report mypy-results/ --lineprecision-report mypy-results/

[testenv:mypy-report]
basepython=python3
skip_install = true
deps=
pycobertura
commands=
pycobertura show mypy-results/cobertura.xml
pycobertura show --format text mypy-results/cobertura.xml

[black]
deps=
Expand Down
6 changes: 5 additions & 1 deletion wandb/apis/public.py
Expand Up @@ -25,7 +25,6 @@
from typing import List, Optional
import urllib

from pkg_resources import parse_version
import requests
import wandb
from wandb import __version__, env, util
Expand Down Expand Up @@ -241,6 +240,8 @@ def server_info(self):
return self._server_info

def version_supported(self, min_version):
from pkg_resources import parse_version

return parse_version(min_version) <= parse_version(
self.server_info["cliVersionInfo"]["max_cli_version"]
)
Expand Down Expand Up @@ -2944,6 +2945,9 @@ def mongo_to_filter(self, filter):


class PythonMongoishQueryGenerator:

from pkg_resources import parse_version

def __init__(self, run_set):
self.run_set = run_set
self.panel_metrics_helper = PanelMetricsHelper()
Expand Down
2 changes: 1 addition & 1 deletion wandb/docker/__init__.py
Expand Up @@ -141,7 +141,7 @@ def auth_token(registry: str, repo: str) -> Dict[str, str]:
auth_info = auth_config.resolve_authconfig(registry)
if auth_info:
normalized = {k.lower(): v for k, v in auth_info.items()}
normalized_auth_info: Optional[Tuple] = (
normalized_auth_info: Optional[Tuple[str, str]] = ( # type: ignore
normalized.get("username"),
normalized.get("password"),
)
Expand Down
2 changes: 1 addition & 1 deletion wandb/docker/auth.py
Expand Up @@ -162,7 +162,7 @@ def parse_auth(
conf = {}
for registry, entry in entries.items():
if not isinstance(entry, dict):
log.debug(f"Config entry for key {registry} is not auth config")
log.debug(f"Config entry for key {registry} is not auth config") # type: ignore
# We sometimes fall back to parsing the whole config as if it
# was the auth config by itself, for legacy purposes. In that
# case, we fail silently and return an empty conf if any of the
Expand Down
6 changes: 4 additions & 2 deletions wandb/sdk/data_types/base_types/wb_value.py
@@ -1,6 +1,5 @@
from typing import Any, ClassVar, Dict, List, Optional, Type, TYPE_CHECKING, Union

from pkg_resources import parse_version
from wandb import util

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -13,6 +12,8 @@


def _server_accepts_client_ids() -> bool:
from pkg_resources import parse_version

# First, if we are offline, assume the backend server cannot
# accept client IDs. Unfortunately, this is the best we can do
# until we are sure that all local versions are > "0.11.0" max_cli_version.
Expand All @@ -30,7 +31,8 @@ def _server_accepts_client_ids() -> bool:
max_cli_version = util._get_max_cli_version()
if max_cli_version is None:
return False
return parse_version("0.11.0") <= parse_version(max_cli_version)
accepts_client_ids: bool = parse_version("0.11.0") <= parse_version(max_cli_version)
return accepts_client_ids


class _WBValueArtifactSource:
Expand Down
13 changes: 10 additions & 3 deletions wandb/sdk/data_types/image.py
Expand Up @@ -4,7 +4,6 @@
import os
from typing import Any, cast, Dict, List, Optional, Sequence, Type, TYPE_CHECKING, Union

from pkg_resources import parse_version
import wandb
from wandb import util

Expand Down Expand Up @@ -32,20 +31,28 @@


def _server_accepts_image_filenames() -> bool:
from pkg_resources import parse_version

# Newer versions of wandb accept large image filenames arrays
# but older versions would have issues with this.
max_cli_version = util._get_max_cli_version()
if max_cli_version is None:
return False
return parse_version("0.12.10") <= parse_version(max_cli_version)
accepts_image_filenames: bool = parse_version("0.12.10") <= parse_version(
max_cli_version
)
return accepts_image_filenames


def _server_accepts_artifact_path() -> bool:
from pkg_resources import parse_version

target_version = "0.12.14"
max_cli_version = util._get_max_cli_version() if not util._is_offline() else None
return max_cli_version is not None and parse_version(
accepts_artifact_path: bool = max_cli_version is not None and parse_version(
target_version
) <= parse_version(max_cli_version)
return accepts_artifact_path


class Image(BatchableMedia):
Expand Down
4 changes: 2 additions & 2 deletions wandb/sdk/internal/handler.py
Expand Up @@ -128,15 +128,15 @@ def handle(self, record: Record) -> None:
record_type = record.WhichOneof("record_type")
assert record_type
handler_str = "handle_" + record_type
handler: Callable[[Record], None] = getattr(self, handler_str, None)
handler: Callable[[Record], None] = getattr(self, handler_str, None) # type: ignore
assert handler, f"unknown handle: {handler_str}"
handler(record)

def handle_request(self, record: Record) -> None:
request_type = record.request.WhichOneof("request_type")
assert request_type
handler_str = "handle_request_" + request_type
handler: Callable[[Record], None] = getattr(self, handler_str, None)
handler: Callable[[Record], None] = getattr(self, handler_str, None) # type: ignore
if request_type != "network_status":
logger.debug(f"handle_request: {request_type}")
assert handler, f"unknown handle: {handler_str}"
Expand Down
6 changes: 3 additions & 3 deletions wandb/sdk/internal/internal_api.py
Expand Up @@ -24,10 +24,8 @@
import base64
from copy import deepcopy
import datetime
from io import BytesIO
import json
import os
from pkg_resources import parse_version
import re
import requests
import logging
Expand Down Expand Up @@ -1767,7 +1765,7 @@ def download_file(self, url: str) -> Tuple[int, requests.Response]:
Returns:
A tuple of the content length and the streaming response
"""
response = requests.get(url, auth=("user", self.api_key), stream=True)
response = requests.get(url, auth=("user", self.api_key), stream=True) # type: ignore
response.raise_for_status()
return int(response.headers.get("content-length", 0)), response

Expand Down Expand Up @@ -2535,6 +2533,8 @@ def create_artifact(
enable_digest_deduplication: Optional[bool] = False,
history_step: Optional[int] = None,
) -> Tuple[Dict, Dict]:
from pkg_resources import parse_version

_, server_info = self.viewer_server_info()
max_cli_version = server_info.get("cliVersionInfo", {}).get(
"max_cli_version", None
Expand Down
5 changes: 3 additions & 2 deletions wandb/sdk/internal/profiler.py
@@ -1,7 +1,6 @@
"""Integration with pytorch profiler."""
import os

from pkg_resources import parse_version
import wandb
from wandb.errors import Error, UsageError
from wandb.sdk.lib import telemetry
Expand All @@ -15,7 +14,7 @@ def torch_trace_handler():

Provide as an argument to `torch.profiler.profile`:
```python
torch.profiler.profile(..., on_trace_ready = wandb.profiler.torch_trace_handler())
torch.profiler.profile(..., on_trace_ready=wandb.profiler.torch_trace_handler())
```

Calling this function ensures that profiler charts & tables can be viewed in your run dashboard
Expand Down Expand Up @@ -53,6 +52,8 @@ def torch_trace_handler():
prof.step()
```
"""
from pkg_resources import parse_version

torch = wandb.util.get_module(PYTORCH_MODULE, required=True)
torch_profiler = wandb.util.get_module(PYTORCH_PROFILER_MODULE, required=True)

Expand Down
5 changes: 4 additions & 1 deletion wandb/sdk/internal/sender.py
Expand Up @@ -24,7 +24,6 @@
TYPE_CHECKING,
)

from pkg_resources import parse_version
import requests
import wandb
from wandb import util
Expand Down Expand Up @@ -1251,6 +1250,8 @@ def send_artifact(self, record: "Record") -> None:
def _send_artifact(
self, artifact: "ArtifactRecord", history_step: Optional[int] = None
) -> Optional[Dict]:
from pkg_resources import parse_version

assert self._pusher
saver = artifacts.ArtifactSaver(
api=self._api,
Expand Down Expand Up @@ -1288,6 +1289,8 @@ def _send_artifact(
)

def send_alert(self, record: "Record") -> None:
from pkg_resources import parse_version

alert = record.alert
max_cli_version = self._max_cli_version()
if max_cli_version is None or parse_version(max_cli_version) < parse_version(
Expand Down
4 changes: 2 additions & 2 deletions wandb/sdk/internal/stats.py
Expand Up @@ -78,7 +78,7 @@ def __init__(self, settings: SettingsStatic, interface: InterfaceQueue) -> None:
self._interface = interface
self.sampler = {}
self.samples = 0
self._shutdown = False
self._shutdown: bool = False
self._telem = telemetry.TelemetryRecord()
if psutil:
net = psutil.net_io_counters()
Expand Down Expand Up @@ -150,7 +150,7 @@ def _thread_body(self) -> None:
time.sleep(0.1)
seconds += 0.1
if self._shutdown:
self.flush()
self.flush() # type: ignore
dmitryduev marked this conversation as resolved.
Show resolved Hide resolved
return

def shutdown(self) -> None:
Expand Down
7 changes: 4 additions & 3 deletions wandb/sdk/internal/update.py
@@ -1,13 +1,14 @@
from typing import Dict, Optional, Tuple

from pkg_resources import parse_version
import requests
import wandb


def _find_available(
current_version: str,
) -> Optional[Tuple[str, bool, bool, bool, Optional[str]]]:
from pkg_resources import parse_version

pypi_url = f"https://pypi.org/pypi/{wandb._wandb_module}/json"

yanked_dict = {}
Expand Down Expand Up @@ -49,13 +50,13 @@ def _find_available(
if parse_version(latest_version) <= parsed_current_version:
# pre-releases are not included in latest_version
# so if we are currently running a pre-release we check more
if not parsed_current_version.is_prerelease: # type: ignore
if not parsed_current_version.is_prerelease:
return None
# Candidates are pre-releases with the same base_version
release_list = map(parse_version, release_list)
release_list = filter(lambda v: v.is_prerelease, release_list)
release_list = filter(
lambda v: v.base_version == parsed_current_version.base_version, # type: ignore
lambda v: v.base_version == parsed_current_version.base_version,
release_list,
)
release_list = sorted(release_list)
Expand Down
2 changes: 1 addition & 1 deletion wandb/sdk/lib/exit_hooks.py
Expand Up @@ -28,7 +28,7 @@ def hook(self) -> None:
!= sys.__excepthook__ # respect hooks by other libraries like pdb
else None
)
sys.excepthook = self.exc_handler
sys.excepthook = self.exc_handler # type: ignore

def exit(self, code: object = 0) -> "NoReturn":
orig_code = code
Expand Down
15 changes: 10 additions & 5 deletions wandb/sdk/lib/printer.py
Expand Up @@ -174,13 +174,16 @@ def progress_close(self) -> None:
wandb.termlog(" " * 79)

def code(self, text: str) -> str:
return click.style(text, bold=True)
ret: str = click.style(text, bold=True)
return ret

def name(self, text: str) -> str:
return click.style(text, fg="yellow")
ret: str = click.style(text, fg="yellow")
return ret

def link(self, link: str, text: Optional[str] = None) -> str:
return click.style(link, fg="blue", underline=True)
ret: str = click.style(link, fg="blue", underline=True)
return ret

def emoji(self, name: str) -> str:
emojis = dict()
Expand All @@ -191,10 +194,12 @@ def emoji(self, name: str) -> str:

def status(self, text: str, failure: Optional[bool] = None) -> str:
color = "red" if failure else "green"
return click.style(text, fg=color)
ret: str = click.style(text, fg=color)
return ret

def files(self, text: str) -> str:
return click.style(text, fg="magenta", bold=True)
ret: str = click.style(text, fg="magenta", bold=True)
return ret

def grid(self, rows: List[List[str]], title: Optional[str] = None) -> str:
max_len = max(len(row[0]) for row in rows)
Expand Down
2 changes: 1 addition & 1 deletion wandb/sdk/service/server_sock.py
Expand Up @@ -106,7 +106,7 @@ def run(self) -> None:
assert sreq, "read_server_request should never timeout"
sreq_type = sreq.WhichOneof("server_request_type")
shandler_str = "server_" + sreq_type
shandler: "Callable[[spb.ServerRequest], None]" = getattr(
shandler: "Callable[[spb.ServerRequest], None]" = getattr( # type: ignore
self, shandler_str, None
)
assert shandler, f"unknown handle: {shandler_str}"
Expand Down
3 changes: 2 additions & 1 deletion wandb/sdk/wandb_artifacts.py
Expand Up @@ -25,6 +25,7 @@
from urllib.parse import parse_qsl, quote, urlparse

import requests
import urllib3
import wandb
from wandb import env
from wandb import util
Expand Down Expand Up @@ -59,7 +60,7 @@

# This makes the first sleep 1s, and then doubles it up to total times,
# which makes for ~18 hours.
_REQUEST_RETRY_STRATEGY = requests.packages.urllib3.util.retry.Retry(
_REQUEST_RETRY_STRATEGY = urllib3.util.retry.Retry(
backoff_factor=1,
total=16,
status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
Expand Down