aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/context.py
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-10-06 18:07:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-06 18:11:09 -0700
commit646db3e3f91cdfcb1d00eb2bd8bc510ce453e7d3 (patch)
treef10315d4fb33784dbb9c093387ab481a76f9df9f /tensorflow/python/eager/context.py
parentfb3c68db3fd9d1f18f8c5f8d6b005523dfcdf34d (diff)
eager: Compute num_gpus() correctly.
Without this change, if TensorFlow is compiled with support for other devices (such with XLA, which makes XLA_CPU and XLA_GPU devices available), then tfe.num_gpus() was incorrectly overcounting the number of available GPUs. PiperOrigin-RevId: 171373389
Diffstat (limited to 'tensorflow/python/eager/context.py')
-rw-r--r--tensorflow/python/eager/context.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 02ff567e9e..be3d535271 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -95,11 +95,18 @@ class Context(object):
device_list = pywrap_tensorflow.TFE_ContextListDevices(
self._context_handle, status)
try:
+ self._num_gpus = 0
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
with errors.raise_exception_on_not_ok_status() as status:
dev_name = pywrap_tensorflow.TF_DeviceListName(
device_list, i, status)
self._context_devices.append(pydev.canonical_name(dev_name))
+ with errors.raise_exception_on_not_ok_status() as status:
+ dev_type = pywrap_tensorflow.TF_DeviceListType(
+ device_list, i, status)
+ if dev_type == "GPU":
+ self._num_gpus += 1
+
finally:
pywrap_tensorflow.TF_DeleteDeviceList(device_list)
@@ -238,8 +245,8 @@ class Context(object):
def num_gpus(self):
"""The number of GPUs available to execute operations."""
- # TODO(ashankar): Use TF_DeviceListType to count GPU devices.
- return len(self._devices) - 1
+ self._initialize_handle_and_devices()
+ return self._num_gpus
def add_function_def(self, fdef):
"""Add a function definition to the context.