diff --git a/tensorboard/plugins/histogram/summary.py b/tensorboard/plugins/histogram/summary.py index 72503c6f9a0..a889fcbefe2 100644 --- a/tensorboard/plugins/histogram/summary.py +++ b/tensorboard/plugins/histogram/summary.py @@ -39,6 +39,9 @@ histogram = summary_v2.histogram histogram_pb = summary_v2.histogram_pb +# Export V3 versions. +histogram_v3 = summary_v2.histogram_v3 + def _buckets(data, bucket_count=None): """Create a TensorFlow op to group data into histogram buckets. diff --git a/tensorboard/plugins/histogram/summary_test.py b/tensorboard/plugins/histogram/summary_test.py index ba087d6c58d..aceae45b50b 100644 --- a/tensorboard/plugins/histogram/summary_test.py +++ b/tensorboard/plugins/histogram/summary_test.py @@ -201,9 +201,12 @@ def write_histogram_event(self, *args, **kwargs): kwargs.setdefault("step", 1) writer = tf2.summary.create_file_writer(self.get_temp_dir()) with writer.as_default(): - summary.histogram(*args, **kwargs) + self.call_histogram_op(*args, **kwargs) writer.close() + def call_histogram_op(self, *args, **kwargs): + summary.histogram(*args, **kwargs) + def test_scoped_tag(self): with tf.name_scope("scope"): self.assertEqual("scope/a", self.histogram("a", []).value[0].tag) @@ -238,7 +241,86 @@ def write_histogram_event(self, *args, **kwargs): def graph_fn(): # Recreate the active scope inside the defun since it won't propagate. with tf.name_scope(scope): - summary.histogram(*args, **kwargs) + self.call_histogram_op(*args, **kwargs) + + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + graph_fn() + writer.close() + + def test_no_gradient_error_xla(self): + @tf2.function(jit_compile=True) + def graph_fn(): + x = tf.constant(1.0) + with tf2.GradientTape() as tape1: + with tf2.GradientTape() as tape2: + tape1.watch(x) + tape2.watch(x) + self.call_histogram_op( + name="loss", step=0, data=x, buckets=10 + ) + + # Note that XLA CPU/GPU has no outside compilation support, so summaries + # won't actually run in a jit_compiled function. TPUs do, and follow + # some similar codepaths, so this test stops at graph building to + # exercise those paths without a TPU available. + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + graph_fn.get_concrete_function() + + +class SummaryV3OpTest(SummaryV2OpTest, tf.test.TestCase): + def call_histogram_op(self, *args, **kwargs): + summary.histogram_v3(*args, **kwargs) + + def test_singleton_input(self): + pb = self.histogram("twelve", [12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + # By default there will be 30 buckets. + expected_buckets = np.array( + [[12, 12, 0] for _ in range(29)] + [[12, 12, 1]] + ) + np.testing.assert_allclose(buckets, expected_buckets) + + def test_input_with_all_same_values(self): + pb = self.histogram("twelven", [12, 12, 12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + # By default there will be 30 buckets. + expected_buckets = np.array( + [[12, 12, 0] for _ in range(29)] + [[12, 12, 3]] + ) + np.testing.assert_allclose(buckets, expected_buckets) + + def test_empty_input(self): + pb = self.histogram("empty", []) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + # By default there will be 30 buckets. + np.testing.assert_allclose(buckets, np.zeros((30, 3))) + + def test_empty_input_of_high_rank(self): + pb = self.histogram("empty_but_fancy", [[[], []], [[], []]]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + # By default there will be 30 buckets. + np.testing.assert_allclose(buckets, np.zeros((30, 3))) + + def test_zero_bucket_count(self): + pb = self.histogram("zero_bucket_count", [1, 1, 1], buckets=0) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_array_equal(buckets, np.array([]).reshape((0, 3))) + + +class SummaryV3OpGraphTest(SummaryV3OpTest, tf.test.TestCase): + def write_histogram_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + # Hack to extract current scope since there's no direct API for it. + with tf.name_scope("_") as temp_scope: + scope = temp_scope.rstrip("/_") + + @tf2.function + def graph_fn(): + # Recreate the active scope inside the defun since it won't propagate. + with tf.name_scope(scope): + self.call_histogram_op(*args, **kwargs) writer = tf2.summary.create_file_writer(self.get_temp_dir()) with writer.as_default(): @@ -253,7 +335,9 @@ def graph_fn(): with tf2.GradientTape() as tape2: tape1.watch(x) tape2.watch(x) - summary.histogram(name="loss", step=0, data=x, buckets=10) + self.call_histogram_op( + name="loss", step=0, data=x, buckets=10 + ) # Note that XLA CPU/GPU has no outside compilation support, so summaries # won't actually run in a jit_compiled function. TPUs do, and follow diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index d4ed4b76ff9..e1cd01846db 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -412,7 +412,8 @@ def _buckets_v3(data, bucket_count=None): Arguments: data: A `Tensor` of any shape. Must be castable to `float64`. - bucket_count: Optional positive `int` or scalar `int32` `Tensor`. + bucket_count: Optional non-negative `int` or scalar `int32` `Tensor`, + defaults to 30. Returns: A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is a triple `[left_edge, right_edge, count]` for a single bucket. @@ -424,21 +425,33 @@ def _buckets_v3(data, bucket_count=None): with tf.name_scope("buckets"): tf.debugging.assert_scalar(bucket_count) tf.debugging.assert_type(bucket_count, tf.int32) + # Treat a negative bucket count as zero. + bucket_count = tf.math.maximum(0, bucket_count) data = tf.reshape(data, shape=[-1]) # flatten data = tf.cast(data, tf.float64) - is_empty = tf.equal(tf.size(input=data), 0) + data_size = tf.size(input=data) + is_empty = tf.logical_or( + tf.equal(data_size, 0), tf.less_equal(bucket_count, 0) + ) def when_empty(): - return tf.constant([], shape=(0, 3), dtype=tf.float64) + """When input data is empty or bucket_count is zero. + + 1. If bucket_count is specified as zero, an empty tensor of shape + (0, 3) will be returned. + 2. If the input data is empty, a tensor of shape (bucket_count, 3) + of all zero values will be returned. + """ + return tf.zeros((bucket_count, 3), dtype=tf.float64) - # TODO(ytjing): Make the nonempty case handling TPU compatible. def when_nonempty(): min_ = tf.reduce_min(input_tensor=data) max_ = tf.reduce_max(input_tensor=data) range_ = max_ - min_ - is_singular = tf.equal(range_, 0) + has_single_value = tf.equal(range_, 0) - def when_nonsingular(): + def when_multiple_values(): + """When input data contains multiple values.""" bucket_width = range_ / tf.cast(bucket_count, tf.float64) offsets = data - min_ bucket_indices = tf.cast( @@ -465,17 +478,22 @@ def when_nonsingular(): a=tf.stack([left_edges, right_edges, bucket_counts]) ) - def when_singular(): - center = min_ - bucket_starts = tf.stack([center - 0.5]) - bucket_ends = tf.stack([center + 0.5]) - bucket_counts = tf.stack( - [tf.cast(tf.size(input=data), tf.float64)] - ) - return tf.transpose( - a=tf.stack([bucket_starts, bucket_ends, bucket_counts]) + def when_single_value(): + """When input data contains a single unique value.""" + # Left and right edges are the same for single value input. + edges = tf.fill([bucket_count], max_) + # Bucket counts are 0 except the last bucket (if bucket_count > 0), + # which is `data_size`. Ensure that the resulting counts vector has + # length `bucket_count` always, including the bucket_count==0 case. + zeroes = tf.fill([bucket_count], 0) + bucket_counts = tf.cast( + tf.concat([zeroes[:-1], [data_size]], 0)[:bucket_count], + dtype=tf.float64, ) + return tf.transpose(a=tf.stack([edges, edges, bucket_counts])) - return tf.cond(is_singular, when_singular, when_nonsingular) + return tf.cond( + has_single_value, when_single_value, when_multiple_values + ) return tf.cond(is_empty, when_empty, when_nonempty)