diff --git a/tensorboard/plugins/histogram/summary.py b/tensorboard/plugins/histogram/summary.py index 9ddd9787c9a..d804852f638 100644 --- a/tensorboard/plugins/histogram/summary.py +++ b/tensorboard/plugins/histogram/summary.py @@ -35,13 +35,7 @@ from tensorboard.plugins.histogram import summary_v2 -# Export V2 versions. -histogram_v2 = summary_v2.histogram_v2 - # Export V3 versions. -histogram_v3 = summary_v2.histogram_v3 - -# Export the default versions. histogram = summary_v2.histogram histogram_pb = summary_v2.histogram_pb diff --git a/tensorboard/plugins/histogram/summary_test.py b/tensorboard/plugins/histogram/summary_test.py index b3d12f3c3d0..c1f8f4e4201 100644 --- a/tensorboard/plugins/histogram/summary_test.py +++ b/tensorboard/plugins/histogram/summary_test.py @@ -211,11 +211,11 @@ def test_zero_bucket_count(self): np.testing.assert_array_equal(buckets, np.array([]).reshape((0, 3))) -class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase): +class SummaryV3OpTest(SummaryBaseTest, tf.test.TestCase): def setUp(self): - super(SummaryV2OpTest, self).setUp() + super(SummaryV3OpTest, self).setUp() if tf2 is None: - self.skipTest("v2 summary API not available") + self.skipTest("v3 histogram summary API not available") def histogram(self, *args, **kwargs): return self.histogram_event(*args, **kwargs).summary @@ -240,7 +240,7 @@ def write_histogram_event(self, *args, **kwargs): writer.close() def call_histogram_op(self, *args, **kwargs): - summary.histogram_v2(*args, **kwargs) + summary.histogram(*args, **kwargs) def test_scoped_tag(self): with tf.name_scope("scope"): @@ -264,50 +264,6 @@ def test_default_step(self): # Reset to default state for other tests. tf2.summary.experimental.set_step(None) - -class SummaryV2OpGraphTest(SummaryV2OpTest, tf.test.TestCase): - def write_histogram_event(self, *args, **kwargs): - kwargs.setdefault("step", 1) - # Hack to extract current scope since there's no direct API for it. - with tf.name_scope("_") as temp_scope: - scope = temp_scope.rstrip("/_") - - @tf2.function - def graph_fn(): - # Recreate the active scope inside the defun since it won't propagate. - with tf.name_scope(scope): - self.call_histogram_op(*args, **kwargs) - - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - graph_fn() - writer.close() - - def test_no_gradient_error_xla(self): - @tf2.function(jit_compile=True) - def graph_fn(): - x = tf.constant(1.0) - with tf2.GradientTape() as tape1: - with tf2.GradientTape() as tape2: - tape1.watch(x) - tape2.watch(x) - self.call_histogram_op( - name="loss", step=0, data=x, buckets=10 - ) - - # Note that XLA CPU/GPU has no outside compilation support, so summaries - # won't actually run in a jit_compiled function. TPUs do, and follow - # some similar codepaths, so this test stops at graph building to - # exercise those paths without a TPU available. - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - graph_fn.get_concrete_function() - - -class SummaryV3OpTest(SummaryV2OpTest, tf.test.TestCase): - def call_histogram_op(self, *args, **kwargs): - summary.histogram(*args, **kwargs) - def test_singleton_input(self): pb = self.histogram("twelve", [12]) buckets = tensor_util.make_ndarray(pb.value[0].tensor) diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index f1d76f96e93..2367ce00c0c 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -19,21 +19,13 @@ of dimension `[k, 3]`, where the first `k - 1` buckets are closed-open and the last bucket is closed-closed. -In general, the value of `k` (the number of buckets) will be a constant, like 30. -For V2 format, there are two edge cases: if there is no data, then there are no -buckets (the shape is `[0, 3]`); and if there is data but all points have the -same value, then there is one bucket whose left and right endpoints are the same -(the shape is `[1, 3]`). - -For V3 format, the shape of the output histogram is always constant (`[k, 3]`). +In general, the shape of the output histogram is always constant (`[k, 3]`). In the case of empty data, the output will be an all-zero histogram of shape `[k, 3]`, where all edges and counts are zeros. If there is data but all points have the same value, then all buckets' left and right edges are the same and only the last bucket has nonzero count. """ -import contextlib - import numpy as np from tensorboard.compat import tf2 as tf @@ -46,217 +38,6 @@ DEFAULT_BUCKET_COUNT = 30 -def histogram_v2(name, data, step=None, buckets=None, description=None): - """Write a histogram summary. - - See also `tf.summary.scalar`, `tf.summary.SummaryWriter`. - - Writes a histogram to the current default summary writer, for later analysis - in TensorBoard's 'Histograms' and 'Distributions' dashboards (data written - using this API will appear in both places). Like `tf.summary.scalar` points, - each histogram is associated with a `step` and a `name`. All the histograms - with the same `name` constitute a time series of histograms. - - The histogram is calculated over all the elements of the given `Tensor` - without regard to its shape or rank. - - This example writes 2 histograms: - - ```python - w = tf.summary.create_file_writer('test/logs') - with w.as_default(): - tf.summary.histogram("activations", tf.random.uniform([100, 50]), step=0) - tf.summary.histogram("initial_weights", tf.random.normal([1000]), step=0) - ``` - - A common use case is to examine the changing activation patterns (or lack - thereof) at specific layers in a neural network, over time. - - ```python - w = tf.summary.create_file_writer('test/logs') - with w.as_default(): - for step in range(100): - # Generate fake "activations". - activations = [ - tf.random.normal([1000], mean=step, stddev=1), - tf.random.normal([1000], mean=step, stddev=10), - tf.random.normal([1000], mean=step, stddev=100), - ] - - tf.summary.histogram("layer1/activate", activations[0], step=step) - tf.summary.histogram("layer2/activate", activations[1], step=step) - tf.summary.histogram("layer3/activate", activations[2], step=step) - ``` - - Arguments: - name: A name for this summary. The summary tag used for TensorBoard will - be this name prefixed by any active name scopes. - data: A `Tensor` of any shape. The histogram is computed over its elements, - which must be castable to `float64`. - step: Explicit `int64`-castable monotonic step value for this summary. If - omitted, this defaults to `tf.summary.experimental.get_step()`, which must - not be None. - buckets: Optional positive `int`. The output will have this - many buckets, except in two edge cases. If there is no data, then - there are no buckets. If there is data but all points have the - same value, then there is one bucket whose left and right - endpoints are the same. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - - Returns: - True on success, or false if no summary was emitted because no default - summary writer was available. - - Raises: - ValueError: if a default writer exists, but no step was provided and - `tf.summary.experimental.get_step()` is None. - """ - # Avoid building unused gradient graphs for conds below. This works around - # an error building second-order gradient graphs when XlaDynamicUpdateSlice - # is used, and will generally speed up graph building slightly. - data = tf.stop_gradient(data) - summary_metadata = metadata.create_summary_metadata( - display_name=None, description=description - ) - # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback - summary_scope = ( - getattr(tf.summary.experimental, "summary_scope", None) - or tf.summary.summary_scope - ) - - # Try to capture current name scope so we can re-enter it below within our - # histogram_summary helper. We do this to avoid having the `tf.cond` below - # insert an extra `cond` into the tag name. - # TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this - # special handling once the format no longer requires dynamic output shapes. - name_scope_cms = [] - if hasattr(tf, "get_current_name_scope"): - # Coerce None to ""; this API should only return a string but as of TF - # 2.5 it returns None in contexts that re-enter the empty scope. - current_scope = tf.get_current_name_scope() or "" - # Append a "/" to the scope name, which causes that scope to be treated - # as absolute instead of relative to the current scope, so that we can - # re-enter it. It also prevents auto-incrementing of the scope name. - # This is legacy graph mode behavior, undocumented except in comments: - # https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/framework/ops.py#L6664-L6666 - scope_to_reenter = current_scope + "/" if current_scope else "" - name_scope_cms.append(tf.name_scope(scope_to_reenter)) - - def histogram_summary(data, buckets, histogram_metadata, step): - with contextlib.ExitStack() as stack: - for cm in name_scope_cms: - stack.enter_context(cm) - with summary_scope( - name, "histogram_summary", values=[data, buckets, step] - ) as (tag, _): - # Defer histogram bucketing logic by passing it as a callable to - # write(), wrapped in a LazyTensorCreator for backwards - # compatibility, so that we only do this work when summaries are - # actually written. - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - return _buckets(data, buckets) - - return tf.summary.write( - tag=tag, - tensor=lazy_tensor, - step=step, - metadata=summary_metadata, - ) - - # `_buckets()` has dynamic output shapes which is not supported on TPU's. - # To address this, explicitly mark this logic for outside compilation so it - # will be executed on the CPU, and skip it entirely if we aren't actually - # recording summaries to avoid overhead of transferring data. - # TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this - # special handling once the format no longer requires dynamic output shapes. - if isinstance( - tf.distribute.get_strategy(), - (tf.distribute.experimental.TPUStrategy, tf.distribute.TPUStrategy), - ): - return tf.cond( - tf.summary.should_record_summaries(), - lambda: tf.compat.v1.tpu.outside_compilation( - histogram_summary, data, buckets, summary_metadata, step - ), - lambda: False, - ) - return histogram_summary(data, buckets, summary_metadata, step) - - -def _buckets(data, bucket_count=None): - """Create a TensorFlow op to group data into histogram buckets. - - Arguments: - data: A `Tensor` of any shape. Must be castable to `float64`. - bucket_count: Optional positive `int` or scalar `int32` `Tensor`. - Returns: - A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is - a triple `[left_edge, right_edge, count]` for a single bucket. - The value of `k` is either `bucket_count` or `1` or `0`. - """ - if bucket_count is None: - bucket_count = DEFAULT_BUCKET_COUNT - with tf.name_scope("buckets"): - tf.debugging.assert_scalar(bucket_count) - tf.debugging.assert_type(bucket_count, tf.int32) - data = tf.reshape(data, shape=[-1]) # flatten - data = tf.cast(data, tf.float64) - is_empty = tf.equal(tf.size(input=data), 0) - - def when_empty(): - return tf.constant([], shape=(0, 3), dtype=tf.float64) - - def when_nonempty(): - min_ = tf.reduce_min(input_tensor=data) - max_ = tf.reduce_max(input_tensor=data) - range_ = max_ - min_ - is_singular = tf.equal(range_, 0) - - def when_nonsingular(): - bucket_width = range_ / tf.cast(bucket_count, tf.float64) - offsets = data - min_ - bucket_indices = tf.cast( - tf.floor(offsets / bucket_width), dtype=tf.int32 - ) - clamped_indices = tf.minimum(bucket_indices, bucket_count - 1) - # Use float64 instead of float32 to avoid accumulating floating point error - # later in tf.reduce_sum when summing more than 2^24 individual `1.0` values. - # See https://github.com/tensorflow/tensorflow/issues/51419 for details. - one_hots = tf.one_hot( - clamped_indices, depth=bucket_count, dtype=tf.float64 - ) - bucket_counts = tf.cast( - tf.reduce_sum(input_tensor=one_hots, axis=0), - dtype=tf.float64, - ) - edges = tf.linspace(min_, max_, bucket_count + 1) - # Ensure edges[-1] == max_, which TF's linspace implementation does not - # do, leaving it subject to the whim of floating point rounding error. - edges = tf.concat([edges[:-1], [max_]], 0) - left_edges = edges[:-1] - right_edges = edges[1:] - return tf.transpose( - a=tf.stack([left_edges, right_edges, bucket_counts]) - ) - - def when_singular(): - center = min_ - bucket_starts = tf.stack([center - 0.5]) - bucket_ends = tf.stack([center + 0.5]) - bucket_counts = tf.stack( - [tf.cast(tf.size(input=data), tf.float64)] - ) - return tf.transpose( - a=tf.stack([bucket_starts, bucket_ends, bucket_counts]) - ) - - return tf.cond(is_singular, when_singular, when_nonsingular) - - return tf.cond(is_empty, when_empty, when_nonempty) - - def histogram_pb(tag, data, buckets=None, description=None): """Create a histogram summary protobuf. @@ -318,7 +99,8 @@ def histogram_pb(tag, data, buckets=None, description=None): return summary -def histogram_v3(name, data, step=None, buckets=None, description=None): +# This is the TPU compatible V3 histogram implementation as of 2021-12-01. +def histogram(name, data, step=None, buckets=None, description=None): """Write a histogram summary. See also `tf.summary.scalar`, `tf.summary.SummaryWriter`. @@ -407,7 +189,7 @@ def histogram_v3(name, data, step=None, buckets=None, description=None): # actually written. @lazy_tensor_creator.LazyTensorCreator def lazy_tensor(): - return _buckets_v3(data, buckets) + return _buckets(data, buckets) return tf.summary.write( tag=tag, @@ -417,7 +199,7 @@ def lazy_tensor(): ) -def _buckets_v3(data, bucket_count=None): +def _buckets(data, bucket_count=None): """Create a TensorFlow op to group data into histogram buckets. Arguments: @@ -507,7 +289,3 @@ def when_single_value(): ) return tf.cond(is_empty, when_empty, when_nonempty) - - -# Set V3 as default. -histogram = histogram_v3