From 05d8c78225012552baf0a6e8f8ee7755243a8297 Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Fri, 1 Oct 2021 14:23:40 +0000 Subject: [PATCH 01/12] rename functions to avoid confusion Use `single_value` here, `singular` has unrelated mathematical meaning: https://en.wikipedia.org/wiki/Singularity_(mathematics). --- tensorboard/plugins/histogram/summary_v2.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index d4ed4b76ff9..68b14b32a14 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -431,14 +431,14 @@ def _buckets_v3(data, bucket_count=None): def when_empty(): return tf.constant([], shape=(0, 3), dtype=tf.float64) - # TODO(ytjing): Make the nonempty case handling TPU compatible. def when_nonempty(): min_ = tf.reduce_min(input_tensor=data) max_ = tf.reduce_max(input_tensor=data) range_ = max_ - min_ - is_singular = tf.equal(range_, 0) + has_single_value = tf.equal(range_, 0) - def when_nonsingular(): + def when_multiple_values(): + """When input data contains multiple values.""" bucket_width = range_ / tf.cast(bucket_count, tf.float64) offsets = data - min_ bucket_indices = tf.cast( @@ -465,7 +465,8 @@ def when_nonsingular(): a=tf.stack([left_edges, right_edges, bucket_counts]) ) - def when_singular(): + def when_single_value(): + """When input data contains a single unique value.""" center = min_ bucket_starts = tf.stack([center - 0.5]) bucket_ends = tf.stack([center + 0.5]) @@ -476,6 +477,8 @@ def when_singular(): a=tf.stack([bucket_starts, bucket_ends, bucket_counts]) ) - return tf.cond(is_singular, when_singular, when_nonsingular) + return tf.cond( + has_single_value, when_single_value, when_multiple_values + ) return tf.cond(is_empty, when_empty, when_nonempty) From 3ea37da44a6fd121465b1a555edbec0a443ad84d Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Fri, 1 Oct 2021 15:29:13 +0000 Subject: [PATCH 02/12] push the single_value test cases down to all test cases Since v3 change the outcome for single value input, this enables the customization of these test cases. --- tensorboard/plugins/histogram/summary_test.py | 60 +++++++++++++++---- 1 file changed, 50 insertions(+), 10 deletions(-) diff --git a/tensorboard/plugins/histogram/summary_test.py b/tensorboard/plugins/histogram/summary_test.py index ba087d6c58d..6a48e9e7be3 100644 --- a/tensorboard/plugins/histogram/summary_test.py +++ b/tensorboard/plugins/histogram/summary_test.py @@ -70,16 +70,6 @@ def test_empty_input_of_high_rank(self): buckets = tensor_util.make_ndarray(pb.value[0].tensor) np.testing.assert_allclose(buckets, np.array([]).reshape((0, 3))) - def test_singleton_input(self): - pb = self.histogram("twelve", [12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) - - def test_input_with_all_same_values(self): - pb = self.histogram("twelven", [12, 12, 12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) - def test_fixed_input(self): pass # TODO: test a small fixed input @@ -145,6 +135,16 @@ def test_tag(self): "a/b/histogram_summary", self.histogram("a/b", []).value[0].tag ) + def test_singleton_input(self): + pb = self.histogram("twelve", [12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) + + def test_input_with_all_same_values(self): + pb = self.histogram("twelven", [12, 12, 12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) + class SummaryV1OpTest(SummaryBaseTest, tf.test.TestCase): def histogram(self, *args, **kwargs): @@ -170,11 +170,31 @@ def test_scoped_tag(self): self.histogram("a", []).value[0].tag, ) + def test_singleton_input(self): + pb = self.histogram("twelve", [12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) + + def test_input_with_all_same_values(self): + pb = self.histogram("twelven", [12, 12, 12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) + class SummaryV2PbTest(SummaryBaseTest, tf.test.TestCase): def histogram(self, *args, **kwargs): return summary.histogram_pb(*args, **kwargs) + def test_singleton_input(self): + pb = self.histogram("twelve", [12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) + + def test_input_with_all_same_values(self): + pb = self.histogram("twelven", [12, 12, 12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) + class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase): def setUp(self): @@ -226,6 +246,16 @@ def test_default_step(self): # Reset to default state for other tests. tf2.summary.experimental.set_step(None) + def test_singleton_input(self): + pb = self.histogram("twelve", [12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) + + def test_input_with_all_same_values(self): + pb = self.histogram("twelven", [12, 12, 12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) + class SummaryV2OpGraphTest(SummaryV2OpTest, tf.test.TestCase): def write_histogram_event(self, *args, **kwargs): @@ -263,6 +293,16 @@ def graph_fn(): with writer.as_default(): graph_fn.get_concrete_function() + def test_singleton_input(self): + pb = self.histogram("twelve", [12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) + + def test_input_with_all_same_values(self): + pb = self.histogram("twelven", [12, 12, 12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) + if __name__ == "__main__": tf.test.main() From f0815dcba6d920df3ad3e7b78b8d395268629260 Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Fri, 1 Oct 2021 15:51:23 +0000 Subject: [PATCH 03/12] add v3 implementation for single value input --- tensorboard/plugins/histogram/summary.py | 3 +++ tensorboard/plugins/histogram/summary_v2.py | 21 +++++++++++++-------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tensorboard/plugins/histogram/summary.py b/tensorboard/plugins/histogram/summary.py index 72503c6f9a0..a889fcbefe2 100644 --- a/tensorboard/plugins/histogram/summary.py +++ b/tensorboard/plugins/histogram/summary.py @@ -39,6 +39,9 @@ histogram = summary_v2.histogram histogram_pb = summary_v2.histogram_pb +# Export V3 versions. +histogram_v3 = summary_v2.histogram_v3 + def _buckets(data, bucket_count=None): """Create a TensorFlow op to group data into histogram buckets. diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index 68b14b32a14..9b70557fb1d 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -426,7 +426,8 @@ def _buckets_v3(data, bucket_count=None): tf.debugging.assert_type(bucket_count, tf.int32) data = tf.reshape(data, shape=[-1]) # flatten data = tf.cast(data, tf.float64) - is_empty = tf.equal(tf.size(input=data), 0) + data_size = tf.size(input=data) + is_empty = tf.equal(data_size, 0) def when_empty(): return tf.constant([], shape=(0, 3), dtype=tf.float64) @@ -467,15 +468,19 @@ def when_multiple_values(): def when_single_value(): """When input data contains a single unique value.""" - center = min_ - bucket_starts = tf.stack([center - 0.5]) - bucket_ends = tf.stack([center + 0.5]) - bucket_counts = tf.stack( - [tf.cast(tf.size(input=data), tf.float64)] + # Left and right edges are the same for single value input. + edges = tf.fill([bucket_count], max_) + # Counts for the first {bucket_count - 1} buckets [v, v) are 0. + zero_bucket_counts = tf.constant( + [0] * (bucket_count - 1), dtype=tf.int32 ) - return tf.transpose( - a=tf.stack([bucket_starts, bucket_ends, bucket_counts]) + # Count for last bucket [v, v] is {data_size}. + last_bucket_count = tf.expand_dims(data_size, 0) + bucket_counts = tf.cast( + tf.concat([zero_bucket_counts, last_bucket_count], 0), + dtype=tf.float64, ) + return tf.transpose(a=tf.stack([edges, edges, bucket_counts])) return tf.cond( has_single_value, when_single_value, when_multiple_values From 6774bea798da55a76e3be62e9d37176b8a0c5d29 Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Fri, 1 Oct 2021 15:52:34 +0000 Subject: [PATCH 04/12] add test got v3 graph ops More tests for v3 pb ops will be added after histogram_pb_v3 is implemented. --- tensorboard/plugins/histogram/summary_test.py | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/tensorboard/plugins/histogram/summary_test.py b/tensorboard/plugins/histogram/summary_test.py index 6a48e9e7be3..a33f93bd2a1 100644 --- a/tensorboard/plugins/histogram/summary_test.py +++ b/tensorboard/plugins/histogram/summary_test.py @@ -304,5 +304,88 @@ def test_input_with_all_same_values(self): np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) +class SummaryV3OpTest(SummaryBaseTest, tf.test.TestCase): + def setUp(self): + super(SummaryV3OpTest, self).setUp() + if tf2 is None: + self.skipTest("v2 summary API not available") + + def histogram(self, *args, **kwargs): + return self.histogram_event(*args, **kwargs).summary + + def histogram_event(self, *args, **kwargs): + self.write_histogram_event(*args, **kwargs) + event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), "*"))) + self.assertEqual(len(event_files), 1) + events = list(tf.compat.v1.train.summary_iterator(event_files[0])) + # Expect a boilerplate event for the file_version, then the summary one. + self.assertEqual(len(events), 2) + # Delete the event file to reset to an empty directory for later calls. + # TODO(nickfelt): use a unique subdirectory per writer instead. + os.remove(event_files[0]) + return events[1] + + def write_histogram_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + summary.histogram(*args, **kwargs) + writer.close() + + +class SummaryV3OpGraphTest(SummaryV3OpTest, tf.test.TestCase): + def write_histogram_event(self, *args, **kwargs): + kwargs.setdefault("step", 1) + # Hack to extract current scope since there's no direct API for it. + with tf.name_scope("_") as temp_scope: + scope = temp_scope.rstrip("/_") + + @tf2.function + def graph_fn(): + # Recreate the active scope inside the defun since it won't propagate. + with tf.name_scope(scope): + summary.histogram_v3(*args, **kwargs) + + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + graph_fn() + writer.close() + + def test_no_gradient_error_xla(self): + @tf2.function(jit_compile=True) + def graph_fn(): + x = tf.constant(1.0) + with tf2.GradientTape() as tape1: + with tf2.GradientTape() as tape2: + tape1.watch(x) + tape2.watch(x) + summary.histogram_v3( + name="loss", step=0, data=x, buckets=10 + ) + + # Note that XLA CPU/GPU has no outside compilation support, so summaries + # won't actually run in a jit_compiled function. TPUs do, and follow + # some similar codepaths, so this test stops at graph building to + # exercise those paths without a TPU available. + writer = tf2.summary.create_file_writer(self.get_temp_dir()) + with writer.as_default(): + graph_fn.get_concrete_function() + + def test_singleton_input(self): + pb = self.histogram("twelve", [12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + # By default there will be 30 buckets. + expected_buckets = np.array( + [[12, 12, 0] for _ in range(29)] + [[12, 12, 1]] + ) + np.testing.assert_allclose(buckets, expected_buckets) + + def test_input_with_all_same_values(self): + pb = self.histogram("twelven", [12, 12, 12], buckets=2) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + expected_buckets = np.array([[12, 12, 0], [12, 12, 3]]) + np.testing.assert_allclose(buckets, expected_buckets) + + if __name__ == "__main__": tf.test.main() From fffb21b5d150e7bd6d18fb509b87481adac72246 Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Fri, 1 Oct 2021 22:40:42 +0000 Subject: [PATCH 05/12] Revert "push the single_value test cases down to all test cases" This reverts commit be2213cac544d17a9cec65e9e3b37a4c8a59d5d2. --- tensorboard/plugins/histogram/summary_test.py | 60 ++++--------------- 1 file changed, 10 insertions(+), 50 deletions(-) diff --git a/tensorboard/plugins/histogram/summary_test.py b/tensorboard/plugins/histogram/summary_test.py index a33f93bd2a1..75350c214a6 100644 --- a/tensorboard/plugins/histogram/summary_test.py +++ b/tensorboard/plugins/histogram/summary_test.py @@ -70,6 +70,16 @@ def test_empty_input_of_high_rank(self): buckets = tensor_util.make_ndarray(pb.value[0].tensor) np.testing.assert_allclose(buckets, np.array([]).reshape((0, 3))) + def test_singleton_input(self): + pb = self.histogram("twelve", [12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) + + def test_input_with_all_same_values(self): + pb = self.histogram("twelven", [12, 12, 12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) + def test_fixed_input(self): pass # TODO: test a small fixed input @@ -135,16 +145,6 @@ def test_tag(self): "a/b/histogram_summary", self.histogram("a/b", []).value[0].tag ) - def test_singleton_input(self): - pb = self.histogram("twelve", [12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) - - def test_input_with_all_same_values(self): - pb = self.histogram("twelven", [12, 12, 12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) - class SummaryV1OpTest(SummaryBaseTest, tf.test.TestCase): def histogram(self, *args, **kwargs): @@ -170,31 +170,11 @@ def test_scoped_tag(self): self.histogram("a", []).value[0].tag, ) - def test_singleton_input(self): - pb = self.histogram("twelve", [12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) - - def test_input_with_all_same_values(self): - pb = self.histogram("twelven", [12, 12, 12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) - class SummaryV2PbTest(SummaryBaseTest, tf.test.TestCase): def histogram(self, *args, **kwargs): return summary.histogram_pb(*args, **kwargs) - def test_singleton_input(self): - pb = self.histogram("twelve", [12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) - - def test_input_with_all_same_values(self): - pb = self.histogram("twelven", [12, 12, 12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) - class SummaryV2OpTest(SummaryBaseTest, tf.test.TestCase): def setUp(self): @@ -246,16 +226,6 @@ def test_default_step(self): # Reset to default state for other tests. tf2.summary.experimental.set_step(None) - def test_singleton_input(self): - pb = self.histogram("twelve", [12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) - - def test_input_with_all_same_values(self): - pb = self.histogram("twelven", [12, 12, 12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) - class SummaryV2OpGraphTest(SummaryV2OpTest, tf.test.TestCase): def write_histogram_event(self, *args, **kwargs): @@ -293,16 +263,6 @@ def graph_fn(): with writer.as_default(): graph_fn.get_concrete_function() - def test_singleton_input(self): - pb = self.histogram("twelve", [12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 1]])) - - def test_input_with_all_same_values(self): - pb = self.histogram("twelven", [12, 12, 12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([[11.5, 12.5, 3]])) - class SummaryV3OpTest(SummaryBaseTest, tf.test.TestCase): def setUp(self): From 5bf44ade04d78f209199587e343fc97504620fcf Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Fri, 1 Oct 2021 23:46:56 +0000 Subject: [PATCH 06/12] check if bucket_count <= 0 and fix tf ops --- tensorboard/plugins/histogram/summary_v2.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index 9b70557fb1d..012deb403ff 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -427,7 +427,11 @@ def _buckets_v3(data, bucket_count=None): data = tf.reshape(data, shape=[-1]) # flatten data = tf.cast(data, tf.float64) data_size = tf.size(input=data) - is_empty = tf.equal(data_size, 0) + # If bucket_count is specified as zero, an empty tensor of shape + # (0, 3) will be returned. + is_empty = tf.math.logical_or( + tf.equal(data_size, 0), tf.constant([bucket_count <= 0]) + ) def when_empty(): return tf.constant([], shape=(0, 3), dtype=tf.float64) @@ -471,13 +475,10 @@ def when_single_value(): # Left and right edges are the same for single value input. edges = tf.fill([bucket_count], max_) # Counts for the first {bucket_count - 1} buckets [v, v) are 0. - zero_bucket_counts = tf.constant( - [0] * (bucket_count - 1), dtype=tf.int32 - ) + zero_bucket_counts = tf.repeat([0], bucket_count - 1) # Count for last bucket [v, v] is {data_size}. - last_bucket_count = tf.expand_dims(data_size, 0) bucket_counts = tf.cast( - tf.concat([zero_bucket_counts, last_bucket_count], 0), + tf.concat([zero_bucket_counts, [data_size]], 0), dtype=tf.float64, ) return tf.transpose(a=tf.stack([edges, edges, bucket_counts])) From 3f4cfd3309a32925c095f6450b05d571179de2fd Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Fri, 1 Oct 2021 23:47:33 +0000 Subject: [PATCH 07/12] Make SummaryV3OpGraphTest inherit from V2 test case and add test for zero bucket count --- tensorboard/plugins/histogram/summary_test.py | 49 +++++++------------ 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/tensorboard/plugins/histogram/summary_test.py b/tensorboard/plugins/histogram/summary_test.py index 75350c214a6..27259f62bbe 100644 --- a/tensorboard/plugins/histogram/summary_test.py +++ b/tensorboard/plugins/histogram/summary_test.py @@ -201,9 +201,12 @@ def write_histogram_event(self, *args, **kwargs): kwargs.setdefault("step", 1) writer = tf2.summary.create_file_writer(self.get_temp_dir()) with writer.as_default(): - summary.histogram(*args, **kwargs) + self.call_histogram_op(*args, **kwargs) writer.close() + def call_histogram_op(self, *args, **kwargs): + summary.histogram(*args, **kwargs) + def test_scoped_tag(self): with tf.name_scope("scope"): self.assertEqual("scope/a", self.histogram("a", []).value[0].tag) @@ -264,36 +267,10 @@ def graph_fn(): graph_fn.get_concrete_function() -class SummaryV3OpTest(SummaryBaseTest, tf.test.TestCase): - def setUp(self): - super(SummaryV3OpTest, self).setUp() - if tf2 is None: - self.skipTest("v2 summary API not available") +class SummaryV3OpGraphTest(SummaryV2OpTest, tf.test.TestCase): + def call_histogram_op(self, *args, **kwargs): + summary.histogram_v3(*args, **kwargs) - def histogram(self, *args, **kwargs): - return self.histogram_event(*args, **kwargs).summary - - def histogram_event(self, *args, **kwargs): - self.write_histogram_event(*args, **kwargs) - event_files = sorted(glob.glob(os.path.join(self.get_temp_dir(), "*"))) - self.assertEqual(len(event_files), 1) - events = list(tf.compat.v1.train.summary_iterator(event_files[0])) - # Expect a boilerplate event for the file_version, then the summary one. - self.assertEqual(len(events), 2) - # Delete the event file to reset to an empty directory for later calls. - # TODO(nickfelt): use a unique subdirectory per writer instead. - os.remove(event_files[0]) - return events[1] - - def write_histogram_event(self, *args, **kwargs): - kwargs.setdefault("step", 1) - writer = tf2.summary.create_file_writer(self.get_temp_dir()) - with writer.as_default(): - summary.histogram(*args, **kwargs) - writer.close() - - -class SummaryV3OpGraphTest(SummaryV3OpTest, tf.test.TestCase): def write_histogram_event(self, *args, **kwargs): kwargs.setdefault("step", 1) # Hack to extract current scope since there's no direct API for it. @@ -341,11 +318,19 @@ def test_singleton_input(self): np.testing.assert_allclose(buckets, expected_buckets) def test_input_with_all_same_values(self): - pb = self.histogram("twelven", [12, 12, 12], buckets=2) + pb = self.histogram("twelven", [12, 12, 12]) buckets = tensor_util.make_ndarray(pb.value[0].tensor) - expected_buckets = np.array([[12, 12, 0], [12, 12, 3]]) + # By default there will be 30 buckets. + expected_buckets = np.array( + [[12, 12, 0] for _ in range(29)] + [[12, 12, 3]] + ) np.testing.assert_allclose(buckets, expected_buckets) + def test_zero_bucket_count(self): + pb = self.histogram("zero_bucket_count", [1, 1, 1], buckets=0) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_allclose(buckets, np.array([]).reshape((0, 3))) + if __name__ == "__main__": tf.test.main() From 9e186648b3d15b67b1afaf8b4d5b454eafaba7d9 Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Mon, 4 Oct 2021 21:34:15 +0000 Subject: [PATCH 08/12] use an alternative (tf.fill) to be consistent --- tensorboard/plugins/histogram/summary_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index 012deb403ff..4ee3fc75904 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -475,7 +475,7 @@ def when_single_value(): # Left and right edges are the same for single value input. edges = tf.fill([bucket_count], max_) # Counts for the first {bucket_count - 1} buckets [v, v) are 0. - zero_bucket_counts = tf.repeat([0], bucket_count - 1) + zero_bucket_counts = tf.fill([bucket_count - 1], 0) # Count for last bucket [v, v] is {data_size}. bucket_counts = tf.cast( tf.concat([zero_bucket_counts, [data_size]], 0), From 3e72bde0597d0fec353691d3e0c4fe0e7555b3da Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Tue, 5 Oct 2021 22:14:16 +0000 Subject: [PATCH 09/12] distinguish zero bucket count case v.s. empty input data case --- tensorboard/plugins/histogram/summary_v2.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index 4ee3fc75904..bff980492bf 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -412,7 +412,8 @@ def _buckets_v3(data, bucket_count=None): Arguments: data: A `Tensor` of any shape. Must be castable to `float64`. - bucket_count: Optional positive `int` or scalar `int32` `Tensor`. + bucket_count: Optional non-negative `int` or scalar `int32` `Tensor`, + defaults to 30. Returns: A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is a triple `[left_edge, right_edge, count]` for a single bucket. @@ -427,14 +428,19 @@ def _buckets_v3(data, bucket_count=None): data = tf.reshape(data, shape=[-1]) # flatten data = tf.cast(data, tf.float64) data_size = tf.size(input=data) - # If bucket_count is specified as zero, an empty tensor of shape - # (0, 3) will be returned. - is_empty = tf.math.logical_or( - tf.equal(data_size, 0), tf.constant([bucket_count <= 0]) + is_empty = tf.logical_or( + tf.equal(data_size, 0), tf.less_equal(bucket_count, 0) ) def when_empty(): - return tf.constant([], shape=(0, 3), dtype=tf.float64) + """When input data is empty or bucket_count is zero. + + 1. If bucket_count is specified as zero, an empty tensor of shape + (0, 3) will be returned. + 2. If the input data is empty, a tensor of shape (bucket_count, 3) + of all zero values will be returned. + """ + return tf.zeros((max(0, bucket_count), 3), dtype=tf.float64) def when_nonempty(): min_ = tf.reduce_min(input_tensor=data) From 9d328427d3eb901c5ca5f13eccce6aecb6a24eff Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Tue, 5 Oct 2021 22:15:03 +0000 Subject: [PATCH 10/12] add SummaryV3OpGraphTest test case and update tests TODO: figure out why `test_zero_bucket_count` doesn't work in the v3 graph test case (works in v3 op test case). --- tensorboard/plugins/histogram/summary_test.py | 72 ++++++++++++------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/tensorboard/plugins/histogram/summary_test.py b/tensorboard/plugins/histogram/summary_test.py index 27259f62bbe..d228dd80e39 100644 --- a/tensorboard/plugins/histogram/summary_test.py +++ b/tensorboard/plugins/histogram/summary_test.py @@ -241,7 +241,7 @@ def write_histogram_event(self, *args, **kwargs): def graph_fn(): # Recreate the active scope inside the defun since it won't propagate. with tf.name_scope(scope): - summary.histogram(*args, **kwargs) + self.call_histogram_op(*args, **kwargs) writer = tf2.summary.create_file_writer(self.get_temp_dir()) with writer.as_default(): @@ -256,7 +256,9 @@ def graph_fn(): with tf2.GradientTape() as tape2: tape1.watch(x) tape2.watch(x) - summary.histogram(name="loss", step=0, data=x, buckets=10) + self.call_histogram_op( + name="loss", step=0, data=x, buckets=10 + ) # Note that XLA CPU/GPU has no outside compilation support, so summaries # won't actually run in a jit_compiled function. TPUs do, and follow @@ -267,10 +269,47 @@ def graph_fn(): graph_fn.get_concrete_function() -class SummaryV3OpGraphTest(SummaryV2OpTest, tf.test.TestCase): +class SummaryV3OpTest(SummaryV2OpTest, tf.test.TestCase): def call_histogram_op(self, *args, **kwargs): summary.histogram_v3(*args, **kwargs) + def test_singleton_input(self): + pb = self.histogram("twelve", [12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + # By default there will be 30 buckets. + expected_buckets = np.array( + [[12, 12, 0] for _ in range(29)] + [[12, 12, 1]] + ) + np.testing.assert_allclose(buckets, expected_buckets) + + def test_input_with_all_same_values(self): + pb = self.histogram("twelven", [12, 12, 12]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + # By default there will be 30 buckets. + expected_buckets = np.array( + [[12, 12, 0] for _ in range(29)] + [[12, 12, 3]] + ) + np.testing.assert_allclose(buckets, expected_buckets) + + def test_empty_input(self): + pb = self.histogram("empty", []) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + # By default there will be 30 buckets. + np.testing.assert_allclose(buckets, np.zeros((30, 3))) + + def test_empty_input_of_high_rank(self): + pb = self.histogram("empty_but_fancy", [[[], []], [[], []]]) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + # By default there will be 30 buckets. + np.testing.assert_allclose(buckets, np.zeros((30, 3))) + + def test_zero_bucket_count(self): + pb = self.histogram("zero_bucket_count", [1, 1, 1], buckets=0) + buckets = tensor_util.make_ndarray(pb.value[0].tensor) + np.testing.assert_array_equal(buckets, np.array([]).reshape((0, 3))) + + +class SummaryV3OpGraphTest(SummaryV3OpTest, tf.test.TestCase): def write_histogram_event(self, *args, **kwargs): kwargs.setdefault("step", 1) # Hack to extract current scope since there's no direct API for it. @@ -281,7 +320,7 @@ def write_histogram_event(self, *args, **kwargs): def graph_fn(): # Recreate the active scope inside the defun since it won't propagate. with tf.name_scope(scope): - summary.histogram_v3(*args, **kwargs) + self.call_histogram_op(*args, **kwargs) writer = tf2.summary.create_file_writer(self.get_temp_dir()) with writer.as_default(): @@ -296,7 +335,7 @@ def graph_fn(): with tf2.GradientTape() as tape2: tape1.watch(x) tape2.watch(x) - summary.histogram_v3( + self.call_histogram_op( name="loss", step=0, data=x, buckets=10 ) @@ -308,28 +347,13 @@ def graph_fn(): with writer.as_default(): graph_fn.get_concrete_function() - def test_singleton_input(self): - pb = self.histogram("twelve", [12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - # By default there will be 30 buckets. - expected_buckets = np.array( - [[12, 12, 0] for _ in range(29)] + [[12, 12, 1]] - ) - np.testing.assert_allclose(buckets, expected_buckets) - - def test_input_with_all_same_values(self): - pb = self.histogram("twelven", [12, 12, 12]) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - # By default there will be 30 buckets. - expected_buckets = np.array( - [[12, 12, 0] for _ in range(29)] + [[12, 12, 3]] - ) - np.testing.assert_allclose(buckets, expected_buckets) - def test_zero_bucket_count(self): + self.skipTest( + "TODO: figure out why this doesn't work in graph test case" + ) pb = self.histogram("zero_bucket_count", [1, 1, 1], buckets=0) buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_allclose(buckets, np.array([]).reshape((0, 3))) + np.testing.assert_array_equal(buckets, np.array([]).reshape((0, 3))) if __name__ == "__main__": From 72cb90b8a2275af630128e04d1fdfb488360de88 Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Wed, 6 Oct 2021 23:59:28 +0000 Subject: [PATCH 11/12] make bucket_count a variable before tf.cond op Move tf.math.maximum() to the top and avoid compile time shape inference that fails the conditional branch that isn't supposed to be execute when bucket_count is 0. --- tensorboard/plugins/histogram/summary_test.py | 11 ++--------- tensorboard/plugins/histogram/summary_v2.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/tensorboard/plugins/histogram/summary_test.py b/tensorboard/plugins/histogram/summary_test.py index d228dd80e39..b90e0af8935 100644 --- a/tensorboard/plugins/histogram/summary_test.py +++ b/tensorboard/plugins/histogram/summary_test.py @@ -316,7 +316,7 @@ def write_histogram_event(self, *args, **kwargs): with tf.name_scope("_") as temp_scope: scope = temp_scope.rstrip("/_") - @tf2.function + @tf2.function(autograph=True) def graph_fn(): # Recreate the active scope inside the defun since it won't propagate. with tf.name_scope(scope): @@ -324,6 +324,7 @@ def graph_fn(): writer = tf2.summary.create_file_writer(self.get_temp_dir()) with writer.as_default(): + # print(tf2.autograph.to_code(graph_fn.python_function)) graph_fn() writer.close() @@ -347,14 +348,6 @@ def graph_fn(): with writer.as_default(): graph_fn.get_concrete_function() - def test_zero_bucket_count(self): - self.skipTest( - "TODO: figure out why this doesn't work in graph test case" - ) - pb = self.histogram("zero_bucket_count", [1, 1, 1], buckets=0) - buckets = tensor_util.make_ndarray(pb.value[0].tensor) - np.testing.assert_array_equal(buckets, np.array([]).reshape((0, 3))) - if __name__ == "__main__": tf.test.main() diff --git a/tensorboard/plugins/histogram/summary_v2.py b/tensorboard/plugins/histogram/summary_v2.py index bff980492bf..e1cd01846db 100644 --- a/tensorboard/plugins/histogram/summary_v2.py +++ b/tensorboard/plugins/histogram/summary_v2.py @@ -425,6 +425,8 @@ def _buckets_v3(data, bucket_count=None): with tf.name_scope("buckets"): tf.debugging.assert_scalar(bucket_count) tf.debugging.assert_type(bucket_count, tf.int32) + # Treat a negative bucket count as zero. + bucket_count = tf.math.maximum(0, bucket_count) data = tf.reshape(data, shape=[-1]) # flatten data = tf.cast(data, tf.float64) data_size = tf.size(input=data) @@ -440,7 +442,7 @@ def when_empty(): 2. If the input data is empty, a tensor of shape (bucket_count, 3) of all zero values will be returned. """ - return tf.zeros((max(0, bucket_count), 3), dtype=tf.float64) + return tf.zeros((bucket_count, 3), dtype=tf.float64) def when_nonempty(): min_ = tf.reduce_min(input_tensor=data) @@ -480,11 +482,12 @@ def when_single_value(): """When input data contains a single unique value.""" # Left and right edges are the same for single value input. edges = tf.fill([bucket_count], max_) - # Counts for the first {bucket_count - 1} buckets [v, v) are 0. - zero_bucket_counts = tf.fill([bucket_count - 1], 0) - # Count for last bucket [v, v] is {data_size}. + # Bucket counts are 0 except the last bucket (if bucket_count > 0), + # which is `data_size`. Ensure that the resulting counts vector has + # length `bucket_count` always, including the bucket_count==0 case. + zeroes = tf.fill([bucket_count], 0) bucket_counts = tf.cast( - tf.concat([zero_bucket_counts, [data_size]], 0), + tf.concat([zeroes[:-1], [data_size]], 0)[:bucket_count], dtype=tf.float64, ) return tf.transpose(a=tf.stack([edges, edges, bucket_counts])) From 2d5fafecca843d8d95ff98e0209df5c99819a0c1 Mon Sep 17 00:00:00 2001 From: Yating Jing Date: Thu, 7 Oct 2021 15:14:57 +0000 Subject: [PATCH 12/12] remove debug lines added by accident --- tensorboard/plugins/histogram/summary_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorboard/plugins/histogram/summary_test.py b/tensorboard/plugins/histogram/summary_test.py index b90e0af8935..aceae45b50b 100644 --- a/tensorboard/plugins/histogram/summary_test.py +++ b/tensorboard/plugins/histogram/summary_test.py @@ -316,7 +316,7 @@ def write_histogram_event(self, *args, **kwargs): with tf.name_scope("_") as temp_scope: scope = temp_scope.rstrip("/_") - @tf2.function(autograph=True) + @tf2.function def graph_fn(): # Recreate the active scope inside the defun since it won't propagate. with tf.name_scope(scope): @@ -324,7 +324,6 @@ def graph_fn(): writer = tf2.summary.create_file_writer(self.get_temp_dir()) with writer.as_default(): - # print(tf2.autograph.to_code(graph_fn.python_function)) graph_fn() writer.close()