From aa8f428a9310b3fd8371bddf612e480b27618b2e Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Tue, 9 Oct 2018 10:47:19 -0700 Subject: Removing the _SHOULD_RECORD_SUMMARIES_NAME and _SUMMARY_WRITER_INIT_COLLECTION_NAME collections from the summaryV2 implementation. Replacing them with global variables. PiperOrigin-RevId: 216383152 --- tensorflow/python/ops/summary_ops_v2.py | 56 +++++++++++++++++---------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index a404507627..18cefb8e1c 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -43,11 +43,12 @@ from tensorflow.python.training import training_util from tensorflow.python.util import tf_contextlib -# Name for a collection which is expected to have at most a single boolean -# Tensor. If this tensor is True the summary ops will record summaries. -_SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries" +# A global dictionary mapping graph keys to boolean values indicating whether +# we should record summaries for this particular graph or not. +_SHOULD_RECORD_SUMMARIES = {} -_SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2" +# A global dictionary mapping graph keys to a list of summary writer init ops. +_SUMMARY_WRITER_INIT_OP = {} _EXPERIMENT_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,256}$") _RUN_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,512}$") @@ -56,14 +57,9 @@ _USER_NAME_PATTERNS = re.compile(r"^[a-z]([-a-z0-9]{0,29}[a-z0-9])?$", re.I) def should_record_summaries(): """Returns boolean Tensor which is true if summaries should be recorded.""" - should_record_collection = ops.get_collection(_SHOULD_RECORD_SUMMARIES_NAME) - if not should_record_collection: - return False - if len(should_record_collection) != 1: - raise ValueError( - "More than one tensor specified for whether summaries " - "should be recorded: %s" % should_record_collection) - return should_record_collection[0] + global _SHOULD_RECORD_SUMMARIES + key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + return _SHOULD_RECORD_SUMMARIES.setdefault(key, False) # TODO(apassos) consider how to handle local step here. @@ -72,38 +68,41 @@ def record_summaries_every_n_global_steps(n, global_step=None): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" if global_step is None: global_step = training_util.get_or_create_global_step() - collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) - old = collection_ref[:] + global _SHOULD_RECORD_SUMMARIES + key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + old = _SHOULD_RECORD_SUMMARIES.setdefault(key, False) try: with ops.device("cpu:0"): - collection_ref[:] = [math_ops.equal(global_step % n, 0)] + _SHOULD_RECORD_SUMMARIES[key] = math_ops.equal(global_step % n, 0) yield finally: - collection_ref[:] = old + _SHOULD_RECORD_SUMMARIES[key] = old @tf_contextlib.contextmanager def always_record_summaries(): """Sets the should_record_summaries Tensor to always true.""" - collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) - old = collection_ref[:] + global _SHOULD_RECORD_SUMMARIES + key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + old = _SHOULD_RECORD_SUMMARIES.setdefault(key, False) try: - collection_ref[:] = [True] + _SHOULD_RECORD_SUMMARIES[key] = True yield finally: - collection_ref[:] = old + _SHOULD_RECORD_SUMMARIES[key] = old @tf_contextlib.contextmanager def never_record_summaries(): """Sets the should_record_summaries Tensor to always false.""" - collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) - old = collection_ref[:] + global _SHOULD_RECORD_SUMMARIES + key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + old = _SHOULD_RECORD_SUMMARIES.setdefault(key, False) try: - collection_ref[:] = [False] + _SHOULD_RECORD_SUMMARIES[key] = False yield finally: - collection_ref[:] = old + _SHOULD_RECORD_SUMMARIES[key] = old class SummaryWriter(object): @@ -143,7 +142,6 @@ class SummaryWriter(object): finally: context.context().summary_writer_resource = old - def init(self): """Operation to initialize the summary writer resource.""" if self._resource is not None: @@ -311,7 +309,9 @@ def _make_summary_writer(name, factory, **kwargs): if not context.executing_eagerly(): # TODO(apassos): Consider doing this instead. # ops.get_default_session().run(init_op) - ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op) + global _SUMMARY_WRITER_INIT_OP + key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + _SUMMARY_WRITER_INIT_OP.setdefault(key, []).append(init_op) return SummaryWriter(resource, init_op_fn) @@ -352,7 +352,9 @@ def summary_writer_initializer_op(): raise RuntimeError( "tf.contrib.summary.summary_writer_initializer_op is only " "supported in graph mode.") - return ops.get_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME) + global _SUMMARY_WRITER_INIT_OP + key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + return _SUMMARY_WRITER_INIT_OP.setdefault(key, []) def summary_writer_function(name, tensor, function, family=None): -- cgit v1.2.3