diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-20 13:44:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-20 13:48:22 -0700 |
commit | 32eb07bf7b4cf5c9f5ee14e1f4cbe18b1eba6c4d (patch) | |
tree | 281885bb2f9b9e26f5a516ab0511934e86bcfbb5 /tensorflow/contrib/summary | |
parent | d2d9a6c7cc3b4f8c068054082a0fa2f2b95bb3d6 (diff) |
Simplify the graph generated for contrib/summaries in the
"always summarize" and "never summarize" cases by skipping the `cond`.
PiperOrigin-RevId: 172928083
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r-- | tensorflow/contrib/summary/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 4 |
3 files changed, 8 insertions, 9 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index d09ad48e10..bcb2d74b4a 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -43,9 +43,9 @@ py_library( deps = [ ":gen_summary_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:layers_base", "//tensorflow/python:summary_op_util", "//tensorflow/python:training", "//tensorflow/python/eager:context", diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index c8d0c14e19..ba3619bfc9 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -24,11 +24,10 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import control_flow_ops +from tensorflow.python.layers import utils from tensorflow.python.ops import summary_op_util from tensorflow.python.training import training_util - # Name for a collection which is expected to have at most a single boolean # Tensor. If this tensor is True the summary ops will record summaries. _SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries" @@ -38,7 +37,7 @@ def should_record_summaries(): """Returns boolean Tensor which is true if summaries should be recorded.""" should_record_collection = ops.get_collection(_SHOULD_RECORD_SUMMARIES_NAME) if not should_record_collection: - return constant_op.constant(False) + return False if len(should_record_collection) != 1: raise ValueError( "More than one tensor specified for whether summaries " @@ -56,13 +55,13 @@ def record_summaries_every_n_global_steps(n): def always_record_summaries(): """Sets the should_record_summaries Tensor to always true.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) - collection_ref[:] = [constant_op.constant(True)] + collection_ref[:] = [True] def never_record_summaries(): """Sets the should_record_summaries Tensor to always false.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) - collection_ref[:] = [constant_op.constant(False)] + collection_ref[:] = [False] def create_summary_file_writer(logdir, @@ -106,7 +105,7 @@ def summary_writer_function(name, tensor, function, family=None): function(tag, scope) return True - return control_flow_ops.cond( + return utils.smart_cond( should_record_summaries(), record, _nothing, name="") diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 6958ee8dd8..2cd4fce5b3 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -40,9 +40,9 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0') def testShouldRecordSummary(self): - self.assertFalse(summary_ops.should_record_summaries().numpy()) + self.assertFalse(summary_ops.should_record_summaries()) summary_ops.always_record_summaries() - self.assertTrue(summary_ops.should_record_summaries().numpy()) + self.assertTrue(summary_ops.should_record_summaries()) def testSummaryOps(self): training_util.get_or_create_global_step() |