diff options
-rw-r--r-- | tensorflow/python/client/session_test.py | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 5c0c405306..347833ce8f 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -120,11 +120,17 @@ class SessionTest(test_util.TensorFlowTestCase): inp = constant_op.constant(10.0, name='W1') self.assertAllEqual(inp.eval(), 10.0) - devices = sess.list_devices() - self.assertEqual(2, len(devices)) - for device in devices: - self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string( - device.name).device_type) + num_cpu_devices = 0 + num_gpu_devices = 0 + for device in sess.list_devices(): + device_type = framework_device_lib.DeviceSpec.from_string( + device.name).device_type + if device_type == 'CPU': + num_cpu_devices += 1 + elif device_type == 'GPU': + num_gpu_devices += 1 + self.assertEqual(2, num_cpu_devices) + self.assertEqual(0, num_gpu_devices) def testPerSessionThreads(self): with session.Session( |