From f72b767755ea7880ecf69e1a2c580d4ceea23ed3 Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Thu, 11 Nov 2021 16:15:51 +0000 Subject: [PATCH 1/4] set histogram v3 as default --- tensorboard/plugins/histogram/summary.py | 22 +- tensorboard/plugins/histogram/summary_test.py | 46 +-- tensorboard/plugins/histogram/summary_v2.py | 334 ++++-------------- .../plugins/metrics/metrics_plugin_test.py | 16 +- 4 files changed, 83 insertions(+), 335 deletions(-) diff --git a/tensorboard/plugins/histogram/summary.py b/tensorboard/plugins/histogram/summary.py index a889fcbefe2..fb8f257ef7f 100644 --- a/tensorboard/plugins/histogram/summary.py +++ b/tensorboard/plugins/histogram/summary.py @@ -16,13 +16,20 @@ A histogram summary stores a list of buckets. Each bucket is encoded as a triple `[left_edge, right_edge, count]`. Thus, a full histogram is -encoded as a tensor of dimension `[k, 3]`. +encoded as a tensor of dimension `[k, 3]`, where the first `k - 1` buckets +are closed-open and the last bucket is closed-closed. In general, the value of `k` (the number of buckets) will be a constant, -like 30. There are two edge cases: if there is no data, then there are -no buckets (the shape is `[0, 3]`); and if there is data but all points -have the same value, then there is one bucket whose left and right -endpoints are the same (the shape is `[1, 3]`). +like 30. For V2 format (deprecated), there are two edge cases: if there +is no data, then there are no buckets (the shape is `[0, 3]`); and if +there is data but all points have the same value, then there is one +bucket whose left and right endpoints are the same (the shape is `[1, 3]`). + +For V3 format, the shape of the output histogram is always constant (`[k, 3]`). +In the case of empty data, the output will be an all-zero histogram of shape +`[k, 3]`, where all edges and counts are zeros. If there is data but all points +have the same value, then all buckets' left and right edges are the same and +only the last bucket has nonzero count. NOTE: This module is in beta, and its API is subject to change, but the data that it stores to disk will be supported forever. @@ -35,13 +42,10 @@ from tensorboard.plugins.histogram import summary_v2 -# Export V2 versions. +# Export the latest versions. histogram = summary_v2.histogram histogram_pb = summary_v2.histogram_pb -# Export V3 versions. -histogram_v3 = summary_v2.histogram_v3 - def _buckets(data, bucket_count=None): """Create a TensorFlow op to group data into histogram buckets. diff --git a/tensorboard/plugins/histogram/summary_test.py b/tensorboard/plugins/histogram/summary_test.py index 89af32953ba..eac45782beb 100644 --- a/tensorboard/plugins/histogram/summary_test.py +++ b/tensorboard/plugins/histogram/summary_test.py @@ -264,50 +264,6 @@ def test_default_step(self): # Reset to default state for other tests. tf2.summary.experimental.set_step(None) - -class SummaryV2OpGraphTest(SummaryV2OpTest, tf.test.TestCase): - def write_histogram_event(self, *args, **kwargs): - kwargs.setdefault("step", 1) - # Hack to extract current scope since there's no direct API for it. - with tf.name_scope("_") as temp_scope: - scope = temp_scope.rstrip("/_") - - @tf2.function - def graph_fn(): - # Recreate the active scope inside the defun since it won't propagate. - with tf.name_scope(scope): - self.call_histogram_op(*args, **kwargs) - - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - graph_fn() - writer.close() - - def test_no_gradient_error_xla(self): - @tf2.function(jit_compile=True) - def graph_fn(): - x = tf.constant(1.0) - with tf2.GradientTape() as tape1: - with tf2.GradientTape() as tape2: - tape1.watch(x) - tape2.watch(x) - self.call_histogram_op( - name="loss", step=0, data=x, buckets=10 - ) - - # Note that XLA CPU/GPU has no outside compilation support, so summaries - # won't actually run in a jit_compiled function. TPUs do, and follow - # some similar codepaths, so this test stops at graph building to - # exercise those paths without a TPU available. - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - graph_fn.get_concrete_function() - - -class SummaryV3OpTest(SummaryV2OpTest, tf.test.TestCase): - def call_histogram_op(self, *args, **kwargs): - summary.histogram_v3(*args, **kwargs) - def test_singleton_input(self): pb = self.histogram("twelve", [12]) buckets = tensor_util.make_ndarray(pb.value[0].tensor) @@ -344,7 +300,7 @@ def test_zero_bucket_count(self): np.testing.assert_array_equal(buckets, np.array([]).reshape((0, 3))) -class SummaryV3OpGraphTest(SummaryV3OpTest, tf.test.TestCase): +class SummaryV2OpGraphTest(SummaryV2OpTest, tf.test.TestCase): def write_histogram_event(self, *args, **kwargs): kwargs.setdefault("step", 1) # Hack to extract current scope since there's no direct API for it. diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index 1b850f25872..10a0ef27bd3 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -20,16 +20,11 @@ last bucket is closed-closed. In general, the value of `k` (the number of buckets) will be a constant, like 30. -For V2 format, there are two edge cases: if there is no data, then there are no -buckets (the shape is `[0, 3]`); and if there is data but all points have the -same value, then there is one bucket whose left and right endpoints are the same -(the shape is `[1, 3]`). - -For V3 format, the shape of the output histogram is always constant (`[k, 3]`). -In the case of empty data, the output will be an all-zero histogram of shape -`[k, 3]`, where all edges and counts are zeros. If there is data but all points -have the same value, then all buckets' left and right edges are the same and only -the last bucket has nonzero count. +The shape of the output histogram is always constant (`[k, 3]`). In the case of +empty data, the output will be an all-zero histogram of shape `[k, 3]`, where all +edges and counts are zeros. If there is data but all points have the same value, +then all buckets' left and right edges are the same and only the last bucket has +nonzero count. """ import contextlib @@ -99,8 +94,8 @@ def histogram(name, data, step=None, buckets=None, description=None): buckets: Optional positive `int`. The output will have this many buckets, except in two edge cases. If there is no data, then there are no buckets. If there is data but all points have the - same value, then there is one bucket whose left and right - endpoints are the same. + same value, then all buckets' left and right endpoints are the same + and only the last bucket has nonzero count. description: Optional long-form description for this summary, as a constant `str`. Markdown is supported. Defaults to empty. @@ -125,64 +120,24 @@ def histogram(name, data, step=None, buckets=None, description=None): or tf.summary.summary_scope ) - # Try to capture current name scope so we can re-enter it below within our - # histogram_summary helper. We do this to avoid having the `tf.cond` below - # insert an extra `cond` into the tag name. - # TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this - # special handling once the format no longer requires dynamic output shapes. - name_scope_cms = [] - if hasattr(tf, "get_current_name_scope"): - # Coerce None to ""; this API should only return a string but as of TF - # 2.5 it returns None in contexts that re-enter the empty scope. - current_scope = tf.get_current_name_scope() or "" - # Append a "/" to the scope name, which causes that scope to be treated - # as absolute instead of relative to the current scope, so that we can - # re-enter it. It also prevents auto-incrementing of the scope name. - # This is legacy graph mode behavior, undocumented except in comments: - # https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/framework/ops.py#L6664-L6666 - scope_to_reenter = current_scope + "/" if current_scope else "" - name_scope_cms.append(tf.name_scope(scope_to_reenter)) - - def histogram_summary(data, buckets, histogram_metadata, step): - with contextlib.ExitStack() as stack: - for cm in name_scope_cms: - stack.enter_context(cm) - with summary_scope( - name, "histogram_summary", values=[data, buckets, step] - ) as (tag, _): - # Defer histogram bucketing logic by passing it as a callable to - # write(), wrapped in a LazyTensorCreator for backwards - # compatibility, so that we only do this work when summaries are - # actually written. - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - return _buckets(data, buckets) - - return tf.summary.write( - tag=tag, - tensor=lazy_tensor, - step=step, - metadata=summary_metadata, - ) + # TODO(ytjing): add special case handling. + with summary_scope( + name, "histogram_summary", values=[data, buckets, step] + ) as (tag, _): + # Defer histogram bucketing logic by passing it as a callable to + # write(), wrapped in a LazyTensorCreator for backwards + # compatibility, so that we only do this work when summaries are + # actually written. + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + return _buckets(data, buckets) - # `_buckets()` has dynamic output shapes which is not supported on TPU's. - # To address this, explicitly mark this logic for outside compilation so it - # will be executed on the CPU, and skip it entirely if we aren't actually - # recording summaries to avoid overhead of transferring data. - # TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this - # special handling once the format no longer requires dynamic output shapes. - if isinstance( - tf.distribute.get_strategy(), - (tf.distribute.experimental.TPUStrategy, tf.distribute.TPUStrategy), - ): - return tf.cond( - tf.summary.should_record_summaries(), - lambda: tf.compat.v1.tpu.outside_compilation( - histogram_summary, data, buckets, summary_metadata, step - ), - lambda: False, + return tf.summary.write( + tag=tag, + tensor=lazy_tensor, + step=step, + metadata=summary_metadata, ) - return histogram_summary(data, buckets, summary_metadata, step) def _buckets(data, bucket_count=None): @@ -190,31 +145,46 @@ def _buckets(data, bucket_count=None): Arguments: data: A `Tensor` of any shape. Must be castable to `float64`. - bucket_count: Optional positive `int` or scalar `int32` `Tensor`. + bucket_count: Optional non-negative `int` or scalar `int32` `Tensor`, + defaults to 30. Returns: A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is a triple `[left_edge, right_edge, count]` for a single bucket. - The value of `k` is either `bucket_count` or `1` or `0`. + The value of `k` is either `bucket_count` or `0` (when input data + is empty). """ if bucket_count is None: bucket_count = DEFAULT_BUCKET_COUNT with tf.name_scope("buckets"): tf.debugging.assert_scalar(bucket_count) tf.debugging.assert_type(bucket_count, tf.int32) + # Treat a negative bucket count as zero. + bucket_count = tf.math.maximum(0, bucket_count) data = tf.reshape(data, shape=[-1]) # flatten data = tf.cast(data, tf.float64) - is_empty = tf.equal(tf.size(input=data), 0) + data_size = tf.size(input=data) + is_empty = tf.logical_or( + tf.equal(data_size, 0), tf.less_equal(bucket_count, 0) + ) def when_empty(): - return tf.constant([], shape=(0, 3), dtype=tf.float64) + """When input data is empty or bucket_count is zero. + + 1. If bucket_count is specified as zero, an empty tensor of shape + (0, 3) will be returned. + 2. If the input data is empty, a tensor of shape (bucket_count, 3) + of all zero values will be returned. + """ + return tf.zeros((bucket_count, 3), dtype=tf.float64) def when_nonempty(): min_ = tf.reduce_min(input_tensor=data) max_ = tf.reduce_max(input_tensor=data) range_ = max_ - min_ - is_singular = tf.equal(range_, 0) + has_single_value = tf.equal(range_, 0) - def when_nonsingular(): + def when_multiple_values(): + """When input data contains multiple values.""" bucket_width = range_ / tf.cast(bucket_count, tf.float64) offsets = data - min_ bucket_indices = tf.cast( @@ -241,18 +211,23 @@ def when_nonsingular(): a=tf.stack([left_edges, right_edges, bucket_counts]) ) - def when_singular(): - center = min_ - bucket_starts = tf.stack([center - 0.5]) - bucket_ends = tf.stack([center + 0.5]) - bucket_counts = tf.stack( - [tf.cast(tf.size(input=data), tf.float64)] - ) - return tf.transpose( - a=tf.stack([bucket_starts, bucket_ends, bucket_counts]) + def when_single_value(): + """When input data contains a single unique value.""" + # Left and right edges are the same for single value input. + edges = tf.fill([bucket_count], max_) + # Bucket counts are 0 except the last bucket (if bucket_count > 0), + # which is `data_size`. Ensure that the resulting counts vector has + # length `bucket_count` always, including the bucket_count==0 case. + zeroes = tf.fill([bucket_count], 0) + bucket_counts = tf.cast( + tf.concat([zeroes[:-1], [data_size]], 0)[:bucket_count], + dtype=tf.float64, ) + return tf.transpose(a=tf.stack([edges, edges, bucket_counts])) - return tf.cond(is_singular, when_singular, when_nonsingular) + return tf.cond( + has_single_value, when_single_value, when_multiple_values + ) return tf.cond(is_empty, when_empty, when_nonempty) @@ -316,194 +291,3 @@ def histogram_pb(tag, data, buckets=None, description=None): summary = summary_pb2.Summary() summary.value.add(tag=tag, metadata=summary_metadata, tensor=tensor) return summary - - -def histogram_v3(name, data, step=None, buckets=None, description=None): - """Write a histogram summary. - - See also `tf.summary.scalar`, `tf.summary.SummaryWriter`. - - Writes a histogram to the current default summary writer, for later analysis - in TensorBoard's 'Histograms' and 'Distributions' dashboards (data written - using this API will appear in both places). Like `tf.summary.scalar` points, - each histogram is associated with a `step` and a `name`. All the histograms - with the same `name` constitute a time series of histograms. - - The histogram is calculated over all the elements of the given `Tensor` - without regard to its shape or rank. - - This example writes 2 histograms: - - ```python - w = tf.summary.create_file_writer('test/logs') - with w.as_default(): - tf.summary.histogram("activations", tf.random.uniform([100, 50]), step=0) - tf.summary.histogram("initial_weights", tf.random.normal([1000]), step=0) - ``` - - A common use case is to examine the changing activation patterns (or lack - thereof) at specific layers in a neural network, over time. - - ```python - w = tf.summary.create_file_writer('test/logs') - with w.as_default(): - for step in range(100): - # Generate fake "activations". - activations = [ - tf.random.normal([1000], mean=step, stddev=1), - tf.random.normal([1000], mean=step, stddev=10), - tf.random.normal([1000], mean=step, stddev=100), - ] - - tf.summary.histogram("layer1/activate", activations[0], step=step) - tf.summary.histogram("layer2/activate", activations[1], step=step) - tf.summary.histogram("layer3/activate", activations[2], step=step) - ``` - - Arguments: - name: A name for this summary. The summary tag used for TensorBoard will - be this name prefixed by any active name scopes. - data: A `Tensor` of any shape. The histogram is computed over its elements, - which must be castable to `float64`. - step: Explicit `int64`-castable monotonic step value for this summary. If - omitted, this defaults to `tf.summary.experimental.get_step()`, which must - not be None. - buckets: Optional positive `int`. The output will have this - many buckets, except in two edge cases. If there is no data, then - there are no buckets. If there is data but all points have the - same value, then all buckets' left and right endpoints are the same - and only the last bucket has nonzero count. - description: Optional long-form description for this summary, as a - constant `str`. Markdown is supported. Defaults to empty. - - Returns: - True on success, or false if no summary was emitted because no default - summary writer was available. - - Raises: - ValueError: if a default writer exists, but no step was provided and - `tf.summary.experimental.get_step()` is None. - """ - # Avoid building unused gradient graphs for conds below. This works around - # an error building second-order gradient graphs when XlaDynamicUpdateSlice - # is used, and will generally speed up graph building slightly. - data = tf.stop_gradient(data) - summary_metadata = metadata.create_summary_metadata( - display_name=None, description=description - ) - # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback - summary_scope = ( - getattr(tf.summary.experimental, "summary_scope", None) - or tf.summary.summary_scope - ) - - # TODO(ytjing): add special case handling. - with summary_scope( - name, "histogram_summary", values=[data, buckets, step] - ) as (tag, _): - # Defer histogram bucketing logic by passing it as a callable to - # write(), wrapped in a LazyTensorCreator for backwards - # compatibility, so that we only do this work when summaries are - # actually written. - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - return _buckets_v3(data, buckets) - - return tf.summary.write( - tag=tag, - tensor=lazy_tensor, - step=step, - metadata=summary_metadata, - ) - - -def _buckets_v3(data, bucket_count=None): - """Create a TensorFlow op to group data into histogram buckets. - - Arguments: - data: A `Tensor` of any shape. Must be castable to `float64`. - bucket_count: Optional non-negative `int` or scalar `int32` `Tensor`, - defaults to 30. - Returns: - A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is - a triple `[left_edge, right_edge, count]` for a single bucket. - The value of `k` is either `bucket_count` or `0` (when input data - is empty). - """ - if bucket_count is None: - bucket_count = DEFAULT_BUCKET_COUNT - with tf.name_scope("buckets"): - tf.debugging.assert_scalar(bucket_count) - tf.debugging.assert_type(bucket_count, tf.int32) - # Treat a negative bucket count as zero. - bucket_count = tf.math.maximum(0, bucket_count) - data = tf.reshape(data, shape=[-1]) # flatten - data = tf.cast(data, tf.float64) - data_size = tf.size(input=data) - is_empty = tf.logical_or( - tf.equal(data_size, 0), tf.less_equal(bucket_count, 0) - ) - - def when_empty(): - """When input data is empty or bucket_count is zero. - - 1. If bucket_count is specified as zero, an empty tensor of shape - (0, 3) will be returned. - 2. If the input data is empty, a tensor of shape (bucket_count, 3) - of all zero values will be returned. - """ - return tf.zeros((bucket_count, 3), dtype=tf.float64) - - def when_nonempty(): - min_ = tf.reduce_min(input_tensor=data) - max_ = tf.reduce_max(input_tensor=data) - range_ = max_ - min_ - has_single_value = tf.equal(range_, 0) - - def when_multiple_values(): - """When input data contains multiple values.""" - bucket_width = range_ / tf.cast(bucket_count, tf.float64) - offsets = data - min_ - bucket_indices = tf.cast( - tf.floor(offsets / bucket_width), dtype=tf.int32 - ) - clamped_indices = tf.minimum(bucket_indices, bucket_count - 1) - # Use float64 instead of float32 to avoid accumulating floating point error - # later in tf.reduce_sum when summing more than 2^24 individual `1.0` values. - # See https://github.com/tensorflow/tensorflow/issues/51419 for details. - one_hots = tf.one_hot( - clamped_indices, depth=bucket_count, dtype=tf.float64 - ) - bucket_counts = tf.cast( - tf.reduce_sum(input_tensor=one_hots, axis=0), - dtype=tf.float64, - ) - edges = tf.linspace(min_, max_, bucket_count + 1) - # Ensure edges[-1] == max_, which TF's linspace implementation does not - # do, leaving it subject to the whim of floating point rounding error. - edges = tf.concat([edges[:-1], [max_]], 0) - left_edges = edges[:-1] - right_edges = edges[1:] - return tf.transpose( - a=tf.stack([left_edges, right_edges, bucket_counts]) - ) - - def when_single_value(): - """When input data contains a single unique value.""" - # Left and right edges are the same for single value input. - edges = tf.fill([bucket_count], max_) - # Bucket counts are 0 except the last bucket (if bucket_count > 0), - # which is `data_size`. Ensure that the resulting counts vector has - # length `bucket_count` always, including the bucket_count==0 case. - zeroes = tf.fill([bucket_count], 0) - bucket_counts = tf.cast( - tf.concat([zeroes[:-1], [data_size]], 0)[:bucket_count], - dtype=tf.float64, - ) - return tf.transpose(a=tf.stack([edges, edges, bucket_counts])) - - return tf.cond( - has_single_value, when_single_value, when_multiple_values - ) - - return tf.cond(is_empty, when_empty, when_nonempty) diff --git a/tensorboard/plugins/metrics/metrics_plugin_test.py b/tensorboard/plugins/metrics/metrics_plugin_test.py index a719b949d73..95852b5727c 100644 --- a/tensorboard/plugins/metrics/metrics_plugin_test.py +++ b/tensorboard/plugins/metrics/metrics_plugin_test.py @@ -420,6 +420,14 @@ def test_time_series_histogram(self): ) clean_response = self._clean_time_series_responses(response) + # By default 30 bins will be generated. + bins_zero = [{"min": 0, "max": 0, "count": 0}] * 29 + [ + {"min": 0, "max": 0, "count": 1.0} + ] + bins_ten = [{"min": 10, "max": 10, "count": 0}] * 29 + [ + {"min": 10, "max": 10, "count": 1.0} + ] + self.assertEqual( [ { @@ -431,16 +439,12 @@ def test_time_series_histogram(self): { "wallTime": "", "step": 0, - "bins": [ - {"min": -0.5, "max": 0.5, "count": 1.0} - ], + "bins": bins_zero, }, { "wallTime": "", "step": 1, - "bins": [ - {"min": 9.5, "max": 10.5, "count": 1.0} - ], + "bins": bins_ten, }, ] }, From 40bf2e13d5e3ba1c26ef2dfe35ee8054c1440d2b Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Thu, 11 Nov 2021 16:30:47 +0000 Subject: [PATCH 2/4] remove unused import --- tensorboard/plugins/histogram/summary_v2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index 10a0ef27bd3..b0e69bd9555 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -27,8 +27,6 @@ nonzero count. """ -import contextlib - import numpy as np from tensorboard.compat import tf2 as tf From a945fa76fe6bb3f8414d8aa899239a98867ec027 Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Thu, 11 Nov 2021 21:15:25 +0000 Subject: [PATCH 3/4] keep the v2 version for now --- tensorboard/plugins/histogram/summary.py | 22 +- tensorboard/plugins/histogram/summary_test.py | 48 ++- tensorboard/plugins/histogram/summary_v2.py | 342 +++++++++++++++--- 3 files changed, 337 insertions(+), 75 deletions(-) diff --git a/tensorboard/plugins/histogram/summary.py b/tensorboard/plugins/histogram/summary.py index fb8f257ef7f..1b25d922210 100644 --- a/tensorboard/plugins/histogram/summary.py +++ b/tensorboard/plugins/histogram/summary.py @@ -16,20 +16,13 @@ A histogram summary stores a list of buckets. Each bucket is encoded as a triple `[left_edge, right_edge, count]`. Thus, a full histogram is -encoded as a tensor of dimension `[k, 3]`, where the first `k - 1` buckets -are closed-open and the last bucket is closed-closed. +encoded as a tensor of dimension `[k, 3]`. In general, the value of `k` (the number of buckets) will be a constant, -like 30. For V2 format (deprecated), there are two edge cases: if there -is no data, then there are no buckets (the shape is `[0, 3]`); and if -there is data but all points have the same value, then there is one -bucket whose left and right endpoints are the same (the shape is `[1, 3]`). - -For V3 format, the shape of the output histogram is always constant (`[k, 3]`). -In the case of empty data, the output will be an all-zero histogram of shape -`[k, 3]`, where all edges and counts are zeros. If there is data but all points -have the same value, then all buckets' left and right edges are the same and -only the last bucket has nonzero count. +like 30. There are two edge cases: if there is no data, then there are +no buckets (the shape is `[0, 3]`); and if there is data but all points +have the same value, then there is one bucket whose left and right +endpoints are the same (the shape is `[1, 3]`). NOTE: This module is in beta, and its API is subject to change, but the data that it stores to disk will be supported forever. @@ -42,7 +35,10 @@ from tensorboard.plugins.histogram import summary_v2 -# Export the latest versions. +# Export V2 versions. +histogram_v2 = summary_v2.histogram_v2 + +# Export the default versions. histogram = summary_v2.histogram histogram_pb = summary_v2.histogram_pb diff --git a/tensorboard/plugins/histogram/summary_test.py b/tensorboard/plugins/histogram/summary_test.py index eac45782beb..b3d12f3c3d0 100644 --- a/tensorboard/plugins/histogram/summary_test.py +++ b/tensorboard/plugins/histogram/summary_test.py @@ -240,7 +240,7 @@ def write_histogram_event(self, *args, **kwargs): writer.close() def call_histogram_op(self, *args, **kwargs): - summary.histogram(*args, **kwargs) + summary.histogram_v2(*args, **kwargs) def test_scoped_tag(self): with tf.name_scope("scope"): @@ -264,6 +264,50 @@ def test_default_step(self): # Reset to default state for other tests. tf2.summary.experimental.set_step(None) + +class SummaryV2OpGraphTest(SummaryV2OpTest, tf.test.TestCase): + def write_histogram_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + # Hack to extract current scope since there's no direct API for it. + with tf.name_scope("_") as temp_scope: + scope = temp_scope.rstrip("/_") + + @tf2.function + def graph_fn(): + # Recreate the active scope inside the defun since it won't propagate. + with tf.name_scope(scope): + self.call_histogram_op(*args, **kwargs) + + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + graph_fn() + writer.close() + + def test_no_gradient_error_xla(self): + @tf2.function(jit_compile=True) + def graph_fn(): + x = tf.constant(1.0) + with tf2.GradientTape() as tape1: + with tf2.GradientTape() as tape2: + tape1.watch(x) + tape2.watch(x) + self.call_histogram_op( + name="loss", step=0, data=x, buckets=10 + ) + + # Note that XLA CPU/GPU has no outside compilation support, so summaries + # won't actually run in a jit_compiled function. TPUs do, and follow + # some similar codepaths, so this test stops at graph building to + # exercise those paths without a TPU available. + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + graph_fn.get_concrete_function() + + +class SummaryV3OpTest(SummaryV2OpTest, tf.test.TestCase): + def call_histogram_op(self, *args, **kwargs): + summary.histogram(*args, **kwargs) + def test_singleton_input(self): pb = self.histogram("twelve", [12]) buckets = tensor_util.make_ndarray(pb.value[0].tensor) @@ -300,7 +344,7 @@ def test_zero_bucket_count(self): np.testing.assert_array_equal(buckets, np.array([]).reshape((0, 3))) -class SummaryV2OpGraphTest(SummaryV2OpTest, tf.test.TestCase): +class SummaryV3OpGraphTest(SummaryV3OpTest, tf.test.TestCase): def write_histogram_event(self, *args, **kwargs): kwargs.setdefault("step", 1) # Hack to extract current scope since there's no direct API for it. diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index b0e69bd9555..f1d76f96e93 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -20,13 +20,20 @@ last bucket is closed-closed. In general, the value of `k` (the number of buckets) will be a constant, like 30. -The shape of the output histogram is always constant (`[k, 3]`). In the case of -empty data, the output will be an all-zero histogram of shape `[k, 3]`, where all -edges and counts are zeros. If there is data but all points have the same value, -then all buckets' left and right edges are the same and only the last bucket has -nonzero count. +For V2 format, there are two edge cases: if there is no data, then there are no +buckets (the shape is `[0, 3]`); and if there is data but all points have the +same value, then there is one bucket whose left and right endpoints are the same +(the shape is `[1, 3]`). + +For V3 format, the shape of the output histogram is always constant (`[k, 3]`). +In the case of empty data, the output will be an all-zero histogram of shape +`[k, 3]`, where all edges and counts are zeros. If there is data but all points +have the same value, then all buckets' left and right edges are the same and only +the last bucket has nonzero count. """ +import contextlib + import numpy as np from tensorboard.compat import tf2 as tf @@ -39,7 +46,7 @@ DEFAULT_BUCKET_COUNT = 30 -def histogram(name, data, step=None, buckets=None, description=None): +def histogram_v2(name, data, step=None, buckets=None, description=None): """Write a histogram summary. See also `tf.summary.scalar`, `tf.summary.SummaryWriter`. @@ -92,8 +99,8 @@ def histogram(name, data, step=None, buckets=None, description=None): buckets: Optional positive `int`. The output will have this many buckets, except in two edge cases. If there is no data, then there are no buckets. If there is data but all points have the - same value, then all buckets' left and right endpoints are the same - and only the last bucket has nonzero count. + same value, then there is one bucket whose left and right + endpoints are the same. description: Optional long-form description for this summary, as a constant `str`. Markdown is supported. Defaults to empty. @@ -118,24 +125,64 @@ def histogram(name, data, step=None, buckets=None, description=None): or tf.summary.summary_scope ) - # TODO(ytjing): add special case handling. - with summary_scope( - name, "histogram_summary", values=[data, buckets, step] - ) as (tag, _): - # Defer histogram bucketing logic by passing it as a callable to - # write(), wrapped in a LazyTensorCreator for backwards - # compatibility, so that we only do this work when summaries are - # actually written. - @lazy_tensor_creator.LazyTensorCreator - def lazy_tensor(): - return _buckets(data, buckets) + # Try to capture current name scope so we can re-enter it below within our + # histogram_summary helper. We do this to avoid having the `tf.cond` below + # insert an extra `cond` into the tag name. + # TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this + # special handling once the format no longer requires dynamic output shapes. + name_scope_cms = [] + if hasattr(tf, "get_current_name_scope"): + # Coerce None to ""; this API should only return a string but as of TF + # 2.5 it returns None in contexts that re-enter the empty scope. + current_scope = tf.get_current_name_scope() or "" + # Append a "/" to the scope name, which causes that scope to be treated + # as absolute instead of relative to the current scope, so that we can + # re-enter it. It also prevents auto-incrementing of the scope name. + # This is legacy graph mode behavior, undocumented except in comments: + # https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/framework/ops.py#L6664-L6666 + scope_to_reenter = current_scope + "/" if current_scope else "" + name_scope_cms.append(tf.name_scope(scope_to_reenter)) + + def histogram_summary(data, buckets, histogram_metadata, step): + with contextlib.ExitStack() as stack: + for cm in name_scope_cms: + stack.enter_context(cm) + with summary_scope( + name, "histogram_summary", values=[data, buckets, step] + ) as (tag, _): + # Defer histogram bucketing logic by passing it as a callable to + # write(), wrapped in a LazyTensorCreator for backwards + # compatibility, so that we only do this work when summaries are + # actually written. + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + return _buckets(data, buckets) + + return tf.summary.write( + tag=tag, + tensor=lazy_tensor, + step=step, + metadata=summary_metadata, + ) - return tf.summary.write( - tag=tag, - tensor=lazy_tensor, - step=step, - metadata=summary_metadata, + # `_buckets()` has dynamic output shapes which is not supported on TPU's. + # To address this, explicitly mark this logic for outside compilation so it + # will be executed on the CPU, and skip it entirely if we aren't actually + # recording summaries to avoid overhead of transferring data. + # TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this + # special handling once the format no longer requires dynamic output shapes. + if isinstance( + tf.distribute.get_strategy(), + (tf.distribute.experimental.TPUStrategy, tf.distribute.TPUStrategy), + ): + return tf.cond( + tf.summary.should_record_summaries(), + lambda: tf.compat.v1.tpu.outside_compilation( + histogram_summary, data, buckets, summary_metadata, step + ), + lambda: False, ) + return histogram_summary(data, buckets, summary_metadata, step) def _buckets(data, bucket_count=None): @@ -143,46 +190,31 @@ def _buckets(data, bucket_count=None): Arguments: data: A `Tensor` of any shape. Must be castable to `float64`. - bucket_count: Optional non-negative `int` or scalar `int32` `Tensor`, - defaults to 30. + bucket_count: Optional positive `int` or scalar `int32` `Tensor`. Returns: A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is a triple `[left_edge, right_edge, count]` for a single bucket. - The value of `k` is either `bucket_count` or `0` (when input data - is empty). + The value of `k` is either `bucket_count` or `1` or `0`. """ if bucket_count is None: bucket_count = DEFAULT_BUCKET_COUNT with tf.name_scope("buckets"): tf.debugging.assert_scalar(bucket_count) tf.debugging.assert_type(bucket_count, tf.int32) - # Treat a negative bucket count as zero. - bucket_count = tf.math.maximum(0, bucket_count) data = tf.reshape(data, shape=[-1]) # flatten data = tf.cast(data, tf.float64) - data_size = tf.size(input=data) - is_empty = tf.logical_or( - tf.equal(data_size, 0), tf.less_equal(bucket_count, 0) - ) + is_empty = tf.equal(tf.size(input=data), 0) def when_empty(): - """When input data is empty or bucket_count is zero. - - 1. If bucket_count is specified as zero, an empty tensor of shape - (0, 3) will be returned. - 2. If the input data is empty, a tensor of shape (bucket_count, 3) - of all zero values will be returned. - """ - return tf.zeros((bucket_count, 3), dtype=tf.float64) + return tf.constant([], shape=(0, 3), dtype=tf.float64) def when_nonempty(): min_ = tf.reduce_min(input_tensor=data) max_ = tf.reduce_max(input_tensor=data) range_ = max_ - min_ - has_single_value = tf.equal(range_, 0) + is_singular = tf.equal(range_, 0) - def when_multiple_values(): - """When input data contains multiple values.""" + def when_nonsingular(): bucket_width = range_ / tf.cast(bucket_count, tf.float64) offsets = data - min_ bucket_indices = tf.cast( @@ -209,23 +241,18 @@ def when_multiple_values(): a=tf.stack([left_edges, right_edges, bucket_counts]) ) - def when_single_value(): - """When input data contains a single unique value.""" - # Left and right edges are the same for single value input. - edges = tf.fill([bucket_count], max_) - # Bucket counts are 0 except the last bucket (if bucket_count > 0), - # which is `data_size`. Ensure that the resulting counts vector has - # length `bucket_count` always, including the bucket_count==0 case. - zeroes = tf.fill([bucket_count], 0) - bucket_counts = tf.cast( - tf.concat([zeroes[:-1], [data_size]], 0)[:bucket_count], - dtype=tf.float64, + def when_singular(): + center = min_ + bucket_starts = tf.stack([center - 0.5]) + bucket_ends = tf.stack([center + 0.5]) + bucket_counts = tf.stack( + [tf.cast(tf.size(input=data), tf.float64)] + ) + return tf.transpose( + a=tf.stack([bucket_starts, bucket_ends, bucket_counts]) ) - return tf.transpose(a=tf.stack([edges, edges, bucket_counts])) - return tf.cond( - has_single_value, when_single_value, when_multiple_values - ) + return tf.cond(is_singular, when_singular, when_nonsingular) return tf.cond(is_empty, when_empty, when_nonempty) @@ -289,3 +316,198 @@ def histogram_pb(tag, data, buckets=None, description=None): summary = summary_pb2.Summary() summary.value.add(tag=tag, metadata=summary_metadata, tensor=tensor) return summary + + +def histogram_v3(name, data, step=None, buckets=None, description=None): + """Write a histogram summary. + + See also `tf.summary.scalar`, `tf.summary.SummaryWriter`. + + Writes a histogram to the current default summary writer, for later analysis + in TensorBoard's 'Histograms' and 'Distributions' dashboards (data written + using this API will appear in both places). Like `tf.summary.scalar` points, + each histogram is associated with a `step` and a `name`. All the histograms + with the same `name` constitute a time series of histograms. + + The histogram is calculated over all the elements of the given `Tensor` + without regard to its shape or rank. + + This example writes 2 histograms: + + ```python + w = tf.summary.create_file_writer('test/logs') + with w.as_default(): + tf.summary.histogram("activations", tf.random.uniform([100, 50]), step=0) + tf.summary.histogram("initial_weights", tf.random.normal([1000]), step=0) + ``` + + A common use case is to examine the changing activation patterns (or lack + thereof) at specific layers in a neural network, over time. + + ```python + w = tf.summary.create_file_writer('test/logs') + with w.as_default(): + for step in range(100): + # Generate fake "activations". + activations = [ + tf.random.normal([1000], mean=step, stddev=1), + tf.random.normal([1000], mean=step, stddev=10), + tf.random.normal([1000], mean=step, stddev=100), + ] + + tf.summary.histogram("layer1/activate", activations[0], step=step) + tf.summary.histogram("layer2/activate", activations[1], step=step) + tf.summary.histogram("layer3/activate", activations[2], step=step) + ``` + + Arguments: + name: A name for this summary. The summary tag used for TensorBoard will + be this name prefixed by any active name scopes. + data: A `Tensor` of any shape. The histogram is computed over its elements, + which must be castable to `float64`. + step: Explicit `int64`-castable monotonic step value for this summary. If + omitted, this defaults to `tf.summary.experimental.get_step()`, which must + not be None. + buckets: Optional positive `int`. The output will have this + many buckets, except in two edge cases. If there is no data, then + there are no buckets. If there is data but all points have the + same value, then all buckets' left and right endpoints are the same + and only the last bucket has nonzero count. + description: Optional long-form description for this summary, as a + constant `str`. Markdown is supported. Defaults to empty. + + Returns: + True on success, or false if no summary was emitted because no default + summary writer was available. + + Raises: + ValueError: if a default writer exists, but no step was provided and + `tf.summary.experimental.get_step()` is None. + """ + # Avoid building unused gradient graphs for conds below. This works around + # an error building second-order gradient graphs when XlaDynamicUpdateSlice + # is used, and will generally speed up graph building slightly. + data = tf.stop_gradient(data) + summary_metadata = metadata.create_summary_metadata( + display_name=None, description=description + ) + # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback + summary_scope = ( + getattr(tf.summary.experimental, "summary_scope", None) + or tf.summary.summary_scope + ) + + # TODO(ytjing): add special case handling. + with summary_scope( + name, "histogram_summary", values=[data, buckets, step] + ) as (tag, _): + # Defer histogram bucketing logic by passing it as a callable to + # write(), wrapped in a LazyTensorCreator for backwards + # compatibility, so that we only do this work when summaries are + # actually written. + @lazy_tensor_creator.LazyTensorCreator + def lazy_tensor(): + return _buckets_v3(data, buckets) + + return tf.summary.write( + tag=tag, + tensor=lazy_tensor, + step=step, + metadata=summary_metadata, + ) + + +def _buckets_v3(data, bucket_count=None): + """Create a TensorFlow op to group data into histogram buckets. + + Arguments: + data: A `Tensor` of any shape. Must be castable to `float64`. + bucket_count: Optional non-negative `int` or scalar `int32` `Tensor`, + defaults to 30. + Returns: + A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is + a triple `[left_edge, right_edge, count]` for a single bucket. + The value of `k` is either `bucket_count` or `0` (when input data + is empty). + """ + if bucket_count is None: + bucket_count = DEFAULT_BUCKET_COUNT + with tf.name_scope("buckets"): + tf.debugging.assert_scalar(bucket_count) + tf.debugging.assert_type(bucket_count, tf.int32) + # Treat a negative bucket count as zero. + bucket_count = tf.math.maximum(0, bucket_count) + data = tf.reshape(data, shape=[-1]) # flatten + data = tf.cast(data, tf.float64) + data_size = tf.size(input=data) + is_empty = tf.logical_or( + tf.equal(data_size, 0), tf.less_equal(bucket_count, 0) + ) + + def when_empty(): + """When input data is empty or bucket_count is zero. + + 1. If bucket_count is specified as zero, an empty tensor of shape + (0, 3) will be returned. + 2. If the input data is empty, a tensor of shape (bucket_count, 3) + of all zero values will be returned. + """ + return tf.zeros((bucket_count, 3), dtype=tf.float64) + + def when_nonempty(): + min_ = tf.reduce_min(input_tensor=data) + max_ = tf.reduce_max(input_tensor=data) + range_ = max_ - min_ + has_single_value = tf.equal(range_, 0) + + def when_multiple_values(): + """When input data contains multiple values.""" + bucket_width = range_ / tf.cast(bucket_count, tf.float64) + offsets = data - min_ + bucket_indices = tf.cast( + tf.floor(offsets / bucket_width), dtype=tf.int32 + ) + clamped_indices = tf.minimum(bucket_indices, bucket_count - 1) + # Use float64 instead of float32 to avoid accumulating floating point error + # later in tf.reduce_sum when summing more than 2^24 individual `1.0` values. + # See https://github.com/tensorflow/tensorflow/issues/51419 for details. + one_hots = tf.one_hot( + clamped_indices, depth=bucket_count, dtype=tf.float64 + ) + bucket_counts = tf.cast( + tf.reduce_sum(input_tensor=one_hots, axis=0), + dtype=tf.float64, + ) + edges = tf.linspace(min_, max_, bucket_count + 1) + # Ensure edges[-1] == max_, which TF's linspace implementation does not + # do, leaving it subject to the whim of floating point rounding error. + edges = tf.concat([edges[:-1], [max_]], 0) + left_edges = edges[:-1] + right_edges = edges[1:] + return tf.transpose( + a=tf.stack([left_edges, right_edges, bucket_counts]) + ) + + def when_single_value(): + """When input data contains a single unique value.""" + # Left and right edges are the same for single value input. + edges = tf.fill([bucket_count], max_) + # Bucket counts are 0 except the last bucket (if bucket_count > 0), + # which is `data_size`. Ensure that the resulting counts vector has + # length `bucket_count` always, including the bucket_count==0 case. + zeroes = tf.fill([bucket_count], 0) + bucket_counts = tf.cast( + tf.concat([zeroes[:-1], [data_size]], 0)[:bucket_count], + dtype=tf.float64, + ) + return tf.transpose(a=tf.stack([edges, edges, bucket_counts])) + + return tf.cond( + has_single_value, when_single_value, when_multiple_values + ) + + return tf.cond(is_empty, when_empty, when_nonempty) + + +# Set V3 as default. +histogram = histogram_v3 From 74712dc703e699b60e0cd5c43068c6b04f46e657 Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Thu, 11 Nov 2021 22:52:01 +0000 Subject: [PATCH 4/4] keep the v3 export to avoid breaking internal tests --- tensorboard/plugins/histogram/summary.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorboard/plugins/histogram/summary.py b/tensorboard/plugins/histogram/summary.py index 1b25d922210..9ddd9787c9a 100644 --- a/tensorboard/plugins/histogram/summary.py +++ b/tensorboard/plugins/histogram/summary.py @@ -38,6 +38,9 @@ # Export V2 versions. histogram_v2 = summary_v2.histogram_v2 +# Export V3 versions. +histogram_v3 = summary_v2.histogram_v3 + # Export the default versions. histogram = summary_v2.histogram histogram_pb = summary_v2.histogram_pb