aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-10-09 10:47:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 10:52:08 -0700
commitaa8f428a9310b3fd8371bddf612e480b27618b2e (patch)
tree82eaa5d1d28b4f3ce7286654a82dbbbf234e67d5
parent11f32ebbdcd4eaf5e9e09fe27571e26ec0bd9dd8 (diff)
Removing the _SHOULD_RECORD_SUMMARIES_NAME and
_SUMMARY_WRITER_INIT_COLLECTION_NAME collections from the summaryV2 implementation. Replacing them with global variables. PiperOrigin-RevId: 216383152
-rw-r--r--tensorflow/python/ops/summary_ops_v2.py56
1 files changed, 29 insertions, 27 deletions
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index a404507627..18cefb8e1c 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -43,11 +43,12 @@ from tensorflow.python.training import training_util
from tensorflow.python.util import tf_contextlib
-# Name for a collection which is expected to have at most a single boolean
-# Tensor. If this tensor is True the summary ops will record summaries.
-_SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries"
+# A global dictionary mapping graph keys to boolean values indicating whether
+# we should record summaries for this particular graph or not.
+_SHOULD_RECORD_SUMMARIES = {}
-_SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2"
+# A global dictionary mapping graph keys to a list of summary writer init ops.
+_SUMMARY_WRITER_INIT_OP = {}
_EXPERIMENT_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,256}$")
_RUN_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,512}$")
@@ -56,14 +57,9 @@ _USER_NAME_PATTERNS = re.compile(r"^[a-z]([-a-z0-9]{0,29}[a-z0-9])?$", re.I)
def should_record_summaries():
"""Returns boolean Tensor which is true if summaries should be recorded."""
- should_record_collection = ops.get_collection(_SHOULD_RECORD_SUMMARIES_NAME)
- if not should_record_collection:
- return False
- if len(should_record_collection) != 1:
- raise ValueError(
- "More than one tensor specified for whether summaries "
- "should be recorded: %s" % should_record_collection)
- return should_record_collection[0]
+ global _SHOULD_RECORD_SUMMARIES
+ key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ return _SHOULD_RECORD_SUMMARIES.setdefault(key, False)
# TODO(apassos) consider how to handle local step here.
@@ -72,38 +68,41 @@ def record_summaries_every_n_global_steps(n, global_step=None):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
if global_step is None:
global_step = training_util.get_or_create_global_step()
- collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
- old = collection_ref[:]
+ global _SHOULD_RECORD_SUMMARIES
+ key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ old = _SHOULD_RECORD_SUMMARIES.setdefault(key, False)
try:
with ops.device("cpu:0"):
- collection_ref[:] = [math_ops.equal(global_step % n, 0)]
+ _SHOULD_RECORD_SUMMARIES[key] = math_ops.equal(global_step % n, 0)
yield
finally:
- collection_ref[:] = old
+ _SHOULD_RECORD_SUMMARIES[key] = old
@tf_contextlib.contextmanager
def always_record_summaries():
"""Sets the should_record_summaries Tensor to always true."""
- collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
- old = collection_ref[:]
+ global _SHOULD_RECORD_SUMMARIES
+ key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ old = _SHOULD_RECORD_SUMMARIES.setdefault(key, False)
try:
- collection_ref[:] = [True]
+ _SHOULD_RECORD_SUMMARIES[key] = True
yield
finally:
- collection_ref[:] = old
+ _SHOULD_RECORD_SUMMARIES[key] = old
@tf_contextlib.contextmanager
def never_record_summaries():
"""Sets the should_record_summaries Tensor to always false."""
- collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
- old = collection_ref[:]
+ global _SHOULD_RECORD_SUMMARIES
+ key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ old = _SHOULD_RECORD_SUMMARIES.setdefault(key, False)
try:
- collection_ref[:] = [False]
+ _SHOULD_RECORD_SUMMARIES[key] = False
yield
finally:
- collection_ref[:] = old
+ _SHOULD_RECORD_SUMMARIES[key] = old
class SummaryWriter(object):
@@ -143,7 +142,6 @@ class SummaryWriter(object):
finally:
context.context().summary_writer_resource = old
-
def init(self):
"""Operation to initialize the summary writer resource."""
if self._resource is not None:
@@ -311,7 +309,9 @@ def _make_summary_writer(name, factory, **kwargs):
if not context.executing_eagerly():
# TODO(apassos): Consider doing this instead.
# ops.get_default_session().run(init_op)
- ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op)
+ global _SUMMARY_WRITER_INIT_OP
+ key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ _SUMMARY_WRITER_INIT_OP.setdefault(key, []).append(init_op)
return SummaryWriter(resource, init_op_fn)
@@ -352,7 +352,9 @@ def summary_writer_initializer_op():
raise RuntimeError(
"tf.contrib.summary.summary_writer_initializer_op is only "
"supported in graph mode.")
- return ops.get_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME)
+ global _SUMMARY_WRITER_INIT_OP
+ key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ return _SUMMARY_WRITER_INIT_OP.setdefault(key, [])
def summary_writer_function(name, tensor, function, family=None):