aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar Nick Felt <nickfelt@google.com>2018-07-24 09:49:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-24 09:53:22 -0700
commit568727eed199dba04e37f500265b50f96fed455e (patch)
tree999f31d1469b3b5f2dc12d5ca04061cfe6062faa /tensorflow/python/training
parentf8bbd3ceb7e86b7595ba74a9a03cfc7c1be252a8 (diff)
Add v2 summary support to Estimator.train() and MonitoredSession hooks
This change makes Estimator.train() support v2 summaries (tf.contrib.summary.*) out-of-the-box, to match the support for v1 summaries. Estimator.train() will now handle the boilerplate necessary to initialize a file writer and enable summary writing every N steps, and will ensure that its own automatically exported summaries (for loss and global_step/sec) get written to the same underlying events file. As part of this change, tf.train.SummarySaverHook, tf.train.CheckpointSaverHook, tf.train.StepCounterHook, and tf.train.ProfilerHook have also been adapted to write summaries using the v2 summary system (via a compatibility layer), instead of using FileWriterCache. A couple additional smaller changes are: - the 'session' parameter to FileWriter() can now be a callable returning a tf.Session instance. - the introduction of tf.contrib.summary.record_summaries_if() which takes a boolean tensor for direct control of tf.contrib.summary.should_record_summaries(). - EstimatorSpec.train_op, besides a tf.Operation, is now allowed to be any Tensor-equivalent object rather than just a tf.Tensor. PiperOrigin-RevId: 205843986
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py182
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py476
-rw-r--r--tensorflow/python/training/monitored_session.py11
-rw-r--r--tensorflow/python/training/optimizer.py6
4 files changed, 476 insertions, 199 deletions
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index b0dd188db1..b8df7fe51b 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -31,12 +31,13 @@ from tensorflow.python.client import timeline
from tensorflow.python.framework import errors
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.summary.writer import writer
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.training.session_run_hook import SessionRunArgs
-from tensorflow.python.training.summary_io import SummaryWriterCache
from tensorflow.python.util.tf_export import tf_export
@@ -422,7 +423,9 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
self._steps_per_run = steps_per_run
def begin(self):
- self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
+ self._summary_writer = writer.FileWriter(
+ self._checkpoint_dir, session=ops.get_default_session,
+ filename_suffix="")
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError(
@@ -431,10 +434,12 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
l.begin()
def after_create_session(self, session, coord):
+ del coord
+ # Ensure summary writer resource has been initialized.
+ session.run(summary_ops_v2.summary_writer_initializer_op())
global_step = session.run(self._global_step_tensor)
- # We do write graph and saver_def at the first call of before_run.
- # We cannot do this in begin, since we let other hooks to change graph and
- # add variables in begin. Graph is finalized after all begin calls.
+ # Write graph and saver_def once graph is finalized, which isn't true yet
+ # in begin() since later hooks can still change the graph.
training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir,
@@ -444,8 +449,9 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=graph.as_graph_def(add_shapes=True),
saver_def=saver_def)
- self._summary_writer.add_graph(graph)
- self._summary_writer.add_meta_graph(meta_graph_def)
+ with ops.default_session(session):
+ self._summary_writer.add_graph(graph)
+ self._summary_writer.add_meta_graph(meta_graph_def)
# The checkpoint saved here is the state at step "global_step".
self._save(session, global_step)
self._timer.update_last_triggered_step(global_step)
@@ -470,6 +476,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
self._save(session, last_step)
for l in self._listeners:
l.end(session, last_step)
+ with ops.default_session(session):
+ self._summary_writer.flush()
def _save(self, session, step):
"""Saves the latest checkpoint, returns should_stop."""
@@ -479,10 +487,12 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
l.before_save(session, step)
self._get_saver().save(session, self._save_path, global_step=step)
- self._summary_writer.add_session_log(
- SessionLog(
- status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
- step)
+ with ops.default_session(session):
+ self._summary_writer.add_session_log(
+ SessionLog(
+ status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
+ step)
+ self._summary_writer.flush()
should_stop = False
for l in self._listeners:
@@ -543,13 +553,23 @@ class StepCounterHook(session_run_hook.SessionRunHook):
def begin(self):
if self._summary_writer is None and self._output_dir:
- self._summary_writer = SummaryWriterCache.get(self._output_dir)
+ self._summary_writer = writer.FileWriter(
+ self._output_dir, session=ops.get_default_session,
+ filename_suffix="")
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use StepCounterHook.")
self._summary_tag = training_util.get_global_step().op.name + "/sec"
+ def after_create_session(self, session, coord):
+ del coord
+ # Reset any stale state in case we're recovering from a previous error.
+ session.run(summary_ops_v2.summary_writer_initializer_op())
+ self._last_global_step = None
+ self._global_step_check_count = 0
+ self._timer.reset()
+
def before_run(self, run_context): # pylint: disable=unused-argument
return SessionRunArgs(self._global_step_tensor)
@@ -562,8 +582,6 @@ class StepCounterHook(session_run_hook.SessionRunHook):
logging.info("%s: %g", self._summary_tag, steps_per_sec)
def after_run(self, run_context, run_values):
- _ = run_context
-
stale_global_step = run_values.results
if self._timer.should_trigger_for_step(
stale_global_step + self._steps_per_run):
@@ -573,7 +591,8 @@ class StepCounterHook(session_run_hook.SessionRunHook):
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
global_step)
if elapsed_time is not None:
- self._log_and_record(elapsed_steps, elapsed_time, global_step)
+ with ops.default_session(run_context.session):
+ self._log_and_record(elapsed_steps, elapsed_time, global_step)
# Check whether the global step has been increased. Here, we do not use the
# timer.last_triggered_step as the timer might record a different global
@@ -599,6 +618,11 @@ class StepCounterHook(session_run_hook.SessionRunHook):
self._last_global_step = stale_global_step
+ def end(self, session):
+ if self._summary_writer is not None:
+ with ops.default_session(session):
+ self._summary_writer.flush()
+
@tf_export("train.NanLossDuringTrainingError")
class NanLossDuringTrainingError(RuntimeError):
@@ -643,6 +667,25 @@ class NanTensorHook(session_run_hook.SessionRunHook):
class SummarySaverHook(session_run_hook.SessionRunHook):
"""Saves summaries every N steps."""
+ _SUMMARY_PLACEHOLDER_COLLECTION = "_SUMMARY_SAVER_PLACEHOLDER"
+
+ @classmethod
+ def _set_placeholder(cls, placeholder):
+ """Sets a `tf.placeholder` to be fed by the first SummarySaverHook.
+
+ If a placeholder is provided, the first instance of SummarySaverHook in use
+ will feed it a boolean indicating whether summaries should be written,
+ according to the `save_steps` and `save_secs` parameters of that hook. This
+ makes the placeholder usable with `tf.contrib.summary.record_summaries_if`
+ to control `tf.contrib.summary` summary writing using the same schedule as
+ the `tf.summary` summary writing (which the hook controls directly).
+
+ Args:
+ placeholder: `tf.placeholder` for the first SummarySaverHook to feed
+ """
+ collection = ops.get_collection_ref(cls._SUMMARY_PLACEHOLDER_COLLECTION)
+ collection[:] = [placeholder]
+
def __init__(self,
save_steps=None,
save_secs=None,
@@ -680,53 +723,82 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
self._scaffold = scaffold
self._timer = SecondOrStepTimer(every_secs=save_secs,
every_steps=save_steps)
+ self._placeholder = None
# TODO(mdan): Throw an error if output_dir and summary_writer are None.
def begin(self):
if self._summary_writer is None and self._output_dir:
- self._summary_writer = SummaryWriterCache.get(self._output_dir)
- self._next_step = None
- self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
+ self._summary_writer = writer.FileWriter(
+ self._output_dir, filename_suffix="", session=ops.get_default_session)
+ # Designate the first SummarySaverHook to call begin() as the "primary"
+ # hook; it will control writing of v2 summaries via a placeholder bool.
+ collection = ops.get_collection_ref(self._SUMMARY_PLACEHOLDER_COLLECTION)
+ if collection:
+ self._placeholder = collection[0]
+ collection[:] = []
+ self._current_step = None
+ self._global_step_tensor = training_util.get_or_create_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use SummarySaverHook.")
- def before_run(self, run_context): # pylint: disable=unused-argument
- self._request_summary = (
- self._next_step is None or
- self._timer.should_trigger_for_step(self._next_step))
+ def after_create_session(self, session, coord):
+ del coord
+ # Reset any stale state in case we're recovering from a previous error.
+ session.run(summary_ops_v2.summary_writer_initializer_op())
+ self._current_step = None
+ self._timer.reset()
+
+ def before_run(self, run_context):
+ # For the first run, record a SessionLog.START at the pre-run global step.
+ if self._current_step is None:
+ self._current_step = run_context.session.run(self._global_step_tensor)
+ with ops.default_session(run_context.session):
+ self._summary_writer.add_session_log(
+ SessionLog(status=SessionLog.START), self._current_step)
requests = {"global_step": self._global_step_tensor}
+ self._request_summary = self._timer.should_trigger_for_step(
+ self._current_step)
if self._request_summary:
+ self._timer.update_last_triggered_step(self._current_step)
if self._get_summary_op() is not None:
requests["summary"] = self._get_summary_op()
-
- return SessionRunArgs(requests)
+ feeds = {}
+ if self._placeholder is not None and self._request_summary:
+ feeds[self._placeholder] = self._request_summary
+ args = SessionRunArgs(fetches=requests, feed_dict=feeds)
+ return args
def after_run(self, run_context, run_values):
- _ = run_context
- if not self._summary_writer:
- return
-
+ # Collect any legacy v1 summaries to emit.
+ summaries_to_emit = []
+ if self._summary_writer and self._request_summary:
+ for summary in run_values.results.get("summary", []):
+ # Skip None results corresponding to V2 summary operations.
+ if summary is not None:
+ summaries_to_emit.append(summary)
+ # Heuristically estimate current step as possibly-stale value plus one.
stale_global_step = run_values.results["global_step"]
- global_step = stale_global_step + 1
- if self._next_step is None or self._request_summary:
- global_step = run_context.session.run(self._global_step_tensor)
-
- if self._next_step is None:
- self._summary_writer.add_session_log(
- SessionLog(status=SessionLog.START), global_step)
-
- if self._request_summary:
- self._timer.update_last_triggered_step(global_step)
- if "summary" in run_values.results:
- for summary in run_values.results["summary"]:
- self._summary_writer.add_summary(summary, global_step)
-
- self._next_step = global_step + 1
+ self._current_step = stale_global_step + 1
+ # Read the actual post-run global step if we need better accuracy because
+ # 1) we will request summaries on the next run (based on estimate now) and
+ # must ensure we record an accurate "last triggered step" value, or
+ # 2) we have legacy v1 summaries to emit using the post-run step value.
+ # Note: we could have dealt with (1) separately in before_run() but by doing
+ # it here we can consolidate the reads in case both (1) and (2) apply.
+ near_next_trigger = self._timer.should_trigger_for_step(self._current_step)
+ if near_next_trigger or summaries_to_emit:
+ self._current_step = run_context.session.run(self._global_step_tensor)
+ # Emit any legacy v1 summaries.
+ if summaries_to_emit:
+ with ops.default_session(run_context.session):
+ for summary in summaries_to_emit:
+ self._summary_writer.add_summary(summary, self._current_step)
def end(self, session=None):
- if self._summary_writer:
- self._summary_writer.flush()
+ if self._summary_writer and session:
+ with ops.default_session(session):
+ self._summary_writer.flush()
def _get_summary_op(self):
"""Fetches the summary op either from self._summary_op or self._scaffold.
@@ -893,19 +965,27 @@ class ProfilerHook(session_run_hook.SessionRunHook):
show_memory: `bool`, if True, add object snapshot events to the trace
showing the sizes and lifetimes of tensors.
"""
+ self._output_dir = output_dir
self._output_file = os.path.join(output_dir, "timeline-{}.json")
- self._file_writer = SummaryWriterCache.get(output_dir)
self._show_dataflow = show_dataflow
self._show_memory = show_memory
self._timer = SecondOrStepTimer(
every_secs=save_secs, every_steps=save_steps)
def begin(self):
+ self._file_writer = writer.FileWriter(
+ self._output_dir, filename_suffix="", session=ops.get_default_session)
self._next_step = None
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use ProfilerHook.")
+ def after_create_session(self, session, coord):
+ del coord
+ # Reset any stale state in case we're recovering from a previous error.
+ session.run(summary_ops_v2.summary_writer_initializer_op())
+ self._timer.reset()
+
def before_run(self, run_context):
self._request_summary = (
self._next_step is None or
@@ -925,8 +1005,10 @@ class ProfilerHook(session_run_hook.SessionRunHook):
self._save(global_step,
self._output_file.format(global_step),
run_values.run_metadata.step_stats)
- self._file_writer.add_run_metadata(run_values.run_metadata,
- "step_%d" % global_step)
+ with ops.default_session(run_context.session):
+ self._file_writer.add_run_metadata(run_values.run_metadata,
+ "step_%d" % global_step,
+ global_step=global_step)
self._next_step = global_step + 1
@@ -938,6 +1020,10 @@ class ProfilerHook(session_run_hook.SessionRunHook):
trace.generate_chrome_trace_format(
show_dataflow=self._show_dataflow, show_memory=self._show_memory))
+ def end(self, session):
+ with ops.default_session(session):
+ self._file_writer.flush()
+
def _as_graph_element(obj):
"""Retrieves Graph element."""
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index b49a871a56..b89167f3c1 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -19,8 +19,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import glob
import os.path
import shutil
+import sys
import tempfile
import threading
import time
@@ -28,6 +30,9 @@ import time
from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.testing.python.framework import fake_summary_writer
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.core.util.event_pb2 import SessionLog
from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@@ -35,9 +40,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
@@ -45,13 +53,27 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
from tensorflow.python.summary import summary as summary_lib
+from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import monitored_session
+from tensorflow.python.training import saver
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
+def load_eventfile_contents(directory_path):
+ """Returns the contents of the singular event file in the given directory."""
+ writer_cache.FileWriterCache.clear()
+
+ # Get last Event written.
+ event_paths = glob.glob(os.path.join(directory_path, '*tfevent*'))
+ if len(event_paths) != 1:
+ raise AssertionError('Expected one eventfile, got %s' % str(event_paths))
+ result = list(summary_iterator.summary_iterator(event_paths[0]))
+ return result
+
+
class MockCheckpointSaverListener(
basic_session_run_hooks.CheckpointSaverListener):
@@ -717,11 +739,12 @@ class CheckpointSaverHookTest(test.TestCase):
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
- def test_summary_writer_defs(self):
- fake_summary_writer.FakeSummaryWriter.install()
- writer_cache.FileWriterCache.clear()
- summary_writer = writer_cache.FileWriterCache.get(self.model_dir)
+ def _assertCheckpointEvent(self, event, step, checkpoint_path):
+ self.assertEqual(step, event.step)
+ self.assertEqual(SessionLog.CHECKPOINT, event.session_log.status)
+ self.assertEqual(checkpoint_path, event.session_log.checkpoint_path)
+ def test_summary_writer_defs(self):
with self.graph.as_default():
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir, save_steps=2, scaffold=self.scaffold)
@@ -730,18 +753,40 @@ class CheckpointSaverHookTest(test.TestCase):
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
- hook.after_create_session(sess, None)
- mon_sess.run(self.train_op)
- summary_writer.assert_summaries(
- test_case=self,
- expected_logdir=self.model_dir,
- expected_added_meta_graphs=[
- meta_graph.create_meta_graph_def(
- graph_def=self.graph.as_graph_def(add_shapes=True),
- saver_def=self.scaffold.saver.saver_def)
- ])
-
- fake_summary_writer.FakeSummaryWriter.uninstall()
+ hook.after_create_session(sess, None) # Checkpoint saved at step 0.
+ expected_graph_def = self.graph.as_graph_def(add_shapes=True)
+ expected_meta_graph_def = meta_graph.create_meta_graph_def(
+ graph_def=expected_graph_def,
+ saver_def=self.scaffold.saver.saver_def)
+ mon_sess.run(self.train_op) # No checkpoint saved at step 1.
+ mon_sess.run(self.train_op) # Checkpoint saved at step 2.
+ mon_sess.run(self.train_op) # No checkpoint saved at step 3.
+ hook.end(sess) # Checkpoint saved at the last step (3)
+ events = iter(load_eventfile_contents(self.model_dir))
+ next(events) # Skip version event that's always there.
+
+ # Graph.
+ event = next(events)
+ self.assertEqual(0, event.step)
+ actual_graph_def = graph_pb2.GraphDef()
+ actual_graph_def.ParseFromString(event.graph_def)
+ test_util.assert_equal_graph_def(actual_graph_def, expected_graph_def)
+
+ # Metagraph.
+ event = next(events)
+ self.assertEqual(0, event.step)
+ actual_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ actual_meta_graph_def.ParseFromString(event.meta_graph_def)
+ test_util.assert_meta_graph_protos_equal(
+ self, expected_meta_graph_def, actual_meta_graph_def)
+
+ # Checkpoints.
+ # Strip the "-step#" suffix off the latest checkpoint to get base path.
+ checkpoint_path = saver.latest_checkpoint(self.model_dir).rsplit('-', 1)[0]
+ self._assertCheckpointEvent(next(events), 0, checkpoint_path)
+ self._assertCheckpointEvent(next(events), 2, checkpoint_path)
+ self._assertCheckpointEvent(next(events), 3, checkpoint_path)
+ self.assertRaises(StopIteration, lambda: next(events)) # No more events.
def test_save_checkpoint_before_first_train_step(self):
with self.graph.as_default():
@@ -1102,167 +1147,305 @@ class StepCounterHookTest(test.TestCase):
self.assertEqual('global_step/sec', summary_value.tag)
self.assertGreater(summary_value.simple_value, 0)
+ def test_summary_writer(self):
+ with ops.Graph().as_default(), session_lib.Session() as sess:
+ variables.get_or_create_global_step()
+ train_op = training_util._increment_global_step(1)
+ hook = basic_session_run_hooks.StepCounterHook(
+ output_dir=self.log_dir, every_n_steps=10)
+ hook.begin()
+ sess.run(variables_lib.global_variables_initializer())
+ mon_sess = monitored_session._HookedSession(sess, [hook])
+ for _ in range(30):
+ mon_sess.run(train_op)
+ hook.end(sess)
+ events = iter(load_eventfile_contents(self.log_dir))
+ next(events) # Skip version event that's always there.
+
+ event = next(events)
+ self.assertEqual(11, event.step)
+ self.assertEqual('global_step/sec', event.summary.value[0].tag)
+ self.assertLess(0, event.summary.value[0].simple_value)
-class SummarySaverHookTest(test.TestCase):
+ event = next(events)
+ self.assertEqual(21, event.step)
+ self.assertEqual('global_step/sec', event.summary.value[0].tag)
+ self.assertLess(0, event.summary.value[0].simple_value)
- def setUp(self):
- test.TestCase.setUp(self)
+ self.assertRaises(StopIteration, lambda: next(events)) # No more events.
- self.log_dir = 'log/dir'
- self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)
- var = variables_lib.Variable(0.0)
- tensor = state_ops.assign_add(var, 1.0)
- tensor2 = tensor * 2
- self.summary_op = summary_lib.scalar('my_summary', tensor)
- self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)
+class SummarySaverHookTest(test.TestCase):
- variables.get_or_create_global_step()
- self.train_op = training_util._increment_global_step(1)
+ def setUp(self):
+ test.TestCase.setUp(self)
+ self.logdir = self.get_temp_dir()
+ self._create_stable_global_step()
+
+ def _create_stable_global_step(self):
+ """Returns a new ResourceVariable global_step for deterministic tests."""
+ # TODO(nickfelt): remove after standard global_step is a ResourceVariable.
+ with ops.get_default_graph().name_scope(None):
+ return variable_scope.get_variable(
+ ops.GraphKeys.GLOBAL_STEP,
+ shape=[],
+ dtype=dtypes.int64,
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ collections=[
+ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
+ ],
+ # Use a ResourceVariable and set caching_device to make the read
+ # behavior deterministic and well-defined.
+ caching_device='cpu:0',
+ use_resource=True)
def test_raise_when_scaffold_and_summary_op_both_missing(self):
with self.assertRaises(ValueError):
basic_session_run_hooks.SummarySaverHook()
def test_raise_when_scaffold_and_summary_op_both_present(self):
+ summary_op = summary_lib.merge_all()
with self.assertRaises(ValueError):
basic_session_run_hooks.SummarySaverHook(
- scaffold=monitored_session.Scaffold(), summary_op=self.summary_op)
+ scaffold=monitored_session.Scaffold(), summary_op=summary_op)
- def test_raise_in_both_secs_and_steps(self):
+ def test_raise_when_secs_and_steps_both_missing(self):
with self.assertRaises(ValueError):
basic_session_run_hooks.SummarySaverHook(
- save_secs=10, save_steps=20, summary_writer=self.summary_writer)
+ save_secs=None, save_steps=None, output_dir=self.logdir)
- def test_raise_in_none_secs_and_steps(self):
+ def test_raise_when_secs_and_steps_both_present(self):
with self.assertRaises(ValueError):
basic_session_run_hooks.SummarySaverHook(
- save_secs=None, save_steps=None, summary_writer=self.summary_writer)
+ save_secs=10, save_steps=20, output_dir=self.logdir)
- def test_save_steps(self):
- hook = basic_session_run_hooks.SummarySaverHook(
- save_steps=8,
- summary_writer=self.summary_writer,
- summary_op=self.summary_op)
+ def _makeHook(self, **kwargs):
+ kwargs['output_dir'] = self.logdir
+ kwargs['scaffold'] = monitored_session.Scaffold()
+ return basic_session_run_hooks.SummarySaverHook(**kwargs)
+ def _runForSteps(self, hook, steps, loop_body_fn=None):
+ train_op = training_util.get_global_step().assign_add(1)
with self.test_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
+ scaffold = hook._scaffold # pylint: disable=protected-access
+ if scaffold is not None:
+ scaffold.finalize()
+ sess.run(scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
- for _ in range(30):
- mon_sess.run(self.train_op)
+ for _ in range(steps):
+ mon_sess.run(train_op)
+ if loop_body_fn is not None:
+ loop_body_fn()
hook.end(sess)
- self.summary_writer.assert_summaries(
- test_case=self,
- expected_logdir=self.log_dir,
- expected_summaries={
- 1: {
- 'my_summary': 1.0
- },
- 9: {
- 'my_summary': 2.0
- },
- 17: {
- 'my_summary': 3.0
- },
- 25: {
- 'my_summary': 4.0
- },
- })
+ def _assertSessionEvent(self, event, step, session_status):
+ self.assertEqual(step, event.step)
+ self.assertEqual(session_status, event.session_log.status)
+
+ def _assertSummaryEvent(self, event, step, tag_value_list):
+ self.assertEqual(step, event.step)
+ tag_value_actual_list = [
+ (value.tag, value.simple_value) for value in event.summary.value
+ ]
+ self.assertItemsEqual(tag_value_list, tag_value_actual_list)
+
+ def test_no_summaries(self):
+ hook = self._makeHook(save_steps=1)
+ self._runForSteps(hook, 3)
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ def test_basic_summaries(self):
+ summary_lib.scalar('foo-v1', 1.0)
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ with summary_ops_v2.always_record_summaries():
+ summary_ops_v2.scalar('foo-v2', 2.0)
+ hook = self._makeHook(save_steps=1)
+ self._runForSteps(hook, 3)
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 1, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 1, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 2, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 2, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 3, [('foo-v1', 1.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
def test_multiple_summaries(self):
- hook = basic_session_run_hooks.SummarySaverHook(
- save_steps=8,
- summary_writer=self.summary_writer,
- summary_op=[self.summary_op, self.summary_op2])
-
+ summary_lib.scalar('foo-v1', 1.0)
+ summary_lib.scalar('bar-v1', 10.0)
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ with summary_ops_v2.always_record_summaries():
+ foo = summary_ops_v2.scalar('foo-v2', 2.0)
+ # Ensure deterministic write order
+ with ops.control_dependencies([foo]):
+ summary_ops_v2.scalar('bar-v2', 20.0)
+ hook = self._makeHook(save_steps=1)
+ self._runForSteps(hook, 1)
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 0, [('bar-v2', 20.0)])
+ self._assertSummaryEvent(
+ next(events), 1, [('foo-v1', 1.0), ('bar-v1', 10.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ def test_v2_summaries_only(self):
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ with summary_ops_v2.always_record_summaries():
+ summary_ops_v2.scalar('foo-v2', 2.0)
+ hook = self._makeHook(save_steps=1)
+ self._runForSteps(hook, 1)
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ def test_v2_summaries_custom_file_writer(self):
+ other_dir = os.path.join(self.logdir, 'other')
+ other_writer = summary_ops_v2.create_file_writer(other_dir)
+ # SummarySaverHook only flushes the writer for logdir; this one needs to be
+ # manually flushed.
+ flush_op = other_writer.flush()
+ with summary_ops_v2.always_record_summaries():
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ summary_ops_v2.scalar('foo-v2', 2.0)
+ with other_writer.as_default():
+ summary_ops_v2.scalar('other-v2', 3.0)
+ hook = self._makeHook(save_steps=1)
+ self._runForSteps(hook, 1)
with self.test_session() as sess:
- hook.begin()
- sess.run(variables_lib.global_variables_initializer())
- mon_sess = monitored_session._HookedSession(sess, [hook])
- for _ in range(10):
- mon_sess.run(self.train_op)
- hook.end(sess)
+ sess.run(flush_op)
- self.summary_writer.assert_summaries(
- test_case=self,
- expected_logdir=self.log_dir,
- expected_summaries={
- 1: {
- 'my_summary': 1.0,
- 'my_summary2': 2.0
- },
- 9: {
- 'my_summary': 2.0,
- 'my_summary2': 4.0
- },
- })
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
- def test_save_secs_saving_once_every_step(self):
- hook = basic_session_run_hooks.SummarySaverHook(
- save_secs=0.5,
- summary_writer=self.summary_writer,
- summary_op=self.summary_op)
+ events = iter(load_eventfile_contents(other_dir))
+ next(events) # Skip version event that's always there.
+ self._assertSummaryEvent(next(events), 0, [('other-v2', 3.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
- with self.test_session() as sess:
- hook.begin()
- sess.run(variables_lib.global_variables_initializer())
- mon_sess = monitored_session._HookedSession(sess, [hook])
- for _ in range(4):
- mon_sess.run(self.train_op)
- time.sleep(0.5)
- hook.end(sess)
+ def test_save_steps(self):
+ summary_lib.scalar('foo-v1', 1.0)
+ placeholder = array_ops.placeholder_with_default(False, shape=[])
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ with summary_ops_v2.record_summaries_if(placeholder):
+ summary_ops_v2.scalar('foo-v2', 2.0)
- self.summary_writer.assert_summaries(
- test_case=self,
- expected_logdir=self.log_dir,
- expected_summaries={
- 1: {
- 'my_summary': 1.0
- },
- 2: {
- 'my_summary': 2.0
- },
- 3: {
- 'my_summary': 3.0
- },
- 4: {
- 'my_summary': 4.0
- },
- })
+ basic_session_run_hooks.SummarySaverHook._set_placeholder(placeholder)
+ hook = self._makeHook(save_steps=8)
+ self._runForSteps(hook, 30)
+
+ events = load_eventfile_contents(self.logdir)
+ print('TEST SAVE STEPS EVENTS', str(events), file=sys.stderr)
+ events = iter(events)
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 1, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 8, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 9, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 16, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 17, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 24, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 25, [('foo-v1', 1.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
@test.mock.patch.object(time, 'time')
- def test_save_secs_saving_once_every_three_steps(self, mock_time):
- mock_time.return_value = 1484695987.209386
- hook = basic_session_run_hooks.SummarySaverHook(
- save_secs=9.,
- summary_writer=self.summary_writer,
- summary_op=self.summary_op)
+ def test_save_secs_saving_once_every_step(self, mock_time):
+ mock_time.return_value = 1000.0
+ summary_lib.scalar('foo-v1', 1.0)
+ placeholder = array_ops.placeholder_with_default(False, shape=[])
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ with summary_ops_v2.record_summaries_if(placeholder):
+ summary_ops_v2.scalar('foo-v2', 2.0)
- with self.test_session() as sess:
- hook.begin()
- sess.run(variables_lib.global_variables_initializer())
- mon_sess = monitored_session._HookedSession(sess, [hook])
- for _ in range(8):
- mon_sess.run(self.train_op)
- mock_time.return_value += 3.1
- hook.end(sess)
+ basic_session_run_hooks.SummarySaverHook._set_placeholder(placeholder)
+ hook = self._makeHook(save_secs=0.5)
+ def fake_sleep():
+ mock_time.return_value += 0.5
+ self._runForSteps(hook, 4, fake_sleep)
+
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
+
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 1, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 1, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 2, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 2, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 3, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 3, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 4, [('foo-v1', 1.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ @test.mock.patch.object(time, 'time')
+ def test_save_secs_saving_once_every_three_steps(self, mock_time):
+ mock_time.return_value = 1000.0
+ summary_lib.scalar('foo-v1', 1.0)
+ placeholder = array_ops.placeholder_with_default(False, shape=[])
+ with summary_ops_v2.create_file_writer(self.logdir).as_default():
+ with summary_ops_v2.record_summaries_if(placeholder):
+ summary_ops_v2.scalar('foo-v2', 2.0)
+
+ basic_session_run_hooks.SummarySaverHook._set_placeholder(placeholder)
+ hook = self._makeHook(save_secs=9)
+ def fake_sleep():
+ mock_time.return_value += 3.1
+ self._runForSteps(hook, 8, fake_sleep)
+
+ events = iter(load_eventfile_contents(self.logdir))
+ next(events) # Skip version event that's always there.
+ self._assertSessionEvent(next(events), 0, SessionLog.START)
# 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first:
- self.summary_writer.assert_summaries(
+ self._assertSummaryEvent(next(events), 0, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 1, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 3, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 4, [('foo-v1', 1.0)])
+
+ self._assertSummaryEvent(next(events), 6, [('foo-v2', 2.0)])
+ self._assertSummaryEvent(next(events), 7, [('foo-v1', 1.0)])
+ self.assertRaises(StopIteration, lambda: next(events))
+
+ def test_explicit_summary_writer_and_op(self):
+ summary_writer = fake_summary_writer.FakeSummaryWriter(self.logdir)
+ hook = basic_session_run_hooks.SummarySaverHook(
+ save_steps=1,
+ summary_writer=summary_writer,
+ summary_op=summary_lib.scalar('foo-v1', 1.0))
+ self._runForSteps(hook, 3)
+ summary_writer.assert_summaries(
test_case=self,
- expected_logdir=self.log_dir,
+ expected_logdir=self.logdir,
expected_summaries={
- 1: {
- 'my_summary': 1.0
- },
- 4: {
- 'my_summary': 2.0
- },
- 7: {
- 'my_summary': 3.0
- },
+ 1: {'foo-v1': 1.0},
+ 2: {'foo-v1': 1.0},
+ 3: {'foo-v1': 1.0},
})
@@ -1518,18 +1701,23 @@ class ProfilerHookTest(test.TestCase):
sess.run(self.train_op) # Saved.
self.assertEqual(3, self._count_timeline_files())
- def test_run_metadata_saves_in_first_step(self):
- writer_cache.FileWriterCache.clear()
- fake_summary_writer.FakeSummaryWriter.install()
- fake_writer = writer_cache.FileWriterCache.get(self.output_dir)
+ def test_run_metadata_summary_saving(self):
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
- save_secs=2, output_dir=self.output_dir)
+ save_steps=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
sess.run(self.train_op) # Saved.
- self.assertEqual(
- list(fake_writer._added_run_metadata.keys()), ['step_1'])
- fake_summary_writer.FakeSummaryWriter.uninstall()
+ sess.run(self.train_op) # Not saved.
+ sess.run(self.train_op) # Saved.
+ events = iter(load_eventfile_contents(self.output_dir))
+ next(events) # Skip version event that's always there.
+ event = next(events)
+ self.assertEqual(1, event.step)
+ self.assertEqual('step_1', event.tagged_run_metadata.tag)
+ event = next(events)
+ self.assertEqual(3, event.step)
+ self.assertEqual('step_3', event.tagged_run_metadata.tag)
+ self.assertRaises(StopIteration, lambda: next(events)) # No more events.
if __name__ == '__main__':
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 7b06bffa4b..8a4ca04b1e 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources
+from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
@@ -204,13 +205,17 @@ class Scaffold(object):
'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
Scaffold.default_local_init_op)
if self._summary_op is None:
+ def default_summary_op():
+ v1_op = summary.merge_all()
+ v2_ops = summary_ops_v2.all_summary_ops() or []
+ if v1_op is not None:
+ return control_flow_ops.with_dependencies(v2_ops, v1_op)
+ return control_flow_ops.group(v2_ops) if v2_ops else None
self._summary_op = Scaffold.get_or_default('summary_op',
ops.GraphKeys.SUMMARY_OP,
- summary.merge_all)
- # pylint: disable=g-long-lambda
+ default_summary_op)
if self._saver is None:
self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access
- # pylint: enable=g-long-lambda
self._saver.build()
ops.get_default_graph().finalize()
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index f75db08059..b9d42b034e 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -611,10 +611,8 @@ class Optimizer(
if isinstance(global_step, resource_variable_ops.ResourceVariable):
# TODO(apassos): the implicit read in assign_add is slow; consider
# making it less so.
- apply_updates = resource_variable_ops.assign_add_variable_op(
- global_step.handle,
- ops.convert_to_tensor(1, dtype=global_step.dtype),
- name=name)
+ apply_updates = global_step.assign_add(
+ 1, name=name, read_value=False)
else:
apply_updates = state_ops.assign_add(global_step, 1, name=name)