diff options
-rw-r--r-- | tensorflow/contrib/learn/python/learn/learn_io/graph_io.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py | 12 |
2 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index cf4ea6ae2a..7e7974bada 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -673,6 +673,11 @@ def queue_parsed_features(parsed_features, errors.CancelledError))) dequeued_tensors = input_queue.dequeue() + if not isinstance(dequeued_tensors, list): + # input_queue.dequeue() returns a single tensor instead of a list of + # tensors if there is only one tensor to dequeue, which breaks the + # assumption of a list below. + dequeued_tensors = [dequeued_tensors] # Reset shapes on dequeued tensors. for i in range(len(tensors_to_enqueue)): diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index 90d58dec14..0f7307e406 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -34,6 +34,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.learn.python.learn.learn_io import graph_io from tensorflow.contrib.learn.python.learn.learn_io.graph_io import _read_keyed_batch_examples_shared_queue from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -743,6 +744,17 @@ class GraphIOTest(test.TestCase): coord.request_stop() coord.join(threads) + def test_queue_parsed_features_single_tensor(self): + with ops.Graph().as_default() as g, self.test_session(graph=g) as session: + features = {"test": constant_op.constant([1, 2, 3])} + _, queued_features = graph_io.queue_parsed_features(features) + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + out_features = session.run(queued_features["test"]) + self.assertAllEqual([1, 2, 3], out_features) + coord.request_stop() + coord.join(threads) + if __name__ == "__main__": test.main() |