diff options
Diffstat (limited to 'tensorflow/contrib/eager/python/datasets_test.py')
-rw-r--r-- | tensorflow/contrib/eager/python/datasets_test.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 68bec9aee8..acc605247f 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -193,6 +193,20 @@ class IteratorTest(test.TestCase): x = math_ops.add(x, x) self.assertAllEqual([0., 2.], x.numpy()) + def testGpuTensor(self): + ds = Dataset.from_tensors([0., 1.]) + with ops.device(test.gpu_device_name()): + for x in ds: + y = math_ops.add(x, x) + self.assertAllEqual([0., 2.], y.numpy()) + + def testGpuDefinedDataset(self): + with ops.device(test.gpu_device_name()): + ds = Dataset.from_tensors([0., 1.]) + for x in ds: + y = math_ops.add(x, x) + self.assertAllEqual([0., 2.], y.numpy()) + def testTensorsExplicitPrefetchToDevice(self): ds = Dataset.from_tensor_slices([0., 1.]) ds = ds.apply(prefetching_ops.prefetch_to_device(test.gpu_device_name())) |