aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-11-20 15:50:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-20 15:54:03 -0800
commitb99ba0d749f04311d2c8e8c5843d78e427edb832 (patch)
treedc8dd569bae9a3775f09765fce2fb1fcd462bd80 /tensorflow/contrib/summary
parent901f3af1891804d6a5f211346a867dbb4167653d (diff)
Contrib summaries always try to run when inside loops or conditionals.
PiperOrigin-RevId: 176430089
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r--tensorflow/contrib/summary/BUILD12
-rw-r--r--tensorflow/contrib/summary/summary_ops.py5
-rw-r--r--tensorflow/contrib/summary/summary_ops_graph_test.py50
3 files changed, 62 insertions, 5 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD
index 3892654f25..45d6454526 100644
--- a/tensorflow/contrib/summary/BUILD
+++ b/tensorflow/contrib/summary/BUILD
@@ -47,10 +47,16 @@ py_test(
deps = [
":summary_ops",
":summary_test_internal",
+ ":summary_test_util",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:ops",
- "//tensorflow/python:platform",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:training",
+ "@six_archive//:six",
],
)
@@ -61,6 +67,7 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":gen_summary_ops",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
@@ -73,6 +80,7 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
+ "@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index 3e65f83051..8e37987cb7 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -45,7 +45,6 @@ from tensorflow.python.util import tf_contextlib
# Tensor. If this tensor is True the summary ops will record summaries.
_SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries"
-_SUMMARY_COLLECTION_NAME = "_SUMMARY_V2"
_SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2"
_EXPERIMENT_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,256}$")
@@ -298,7 +297,7 @@ def all_summary_ops():
if context.in_eager_mode():
raise RuntimeError(
"tf.contrib.summary.all_summary_ops is only supported in graph mode.")
- return ops.get_collection(_SUMMARY_COLLECTION_NAME)
+ return ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
def summary_writer_initializer_op():
@@ -340,7 +339,7 @@ def summary_writer_function(name, tensor, function, family=None):
with ops.device("cpu:0"):
op = utils.smart_cond(
should_record_summaries(), record, _nothing, name="")
- ops.add_to_collection(_SUMMARY_COLLECTION_NAME, op)
+ ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access
return op
diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py
index 8f85f67a25..fe55bf93e2 100644
--- a/tensorflow/contrib/summary/summary_ops_graph_test.py
+++ b/tensorflow/contrib/summary/summary_ops_graph_test.py
@@ -16,13 +16,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import tempfile
+
import six
from tensorflow.contrib.summary import summary_ops
from tensorflow.contrib.summary import summary_test_internal
+from tensorflow.contrib.summary import summary_test_util
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.training import training_util
@@ -47,6 +54,49 @@ class DbTest(summary_test_internal.SummaryDbTest):
six.assertCountEqual(self, [name],
get_all(self.db, 'SELECT node_name FROM Nodes'))
+ def testSummaryGraphModeCond(self):
+ with ops.Graph().as_default(), self.test_session():
+ training_util.get_or_create_global_step()
+ logdir = tempfile.mkdtemp()
+ with summary_ops.create_summary_file_writer(
+ logdir, max_queue=0,
+ name='t2').as_default(), summary_ops.always_record_summaries():
+ summary_ops.initialize()
+ training_util.get_or_create_global_step().initializer.run()
+ def f():
+ summary_ops.scalar('scalar', 2.0)
+ return constant_op.constant(True)
+ pred = array_ops.placeholder(dtypes.bool)
+ x = control_flow_ops.cond(pred, f,
+ lambda: constant_op.constant(False))
+ x.eval(feed_dict={pred: True})
+
+ events = summary_test_util.events_from_logdir(logdir)
+ self.assertEqual(len(events), 2)
+ self.assertEqual(events[1].summary.value[0].tag, 'cond/scalar')
+
+ def testSummaryGraphModeWhile(self):
+ with ops.Graph().as_default(), self.test_session():
+ training_util.get_or_create_global_step()
+ logdir = tempfile.mkdtemp()
+ with summary_ops.create_summary_file_writer(
+ logdir, max_queue=0,
+ name='t2').as_default(), summary_ops.always_record_summaries():
+ summary_ops.initialize()
+ training_util.get_or_create_global_step().initializer.run()
+ def body(unused_pred):
+ summary_ops.scalar('scalar', 2.0)
+ return constant_op.constant(False)
+ def cond(pred):
+ return pred
+ pred = array_ops.placeholder(dtypes.bool)
+ x = control_flow_ops.while_loop(cond, body, [pred])
+ x.eval(feed_dict={pred: True})
+
+ events = summary_test_util.events_from_logdir(logdir)
+ self.assertEqual(len(events), 2)
+ self.assertEqual(events[1].summary.value[0].tag, 'while/scalar')
+
if __name__ == '__main__':
test.main()