diff options
author | 2016-07-01 14:05:49 -0800 | |
---|---|---|
committer | 2016-07-01 15:18:29 -0700 | |
commit | a45bf5e45e74b6d7599453a24372c686c1818e79 (patch) | |
tree | f55bf0a906c24d67af34c72f5be1f414d5e558c8 | |
parent | 26217c20fb0b2766e2c17430340bdacc4f41f5c2 (diff) |
Handle input iterator exhaustion.
Change: 126450913
-rw-r--r-- | tensorflow/contrib/learn/python/learn/graph_actions.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py | 21 |
2 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index 1525407f36..6db0e2fda1 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -321,6 +321,8 @@ def train(graph, except errors.OutOfRangeError as e: logging.warn('Got exception during tf.learn training loop possibly ' 'due to exhausted input queue %s.', e) + except StopIteration: + logging.info('Exhausted input iterarator.') except BaseException as e: # pylint: disable=broad-except # Hold on to any other exceptions while we try recording a final # checkpoint and summary. diff --git a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py index a52e6c97a6..7718fdd89c 100644 --- a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py @@ -33,15 +33,18 @@ from tensorflow.contrib.learn.python.learn.utils import checkpoints class _Feeder(object): """Simple generator for `feed_fn`, returning 10 * step.""" - def __init__(self, tensor): + def __init__(self, tensor, max_step): self._step = 0 self._tensor = tensor + self._max_step = max_step @property def step(self): return self._step def feed_fn(self): + if self._step >= self._max_step: + raise StopIteration value = self._step * 10.0 self._step += 1 return {self._tensor: value} @@ -165,7 +168,7 @@ class GraphActionsTest(tf.test.TestCase): with tf.Graph().as_default() as g, self.test_session(g): in0, _, out = self._build_inference_graph() self._assert_summaries(self._output_dir, expected_session_logs=[]) - feeder = _Feeder(in0) + feeder = _Feeder(in0, 3) results = learn.graph_actions.evaluate( g, output_dir=self._output_dir, checkpoint_path=None, eval_dict={'a': out}, feed_fn=feeder.feed_fn, max_steps=3) @@ -175,6 +178,20 @@ class GraphActionsTest(tf.test.TestCase): self._output_dir, expected_summaries={0: {'a': 25.0}}, expected_session_logs=[]) + def test_evaluate_feed_fn_with_exhaustion(self): + with tf.Graph().as_default() as g, self.test_session(g): + in0, _, out = self._build_inference_graph() + self._assert_summaries(self._output_dir, expected_session_logs=[]) + feeder = _Feeder(in0, 2) + results = learn.graph_actions.evaluate( + g, output_dir=self._output_dir, checkpoint_path=None, + eval_dict={'a': out}, feed_fn=feeder.feed_fn, max_steps=3) + self.assertEqual(2, feeder.step) + self.assertEqual(({'a': 15.0}, 0), results) + self._assert_summaries( + self._output_dir, expected_summaries={0: {'a': 15.0}}, + expected_session_logs=[]) + def test_train_invalid_args(self): with tf.Graph().as_default() as g, self.test_session(g): train_op = tf.constant(1.0) |