aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-14 12:54:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-14 13:00:52 -0700
commita1df092d60e2c2e6fa0e2de668224b536892c244 (patch)
tree91f8c258b157bcd4fb4ac59c398a51f1d31c38dc
parent0c8a92edd07fa0a1316dc8e31c76d0efd3e54b33 (diff)
Add a custom numbering option to tf.contrib.checkpoint.CheckpointManager
Still increments the save counter, but uses the provided variable/integer for checkpoint numbering. PiperOrigin-RevId: 208696240
-rw-r--r--tensorflow/python/training/checkpoint_management.py20
-rw-r--r--tensorflow/python/training/checkpoint_management_test.py26
2 files changed, 42 insertions, 4 deletions
diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py
index 9a90f91a7c..85f2904318 100644
--- a/tensorflow/python/training/checkpoint_management.py
+++ b/tensorflow/python/training/checkpoint_management.py
@@ -33,7 +33,9 @@ from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import training_util
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
+from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@@ -622,13 +624,19 @@ class CheckpointManager(object):
"""
return self._checkpoint_prefix
- def save(self, session=None):
+ def save(self, session=None, checkpoint_number=None):
"""Creates a new checkpoint and manages it.
Args:
session: The session to evaluate variables in. Ignored when executing
eagerly. If not provided when graph building, the default session is
used.
+ checkpoint_number: An optional integer, or an integer-dtype `Variable` or
+ `Tensor`, used to number the checkpoint. If `None` (default),
+ checkpoints are numbered using `checkpoint.save_counter`. Even if
+ `checkpoint_number` is provided, `save_counter` is still incremented. A
+ user-provided `checkpoint_number` is not incremented even if it is a
+ `Variable`.
Returns:
The path to the new checkpoint. It is also recorded in the `checkpoints`
@@ -639,7 +647,6 @@ class CheckpointManager(object):
if context.executing_eagerly():
save_counter = self._checkpoint.save_counter
save_counter.assign_add(1)
- checkpoint_number = save_counter.numpy()
else:
if session is None:
session = ops.get_default_session()
@@ -653,8 +660,13 @@ class CheckpointManager(object):
with variable_scope.variable_creator_scope(_initializing_creator):
save_counter = self._checkpoint.save_counter
if self._save_counter_assign is None:
- self._save_counter_assign = save_counter.assign_add(1, read_value=True)
- checkpoint_number = session.run(self._save_counter_assign)
+ self._save_counter_assign = save_counter.assign_add(1, read_value=False)
+ session.run(self._save_counter_assign)
+ if checkpoint_number is None:
+ checkpoint_number = save_counter
+ if not isinstance(checkpoint_number, compat.integral_types):
+ checkpoint_number = training_util.global_step(
+ sess=session, global_step_tensor=checkpoint_number)
prefix = "%s-%d" % (self._prefix, checkpoint_number)
save_path = self._checkpoint.write(prefix)
timestamp = time.time()
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py
index 95e688d3c7..1e2827d0a4 100644
--- a/tensorflow/python/training/checkpoint_management_test.py
+++ b/tensorflow/python/training/checkpoint_management_test.py
@@ -26,6 +26,7 @@ import tempfile
from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
@@ -486,6 +487,31 @@ class CheckpointManagerTest(test.TestCase):
self.assertEqual([5020.],
state.all_model_checkpoint_timestamps)
+ @test_util.run_in_graph_and_eager_modes
+ def testCustomNumbering(self):
+ directory = self.get_temp_dir()
+ step = variables.Variable(0, dtype=dtypes.int64)
+ checkpoint = util.Checkpoint(step=step)
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=2)
+ self.evaluate(step.initializer)
+ for i in range(5):
+ path = manager.save(checkpoint_number=step)
+ expected_suffix = "-%d" % (2 * i,)
+ if not path.endswith(expected_suffix):
+ self.fail("%s should have suffix %s" % (path, expected_suffix))
+ self.evaluate(step.assign_add(2))
+ self.assertEqual(5, self.evaluate(checkpoint.save_counter))
+ # Test regular integers
+ last_path = manager.save(checkpoint_number=32)
+ self.assertIn("-32", last_path)
+ self.assertEqual(last_path, manager.latest_checkpoint)
+ self.assertEqual(
+ last_path, checkpoint_management.latest_checkpoint(directory))
+ state = checkpoint_management.get_checkpoint_state(directory)
+ # Only the most recent two checkpoints are saved
+ self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
+
if __name__ == "__main__":
test.main()