Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
kptkin committed Aug 16, 2022
1 parent 9053171 commit 5a62e14
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 65 deletions.
37 changes: 18 additions & 19 deletions wandb/sdk/wandb_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,15 @@ def watch(
tel.feature.watch = True

logger.info("Watching")
# TODO: temporary override for huggingface remove after: https://github.com/huggingface/transformers/pull/4220
if os.getenv("WANDB_WATCH") == "false":
return

if wandb.run is None:
raise ValueError("You must call `wandb.init` before calling watch")

log_parameters = False
log_gradients = True
if log == "all":
log_parameters = True
elif log == "parameters":
log_parameters = True
log_gradients = False
elif log is None:
log_gradients = False
if log not in {"gradients", "parameters", "all", None}:
raise ValueError("log must be one of 'gradients', 'parameters', 'all', or None")

log_parameters = log in {"parameters", "all"}
log_gradients = log in {"gradients", "all"}

if not isinstance(models, (tuple, list)):
models = (models,)
Expand Down Expand Up @@ -88,13 +81,19 @@ def watch(
# TODO: this makes ugly chart names like gradients/graph_1conv1d.bias
prefix = "graph_%i" % global_idx

wandb.run._torch.add_log_hooks_to_pytorch_module(
model,
log_parameters=log_parameters,
log_gradients=log_gradients,
prefix=prefix,
log_freq=log_freq,
)
if log_parameters:
wandb.run._torch.add_log_parameters_hook(
model,
prefix=prefix,
log_freq=log_freq,
)

if log_gradients:
wandb.run._torch.add_log_gradients_hook(
model,
prefix=prefix,
log_freq=log_freq,
)

if log_graph:
graph = wandb.run._torch.hook_torch(model, criterion, graph_idx=global_idx)
Expand Down
116 changes: 70 additions & 46 deletions wandb/wandb_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import itertools
from functools import reduce
from operator import mul
from typing import List

from wandb import util
from wandb.data_types import Node
Expand Down Expand Up @@ -48,14 +49,14 @@ def nested_shape(array_or_tuple, seen=None):
LOG_TRACK_COUNT, LOG_TRACK_THRESHOLD = range(2)


def log_track_init(log_freq):
def log_track_init(log_freq: int) -> List[int]:
"""create tracking structure used by log_track_update"""
l = [0] * 2
l[LOG_TRACK_THRESHOLD] = log_freq
return l


def log_track_update(log_track):
def log_track_update(log_track: int) -> bool:
"""count (log_track[0]) up to threshold (log_track[1]), reset count (log_track[0]) and return true when reached"""
log_track[LOG_TRACK_COUNT] += 1
if log_track[LOG_TRACK_COUNT] < log_track[LOG_TRACK_THRESHOLD]:
Expand All @@ -75,56 +76,73 @@ def __init__(self):
self._is_cuda_histc_supported = None
self.hook_torch = TorchGraph.hook_torch

def add_log_hooks_to_pytorch_module(
def add_log_parameters_hook(
self,
module,
name=None,
prefix="",
log_parameters=True,
log_gradients=True,
log_freq=0,
):
module: "torch.nn.Module",
name: str = "",
prefix: str = "",
log_freq: int = 0,
) -> None:
"""This instruments hooks into the pytorch module
log_parameters - log parameters after a forward pass
log_gradients - log gradients after a backward pass
log parameters after a forward pass
log_freq - log gradients/parameters every N batches
"""
if name is not None:
prefix = prefix + name
# if name is not None:
prefix = prefix + name

if not hasattr(module, "_wandb_hook_names"):
module._wandb_hook_names = []

if log_parameters:

def parameter_log_hook(module, input_, output, log_track):
if not log_track_update(log_track):
return
for name, parameter in module.named_parameters():
# for pytorch 0.3 Variables
if isinstance(parameter, torch.autograd.Variable):
data = parameter.data
else:
data = parameter
self.log_tensor_stats(data.cpu(), "parameters/" + prefix + name)

log_track_params = log_track_init(log_freq)
def parameter_log_hook(module, input_, output, log_track):
if not log_track_update(log_track):
return
for name, parameter in module.named_parameters():
# for pytorch 0.3 Variables
if isinstance(parameter, torch.autograd.Variable):
data = parameter.data
else:
data = parameter
self.log_tensor_stats(data.cpu(), "parameters/" + prefix + name)

log_track_params = log_track_init(log_freq)
try:
hook = module.register_forward_hook(
lambda mod, inp, outp: parameter_log_hook(
mod, inp, outp, log_track_params
)
)
self._hook_handles["parameters/" + prefix] = hook
module._wandb_hook_names.append("parameters/" + prefix)
except RuntimeError as e:
wandb.termwarn(
f"Trying to register forward_hook failed ({e}) - skipping parameter tracking."
)

if log_gradients:
for name, parameter in module.named_parameters():
if parameter.requires_grad:
log_track_grad = log_track_init(log_freq)
module._wandb_hook_names.append("gradients/" + prefix + name)
self._hook_variable_gradient_stats(
parameter, "gradients/" + prefix + name, log_track_grad
)
def add_log_gradients_hook(
self,
module: "torch.nn.Module",
name: str = "",
prefix: str = "",
log_freq: int = 0,
) -> None:
"""This instruments hooks into the pytorch module
log gradients after a backward pass
log_freq - log gradients/parameters every N batches
"""

# if name is not None:
prefix = prefix + name

if not hasattr(module, "_wandb_hook_names"):
module._wandb_hook_names = []

for name, parameter in module.named_parameters():
if parameter.requires_grad:
log_track_grad = log_track_init(log_freq)
module._wandb_hook_names.append("gradients/" + prefix + name)
self._hook_variable_gradient_stats(
parameter, "gradients/" + prefix + name, log_track_grad
)

def log_tensor_stats(self, tensor, name):
"""Add distribution statistics on a tensor's elements to the current History entry"""
Expand Down Expand Up @@ -394,16 +412,22 @@ def hook_torch_modules(
self.hook_torch_modules(sub_module, prefix=name, parent=parent)
else:
self._graph_hooks |= {id(sub_module)}
graph_hook = sub_module.register_forward_hook(
self.create_forward_hook(name, graph_idx)
)
wandb.run._torch._hook_handles[
"topology/" + str(id(graph_hook))
] = graph_hook
if not hasattr(parent, "_wandb_hook_names"):
# should never happen but let's be extra safe
parent._wandb_hook_names = []
parent._wandb_hook_names.append("topology/" + str(id(graph_hook)))
try:
graph_hook = sub_module.register_forward_hook(
self.create_forward_hook(name, graph_idx)
)
wandb.run._torch._hook_handles[
"topology/" + str(id(graph_hook))
] = graph_hook
if not hasattr(parent, "_wandb_hook_names"):
# should never happen but let's be extra safe
parent._wandb_hook_names = []
parent._wandb_hook_names.append("topology/" + str(id(graph_hook)))
except RuntimeError as e:
wandb.termwarn(
f"Trying to register forward_hook failed ({e}) - skipping graph tracking.",
repeat=False,
)

@classmethod
def from_torch_layers(cls, module_graph, variable):
Expand Down

0 comments on commit 5a62e14

Please sign in to comment.