diff options
author | 2018-04-24 10:31:17 -0700 | |
---|---|---|
committer | 2018-04-24 10:33:57 -0700 | |
commit | 55a4a479df8e1fbc8aa726596e6d4591364b3585 (patch) | |
tree | 3fac4b5aefe3a051ff330293fcb1da16cfd23929 /tensorflow | |
parent | 9c7e819352581bf5a97509b1fa5dc71dffa26500 (diff) |
Added a call in CheckpointSaverHook.after_create_session to always save
checkpoint before the first training step.
PiperOrigin-RevId: 194107958
Diffstat (limited to 'tensorflow')
4 files changed, 58 insertions, 24 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index d81a534b79..9e5aaf3118 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -715,7 +715,9 @@ class EstimatorTest(test.TestCase): ckpt = checkpoint_state_pb2.CheckpointState() text_format.Merge(checkpoint_file_content, ckpt) self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5') - self.assertAllEqual(['model.ckpt-1', 'model.ckpt-5'], + # TODO(b/78461127): Please modify tests to not directly rely on names of + # checkpoints. + self.assertAllEqual(['model.ckpt-0', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths) def test_train_save_copy_reload(self): diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index d453e19357..0fea86124c 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -679,8 +679,10 @@ class EstimatorTrainTest(test.TestCase): ckpt = checkpoint_state_pb2.CheckpointState() text_format.Merge(checkpoint_file_content, ckpt) self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5') + # TODO(b/78461127): Please modify tests to not directly rely on names of + # checkpoints. self.assertAllEqual( - ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths) + ['model.ckpt-0', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths) def test_train_save_copy_reload(self): tmpdir = tempfile.mkdtemp() diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index 3651291bdf..47339e057f 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -434,23 +434,27 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): for l in self._listeners: l.begin() - def before_run(self, run_context): # pylint: disable=unused-argument - if self._timer.last_triggered_step() is None: - # We do write graph and saver_def at the first call of before_run. - # We cannot do this in begin, since we let other hooks to change graph and - # add variables in begin. Graph is finalized after all begin calls. - training_util.write_graph( - ops.get_default_graph().as_graph_def(add_shapes=True), - self._checkpoint_dir, - "graph.pbtxt") - saver_def = self._get_saver().saver_def if self._get_saver() else None - graph = ops.get_default_graph() - meta_graph_def = meta_graph.create_meta_graph_def( - graph_def=graph.as_graph_def(add_shapes=True), - saver_def=saver_def) - self._summary_writer.add_graph(graph) - self._summary_writer.add_meta_graph(meta_graph_def) + def after_create_session(self, session, coord): + global_step = session.run(self._global_step_tensor) + # We do write graph and saver_def at the first call of before_run. + # We cannot do this in begin, since we let other hooks to change graph and + # add variables in begin. Graph is finalized after all begin calls. + training_util.write_graph( + ops.get_default_graph().as_graph_def(add_shapes=True), + self._checkpoint_dir, + "graph.pbtxt") + saver_def = self._get_saver().saver_def if self._get_saver() else None + graph = ops.get_default_graph() + meta_graph_def = meta_graph.create_meta_graph_def( + graph_def=graph.as_graph_def(add_shapes=True), + saver_def=saver_def) + self._summary_writer.add_graph(graph) + self._summary_writer.add_meta_graph(meta_graph_def) + # The checkpoint saved here is the state at step "global_step". + self._save(session, global_step) + self._timer.update_last_triggered_step(global_step) + def before_run(self, run_context): # pylint: disable=unused-argument return SessionRunArgs(self._global_step_tensor) def after_run(self, run_context, run_values): diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index 25962f6bf7..31898562f8 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -466,8 +466,8 @@ class CheckpointSaverHookTest(test.TestCase): self.assertEqual(2, global_step_val) self.assertEqual({ 'begin': 1, - 'before_save': 2, - 'after_save': 2, + 'before_save': 3, + 'after_save': 3, 'end': 1 }, listener_counts) @@ -490,8 +490,8 @@ class CheckpointSaverHookTest(test.TestCase): self.assertEqual(2, global_step_val) self.assertEqual({ 'begin': 1, - 'before_save': 2, - 'after_save': 2, + 'before_save': 3, + 'after_save': 3, 'end': 1 }, listener_counts) @@ -523,8 +523,8 @@ class CheckpointSaverHookTest(test.TestCase): self.assertEqual(2, global_step_val) self.assertEqual({ 'begin': 1, - 'before_save': 2, - 'after_save': 2, + 'before_save': 3, + 'after_save': 3, 'end': 1 }, listener1_counts) self.assertEqual(listener1_counts, listener2_counts) @@ -706,6 +706,7 @@ class CheckpointSaverHookTest(test.TestCase): with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) + hook.after_create_session(sess, None) mon_sess.run(self.train_op) summary_writer.assert_summaries( test_case=self, @@ -718,6 +719,31 @@ class CheckpointSaverHookTest(test.TestCase): fake_summary_writer.FakeSummaryWriter.uninstall() + def test_save_checkpoint_before_first_train_step(self): + with self.graph.as_default(): + hook = basic_session_run_hooks.CheckpointSaverHook( + self.model_dir, save_steps=2, scaffold=self.scaffold) + hook.begin() + self.scaffold.finalize() + with session_lib.Session() as sess: + mon_sess = monitored_session._HookedSession(sess, [hook]) + sess.run(self.scaffold.init_op) + hook.after_create_session(sess, None) + # Verifies that checkpoint is saved at step 0. + self.assertEqual(0, + checkpoint_utils.load_variable(self.model_dir, + self.global_step.name)) + # Verifies that no checkpoint is saved after one training step. + mon_sess.run(self.train_op) + self.assertEqual(0, + checkpoint_utils.load_variable(self.model_dir, + self.global_step.name)) + # Verifies that checkpoint is saved after save_steps. + mon_sess.run(self.train_op) + self.assertEqual(2, + checkpoint_utils.load_variable(self.model_dir, + self.global_step.name)) + class CheckpointSaverHookMultiStepTest(test.TestCase): |