aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util.py
diff options
context:
space:
mode:
authorGravatar Andrew Harp <andrewharp@google.com>2017-03-01 17:59:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-01 18:08:24 -0800
commit3e975ea978bac4d861bb09328b06f3c316212611 (patch)
tree79bac044c9723df8443495eb962c2dd98a2ed421 /tensorflow/python/framework/test_util.py
parent8043a27ed77f59bb68409070f2bfa01df0e04b89 (diff)
Merge changes from github.
Change: 148954491
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r--tensorflow/python/framework/test_util.py21
1 files changed, 19 insertions, 2 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index f2fd687adf..3ea7e547ee 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -44,7 +44,14 @@ from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util.protobuf import compare
+from tensorflow.python.client import device_lib
+def gpu_device_name():
+ """Returns the name of a GPU device if available or the empty string."""
+ for x in device_lib.list_local_devices():
+ if x.device_type == 'GPU' or x.device_type == 'SYCL':
+ return x.name
+ return ''
def assert_ops_in_graph(expected_ops, graph):
"""Assert all expected operations are found.
@@ -301,7 +308,12 @@ class TensorFlowTestCase(googletest.TestCase):
sess = self._cached_session
with sess.graph.as_default(), sess.as_default():
if force_gpu:
- with sess.graph.device("/gpu:0"):
+ # Use the name of an actual device if one is detected, or '/gpu:0'
+ # otherwise
+ gpu_name = gpu_device_name()
+ if len(gpu_name) == 0:
+ gpu_name = '/gpu:0'
+ with sess.graph.device(gpu_name):
yield sess
elif use_gpu:
yield sess
@@ -311,7 +323,12 @@ class TensorFlowTestCase(googletest.TestCase):
else:
with session.Session(graph=graph, config=prepare_config(config)) as sess:
if force_gpu:
- with sess.graph.device("/gpu:0"):
+ # Use the name of an actual device if one is detected, or '/gpu:0'
+ # otherwise
+ gpu_name = gpu_device_name()
+ if len(gpu_name) == 0:
+ gpu_name = '/gpu:0'
+ with sess.graph.device(gpu_name):
yield sess
elif use_gpu:
yield sess