From a1df092d60e2c2e6fa0e2de668224b536892c244 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 14 Aug 2018 12:54:24 -0700 Subject: Add a custom numbering option to tf.contrib.checkpoint.CheckpointManager Still increments the save counter, but uses the provided variable/integer for checkpoint numbering. PiperOrigin-RevId: 208696240 --- .../python/training/checkpoint_management.py | 20 +++++++++++++---- .../python/training/checkpoint_management_test.py | 26 ++++++++++++++++++++++ 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py index 9a90f91a7c..85f2904318 100644 --- a/tensorflow/python/training/checkpoint_management.py +++ b/tensorflow/python/training/checkpoint_management.py @@ -33,7 +33,9 @@ from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import training_util from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState +from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -622,13 +624,19 @@ class CheckpointManager(object): """ return self._checkpoint_prefix - def save(self, session=None): + def save(self, session=None, checkpoint_number=None): """Creates a new checkpoint and manages it. Args: session: The session to evaluate variables in. Ignored when executing eagerly. If not provided when graph building, the default session is used. + checkpoint_number: An optional integer, or an integer-dtype `Variable` or + `Tensor`, used to number the checkpoint. If `None` (default), + checkpoints are numbered using `checkpoint.save_counter`. Even if + `checkpoint_number` is provided, `save_counter` is still incremented. A + user-provided `checkpoint_number` is not incremented even if it is a + `Variable`. Returns: The path to the new checkpoint. It is also recorded in the `checkpoints` @@ -639,7 +647,6 @@ class CheckpointManager(object): if context.executing_eagerly(): save_counter = self._checkpoint.save_counter save_counter.assign_add(1) - checkpoint_number = save_counter.numpy() else: if session is None: session = ops.get_default_session() @@ -653,8 +660,13 @@ class CheckpointManager(object): with variable_scope.variable_creator_scope(_initializing_creator): save_counter = self._checkpoint.save_counter if self._save_counter_assign is None: - self._save_counter_assign = save_counter.assign_add(1, read_value=True) - checkpoint_number = session.run(self._save_counter_assign) + self._save_counter_assign = save_counter.assign_add(1, read_value=False) + session.run(self._save_counter_assign) + if checkpoint_number is None: + checkpoint_number = save_counter + if not isinstance(checkpoint_number, compat.integral_types): + checkpoint_number = training_util.global_step( + sess=session, global_step_tensor=checkpoint_number) prefix = "%s-%d" % (self._prefix, checkpoint_number) save_path = self._checkpoint.write(prefix) timestamp = time.time() diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py index 95e688d3c7..1e2827d0a4 100644 --- a/tensorflow/python/training/checkpoint_management_test.py +++ b/tensorflow/python/training/checkpoint_management_test.py @@ -26,6 +26,7 @@ import tempfile from google.protobuf import text_format from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io @@ -486,6 +487,31 @@ class CheckpointManagerTest(test.TestCase): self.assertEqual([5020.], state.all_model_checkpoint_timestamps) + @test_util.run_in_graph_and_eager_modes + def testCustomNumbering(self): + directory = self.get_temp_dir() + step = variables.Variable(0, dtype=dtypes.int64) + checkpoint = util.Checkpoint(step=step) + manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=2) + self.evaluate(step.initializer) + for i in range(5): + path = manager.save(checkpoint_number=step) + expected_suffix = "-%d" % (2 * i,) + if not path.endswith(expected_suffix): + self.fail("%s should have suffix %s" % (path, expected_suffix)) + self.evaluate(step.assign_add(2)) + self.assertEqual(5, self.evaluate(checkpoint.save_counter)) + # Test regular integers + last_path = manager.save(checkpoint_number=32) + self.assertIn("-32", last_path) + self.assertEqual(last_path, manager.latest_checkpoint) + self.assertEqual( + last_path, checkpoint_management.latest_checkpoint(directory)) + state = checkpoint_management.get_checkpoint_state(directory) + # Only the most recent two checkpoints are saved + self.assertEqual([path, last_path], state.all_model_checkpoint_paths) + if __name__ == "__main__": test.main() -- cgit v1.2.3