diff options
author | 2017-10-06 18:07:17 -0700 | |
---|---|---|
committer | 2017-10-06 18:11:09 -0700 | |
commit | 646db3e3f91cdfcb1d00eb2bd8bc510ce453e7d3 (patch) | |
tree | f10315d4fb33784dbb9c093387ab481a76f9df9f /tensorflow/python/eager/context.py | |
parent | fb3c68db3fd9d1f18f8c5f8d6b005523dfcdf34d (diff) |
eager: Compute num_gpus() correctly.
Without this change, if TensorFlow is compiled with support for other devices
(such with XLA, which makes XLA_CPU and XLA_GPU devices available), then
tfe.num_gpus() was incorrectly overcounting the number of available GPUs.
PiperOrigin-RevId: 171373389
Diffstat (limited to 'tensorflow/python/eager/context.py')
-rw-r--r-- | tensorflow/python/eager/context.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 02ff567e9e..be3d535271 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -95,11 +95,18 @@ class Context(object): device_list = pywrap_tensorflow.TFE_ContextListDevices( self._context_handle, status) try: + self._num_gpus = 0 for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): with errors.raise_exception_on_not_ok_status() as status: dev_name = pywrap_tensorflow.TF_DeviceListName( device_list, i, status) self._context_devices.append(pydev.canonical_name(dev_name)) + with errors.raise_exception_on_not_ok_status() as status: + dev_type = pywrap_tensorflow.TF_DeviceListType( + device_list, i, status) + if dev_type == "GPU": + self._num_gpus += 1 + finally: pywrap_tensorflow.TF_DeleteDeviceList(device_list) @@ -238,8 +245,8 @@ class Context(object): def num_gpus(self): """The number of GPUs available to execute operations.""" - # TODO(ashankar): Use TF_DeviceListType to count GPU devices. - return len(self._devices) - 1 + self._initialize_handle_and_devices() + return self._num_gpus def add_function_def(self, fdef): """Add a function definition to the context. |