aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/numerics_test.py
diff options
context:
space:
mode:
authorGravatar Gunhan Gulsoy <gunan@google.com>2016-09-09 22:42:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-09 23:48:49 -0700
commit9faf6fe4abc4f749f7ebda1056799d8130165c09 (patch)
treec47a9db1a8510fd4476b84da5199f1336fa85929 /tensorflow/python/kernel_tests/numerics_test.py
parent60efa7994acb2c38cc855f2915ceff6e9304779e (diff)
To make the tests run both on GPU and CPU, when available, override use_gpu to
True in test_session. Change: 132750351
Diffstat (limited to 'tensorflow/python/kernel_tests/numerics_test.py')
-rw-r--r--tensorflow/python/kernel_tests/numerics_test.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/python/kernel_tests/numerics_test.py b/tensorflow/python/kernel_tests/numerics_test.py
index 9abd25cb56..6e0799363b 100644
--- a/tensorflow/python/kernel_tests/numerics_test.py
+++ b/tensorflow/python/kernel_tests/numerics_test.py
@@ -29,7 +29,7 @@ class VerifyTensorAllFiniteTest(tf.test.TestCase):
def testVerifyTensorAllFiniteSucceeds(self):
x_shape = [5, 4]
x = np.random.random_sample(x_shape).astype(np.float32)
- with self.test_session():
+ with self.test_session(use_gpu=True):
t = tf.constant(x, shape=x_shape, dtype=tf.float32)
t_verified = tf.verify_tensor_all_finite(t, "Input is not a number.")
self.assertAllClose(x, t_verified.eval())
@@ -41,7 +41,7 @@ class VerifyTensorAllFiniteTest(tf.test.TestCase):
# Test NaN.
x[0] = np.nan
- with self.test_session():
+ with self.test_session(use_gpu=True):
with self.assertRaisesOpError(my_msg):
t = tf.constant(x, shape=x_shape, dtype=tf.float32)
t_verified = tf.verify_tensor_all_finite(t, my_msg)
@@ -49,7 +49,7 @@ class VerifyTensorAllFiniteTest(tf.test.TestCase):
# Test Inf.
x[0] = np.inf
- with self.test_session():
+ with self.test_session(use_gpu=True):
with self.assertRaisesOpError(my_msg):
t = tf.constant(x, shape=x_shape, dtype=tf.float32)
t_verified = tf.verify_tensor_all_finite(t, my_msg)