aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-20 13:44:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-20 13:48:22 -0700
commit32eb07bf7b4cf5c9f5ee14e1f4cbe18b1eba6c4d (patch)
tree281885bb2f9b9e26f5a516ab0511934e86bcfbb5 /tensorflow/contrib/summary
parentd2d9a6c7cc3b4f8c068054082a0fa2f2b95bb3d6 (diff)
Simplify the graph generated for contrib/summaries in the
"always summarize" and "never summarize" cases by skipping the `cond`. PiperOrigin-RevId: 172928083
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r--tensorflow/contrib/summary/BUILD2
-rw-r--r--tensorflow/contrib/summary/summary_ops.py11
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py4
3 files changed, 8 insertions, 9 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD
index d09ad48e10..bcb2d74b4a 100644
--- a/tensorflow/contrib/summary/BUILD
+++ b/tensorflow/contrib/summary/BUILD
@@ -43,9 +43,9 @@ py_library(
deps = [
":gen_summary_ops",
"//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:layers_base",
"//tensorflow/python:summary_op_util",
"//tensorflow/python:training",
"//tensorflow/python/eager:context",
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index c8d0c14e19..ba3619bfc9 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -24,11 +24,10 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.layers import utils
from tensorflow.python.ops import summary_op_util
from tensorflow.python.training import training_util
-
# Name for a collection which is expected to have at most a single boolean
# Tensor. If this tensor is True the summary ops will record summaries.
_SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries"
@@ -38,7 +37,7 @@ def should_record_summaries():
"""Returns boolean Tensor which is true if summaries should be recorded."""
should_record_collection = ops.get_collection(_SHOULD_RECORD_SUMMARIES_NAME)
if not should_record_collection:
- return constant_op.constant(False)
+ return False
if len(should_record_collection) != 1:
raise ValueError(
"More than one tensor specified for whether summaries "
@@ -56,13 +55,13 @@ def record_summaries_every_n_global_steps(n):
def always_record_summaries():
"""Sets the should_record_summaries Tensor to always true."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
- collection_ref[:] = [constant_op.constant(True)]
+ collection_ref[:] = [True]
def never_record_summaries():
"""Sets the should_record_summaries Tensor to always false."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
- collection_ref[:] = [constant_op.constant(False)]
+ collection_ref[:] = [False]
def create_summary_file_writer(logdir,
@@ -106,7 +105,7 @@ def summary_writer_function(name, tensor, function, family=None):
function(tag, scope)
return True
- return control_flow_ops.cond(
+ return utils.smart_cond(
should_record_summaries(), record, _nothing, name="")
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 6958ee8dd8..2cd4fce5b3 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -40,9 +40,9 @@ class TargetTest(test_util.TensorFlowTestCase):
summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0')
def testShouldRecordSummary(self):
- self.assertFalse(summary_ops.should_record_summaries().numpy())
+ self.assertFalse(summary_ops.should_record_summaries())
summary_ops.always_record_summaries()
- self.assertTrue(summary_ops.should_record_summaries().numpy())
+ self.assertTrue(summary_ops.should_record_summaries())
def testSummaryOps(self):
training_util.get_or_create_global_step()