Skip to content

Commit

Permalink
histogram: clean up v2 code (tensorflow#5443)
Browse files Browse the repository at this point in the history
* clean up histogram v2 code
  • Loading branch information
yatbear committed Mar 27, 2023
1 parent f037d72 commit 31a0440
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 281 deletions.
6 changes: 0 additions & 6 deletions tensorboard/plugins/histogram/summary.py
Expand Up @@ -35,13 +35,7 @@
from tensorboard.plugins.histogram import summary_v2


# Export V2 versions.
histogram_v2 = summary_v2.histogram_v2

# Export V3 versions.
histogram_v3 = summary_v2.histogram_v3

# Export the default versions.
histogram = summary_v2.histogram
histogram_pb = summary_v2.histogram_pb

Expand Down
52 changes: 4 additions & 48 deletions tensorboard/plugins/histogram/summary_test.py
Expand Up @@ -211,11 +211,11 @@ def test_zero_bucket_count(self):
np.testing.assert_array_equal(buckets, np.array([]).reshape((0, 3)))


class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase):
class SummaryV3OpTest(SummaryBaseTest, tf.test.TestCase):
def setUp(self):
super(SummaryV2OpTest, self).setUp()
super(SummaryV3OpTest, self).setUp()
if tf2 is None:
self.skipTest("v2 summary API not available")
self.skipTest("v3 histogram summary API not available")

def histogram(self, *args, **kwargs):
return self.histogram_event(*args, **kwargs).summary
Expand All @@ -240,7 +240,7 @@ def write_histogram_event(self, *args, **kwargs):
writer.close()

def call_histogram_op(self, *args, **kwargs):
summary.histogram_v2(*args, **kwargs)
summary.histogram(*args, **kwargs)

def test_scoped_tag(self):
with tf.name_scope("scope"):
Expand All @@ -264,50 +264,6 @@ def test_default_step(self):
# Reset to default state for other tests.
tf2.summary.experimental.set_step(None)


class SummaryV2OpGraphTest(SummaryV2OpTest, tf.test.TestCase):
def write_histogram_event(self, *args, **kwargs):
kwargs.setdefault("step", 1)
# Hack to extract current scope since there's no direct API for it.
with tf.name_scope("_") as temp_scope:
scope = temp_scope.rstrip("/_")

@tf2.function
def graph_fn():
# Recreate the active scope inside the defun since it won't propagate.
with tf.name_scope(scope):
self.call_histogram_op(*args, **kwargs)

writer = tf2.summary.create_file_writer(self.get_temp_dir())
with writer.as_default():
graph_fn()
writer.close()

def test_no_gradient_error_xla(self):
@tf2.function(jit_compile=True)
def graph_fn():
x = tf.constant(1.0)
with tf2.GradientTape() as tape1:
with tf2.GradientTape() as tape2:
tape1.watch(x)
tape2.watch(x)
self.call_histogram_op(
name="loss", step=0, data=x, buckets=10
)

# Note that XLA CPU/GPU has no outside compilation support, so summaries
# won't actually run in a jit_compiled function. TPUs do, and follow
# some similar codepaths, so this test stops at graph building to
# exercise those paths without a TPU available.
writer = tf2.summary.create_file_writer(self.get_temp_dir())
with writer.as_default():
graph_fn.get_concrete_function()


class SummaryV3OpTest(SummaryV2OpTest, tf.test.TestCase):
def call_histogram_op(self, *args, **kwargs):
summary.histogram(*args, **kwargs)

def test_singleton_input(self):
pb = self.histogram("twelve", [12])
buckets = tensor_util.make_ndarray(pb.value[0].tensor)
Expand Down
232 changes: 5 additions & 227 deletions tensorboard/plugins/histogram/summary_v2.py
Expand Up @@ -19,21 +19,13 @@
of dimension `[k, 3]`, where the first `k - 1` buckets are closed-open and the
last bucket is closed-closed.
In general, the value of `k` (the number of buckets) will be a constant, like 30.
For V2 format, there are two edge cases: if there is no data, then there are no
buckets (the shape is `[0, 3]`); and if there is data but all points have the
same value, then there is one bucket whose left and right endpoints are the same
(the shape is `[1, 3]`).
For V3 format, the shape of the output histogram is always constant (`[k, 3]`).
In general, the shape of the output histogram is always constant (`[k, 3]`).
In the case of empty data, the output will be an all-zero histogram of shape
`[k, 3]`, where all edges and counts are zeros. If there is data but all points
have the same value, then all buckets' left and right edges are the same and only
the last bucket has nonzero count.
"""

import contextlib

import numpy as np

from tensorboard.compat import tf2 as tf
Expand All @@ -46,217 +38,6 @@
DEFAULT_BUCKET_COUNT = 30


def histogram_v2(name, data, step=None, buckets=None, description=None):
"""Write a histogram summary.
See also `tf.summary.scalar`, `tf.summary.SummaryWriter`.
Writes a histogram to the current default summary writer, for later analysis
in TensorBoard's 'Histograms' and 'Distributions' dashboards (data written
using this API will appear in both places). Like `tf.summary.scalar` points,
each histogram is associated with a `step` and a `name`. All the histograms
with the same `name` constitute a time series of histograms.
The histogram is calculated over all the elements of the given `Tensor`
without regard to its shape or rank.
This example writes 2 histograms:
```python
w = tf.summary.create_file_writer('test/logs')
with w.as_default():
tf.summary.histogram("activations", tf.random.uniform([100, 50]), step=0)
tf.summary.histogram("initial_weights", tf.random.normal([1000]), step=0)
```
A common use case is to examine the changing activation patterns (or lack
thereof) at specific layers in a neural network, over time.
```python
w = tf.summary.create_file_writer('test/logs')
with w.as_default():
for step in range(100):
# Generate fake "activations".
activations = [
tf.random.normal([1000], mean=step, stddev=1),
tf.random.normal([1000], mean=step, stddev=10),
tf.random.normal([1000], mean=step, stddev=100),
]
tf.summary.histogram("layer1/activate", activations[0], step=step)
tf.summary.histogram("layer2/activate", activations[1], step=step)
tf.summary.histogram("layer3/activate", activations[2], step=step)
```
Arguments:
name: A name for this summary. The summary tag used for TensorBoard will
be this name prefixed by any active name scopes.
data: A `Tensor` of any shape. The histogram is computed over its elements,
which must be castable to `float64`.
step: Explicit `int64`-castable monotonic step value for this summary. If
omitted, this defaults to `tf.summary.experimental.get_step()`, which must
not be None.
buckets: Optional positive `int`. The output will have this
many buckets, except in two edge cases. If there is no data, then
there are no buckets. If there is data but all points have the
same value, then there is one bucket whose left and right
endpoints are the same.
description: Optional long-form description for this summary, as a
constant `str`. Markdown is supported. Defaults to empty.
Returns:
True on success, or false if no summary was emitted because no default
summary writer was available.
Raises:
ValueError: if a default writer exists, but no step was provided and
`tf.summary.experimental.get_step()` is None.
"""
# Avoid building unused gradient graphs for conds below. This works around
# an error building second-order gradient graphs when XlaDynamicUpdateSlice
# is used, and will generally speed up graph building slightly.
data = tf.stop_gradient(data)
summary_metadata = metadata.create_summary_metadata(
display_name=None, description=description
)
# TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback
summary_scope = (
getattr(tf.summary.experimental, "summary_scope", None)
or tf.summary.summary_scope
)

# Try to capture current name scope so we can re-enter it below within our
# histogram_summary helper. We do this to avoid having the `tf.cond` below
# insert an extra `cond` into the tag name.
# TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this
# special handling once the format no longer requires dynamic output shapes.
name_scope_cms = []
if hasattr(tf, "get_current_name_scope"):
# Coerce None to ""; this API should only return a string but as of TF
# 2.5 it returns None in contexts that re-enter the empty scope.
current_scope = tf.get_current_name_scope() or ""
# Append a "/" to the scope name, which causes that scope to be treated
# as absolute instead of relative to the current scope, so that we can
# re-enter it. It also prevents auto-incrementing of the scope name.
# This is legacy graph mode behavior, undocumented except in comments:
# https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/framework/ops.py#L6664-L6666
scope_to_reenter = current_scope + "/" if current_scope else ""
name_scope_cms.append(tf.name_scope(scope_to_reenter))

def histogram_summary(data, buckets, histogram_metadata, step):
with contextlib.ExitStack() as stack:
for cm in name_scope_cms:
stack.enter_context(cm)
with summary_scope(
name, "histogram_summary", values=[data, buckets, step]
) as (tag, _):
# Defer histogram bucketing logic by passing it as a callable to
# write(), wrapped in a LazyTensorCreator for backwards
# compatibility, so that we only do this work when summaries are
# actually written.
@lazy_tensor_creator.LazyTensorCreator
def lazy_tensor():
return _buckets(data, buckets)

return tf.summary.write(
tag=tag,
tensor=lazy_tensor,
step=step,
metadata=summary_metadata,
)

# `_buckets()` has dynamic output shapes which is not supported on TPU's.
# To address this, explicitly mark this logic for outside compilation so it
# will be executed on the CPU, and skip it entirely if we aren't actually
# recording summaries to avoid overhead of transferring data.
# TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this
# special handling once the format no longer requires dynamic output shapes.
if isinstance(
tf.distribute.get_strategy(),
(tf.distribute.experimental.TPUStrategy, tf.distribute.TPUStrategy),
):
return tf.cond(
tf.summary.should_record_summaries(),
lambda: tf.compat.v1.tpu.outside_compilation(
histogram_summary, data, buckets, summary_metadata, step
),
lambda: False,
)
return histogram_summary(data, buckets, summary_metadata, step)


def _buckets(data, bucket_count=None):
"""Create a TensorFlow op to group data into histogram buckets.
Arguments:
data: A `Tensor` of any shape. Must be castable to `float64`.
bucket_count: Optional positive `int` or scalar `int32` `Tensor`.
Returns:
A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is
a triple `[left_edge, right_edge, count]` for a single bucket.
The value of `k` is either `bucket_count` or `1` or `0`.
"""
if bucket_count is None:
bucket_count = DEFAULT_BUCKET_COUNT
with tf.name_scope("buckets"):
tf.debugging.assert_scalar(bucket_count)
tf.debugging.assert_type(bucket_count, tf.int32)
data = tf.reshape(data, shape=[-1]) # flatten
data = tf.cast(data, tf.float64)
is_empty = tf.equal(tf.size(input=data), 0)

def when_empty():
return tf.constant([], shape=(0, 3), dtype=tf.float64)

def when_nonempty():
min_ = tf.reduce_min(input_tensor=data)
max_ = tf.reduce_max(input_tensor=data)
range_ = max_ - min_
is_singular = tf.equal(range_, 0)

def when_nonsingular():
bucket_width = range_ / tf.cast(bucket_count, tf.float64)
offsets = data - min_
bucket_indices = tf.cast(
tf.floor(offsets / bucket_width), dtype=tf.int32
)
clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
# Use float64 instead of float32 to avoid accumulating floating point error
# later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
# See https://github.com/tensorflow/tensorflow/issues/51419 for details.
one_hots = tf.one_hot(
clamped_indices, depth=bucket_count, dtype=tf.float64
)
bucket_counts = tf.cast(
tf.reduce_sum(input_tensor=one_hots, axis=0),
dtype=tf.float64,
)
edges = tf.linspace(min_, max_, bucket_count + 1)
# Ensure edges[-1] == max_, which TF's linspace implementation does not
# do, leaving it subject to the whim of floating point rounding error.
edges = tf.concat([edges[:-1], [max_]], 0)
left_edges = edges[:-1]
right_edges = edges[1:]
return tf.transpose(
a=tf.stack([left_edges, right_edges, bucket_counts])
)

def when_singular():
center = min_
bucket_starts = tf.stack([center - 0.5])
bucket_ends = tf.stack([center + 0.5])
bucket_counts = tf.stack(
[tf.cast(tf.size(input=data), tf.float64)]
)
return tf.transpose(
a=tf.stack([bucket_starts, bucket_ends, bucket_counts])
)

return tf.cond(is_singular, when_singular, when_nonsingular)

return tf.cond(is_empty, when_empty, when_nonempty)


def histogram_pb(tag, data, buckets=None, description=None):
"""Create a histogram summary protobuf.
Expand Down Expand Up @@ -318,7 +99,8 @@ def histogram_pb(tag, data, buckets=None, description=None):
return summary


def histogram_v3(name, data, step=None, buckets=None, description=None):
# This is the TPU compatible V3 histogram implementation as of 2021-12-01.
def histogram(name, data, step=None, buckets=None, description=None):
"""Write a histogram summary.
See also `tf.summary.scalar`, `tf.summary.SummaryWriter`.
Expand Down Expand Up @@ -407,7 +189,7 @@ def histogram_v3(name, data, step=None, buckets=None, description=None):
# actually written.
@lazy_tensor_creator.LazyTensorCreator
def lazy_tensor():
return _buckets_v3(data, buckets)
return _buckets(data, buckets)

return tf.summary.write(
tag=tag,
Expand All @@ -417,7 +199,7 @@ def lazy_tensor():
)


def _buckets_v3(data, bucket_count=None):
def _buckets(data, bucket_count=None):
"""Create a TensorFlow op to group data into histogram buckets.
Arguments:
Expand Down Expand Up @@ -507,7 +289,3 @@ def when_single_value():
)

return tf.cond(is_empty, when_empty, when_nonempty)


# Set V3 as default.
histogram = histogram_v3

0 comments on commit 31a0440

Please sign in to comment.