diff options
author | Reed Wanderman-Milne <reedwm@google.com> | 2018-09-06 12:58:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-06 13:10:39 -0700 |
commit | 612166a4f4c79efbe9e34e75652e10300150ec7a (patch) | |
tree | a22ff2043702bc4949c08d3c9fd86e842ef1558c /tensorflow/python/training | |
parent | 1aabc8beacd27b5577c72329310ce309f2e45eca (diff) |
Do not have ProfilerHook output a timeline for the first step.
This is because many ops take longer during the first step due to autotune. Instead, the first timeline is now outputed after N seconds/steps.
PiperOrigin-RevId: 211854304
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r-- | tensorflow/python/training/basic_session_run_hooks.py | 6 | ||||
-rw-r--r-- | tensorflow/python/training/basic_session_run_hooks_test.py | 37 |
2 files changed, 23 insertions, 20 deletions
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index 76625624e4..3bd4bd75bd 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -1025,7 +1025,7 @@ class ProfilerHook(session_run_hook.SessionRunHook): def before_run(self, run_context): self._request_summary = ( - self._next_step is None or + self._next_step is not None and self._timer.should_trigger_for_step(self._next_step)) requests = {"global_step": self._global_step_tensor} opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) @@ -1035,6 +1035,10 @@ class ProfilerHook(session_run_hook.SessionRunHook): def after_run(self, run_context, run_values): stale_global_step = run_values.results["global_step"] + if self._next_step is None: + # Update the timer so that it does not activate until N steps or seconds + # have passed. + self._timer.update_last_triggered_step(stale_global_step) global_step = stale_global_step + 1 if self._request_summary: global_step = run_context.session.run(self._global_step_tensor) diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index b49a871a56..fe8a3e9062 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -1454,52 +1454,50 @@ class ProfilerHookTest(test.TestCase): with self.assertRaises(ValueError): basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None) - def test_save_secs_saves_in_first_step(self): + def test_save_secs_does_not_save_in_first_step(self): with self.graph.as_default(): hook = basic_session_run_hooks.ProfilerHook( save_secs=2, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: sess.run(self.train_op) - self.assertEqual(1, self._count_timeline_files()) + self.assertEqual(0, self._count_timeline_files()) @test.mock.patch.object(time, 'time') def test_save_secs_saves_periodically(self, mock_time): # Pick a fixed start time. - current_time = 1484863632.320497 + current_time = 1484863632. with self.graph.as_default(): mock_time.return_value = current_time hook = basic_session_run_hooks.ProfilerHook( save_secs=2, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) # Saved. - self.assertEqual(1, self._count_timeline_files()) sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) + self.assertEqual(0, self._count_timeline_files()) # Simulate 2.5 seconds of sleep. mock_time.return_value = current_time + 2.5 sess.run(self.train_op) # Saved. + self.assertEqual(1, self._count_timeline_files()) # Pretend some small amount of time has passed. - mock_time.return_value = current_time + 0.1 + mock_time.return_value = current_time + 2.6 sess.run(self.train_op) # Not saved. # Edge test just before we should save the timeline. - mock_time.return_value = current_time + 1.9 + mock_time.return_value = current_time + 4.4 sess.run(self.train_op) # Not saved. - self.assertEqual(2, self._count_timeline_files()) + self.assertEqual(1, self._count_timeline_files()) mock_time.return_value = current_time + 4.5 sess.run(self.train_op) # Saved. - self.assertEqual(3, self._count_timeline_files()) + self.assertEqual(2, self._count_timeline_files()) - def test_save_steps_saves_in_first_step(self): + def test_save_steps_does_not_save_in_first_step(self): with self.graph.as_default(): hook = basic_session_run_hooks.ProfilerHook( - save_secs=2, output_dir=self.output_dir) + save_steps=1, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) # Saved. sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) + self.assertEqual(0, self._count_timeline_files()) def test_save_steps_saves_periodically(self): with self.graph.as_default(): @@ -1507,6 +1505,8 @@ class ProfilerHookTest(test.TestCase): save_steps=2, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: self.assertEqual(0, self._count_timeline_files()) + sess.run(self.train_op) # Not saved. + self.assertEqual(0, self._count_timeline_files()) sess.run(self.train_op) # Saved. self.assertEqual(1, self._count_timeline_files()) sess.run(self.train_op) # Not saved. @@ -1515,20 +1515,19 @@ class ProfilerHookTest(test.TestCase): self.assertEqual(2, self._count_timeline_files()) sess.run(self.train_op) # Not saved. self.assertEqual(2, self._count_timeline_files()) - sess.run(self.train_op) # Saved. - self.assertEqual(3, self._count_timeline_files()) - def test_run_metadata_saves_in_first_step(self): + def test_run_metadata_saves(self): writer_cache.FileWriterCache.clear() fake_summary_writer.FakeSummaryWriter.install() fake_writer = writer_cache.FileWriterCache.get(self.output_dir) with self.graph.as_default(): hook = basic_session_run_hooks.ProfilerHook( - save_secs=2, output_dir=self.output_dir) + save_steps=1, output_dir=self.output_dir) with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: + sess.run(self.train_op) # Not saved. sess.run(self.train_op) # Saved. self.assertEqual( - list(fake_writer._added_run_metadata.keys()), ['step_1']) + list(fake_writer._added_run_metadata.keys()), ['step_2']) fake_summary_writer.FakeSummaryWriter.uninstall() |