aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/basic_session_run_hooks_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/basic_session_run_hooks_test.py')
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py37
1 files changed, 18 insertions, 19 deletions
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index b49a871a56..fe8a3e9062 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -1454,52 +1454,50 @@ class ProfilerHookTest(test.TestCase):
with self.assertRaises(ValueError):
basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None)
- def test_save_secs_saves_in_first_step(self):
+ def test_save_secs_does_not_save_in_first_step(self):
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
save_secs=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
sess.run(self.train_op)
- self.assertEqual(1, self._count_timeline_files())
+ self.assertEqual(0, self._count_timeline_files())
@test.mock.patch.object(time, 'time')
def test_save_secs_saves_periodically(self, mock_time):
# Pick a fixed start time.
- current_time = 1484863632.320497
+ current_time = 1484863632.
with self.graph.as_default():
mock_time.return_value = current_time
hook = basic_session_run_hooks.ProfilerHook(
save_secs=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
- sess.run(self.train_op) # Saved.
- self.assertEqual(1, self._count_timeline_files())
sess.run(self.train_op) # Not saved.
- self.assertEqual(1, self._count_timeline_files())
+ self.assertEqual(0, self._count_timeline_files())
# Simulate 2.5 seconds of sleep.
mock_time.return_value = current_time + 2.5
sess.run(self.train_op) # Saved.
+ self.assertEqual(1, self._count_timeline_files())
# Pretend some small amount of time has passed.
- mock_time.return_value = current_time + 0.1
+ mock_time.return_value = current_time + 2.6
sess.run(self.train_op) # Not saved.
# Edge test just before we should save the timeline.
- mock_time.return_value = current_time + 1.9
+ mock_time.return_value = current_time + 4.4
sess.run(self.train_op) # Not saved.
- self.assertEqual(2, self._count_timeline_files())
+ self.assertEqual(1, self._count_timeline_files())
mock_time.return_value = current_time + 4.5
sess.run(self.train_op) # Saved.
- self.assertEqual(3, self._count_timeline_files())
+ self.assertEqual(2, self._count_timeline_files())
- def test_save_steps_saves_in_first_step(self):
+ def test_save_steps_does_not_save_in_first_step(self):
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
- save_secs=2, output_dir=self.output_dir)
+ save_steps=1, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
- sess.run(self.train_op) # Saved.
sess.run(self.train_op) # Not saved.
- self.assertEqual(1, self._count_timeline_files())
+ self.assertEqual(0, self._count_timeline_files())
def test_save_steps_saves_periodically(self):
with self.graph.as_default():
@@ -1507,6 +1505,8 @@ class ProfilerHookTest(test.TestCase):
save_steps=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
self.assertEqual(0, self._count_timeline_files())
+ sess.run(self.train_op) # Not saved.
+ self.assertEqual(0, self._count_timeline_files())
sess.run(self.train_op) # Saved.
self.assertEqual(1, self._count_timeline_files())
sess.run(self.train_op) # Not saved.
@@ -1515,20 +1515,19 @@ class ProfilerHookTest(test.TestCase):
self.assertEqual(2, self._count_timeline_files())
sess.run(self.train_op) # Not saved.
self.assertEqual(2, self._count_timeline_files())
- sess.run(self.train_op) # Saved.
- self.assertEqual(3, self._count_timeline_files())
- def test_run_metadata_saves_in_first_step(self):
+ def test_run_metadata_saves(self):
writer_cache.FileWriterCache.clear()
fake_summary_writer.FakeSummaryWriter.install()
fake_writer = writer_cache.FileWriterCache.get(self.output_dir)
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
- save_secs=2, output_dir=self.output_dir)
+ save_steps=1, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
+ sess.run(self.train_op) # Not saved.
sess.run(self.train_op) # Saved.
self.assertEqual(
- list(fake_writer._added_run_metadata.keys()), ['step_1'])
+ list(fake_writer._added_run_metadata.keys()), ['step_2'])
fake_summary_writer.FakeSummaryWriter.uninstall()