diff options
author | 2016-07-21 14:29:19 -0800 | |
---|---|---|
committer | 2016-07-21 15:33:38 -0700 | |
commit | ce819aa3eeecc4f00fb6b9db82bbc559ce144605 (patch) | |
tree | 2969a57db3054fba5c8e42789f47b4e5318d4227 | |
parent | f2ca33b6b4dda1d92b68c758cfafd5394828b8ff (diff) |
Do not proceed if max_step already saved before calling train.
Change: 128112890
-rw-r--r-- | tensorflow/contrib/learn/python/learn/graph_actions.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py | 27 |
2 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index f377e90818..5c60aa4017 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -220,6 +220,16 @@ def _supervised_train(graph, if global_step_tensor is None: raise ValueError('No "global_step" was provided or found in the graph.') + if max_steps is not None: + try: + start_step = checkpoints.load_variable(output_dir, + global_step_tensor.name) + if max_steps <= start_step: + logging.info('Skipping training since max_steps has already saved.') + return None + except: # pylint: disable=bare-except + pass + with graph.as_default(): # See question about adding the summary writer to the scaffold. if supervisor_is_chief: diff --git a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py index 1e6ddff1d6..14a0c2c58e 100644 --- a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py @@ -403,6 +403,33 @@ class GraphActionsTest(tf.test.TestCase): self._output_dir, tf.contrib.framework.get_global_step().name) self.assertEqual(15, step) + def test_train_skip_train_if_max_step_already_saved(self): + with tf.Graph().as_default() as g, self.test_session(g): + with tf.control_dependencies(self._build_inference_graph()): + train_op = tf.assign_add(tf.contrib.framework.get_global_step(), 1) + learn.graph_actions._supervised_train( # pylint: disable=protected-access + g, + output_dir=self._output_dir, + train_op=train_op, + loss_op=tf.constant(2.0), + max_steps=10) + step = checkpoints.load_variable( + self._output_dir, tf.contrib.framework.get_global_step().name) + self.assertEqual(10, step) + + with tf.Graph().as_default() as g, self.test_session(g): + with tf.control_dependencies(self._build_inference_graph()): + train_op = tf.assign_add(tf.contrib.framework.get_global_step(), 1) + learn.graph_actions._supervised_train( # pylint: disable=protected-access + g, + output_dir=self._output_dir, + train_op=train_op, + loss_op=tf.constant(2.0), + max_steps=10) + step = checkpoints.load_variable( + self._output_dir, tf.contrib.framework.get_global_step().name) + self.assertEqual(10, step) + def test_train_loss(self): with tf.Graph().as_default() as g, self.test_session(g): tf.contrib.framework.create_global_step() |