Skip to content

Commit

Permalink
Merge github.com:wandb/client into launch-fetch-run-queue-item-query
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Aug 16, 2022
2 parents a4bce79 + 458e208 commit 3326440
Show file tree
Hide file tree
Showing 22 changed files with 107 additions and 55 deletions.
6 changes: 5 additions & 1 deletion mypy.ini
Expand Up @@ -4,6 +4,9 @@

[mypy]
warn_redundant_casts = True
exclude = (?x)(
wandb/vendor/.
)

[mypy-wandb.*]
ignore_errors = True
Expand Down Expand Up @@ -396,7 +399,8 @@ ignore_errors = True
# -----------------------------

[mypy-wandb.proto.*]
ignore_errors = False
;ignore_errors = False
ignore_missing_imports = True

# -----------------------------
# wandb/docker/
Expand Down
2 changes: 1 addition & 1 deletion requirements_sweeps.txt
@@ -1 +1 @@
sweeps>=0.1.0
sweeps>=0.2.0
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -54,7 +54,7 @@
long_description_content_type="text/markdown",
author="Weights & Biases",
author_email="support@wandb.com",
url="https://github.com/wandb/client",
url="https://github.com/wandb/wandb",
packages=["wandb"],
package_dir={"wandb": "wandb"},
package_data={"wandb": ["py.typed"]},
Expand Down
7 changes: 4 additions & 3 deletions tox.ini
Expand Up @@ -202,21 +202,22 @@ commands=
basepython=python3
skip_install = true
deps=
mypy==0.812
types-click==7.1.8
mypy==0.971
lxml
grpcio==1.40.0
setenv =
MYPYPATH = {toxinidir}
commands=
mypy --show-error-codes --config-file {toxinidir}/mypy.ini -p wandb --html-report mypy-results/ --cobertura-xml-report mypy-results/ --exclude wandb/vendor/
mypy --install-types --non-interactive --show-error-codes --config-file {toxinidir}/mypy.ini -p wandb --html-report mypy-results/ --cobertura-xml-report mypy-results/ --lineprecision-report mypy-results/

[testenv:mypy-report]
basepython=python3
skip_install = true
deps=
pycobertura
commands=
pycobertura show mypy-results/cobertura.xml
pycobertura show --format text mypy-results/cobertura.xml

[black]
deps=
Expand Down
6 changes: 5 additions & 1 deletion wandb/apis/public.py
Expand Up @@ -25,7 +25,6 @@
from typing import List, Optional
import urllib

from pkg_resources import parse_version
import requests
import wandb
from wandb import __version__, env, util
Expand Down Expand Up @@ -241,6 +240,8 @@ def server_info(self):
return self._server_info

def version_supported(self, min_version):
from pkg_resources import parse_version

return parse_version(min_version) <= parse_version(
self.server_info["cliVersionInfo"]["max_cli_version"]
)
Expand Down Expand Up @@ -2947,6 +2948,9 @@ def mongo_to_filter(self, filter):


class PythonMongoishQueryGenerator:

from pkg_resources import parse_version

def __init__(self, run_set):
self.run_set = run_set
self.panel_metrics_helper = PanelMetricsHelper()
Expand Down
2 changes: 1 addition & 1 deletion wandb/docker/__init__.py
Expand Up @@ -141,7 +141,7 @@ def auth_token(registry: str, repo: str) -> Dict[str, str]:
auth_info = auth_config.resolve_authconfig(registry)
if auth_info:
normalized = {k.lower(): v for k, v in auth_info.items()}
normalized_auth_info: Optional[Tuple] = (
normalized_auth_info: Optional[Tuple[str, str]] = ( # type: ignore
normalized.get("username"),
normalized.get("password"),
)
Expand Down
2 changes: 1 addition & 1 deletion wandb/docker/auth.py
Expand Up @@ -162,7 +162,7 @@ def parse_auth(
conf = {}
for registry, entry in entries.items():
if not isinstance(entry, dict):
log.debug(f"Config entry for key {registry} is not auth config")
log.debug(f"Config entry for key {registry} is not auth config") # type: ignore
# We sometimes fall back to parsing the whole config as if it
# was the auth config by itself, for legacy purposes. In that
# case, we fail silently and return an empty conf if any of the
Expand Down
6 changes: 4 additions & 2 deletions wandb/sdk/data_types/base_types/wb_value.py
@@ -1,6 +1,5 @@
from typing import Any, ClassVar, Dict, List, Optional, Type, TYPE_CHECKING, Union

from pkg_resources import parse_version
from wandb import util

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -13,6 +12,8 @@


def _server_accepts_client_ids() -> bool:
from pkg_resources import parse_version

# First, if we are offline, assume the backend server cannot
# accept client IDs. Unfortunately, this is the best we can do
# until we are sure that all local versions are > "0.11.0" max_cli_version.
Expand All @@ -30,7 +31,8 @@ def _server_accepts_client_ids() -> bool:
max_cli_version = util._get_max_cli_version()
if max_cli_version is None:
return False
return parse_version("0.11.0") <= parse_version(max_cli_version)
accepts_client_ids: bool = parse_version("0.11.0") <= parse_version(max_cli_version)
return accepts_client_ids


class _WBValueArtifactSource:
Expand Down
13 changes: 10 additions & 3 deletions wandb/sdk/data_types/image.py
Expand Up @@ -4,7 +4,6 @@
import os
from typing import Any, cast, Dict, List, Optional, Sequence, Type, TYPE_CHECKING, Union

from pkg_resources import parse_version
import wandb
from wandb import util

Expand Down Expand Up @@ -32,20 +31,28 @@


def _server_accepts_image_filenames() -> bool:
from pkg_resources import parse_version

# Newer versions of wandb accept large image filenames arrays
# but older versions would have issues with this.
max_cli_version = util._get_max_cli_version()
if max_cli_version is None:
return False
return parse_version("0.12.10") <= parse_version(max_cli_version)
accepts_image_filenames: bool = parse_version("0.12.10") <= parse_version(
max_cli_version
)
return accepts_image_filenames


def _server_accepts_artifact_path() -> bool:
from pkg_resources import parse_version

target_version = "0.12.14"
max_cli_version = util._get_max_cli_version() if not util._is_offline() else None
return max_cli_version is not None and parse_version(
accepts_artifact_path: bool = max_cli_version is not None and parse_version(
target_version
) <= parse_version(max_cli_version)
return accepts_artifact_path


class Image(BatchableMedia):
Expand Down
4 changes: 2 additions & 2 deletions wandb/sdk/internal/handler.py
Expand Up @@ -128,15 +128,15 @@ def handle(self, record: Record) -> None:
record_type = record.WhichOneof("record_type")
assert record_type
handler_str = "handle_" + record_type
handler: Callable[[Record], None] = getattr(self, handler_str, None)
handler: Callable[[Record], None] = getattr(self, handler_str, None) # type: ignore
assert handler, f"unknown handle: {handler_str}"
handler(record)

def handle_request(self, record: Record) -> None:
request_type = record.request.WhichOneof("request_type")
assert request_type
handler_str = "handle_request_" + request_type
handler: Callable[[Record], None] = getattr(self, handler_str, None)
handler: Callable[[Record], None] = getattr(self, handler_str, None) # type: ignore
if request_type != "network_status":
logger.debug(f"handle_request: {request_type}")
assert handler, f"unknown handle: {handler_str}"
Expand Down
6 changes: 3 additions & 3 deletions wandb/sdk/internal/internal_api.py
Expand Up @@ -24,10 +24,8 @@
import base64
from copy import deepcopy
import datetime
from io import BytesIO
import json
import os
from pkg_resources import parse_version
import re
import requests
import logging
Expand Down Expand Up @@ -1767,7 +1765,7 @@ def download_file(self, url: str) -> Tuple[int, requests.Response]:
Returns:
A tuple of the content length and the streaming response
"""
response = requests.get(url, auth=("user", self.api_key), stream=True)
response = requests.get(url, auth=("user", self.api_key), stream=True) # type: ignore
response.raise_for_status()
return int(response.headers.get("content-length", 0)), response

Expand Down Expand Up @@ -2535,6 +2533,8 @@ def create_artifact(
enable_digest_deduplication: Optional[bool] = False,
history_step: Optional[int] = None,
) -> Tuple[Dict, Dict]:
from pkg_resources import parse_version

_, server_info = self.viewer_server_info()
max_cli_version = server_info.get("cliVersionInfo", {}).get(
"max_cli_version", None
Expand Down
5 changes: 3 additions & 2 deletions wandb/sdk/internal/profiler.py
@@ -1,7 +1,6 @@
"""Integration with pytorch profiler."""
import os

from pkg_resources import parse_version
import wandb
from wandb.errors import Error, UsageError
from wandb.sdk.lib import telemetry
Expand All @@ -15,7 +14,7 @@ def torch_trace_handler():
Provide as an argument to `torch.profiler.profile`:
```python
torch.profiler.profile(..., on_trace_ready = wandb.profiler.torch_trace_handler())
torch.profiler.profile(..., on_trace_ready=wandb.profiler.torch_trace_handler())
```
Calling this function ensures that profiler charts & tables can be viewed in your run dashboard
Expand Down Expand Up @@ -53,6 +52,8 @@ def torch_trace_handler():
prof.step()
```
"""
from pkg_resources import parse_version

torch = wandb.util.get_module(PYTORCH_MODULE, required=True)
torch_profiler = wandb.util.get_module(PYTORCH_PROFILER_MODULE, required=True)

Expand Down
5 changes: 4 additions & 1 deletion wandb/sdk/internal/sender.py
Expand Up @@ -24,7 +24,6 @@
TYPE_CHECKING,
)

from pkg_resources import parse_version
import requests
import wandb
from wandb import util
Expand Down Expand Up @@ -1251,6 +1250,8 @@ def send_artifact(self, record: "Record") -> None:
def _send_artifact(
self, artifact: "ArtifactRecord", history_step: Optional[int] = None
) -> Optional[Dict]:
from pkg_resources import parse_version

assert self._pusher
saver = artifacts.ArtifactSaver(
api=self._api,
Expand Down Expand Up @@ -1288,6 +1289,8 @@ def _send_artifact(
)

def send_alert(self, record: "Record") -> None:
from pkg_resources import parse_version

alert = record.alert
max_cli_version = self._max_cli_version()
if max_cli_version is None or parse_version(max_cli_version) < parse_version(
Expand Down
4 changes: 2 additions & 2 deletions wandb/sdk/internal/stats.py
Expand Up @@ -78,7 +78,7 @@ def __init__(self, settings: SettingsStatic, interface: InterfaceQueue) -> None:
self._interface = interface
self.sampler = {}
self.samples = 0
self._shutdown = False
self._shutdown: bool = False
self._telem = telemetry.TelemetryRecord()
if psutil:
net = psutil.net_io_counters()
Expand Down Expand Up @@ -150,7 +150,7 @@ def _thread_body(self) -> None:
time.sleep(0.1)
seconds += 0.1
if self._shutdown:
self.flush()
self.flush() # type: ignore
return

def shutdown(self) -> None:
Expand Down
7 changes: 4 additions & 3 deletions wandb/sdk/internal/update.py
@@ -1,13 +1,14 @@
from typing import Dict, Optional, Tuple

from pkg_resources import parse_version
import requests
import wandb


def _find_available(
current_version: str,
) -> Optional[Tuple[str, bool, bool, bool, Optional[str]]]:
from pkg_resources import parse_version

pypi_url = f"https://pypi.org/pypi/{wandb._wandb_module}/json"

yanked_dict = {}
Expand Down Expand Up @@ -49,13 +50,13 @@ def _find_available(
if parse_version(latest_version) <= parsed_current_version:
# pre-releases are not included in latest_version
# so if we are currently running a pre-release we check more
if not parsed_current_version.is_prerelease: # type: ignore
if not parsed_current_version.is_prerelease:
return None
# Candidates are pre-releases with the same base_version
release_list = map(parse_version, release_list)
release_list = filter(lambda v: v.is_prerelease, release_list)
release_list = filter(
lambda v: v.base_version == parsed_current_version.base_version, # type: ignore
lambda v: v.base_version == parsed_current_version.base_version,
release_list,
)
release_list = sorted(release_list)
Expand Down
2 changes: 1 addition & 1 deletion wandb/sdk/lib/exit_hooks.py
Expand Up @@ -28,7 +28,7 @@ def hook(self) -> None:
!= sys.__excepthook__ # respect hooks by other libraries like pdb
else None
)
sys.excepthook = self.exc_handler
sys.excepthook = self.exc_handler # type: ignore

def exit(self, code: object = 0) -> "NoReturn":
orig_code = code
Expand Down
15 changes: 10 additions & 5 deletions wandb/sdk/lib/printer.py
Expand Up @@ -174,13 +174,16 @@ def progress_close(self) -> None:
wandb.termlog(" " * 79)

def code(self, text: str) -> str:
return click.style(text, bold=True)
ret: str = click.style(text, bold=True)
return ret

def name(self, text: str) -> str:
return click.style(text, fg="yellow")
ret: str = click.style(text, fg="yellow")
return ret

def link(self, link: str, text: Optional[str] = None) -> str:
return click.style(link, fg="blue", underline=True)
ret: str = click.style(link, fg="blue", underline=True)
return ret

def emoji(self, name: str) -> str:
emojis = dict()
Expand All @@ -191,10 +194,12 @@ def emoji(self, name: str) -> str:

def status(self, text: str, failure: Optional[bool] = None) -> str:
color = "red" if failure else "green"
return click.style(text, fg=color)
ret: str = click.style(text, fg=color)
return ret

def files(self, text: str) -> str:
return click.style(text, fg="magenta", bold=True)
ret: str = click.style(text, fg="magenta", bold=True)
return ret

def grid(self, rows: List[List[str]], title: Optional[str] = None) -> str:
max_len = max(len(row[0]) for row in rows)
Expand Down
2 changes: 1 addition & 1 deletion wandb/sdk/service/server_sock.py
Expand Up @@ -106,7 +106,7 @@ def run(self) -> None:
assert sreq, "read_server_request should never timeout"
sreq_type = sreq.WhichOneof("server_request_type")
shandler_str = "server_" + sreq_type
shandler: "Callable[[spb.ServerRequest], None]" = getattr(
shandler: "Callable[[spb.ServerRequest], None]" = getattr( # type: ignore
self, shandler_str, None
)
assert shandler, f"unknown handle: {shandler_str}"
Expand Down
3 changes: 2 additions & 1 deletion wandb/sdk/wandb_artifacts.py
Expand Up @@ -25,6 +25,7 @@
from urllib.parse import parse_qsl, quote, urlparse

import requests
import urllib3
import wandb
from wandb import env
from wandb import util
Expand Down Expand Up @@ -59,7 +60,7 @@

# This makes the first sleep 1s, and then doubles it up to total times,
# which makes for ~18 hours.
_REQUEST_RETRY_STRATEGY = requests.packages.urllib3.util.retry.Retry(
_REQUEST_RETRY_STRATEGY = urllib3.util.retry.Retry(
backoff_factor=1,
total=16,
status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
Expand Down

0 comments on commit 3326440

Please sign in to comment.