diff options
author | Thomas Schumm <fwiffo@google.com> | 2017-03-29 10:37:47 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-29 11:47:54 -0700 |
commit | ee0d146aff8c0ace6ecf3199763b751186b3f20a (patch) | |
tree | b1faa6738923b6f136e3af1906fb80a91d83c964 /tensorflow/python/summary | |
parent | b5c3d7c5c5956857db90023c5dcee4f5fa5c93cd (diff) |
Handle optional collections parameter correctly in tf.summary.text.
Change: 151602957
Diffstat (limited to 'tensorflow/python/summary')
-rw-r--r-- | tensorflow/python/summary/text_summary.py | 2 | ||||
-rw-r--r-- | tensorflow/python/summary/text_summary_test.py | 6 |
2 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/python/summary/text_summary.py b/tensorflow/python/summary/text_summary.py index 4b744fc3f7..82dee45d26 100644 --- a/tensorflow/python/summary/text_summary.py +++ b/tensorflow/python/summary/text_summary.py @@ -59,7 +59,7 @@ def text_summary(name, tensor, collections=None): raise ValueError("Expected tensor %s to be scalar, has shape %s" % (tensor.name, tensor.shape)) - t_summary = tensor_summary(name, tensor, collections) + t_summary = tensor_summary(name, tensor, collections=collections) text_assets = plugin_asset.get_plugin_asset(TextSummaryPluginAsset) text_assets.register_tensor(t_summary.op.name) return t_summary diff --git a/tensorflow/python/summary/text_summary_test.py b/tensorflow/python/summary/text_summary_test.py index b4059778ed..69739573c1 100644 --- a/tensorflow/python/summary/text_summary_test.py +++ b/tensorflow/python/summary/text_summary_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops as framework_ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest @@ -46,6 +47,11 @@ class TextPluginTest(test_util.TensorFlowTestCase): summ = text_summary.text_summary("foo", array_ops.constant("one")) self.assertEqual(summ.op.type, "TensorSummary") + text_summary.text_summary("bar", array_ops.constant("2"), collections=[]) + summaries = framework_ops.get_collection( + framework_ops.GraphKeys.SUMMARIES) + self.assertEqual(len(summaries), 1) + if __name__ == "__main__": googletest.main() |