diff options
author | Todd Wang <toddw@google.com> | 2018-09-28 07:28:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 07:32:34 -0700 |
commit | 32627bfba19606d3c3a34f5d02ae9428675bbc42 (patch) | |
tree | 5d671e3c4a8185d150e33181e0415d7c6fb81e46 | |
parent | 19b2383cc0e221262be0780180558cf5bbb3e37e (diff) |
Allow testManyCPUs to encounter non-CPU devices.
PiperOrigin-RevId: 214932861
-rw-r--r-- | tensorflow/python/client/session_test.py | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 5c0c405306..347833ce8f 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -120,11 +120,17 @@ class SessionTest(test_util.TensorFlowTestCase): inp = constant_op.constant(10.0, name='W1') self.assertAllEqual(inp.eval(), 10.0) - devices = sess.list_devices() - self.assertEqual(2, len(devices)) - for device in devices: - self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string( - device.name).device_type) + num_cpu_devices = 0 + num_gpu_devices = 0 + for device in sess.list_devices(): + device_type = framework_device_lib.DeviceSpec.from_string( + device.name).device_type + if device_type == 'CPU': + num_cpu_devices += 1 + elif device_type == 'GPU': + num_gpu_devices += 1 + self.assertEqual(2, num_cpu_devices) + self.assertEqual(0, num_gpu_devices) def testPerSessionThreads(self): with session.Session( |