aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-04-03 10:00:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-03 10:05:51 -0700
commitcfc886ac6064a04c71dd6c52e8c21ebec91eae50 (patch)
tree1b2983a67716e3b42e57d199e28aad6877ca1128
parentcf8c504688c5f5813c8772eb107ed3d4a1385888 (diff)
[tf.data] Fix handling of nested structures in `tf.contrib.data.prefetch_to_device()`.
PiperOrigin-RevId: 191456191
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py32
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py17
2 files changed, 48 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 676959a900..f2c57f92e2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -231,6 +231,37 @@ class StagingAreaOpsTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ def testPrefetchDictToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto()
+ worker_config.device_count["CPU"] = 2
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testPrefetchToDeviceGpu(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@@ -248,5 +279,6 @@ class StagingAreaOpsTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 98651bb568..554bfaa2cf 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -28,6 +28,7 @@ from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
# TODO(rohanj): Add a python class that constructs resource in the __init__
@@ -77,10 +78,24 @@ class _PrefetchToDeviceIterator(object):
@function.Defun(dtypes.string)
def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
remote_iterator = iterator_ops.Iterator.from_string_handle(
handle, input_iterator.output_types, input_iterator.output_shapes,
input_iterator.output_classes)
- return remote_iterator.get_next()
+ ret = remote_iterator.get_next()
+
+ # Convert any `SparseTensorValue`s to `SparseTensor`s.
+ ret = nest.pack_sequence_as(ret, [
+ sparse_tensor_lib.SparseTensor.from_value(t)
+ if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
+ ])
+
+ # Serialize any sparse tensors and convert result to tensors.
+ ret = nest.pack_sequence_as(ret, [
+ ops.convert_to_tensor(t)
+ for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
+ ])
+ return nest.flatten(ret)
with ops.device(device):
self._buffering_resource = function_buffering_resource(