aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar Reed Wanderman-Milne <reedwm@google.com>2018-09-06 12:58:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 13:10:39 -0700
commit612166a4f4c79efbe9e34e75652e10300150ec7a (patch)
treea22ff2043702bc4949c08d3c9fd86e842ef1558c /tensorflow/python/training
parent1aabc8beacd27b5577c72329310ce309f2e45eca (diff)
Do not have ProfilerHook output a timeline for the first step.
This is because many ops take longer during the first step due to autotune. Instead, the first timeline is now outputed after N seconds/steps. PiperOrigin-RevId: 211854304
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py6
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py37
2 files changed, 23 insertions, 20 deletions
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 76625624e4..3bd4bd75bd 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -1025,7 +1025,7 @@ class ProfilerHook(session_run_hook.SessionRunHook):
def before_run(self, run_context):
self._request_summary = (
- self._next_step is None or
+ self._next_step is not None and
self._timer.should_trigger_for_step(self._next_step))
requests = {"global_step": self._global_step_tensor}
opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
@@ -1035,6 +1035,10 @@ class ProfilerHook(session_run_hook.SessionRunHook):
def after_run(self, run_context, run_values):
stale_global_step = run_values.results["global_step"]
+ if self._next_step is None:
+ # Update the timer so that it does not activate until N steps or seconds
+ # have passed.
+ self._timer.update_last_triggered_step(stale_global_step)
global_step = stale_global_step + 1
if self._request_summary:
global_step = run_context.session.run(self._global_step_tensor)
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index b49a871a56..fe8a3e9062 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -1454,52 +1454,50 @@ class ProfilerHookTest(test.TestCase):
with self.assertRaises(ValueError):
basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None)
- def test_save_secs_saves_in_first_step(self):
+ def test_save_secs_does_not_save_in_first_step(self):
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
save_secs=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
sess.run(self.train_op)
- self.assertEqual(1, self._count_timeline_files())
+ self.assertEqual(0, self._count_timeline_files())
@test.mock.patch.object(time, 'time')
def test_save_secs_saves_periodically(self, mock_time):
# Pick a fixed start time.
- current_time = 1484863632.320497
+ current_time = 1484863632.
with self.graph.as_default():
mock_time.return_value = current_time
hook = basic_session_run_hooks.ProfilerHook(
save_secs=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
- sess.run(self.train_op) # Saved.
- self.assertEqual(1, self._count_timeline_files())
sess.run(self.train_op) # Not saved.
- self.assertEqual(1, self._count_timeline_files())
+ self.assertEqual(0, self._count_timeline_files())
# Simulate 2.5 seconds of sleep.
mock_time.return_value = current_time + 2.5
sess.run(self.train_op) # Saved.
+ self.assertEqual(1, self._count_timeline_files())
# Pretend some small amount of time has passed.
- mock_time.return_value = current_time + 0.1
+ mock_time.return_value = current_time + 2.6
sess.run(self.train_op) # Not saved.
# Edge test just before we should save the timeline.
- mock_time.return_value = current_time + 1.9
+ mock_time.return_value = current_time + 4.4
sess.run(self.train_op) # Not saved.
- self.assertEqual(2, self._count_timeline_files())
+ self.assertEqual(1, self._count_timeline_files())
mock_time.return_value = current_time + 4.5
sess.run(self.train_op) # Saved.
- self.assertEqual(3, self._count_timeline_files())
+ self.assertEqual(2, self._count_timeline_files())
- def test_save_steps_saves_in_first_step(self):
+ def test_save_steps_does_not_save_in_first_step(self):
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
- save_secs=2, output_dir=self.output_dir)
+ save_steps=1, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
- sess.run(self.train_op) # Saved.
sess.run(self.train_op) # Not saved.
- self.assertEqual(1, self._count_timeline_files())
+ self.assertEqual(0, self._count_timeline_files())
def test_save_steps_saves_periodically(self):
with self.graph.as_default():
@@ -1507,6 +1505,8 @@ class ProfilerHookTest(test.TestCase):
save_steps=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
self.assertEqual(0, self._count_timeline_files())
+ sess.run(self.train_op) # Not saved.
+ self.assertEqual(0, self._count_timeline_files())
sess.run(self.train_op) # Saved.
self.assertEqual(1, self._count_timeline_files())
sess.run(self.train_op) # Not saved.
@@ -1515,20 +1515,19 @@ class ProfilerHookTest(test.TestCase):
self.assertEqual(2, self._count_timeline_files())
sess.run(self.train_op) # Not saved.
self.assertEqual(2, self._count_timeline_files())
- sess.run(self.train_op) # Saved.
- self.assertEqual(3, self._count_timeline_files())
- def test_run_metadata_saves_in_first_step(self):
+ def test_run_metadata_saves(self):
writer_cache.FileWriterCache.clear()
fake_summary_writer.FakeSummaryWriter.install()
fake_writer = writer_cache.FileWriterCache.get(self.output_dir)
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
- save_secs=2, output_dir=self.output_dir)
+ save_steps=1, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
+ sess.run(self.train_op) # Not saved.
sess.run(self.train_op) # Saved.
self.assertEqual(
- list(fake_writer._added_run_metadata.keys()), ['step_1'])
+ list(fake_writer._added_run_metadata.keys()), ['step_2'])
fake_summary_writer.FakeSummaryWriter.uninstall()