aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/basic_session_run_hooks.py
diff options
context:
space:
mode:
authorGravatar Sherry Moore <sherrym@google.com>2018-04-05 09:33:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-05 09:35:47 -0700
commit3fb89650a1e7f5cc4c04f091170fac504ba10021 (patch)
tree62454da80fc0c94637db3a5c69f1c155b2447b30 /tensorflow/python/training/basic_session_run_hooks.py
parente6225d9835f63729a9006f10ca9e50068381663d (diff)
Added a call in CheckpointSaverHook.after_create_session to always save
checkpoint before the first training step. PiperOrigin-RevId: 191753026
Diffstat (limited to 'tensorflow/python/training/basic_session_run_hooks.py')
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py5
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index aae757b99a..77d4f15d52 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -429,6 +429,11 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
for l in self._listeners:
l.begin()
+ def after_create_session(self, session, coord):
+ global_step = session.run(self._global_step_tensor)
+ self._save(session, global_step)
+ self._timer.update_last_triggered_step(global_step)
+
def before_run(self, run_context): # pylint: disable=unused-argument
if self._timer.last_triggered_step() is None:
# We do write graph and saver_def at the first call of before_run.