aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Sherry Moore <sherrym@google.com>2018-04-24 10:31:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-24 10:33:57 -0700
commit55a4a479df8e1fbc8aa726596e6d4591364b3585 (patch)
tree3fac4b5aefe3a051ff330293fcb1da16cfd23929 /tensorflow
parent9c7e819352581bf5a97509b1fa5dc71dffa26500 (diff)
Added a call in CheckpointSaverHook.after_create_session to always save
checkpoint before the first training step. PiperOrigin-RevId: 194107958
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py4
-rw-r--r--tensorflow/python/estimator/estimator_test.py4
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py36
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py38
4 files changed, 58 insertions, 24 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index d81a534b79..9e5aaf3118 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -715,7 +715,9 @@ class EstimatorTest(test.TestCase):
ckpt = checkpoint_state_pb2.CheckpointState()
text_format.Merge(checkpoint_file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
- self.assertAllEqual(['model.ckpt-1', 'model.ckpt-5'],
+ # TODO(b/78461127): Please modify tests to not directly rely on names of
+ # checkpoints.
+ self.assertAllEqual(['model.ckpt-0', 'model.ckpt-5'],
ckpt.all_model_checkpoint_paths)
def test_train_save_copy_reload(self):
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index d453e19357..0fea86124c 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -679,8 +679,10 @@ class EstimatorTrainTest(test.TestCase):
ckpt = checkpoint_state_pb2.CheckpointState()
text_format.Merge(checkpoint_file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
+ # TODO(b/78461127): Please modify tests to not directly rely on names of
+ # checkpoints.
self.assertAllEqual(
- ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
+ ['model.ckpt-0', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
def test_train_save_copy_reload(self):
tmpdir = tempfile.mkdtemp()
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 3651291bdf..47339e057f 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -434,23 +434,27 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
for l in self._listeners:
l.begin()
- def before_run(self, run_context): # pylint: disable=unused-argument
- if self._timer.last_triggered_step() is None:
- # We do write graph and saver_def at the first call of before_run.
- # We cannot do this in begin, since we let other hooks to change graph and
- # add variables in begin. Graph is finalized after all begin calls.
- training_util.write_graph(
- ops.get_default_graph().as_graph_def(add_shapes=True),
- self._checkpoint_dir,
- "graph.pbtxt")
- saver_def = self._get_saver().saver_def if self._get_saver() else None
- graph = ops.get_default_graph()
- meta_graph_def = meta_graph.create_meta_graph_def(
- graph_def=graph.as_graph_def(add_shapes=True),
- saver_def=saver_def)
- self._summary_writer.add_graph(graph)
- self._summary_writer.add_meta_graph(meta_graph_def)
+ def after_create_session(self, session, coord):
+ global_step = session.run(self._global_step_tensor)
+ # We do write graph and saver_def at the first call of before_run.
+ # We cannot do this in begin, since we let other hooks to change graph and
+ # add variables in begin. Graph is finalized after all begin calls.
+ training_util.write_graph(
+ ops.get_default_graph().as_graph_def(add_shapes=True),
+ self._checkpoint_dir,
+ "graph.pbtxt")
+ saver_def = self._get_saver().saver_def if self._get_saver() else None
+ graph = ops.get_default_graph()
+ meta_graph_def = meta_graph.create_meta_graph_def(
+ graph_def=graph.as_graph_def(add_shapes=True),
+ saver_def=saver_def)
+ self._summary_writer.add_graph(graph)
+ self._summary_writer.add_meta_graph(meta_graph_def)
+ # The checkpoint saved here is the state at step "global_step".
+ self._save(session, global_step)
+ self._timer.update_last_triggered_step(global_step)
+ def before_run(self, run_context): # pylint: disable=unused-argument
return SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index 25962f6bf7..31898562f8 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -466,8 +466,8 @@ class CheckpointSaverHookTest(test.TestCase):
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
- 'before_save': 2,
- 'after_save': 2,
+ 'before_save': 3,
+ 'after_save': 3,
'end': 1
}, listener_counts)
@@ -490,8 +490,8 @@ class CheckpointSaverHookTest(test.TestCase):
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
- 'before_save': 2,
- 'after_save': 2,
+ 'before_save': 3,
+ 'after_save': 3,
'end': 1
}, listener_counts)
@@ -523,8 +523,8 @@ class CheckpointSaverHookTest(test.TestCase):
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
- 'before_save': 2,
- 'after_save': 2,
+ 'before_save': 3,
+ 'after_save': 3,
'end': 1
}, listener1_counts)
self.assertEqual(listener1_counts, listener2_counts)
@@ -706,6 +706,7 @@ class CheckpointSaverHookTest(test.TestCase):
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
+ hook.after_create_session(sess, None)
mon_sess.run(self.train_op)
summary_writer.assert_summaries(
test_case=self,
@@ -718,6 +719,31 @@ class CheckpointSaverHookTest(test.TestCase):
fake_summary_writer.FakeSummaryWriter.uninstall()
+ def test_save_checkpoint_before_first_train_step(self):
+ with self.graph.as_default():
+ hook = basic_session_run_hooks.CheckpointSaverHook(
+ self.model_dir, save_steps=2, scaffold=self.scaffold)
+ hook.begin()
+ self.scaffold.finalize()
+ with session_lib.Session() as sess:
+ mon_sess = monitored_session._HookedSession(sess, [hook])
+ sess.run(self.scaffold.init_op)
+ hook.after_create_session(sess, None)
+ # Verifies that checkpoint is saved at step 0.
+ self.assertEqual(0,
+ checkpoint_utils.load_variable(self.model_dir,
+ self.global_step.name))
+ # Verifies that no checkpoint is saved after one training step.
+ mon_sess.run(self.train_op)
+ self.assertEqual(0,
+ checkpoint_utils.load_variable(self.model_dir,
+ self.global_step.name))
+ # Verifies that checkpoint is saved after save_steps.
+ mon_sess.run(self.train_op)
+ self.assertEqual(2,
+ checkpoint_utils.load_variable(self.model_dir,
+ self.global_step.name))
+
class CheckpointSaverHookMultiStepTest(test.TestCase):