aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-01 14:05:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-01 15:18:29 -0700
commita45bf5e45e74b6d7599453a24372c686c1818e79 (patch)
treef55bf0a906c24d67af34c72f5be1f414d5e558c8
parent26217c20fb0b2766e2c17430340bdacc4f41f5c2 (diff)
Handle input iterator exhaustion.
Change: 126450913
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py21
2 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py
index 1525407f36..6db0e2fda1 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions.py
@@ -321,6 +321,8 @@ def train(graph,
except errors.OutOfRangeError as e:
logging.warn('Got exception during tf.learn training loop possibly '
'due to exhausted input queue %s.', e)
+ except StopIteration:
+ logging.info('Exhausted input iterarator.')
except BaseException as e: # pylint: disable=broad-except
# Hold on to any other exceptions while we try recording a final
# checkpoint and summary.
diff --git a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py
index a52e6c97a6..7718fdd89c 100644
--- a/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/graph_actions_test.py
@@ -33,15 +33,18 @@ from tensorflow.contrib.learn.python.learn.utils import checkpoints
class _Feeder(object):
"""Simple generator for `feed_fn`, returning 10 * step."""
- def __init__(self, tensor):
+ def __init__(self, tensor, max_step):
self._step = 0
self._tensor = tensor
+ self._max_step = max_step
@property
def step(self):
return self._step
def feed_fn(self):
+ if self._step >= self._max_step:
+ raise StopIteration
value = self._step * 10.0
self._step += 1
return {self._tensor: value}
@@ -165,7 +168,7 @@ class GraphActionsTest(tf.test.TestCase):
with tf.Graph().as_default() as g, self.test_session(g):
in0, _, out = self._build_inference_graph()
self._assert_summaries(self._output_dir, expected_session_logs=[])
- feeder = _Feeder(in0)
+ feeder = _Feeder(in0, 3)
results = learn.graph_actions.evaluate(
g, output_dir=self._output_dir, checkpoint_path=None,
eval_dict={'a': out}, feed_fn=feeder.feed_fn, max_steps=3)
@@ -175,6 +178,20 @@ class GraphActionsTest(tf.test.TestCase):
self._output_dir, expected_summaries={0: {'a': 25.0}},
expected_session_logs=[])
+ def test_evaluate_feed_fn_with_exhaustion(self):
+ with tf.Graph().as_default() as g, self.test_session(g):
+ in0, _, out = self._build_inference_graph()
+ self._assert_summaries(self._output_dir, expected_session_logs=[])
+ feeder = _Feeder(in0, 2)
+ results = learn.graph_actions.evaluate(
+ g, output_dir=self._output_dir, checkpoint_path=None,
+ eval_dict={'a': out}, feed_fn=feeder.feed_fn, max_steps=3)
+ self.assertEqual(2, feeder.step)
+ self.assertEqual(({'a': 15.0}, 0), results)
+ self._assert_summaries(
+ self._output_dir, expected_summaries={0: {'a': 15.0}},
+ expected_session_logs=[])
+
def test_train_invalid_args(self):
with tf.Graph().as_default() as g, self.test_session(g):
train_op = tf.constant(1.0)