aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2016-07-21 14:29:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-21 15:33:38 -0700
commitce819aa3eeecc4f00fb6b9db82bbc559ce144605 (patch)
tree2969a57db3054fba5c8e42789f47b4e5318d4227
parentf2ca33b6b4dda1d92b68c758cfafd5394828b8ff (diff)
Do not proceed if max_step already saved before calling train.
Change: 128112890
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py27
2 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py
index f377e90818..5c60aa4017 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions.py
@@ -220,6 +220,16 @@ def _supervised_train(graph,
if global_step_tensor is None:
raise ValueError('No "global_step" was provided or found in the graph.')
+ if max_steps is not None:
+ try:
+ start_step = checkpoints.load_variable(output_dir,
+ global_step_tensor.name)
+ if max_steps <= start_step:
+ logging.info('Skipping training since max_steps has already saved.')
+ return None
+ except: # pylint: disable=bare-except
+ pass
+
with graph.as_default():
# See question about adding the summary writer to the scaffold.
if supervisor_is_chief:
diff --git a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py
index 1e6ddff1d6..14a0c2c58e 100644
--- a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py
@@ -403,6 +403,33 @@ class GraphActionsTest(tf.test.TestCase):
self._output_dir, tf.contrib.framework.get_global_step().name)
self.assertEqual(15, step)
+ def test_train_skip_train_if_max_step_already_saved(self):
+ with tf.Graph().as_default() as g, self.test_session(g):
+ with tf.control_dependencies(self._build_inference_graph()):
+ train_op = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
+ learn.graph_actions._supervised_train( # pylint: disable=protected-access
+ g,
+ output_dir=self._output_dir,
+ train_op=train_op,
+ loss_op=tf.constant(2.0),
+ max_steps=10)
+ step = checkpoints.load_variable(
+ self._output_dir, tf.contrib.framework.get_global_step().name)
+ self.assertEqual(10, step)
+
+ with tf.Graph().as_default() as g, self.test_session(g):
+ with tf.control_dependencies(self._build_inference_graph()):
+ train_op = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
+ learn.graph_actions._supervised_train( # pylint: disable=protected-access
+ g,
+ output_dir=self._output_dir,
+ train_op=train_op,
+ loss_op=tf.constant(2.0),
+ max_steps=10)
+ step = checkpoints.load_variable(
+ self._output_dir, tf.contrib.framework.get_global_step().name)
+ self.assertEqual(10, step)
+
def test_train_loss(self):
with tf.Graph().as_default() as g, self.test_session(g):
tf.contrib.framework.create_global_step()