diff options
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r-- | tensorflow/python/framework/test_util.py | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index f2fd687adf..3ea7e547ee 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -44,7 +44,14 @@ from tensorflow.python.platform import googletest from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util.protobuf import compare +from tensorflow.python.client import device_lib +def gpu_device_name(): + """Returns the name of a GPU device if available or the empty string.""" + for x in device_lib.list_local_devices(): + if x.device_type == 'GPU' or x.device_type == 'SYCL': + return x.name + return '' def assert_ops_in_graph(expected_ops, graph): """Assert all expected operations are found. @@ -301,7 +308,12 @@ class TensorFlowTestCase(googletest.TestCase): sess = self._cached_session with sess.graph.as_default(), sess.as_default(): if force_gpu: - with sess.graph.device("/gpu:0"): + # Use the name of an actual device if one is detected, or '/gpu:0' + # otherwise + gpu_name = gpu_device_name() + if len(gpu_name) == 0: + gpu_name = '/gpu:0' + with sess.graph.device(gpu_name): yield sess elif use_gpu: yield sess @@ -311,7 +323,12 @@ class TensorFlowTestCase(googletest.TestCase): else: with session.Session(graph=graph, config=prepare_config(config)) as sess: if force_gpu: - with sess.graph.device("/gpu:0"): + # Use the name of an actual device if one is detected, or '/gpu:0' + # otherwise + gpu_name = gpu_device_name() + if len(gpu_name) == 0: + gpu_name = '/gpu:0' + with sess.graph.device(gpu_name): yield sess elif use_gpu: yield sess |