aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/estimator/estimator_test.py2
-rw-r--r--tensorflow/python/estimator/replicate_model_fn_test.py9
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py5
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py37
4 files changed, 43 insertions, 10 deletions
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index f4255091bf..498f5294a4 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -680,7 +680,7 @@ class EstimatorTrainTest(test.TestCase):
text_format.Merge(checkpoint_file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
self.assertAllEqual(
- ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
+ ['model.ckpt-0', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
def test_train_save_copy_reload(self):
tmpdir = tempfile.mkdtemp()
diff --git a/tensorflow/python/estimator/replicate_model_fn_test.py b/tensorflow/python/estimator/replicate_model_fn_test.py
index ad1f9c02b9..00035ef1fe 100644
--- a/tensorflow/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/python/estimator/replicate_model_fn_test.py
@@ -27,6 +27,7 @@ import six
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import replicate_model_fn
+from tensorflow.python.estimator import run_config
from tensorflow.python.estimator.canned import dnn
from tensorflow.python.estimator.canned import optimizers
from tensorflow.python.estimator.canned import prediction_keys
@@ -593,7 +594,8 @@ class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase):
loss=loss,
eval_metric_ops=metrics,
predictions={'probabilities': predictions},
- train_op=optimizer.minimize(loss))
+ train_op=optimizer.minimize(
+ loss, global_step=training.get_global_step()))
@property
def params(self):
@@ -612,8 +614,9 @@ class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase):
estimator = estimator_lib.Estimator(
model_fn=self.model_fn,
model_dir=tempfile.mkdtemp(),
- params=self.params)
- estimator.train(train_input_fn, steps=1)
+ params=self.params,
+ config=run_config.RunConfig(save_checkpoints_steps=1))
+ estimator.train(train_input_fn, steps=2)
self.assertEqual(7.0, estimator.get_variable_value('c'))
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index aae757b99a..77d4f15d52 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -429,6 +429,11 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
for l in self._listeners:
l.begin()
+ def after_create_session(self, session, coord):
+ global_step = session.run(self._global_step_tensor)
+ self._save(session, global_step)
+ self._timer.update_last_triggered_step(global_step)
+
def before_run(self, run_context): # pylint: disable=unused-argument
if self._timer.last_triggered_step() is None:
# We do write graph and saver_def at the first call of before_run.
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index 2547661e52..4bf4a599b4 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -466,8 +466,8 @@ class CheckpointSaverHookTest(test.TestCase):
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
- 'before_save': 2,
- 'after_save': 2,
+ 'before_save': 3,
+ 'after_save': 3,
'end': 1
}, listener_counts)
@@ -490,8 +490,8 @@ class CheckpointSaverHookTest(test.TestCase):
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
- 'before_save': 2,
- 'after_save': 2,
+ 'before_save': 3,
+ 'after_save': 3,
'end': 1
}, listener_counts)
@@ -523,8 +523,8 @@ class CheckpointSaverHookTest(test.TestCase):
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
- 'before_save': 2,
- 'after_save': 2,
+ 'before_save': 3,
+ 'after_save': 3,
'end': 1
}, listener1_counts)
self.assertEqual(listener1_counts, listener2_counts)
@@ -718,6 +718,31 @@ class CheckpointSaverHookTest(test.TestCase):
fake_summary_writer.FakeSummaryWriter.uninstall()
+ def test_save_checkpoint_before_first_train_step(self):
+ with self.graph.as_default():
+ hook = basic_session_run_hooks.CheckpointSaverHook(
+ self.model_dir, save_steps=2, scaffold=self.scaffold)
+ hook.begin()
+ self.scaffold.finalize()
+ with session_lib.Session() as sess:
+ mon_sess = monitored_session._HookedSession(sess, [hook])
+ sess.run(self.scaffold.init_op)
+ hook.after_create_session(sess, None)
+ # Verifies that checkpoint is saved at step 0.
+ self.assertEqual(0,
+ checkpoint_utils.load_variable(self.model_dir,
+ self.global_step.name))
+ # Verifies that no checkpoint is saved after one training step.
+ mon_sess.run(self.train_op)
+ self.assertEqual(0,
+ checkpoint_utils.load_variable(self.model_dir,
+ self.global_step.name))
+ # Verifies that checkpoint is saved after save_steps.
+ mon_sess.run(self.train_op)
+ self.assertEqual(2,
+ checkpoint_utils.load_variable(self.model_dir,
+ self.global_step.name))
+
class ResourceCheckpointSaverHookTest(test.TestCase):