aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Todd Wang <toddw@google.com>2018-09-28 07:28:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 07:32:34 -0700
commit32627bfba19606d3c3a34f5d02ae9428675bbc42 (patch)
tree5d671e3c4a8185d150e33181e0415d7c6fb81e46
parent19b2383cc0e221262be0780180558cf5bbb3e37e (diff)
Allow testManyCPUs to encounter non-CPU devices.
PiperOrigin-RevId: 214932861
-rw-r--r--tensorflow/python/client/session_test.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 5c0c405306..347833ce8f 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -120,11 +120,17 @@ class SessionTest(test_util.TensorFlowTestCase):
inp = constant_op.constant(10.0, name='W1')
self.assertAllEqual(inp.eval(), 10.0)
- devices = sess.list_devices()
- self.assertEqual(2, len(devices))
- for device in devices:
- self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string(
- device.name).device_type)
+ num_cpu_devices = 0
+ num_gpu_devices = 0
+ for device in sess.list_devices():
+ device_type = framework_device_lib.DeviceSpec.from_string(
+ device.name).device_type
+ if device_type == 'CPU':
+ num_cpu_devices += 1
+ elif device_type == 'GPU':
+ num_gpu_devices += 1
+ self.assertEqual(2, num_cpu_devices)
+ self.assertEqual(0, num_gpu_devices)
def testPerSessionThreads(self):
with session.Session(