aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/datasets_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/datasets_test.py')
-rw-r--r--tensorflow/contrib/eager/python/datasets_test.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index 68bec9aee8..acc605247f 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -193,6 +193,20 @@ class IteratorTest(test.TestCase):
x = math_ops.add(x, x)
self.assertAllEqual([0., 2.], x.numpy())
+ def testGpuTensor(self):
+ ds = Dataset.from_tensors([0., 1.])
+ with ops.device(test.gpu_device_name()):
+ for x in ds:
+ y = math_ops.add(x, x)
+ self.assertAllEqual([0., 2.], y.numpy())
+
+ def testGpuDefinedDataset(self):
+ with ops.device(test.gpu_device_name()):
+ ds = Dataset.from_tensors([0., 1.])
+ for x in ds:
+ y = math_ops.add(x, x)
+ self.assertAllEqual([0., 2.], y.numpy())
+
def testTensorsExplicitPrefetchToDevice(self):
ds = Dataset.from_tensor_slices([0., 1.])
ds = ds.apply(prefetching_ops.prefetch_to_device(test.gpu_device_name()))