diff options
author | Alexandre Passos <apassos@google.com> | 2017-11-20 15:50:44 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-20 15:54:03 -0800 |
commit | b99ba0d749f04311d2c8e8c5843d78e427edb832 (patch) | |
tree | dc8dd569bae9a3775f09765fce2fb1fcd462bd80 /tensorflow/contrib/summary | |
parent | 901f3af1891804d6a5f211346a867dbb4167653d (diff) |
Contrib summaries always try to run when inside loops or conditionals.
PiperOrigin-RevId: 176430089
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r-- | tensorflow/contrib/summary/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_graph_test.py | 50 |
3 files changed, 62 insertions, 5 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index 3892654f25..45d6454526 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -47,10 +47,16 @@ py_test( deps = [ ":summary_ops", ":summary_test_internal", + ":summary_test_util", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:ops", - "//tensorflow/python:platform", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:training", + "@six_archive//:six", ], ) @@ -61,6 +67,7 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":gen_summary_ops", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", @@ -73,6 +80,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/eager:context", + "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index 3e65f83051..8e37987cb7 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -45,7 +45,6 @@ from tensorflow.python.util import tf_contextlib # Tensor. If this tensor is True the summary ops will record summaries. _SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries" -_SUMMARY_COLLECTION_NAME = "_SUMMARY_V2" _SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2" _EXPERIMENT_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,256}$") @@ -298,7 +297,7 @@ def all_summary_ops(): if context.in_eager_mode(): raise RuntimeError( "tf.contrib.summary.all_summary_ops is only supported in graph mode.") - return ops.get_collection(_SUMMARY_COLLECTION_NAME) + return ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access def summary_writer_initializer_op(): @@ -340,7 +339,7 @@ def summary_writer_function(name, tensor, function, family=None): with ops.device("cpu:0"): op = utils.smart_cond( should_record_summaries(), record, _nothing, name="") - ops.add_to_collection(_SUMMARY_COLLECTION_NAME, op) + ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access return op diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py index 8f85f67a25..fe55bf93e2 100644 --- a/tensorflow/contrib/summary/summary_ops_graph_test.py +++ b/tensorflow/contrib/summary/summary_ops_graph_test.py @@ -16,13 +16,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import tempfile + import six from tensorflow.contrib.summary import summary_ops from tensorflow.contrib.summary import summary_test_internal +from tensorflow.contrib.summary import summary_test_util from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test from tensorflow.python.training import training_util @@ -47,6 +54,49 @@ class DbTest(summary_test_internal.SummaryDbTest): six.assertCountEqual(self, [name], get_all(self.db, 'SELECT node_name FROM Nodes')) + def testSummaryGraphModeCond(self): + with ops.Graph().as_default(), self.test_session(): + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t2').as_default(), summary_ops.always_record_summaries(): + summary_ops.initialize() + training_util.get_or_create_global_step().initializer.run() + def f(): + summary_ops.scalar('scalar', 2.0) + return constant_op.constant(True) + pred = array_ops.placeholder(dtypes.bool) + x = control_flow_ops.cond(pred, f, + lambda: constant_op.constant(False)) + x.eval(feed_dict={pred: True}) + + events = summary_test_util.events_from_logdir(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].tag, 'cond/scalar') + + def testSummaryGraphModeWhile(self): + with ops.Graph().as_default(), self.test_session(): + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t2').as_default(), summary_ops.always_record_summaries(): + summary_ops.initialize() + training_util.get_or_create_global_step().initializer.run() + def body(unused_pred): + summary_ops.scalar('scalar', 2.0) + return constant_op.constant(False) + def cond(pred): + return pred + pred = array_ops.placeholder(dtypes.bool) + x = control_flow_ops.while_loop(cond, body, [pred]) + x.eval(feed_dict={pred: True}) + + events = summary_test_util.events_from_logdir(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].tag, 'while/scalar') + if __name__ == '__main__': test.main() |