aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py558
1 files changed, 544 insertions, 14 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index b08132cd72..82543b1039 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -21,6 +21,7 @@ import threading
from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -68,6 +69,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
with ops.device(device1):
buffer_resource_handle = prefetching_ops.function_buffering_resource(
f=_remote_fn,
+ output_types=[dtypes.float32],
target_device=target,
string_arg=ds_iterator_handle,
buffer_size=3,
@@ -85,8 +87,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
return (prefetch_op, reset_op, destroy_op)
def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
@@ -125,8 +126,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
"/job:localhost/replica:0/task:0/gpu:0")
def testReinitialization(self):
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
device0 = "/job:localhost/replica:0/task:0/cpu:0"
device1 = "/job:localhost/replica:0/task:0/cpu:1"
@@ -166,8 +166,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
sess.run(destroy_op)
def testReinitializationOutOfRange(self):
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
device0 = "/job:localhost/replica:0/task:0/cpu:0"
device1 = "/job:localhost/replica:0/task:0/cpu:1"
@@ -201,6 +200,49 @@ class PrefetchingKernelsOpsTest(test.TestCase):
sess.run(destroy_op)
+ def testStringsGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/gpu:0"
+
+ ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
+ ds_iterator = ds.make_one_shot_iterator()
+ ds_iterator_handle = ds_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _remote_fn(h):
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ h, ds.output_types, ds.output_shapes)
+ return remote_iterator.get_next()
+
+ target = constant_op.constant(device0)
+ with ops.device(device1):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_remote_fn,
+ output_types=[dtypes.string],
+ target_device=target,
+ string_arg=ds_iterator_handle,
+ buffer_size=3,
+ shared_name="strings")
+
+ with ops.device(device1):
+ prefetch_op = prefetching_ops.function_buffering_resource_get_next(
+ function_buffer_resource=buffer_resource_handle,
+ output_types=[dtypes.string])
+ destroy_op = resource_variable_ops.destroy_resource_op(
+ buffer_resource_handle, ignore_lookup_error=True)
+
+ with self.test_session() as sess:
+ self.assertEqual([b"a"], sess.run(prefetch_op))
+ self.assertEqual([b"b"], sess.run(prefetch_op))
+ self.assertEqual([b"c"], sess.run(prefetch_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ sess.run(destroy_op)
+
class PrefetchToDeviceTest(test.TestCase):
@@ -227,14 +269,43 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ def testPrefetchToSameDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device(
+ "/job:localhost/replica:0/task:0/device:CPU:0"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ with self.test_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testPrefetchDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
device_dataset = host_dataset.apply(
@@ -258,8 +329,7 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element["a"].dtype)
self.assertEqual([], next_element["a"].shape)
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual({"a": i}, sess.run(next_element))
@@ -292,8 +362,7 @@ class PrefetchToDeviceTest(test.TestCase):
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element.dtype)
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
actual = sess.run(next_element)
@@ -343,8 +412,7 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
sess.run(iterator.initializer)
for i in range(5):
@@ -377,5 +445,467 @@ class PrefetchToDeviceTest(test.TestCase):
sess.run(next_element)
+class CopyToDeviceTest(test.TestCase):
+
+ def testCopyToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceInt32(self):
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int32, next_element.dtype)
+ self.assertEqual((4,), next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToSameDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:0"))
+
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyDictToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyDictToDeviceWithPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopySparseTensorsToDevice(self):
+
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
+
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopySparseTensorsToDeviceWithPrefetch(self):
+
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
+
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuInt32(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuInt32AndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuStrings(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuStringsAndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDevicePingPongCPUGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ with compat.forward_compatibility_horizon(2018, 8, 4):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0", source_device="/cpu:0"))
+ back_to_cpu_dataset = device_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:0", source_device="/gpu:0"))
+
+ with ops.device("/cpu:0"):
+ iterator = back_to_cpu_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithReInit(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithReInitAndPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithReInit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithReInitAndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+
if __name__ == "__main__":
test.main()