aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-27 12:01:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-27 12:21:10 -0800
commit79098e13efe58ef3e56025c93761f4d7bb02dfbe (patch)
tree1ce91ce7126df8c2a93ab7edd0703f2486f57861
parent746dc3b49df98b9f3f52db6679d633752327df7e (diff)
Fix a bug where queue_parsed_examples fails for single tensors.
Change: 148676738
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/graph_io.py5
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py12
2 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
index cf4ea6ae2a..7e7974bada 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
@@ -673,6 +673,11 @@ def queue_parsed_features(parsed_features,
errors.CancelledError)))
dequeued_tensors = input_queue.dequeue()
+ if not isinstance(dequeued_tensors, list):
+ # input_queue.dequeue() returns a single tensor instead of a list of
+ # tensors if there is only one tensor to dequeue, which breaks the
+ # assumption of a list below.
+ dequeued_tensors = [dequeued_tensors]
# Reset shapes on dequeued tensors.
for i in range(len(tensors_to_enqueue)):
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
index 90d58dec14..0f7307e406 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
@@ -34,6 +34,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.learn.python.learn.learn_io import graph_io
from tensorflow.contrib.learn.python.learn.learn_io.graph_io import _read_keyed_batch_examples_shared_queue
from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -743,6 +744,17 @@ class GraphIOTest(test.TestCase):
coord.request_stop()
coord.join(threads)
+ def test_queue_parsed_features_single_tensor(self):
+ with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ features = {"test": constant_op.constant([1, 2, 3])}
+ _, queued_features = graph_io.queue_parsed_features(features)
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(session, coord=coord)
+ out_features = session.run(queued_features["test"])
+ self.assertAllEqual([1, 2, 3], out_features)
+ coord.request_stop()
+ coord.join(threads)
+
if __name__ == "__main__":
test.main()