Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

break gradient and parameters hooks #3509

Merged
merged 10 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/unit_tests/test_torch.py
Expand Up @@ -68,3 +68,40 @@ def test_double_log(mock_run):
run.watch(net, log_graph=True)
with pytest.raises(ValueError):
run.watch(net, log_graph=True)


@pytest.mark.parametrize("log_type", ["parameters", "all"])
def test_watch_parameters_torch_jit(mock_run, capsys, log_type):
run = mock_run(use_magic_mock=True)
net = torch.jit.script(nn.Linear(10, 2))
run.watch(net, log=log_type)

outerr = capsys.readouterr()
assert "skipping parameter tracking" in outerr.err


def test_watch_graph_torch_jit(mock_run, capsys):
run = mock_run(use_magic_mock=True)

class Net(nn.Module):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(10, 2)

def forward(self, x):
return self.layer_1(x)

net = torch.jit.script(Net())
run.watch(net, log_graph=True)

outerr = capsys.readouterr()
assert "skipping graph tracking" in outerr.err


def test_watch_bad_argument(mock_run):
run = mock_run(use_magic_mock=True)
net = nn.Linear(10, 2)
with pytest.raises(
ValueError, match="log must be one of 'gradients', 'parameters', 'all', or None"
):
run.watch(net, log="bad_argument")
45 changes: 24 additions & 21 deletions wandb/sdk/wandb_watch.py
@@ -1,9 +1,13 @@
"""watch."""

import logging
import os
from typing import Optional

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal

import wandb

from .lib import telemetry
Expand All @@ -16,7 +20,7 @@
def watch(
models,
criterion=None,
log: Optional[str] = "gradients",
log: Optional[Literal["gradients", "parameters", "all"]] = "gradients",
log_freq: int = 1000,
idx: Optional[int] = None,
log_graph: bool = False,
Expand Down Expand Up @@ -45,22 +49,15 @@ def watch(
tel.feature.watch = True

logger.info("Watching")
# TODO: temporary override for huggingface remove after: https://github.com/huggingface/transformers/pull/4220
if os.getenv("WANDB_WATCH") == "false":
return

if wandb.run is None:
raise ValueError("You must call `wandb.init` before calling watch")

log_parameters = False
log_gradients = True
if log == "all":
log_parameters = True
elif log == "parameters":
log_parameters = True
log_gradients = False
elif log is None:
log_gradients = False
if log not in {"gradients", "parameters", "all", None}:
kptkin marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("log must be one of 'gradients', 'parameters', 'all', or None")

log_parameters = log in {"parameters", "all"}
log_gradients = log in {"gradients", "all"}

if not isinstance(models, (tuple, list)):
models = (models,)
Expand Down Expand Up @@ -88,13 +85,19 @@ def watch(
# TODO: this makes ugly chart names like gradients/graph_1conv1d.bias
prefix = "graph_%i" % global_idx

wandb.run._torch.add_log_hooks_to_pytorch_module(
model,
log_parameters=log_parameters,
log_gradients=log_gradients,
prefix=prefix,
log_freq=log_freq,
)
if log_parameters:
wandb.run._torch.add_log_parameters_hook(
model,
prefix=prefix,
log_freq=log_freq,
)

if log_gradients:
wandb.run._torch.add_log_gradients_hook(
model,
prefix=prefix,
log_freq=log_freq,
)

if log_graph:
graph = wandb.run._torch.hook_torch(model, criterion, graph_idx=global_idx)
Expand Down
116 changes: 70 additions & 46 deletions wandb/wandb_torch.py
Expand Up @@ -6,6 +6,7 @@
import itertools
from functools import reduce
from operator import mul
from typing import List

from wandb import util
from wandb.data_types import Node
Expand Down Expand Up @@ -48,14 +49,14 @@ def nested_shape(array_or_tuple, seen=None):
LOG_TRACK_COUNT, LOG_TRACK_THRESHOLD = range(2)


def log_track_init(log_freq):
def log_track_init(log_freq: int) -> List[int]:
"""create tracking structure used by log_track_update"""
l = [0] * 2
l[LOG_TRACK_THRESHOLD] = log_freq
return l


def log_track_update(log_track):
def log_track_update(log_track: int) -> bool:
"""count (log_track[0]) up to threshold (log_track[1]), reset count (log_track[0]) and return true when reached"""
log_track[LOG_TRACK_COUNT] += 1
if log_track[LOG_TRACK_COUNT] < log_track[LOG_TRACK_THRESHOLD]:
Expand All @@ -75,56 +76,73 @@ def __init__(self):
self._is_cuda_histc_supported = None
self.hook_torch = TorchGraph.hook_torch

def add_log_hooks_to_pytorch_module(
def add_log_parameters_hook(
self,
module,
name=None,
prefix="",
log_parameters=True,
log_gradients=True,
log_freq=0,
):
module: "torch.nn.Module",
name: str = "",
prefix: str = "",
log_freq: int = 0,
) -> None:
"""This instruments hooks into the pytorch module
log_parameters - log parameters after a forward pass
log_gradients - log gradients after a backward pass
log parameters after a forward pass
log_freq - log gradients/parameters every N batches
"""
if name is not None:
prefix = prefix + name
# if name is not None:
prefix = prefix + name

if not hasattr(module, "_wandb_hook_names"):
module._wandb_hook_names = []

if log_parameters:

def parameter_log_hook(module, input_, output, log_track):
if not log_track_update(log_track):
return
for name, parameter in module.named_parameters():
# for pytorch 0.3 Variables
if isinstance(parameter, torch.autograd.Variable):
data = parameter.data
else:
data = parameter
self.log_tensor_stats(data.cpu(), "parameters/" + prefix + name)

log_track_params = log_track_init(log_freq)
def parameter_log_hook(module, input_, output, log_track):
if not log_track_update(log_track):
return
for name, parameter in module.named_parameters():
# for pytorch 0.3 Variables
if isinstance(parameter, torch.autograd.Variable):
data = parameter.data
else:
data = parameter
self.log_tensor_stats(data.cpu(), "parameters/" + prefix + name)

log_track_params = log_track_init(log_freq)
try:
hook = module.register_forward_hook(
lambda mod, inp, outp: parameter_log_hook(
mod, inp, outp, log_track_params
)
)
self._hook_handles["parameters/" + prefix] = hook
module._wandb_hook_names.append("parameters/" + prefix)
except RuntimeError as e:
wandb.termwarn(
f"Trying to register forward_hook failed ({e}) - skipping parameter tracking."
)

if log_gradients:
for name, parameter in module.named_parameters():
if parameter.requires_grad:
log_track_grad = log_track_init(log_freq)
module._wandb_hook_names.append("gradients/" + prefix + name)
self._hook_variable_gradient_stats(
parameter, "gradients/" + prefix + name, log_track_grad
)
def add_log_gradients_hook(
self,
module: "torch.nn.Module",
name: str = "",
prefix: str = "",
log_freq: int = 0,
) -> None:
"""This instruments hooks into the pytorch module
log gradients after a backward pass
log_freq - log gradients/parameters every N batches
"""

# if name is not None:
prefix = prefix + name

if not hasattr(module, "_wandb_hook_names"):
module._wandb_hook_names = []

for name, parameter in module.named_parameters():
if parameter.requires_grad:
log_track_grad = log_track_init(log_freq)
module._wandb_hook_names.append("gradients/" + prefix + name)
self._hook_variable_gradient_stats(
parameter, "gradients/" + prefix + name, log_track_grad
)

def log_tensor_stats(self, tensor, name):
"""Add distribution statistics on a tensor's elements to the current History entry"""
Expand Down Expand Up @@ -394,16 +412,22 @@ def hook_torch_modules(
self.hook_torch_modules(sub_module, prefix=name, parent=parent)
else:
self._graph_hooks |= {id(sub_module)}
graph_hook = sub_module.register_forward_hook(
self.create_forward_hook(name, graph_idx)
)
wandb.run._torch._hook_handles[
"topology/" + str(id(graph_hook))
] = graph_hook
if not hasattr(parent, "_wandb_hook_names"):
# should never happen but let's be extra safe
parent._wandb_hook_names = []
parent._wandb_hook_names.append("topology/" + str(id(graph_hook)))
try:
graph_hook = sub_module.register_forward_hook(
self.create_forward_hook(name, graph_idx)
)
wandb.run._torch._hook_handles[
"topology/" + str(id(graph_hook))
] = graph_hook
if not hasattr(parent, "_wandb_hook_names"):
# should never happen but let's be extra safe
parent._wandb_hook_names = []
parent._wandb_hook_names.append("topology/" + str(id(graph_hook)))
except RuntimeError as e:
wandb.termwarn(
f"Trying to register forward_hook failed ({e}) - skipping graph tracking.",
repeat=False,
)

@classmethod
def from_torch_layers(cls, module_graph, variable):
Expand Down