aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r--tensorflow/python/framework/test_util.py21
1 files changed, 19 insertions, 2 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index f2fd687adf..3ea7e547ee 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -44,7 +44,14 @@ from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util.protobuf import compare
+from tensorflow.python.client import device_lib
+def gpu_device_name():
+ """Returns the name of a GPU device if available or the empty string."""
+ for x in device_lib.list_local_devices():
+ if x.device_type == 'GPU' or x.device_type == 'SYCL':
+ return x.name
+ return ''
def assert_ops_in_graph(expected_ops, graph):
"""Assert all expected operations are found.
@@ -301,7 +308,12 @@ class TensorFlowTestCase(googletest.TestCase):
sess = self._cached_session
with sess.graph.as_default(), sess.as_default():
if force_gpu:
- with sess.graph.device("/gpu:0"):
+ # Use the name of an actual device if one is detected, or '/gpu:0'
+ # otherwise
+ gpu_name = gpu_device_name()
+ if len(gpu_name) == 0:
+ gpu_name = '/gpu:0'
+ with sess.graph.device(gpu_name):
yield sess
elif use_gpu:
yield sess
@@ -311,7 +323,12 @@ class TensorFlowTestCase(googletest.TestCase):
else:
with session.Session(graph=graph, config=prepare_config(config)) as sess:
if force_gpu:
- with sess.graph.device("/gpu:0"):
+ # Use the name of an actual device if one is detected, or '/gpu:0'
+ # otherwise
+ gpu_name = gpu_device_name()
+ if len(gpu_name) == 0:
+ gpu_name = '/gpu:0'
+ with sess.graph.device(gpu_name):
yield sess
elif use_gpu:
yield sess