diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-03-29 19:56:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-29 19:59:30 -0700 |
commit | 9451c12d62b272789947f475554601295ada4083 (patch) | |
tree | aae01bbe8c523fdeba6de2059e3f621583f7bdc1 /tensorflow/python/client | |
parent | 566f9041e19831a4eb8904654ddd365fd8f234c0 (diff) |
Internal change
PiperOrigin-RevId: 191024677
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r-- | tensorflow/python/client/device_lib.i | 41 | ||||
-rw-r--r-- | tensorflow/python/client/device_lib.py | 10 | ||||
-rw-r--r-- | tensorflow/python/client/device_lib_test.py | 5 |
3 files changed, 50 insertions, 6 deletions
diff --git a/tensorflow/python/client/device_lib.i b/tensorflow/python/client/device_lib.i index 51c04584a5..944e855cee 100644 --- a/tensorflow/python/client/device_lib.i +++ b/tensorflow/python/client/device_lib.i @@ -15,19 +15,39 @@ limitations under the License. %include "tensorflow/python/platform/base.i" +%typemap(in) const tensorflow::ConfigProto& (tensorflow::ConfigProto temp) { + char* c_string; + Py_ssize_t py_size; + if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { + // Python has raised an error (likely TypeError or UnicodeEncodeError). + SWIG_fail; + } + + if (!temp.ParseFromString(string(c_string, py_size))) { + PyErr_SetString( + PyExc_TypeError, + "The ConfigProto could not be parsed as a valid protocol buffer"); + SWIG_fail; + } + $1 = &temp; +} + %{ #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { namespace swig { -static std::vector<string> ListDevices(TF_Status* out_status) { +static std::vector<string> ListDevicesWithSessionConfig( + const tensorflow::ConfigProto& config, TF_Status* out_status) { std::vector<string> output; SessionOptions options; + options.config = config; std::vector<Device*> devices; Status status = DeviceFactory::AddDevices( options, "" /* name_prefix */, &devices); @@ -35,7 +55,8 @@ static std::vector<string> ListDevices(TF_Status* out_status) { Set_TF_Status_from_Status(out_status, status); } - std::vector<std::unique_ptr<Device>> device_holder(devices.begin(), devices.end()); + std::vector<std::unique_ptr<Device>> device_holder(devices.begin(), + devices.end()); for (const Device* device : devices) { const DeviceAttributes& attr = device->attributes(); @@ -53,6 +74,11 @@ static std::vector<string> ListDevices(TF_Status* out_status) { return output; } +std::vector<string> ListDevices(TF_Status* out_status) { + tensorflow::ConfigProto session_config; + return ListDevicesWithSessionConfig(session_config, out_status); +} + } // namespace swig } // namespace tensorflow @@ -62,21 +88,28 @@ static std::vector<string> ListDevices(TF_Status* out_status) { %unignore tensorflow; %unignore tensorflow::swig; +%unignore tensorflow::swig::ListDevicesWithSessionConfig; %unignore tensorflow::swig::ListDevices; // Wrap this function namespace tensorflow { namespace swig { std::vector<string> ListDevices(TF_Status* out_status); +static std::vector<string> ListDevicesWithSessionConfig( + const tensorflow::ConfigProto& config, TF_Status* out_status); } // namespace swig } // namespace tensorflow %insert("python") %{ -def list_devices(): +def list_devices(session_config=None): from tensorflow.python.framework import errors with errors.raise_exception_on_not_ok_status() as status: - return ListDevices(status) + if session_config: + return ListDevicesWithSessionConfig(session_config.SerializeToString(), + status) + else: + return ListDevices(status) %} %unignoreall diff --git a/tensorflow/python/client/device_lib.py b/tensorflow/python/client/device_lib.py index ad430cbae5..9d90d5395e 100644 --- a/tensorflow/python/client/device_lib.py +++ b/tensorflow/python/client/device_lib.py @@ -22,9 +22,12 @@ from tensorflow.core.framework import device_attributes_pb2 from tensorflow.python import pywrap_tensorflow -def list_local_devices(): +def list_local_devices(session_config=None): """List the available devices available in the local process. + Args: + session_config: a session config proto or None to use the default config. + Returns: A list of `DeviceAttribute` protocol buffers. """ @@ -33,4 +36,7 @@ def list_local_devices(): m.ParseFromString(pb_str) return m - return [_convert(s) for s in pywrap_tensorflow.list_devices()] + return [ + _convert(s) + for s in pywrap_tensorflow.list_devices(session_config=session_config) + ] diff --git a/tensorflow/python/client/device_lib_test.py b/tensorflow/python/client/device_lib_test.py index aaf41626ab..fec41f50b6 100644 --- a/tensorflow/python/client/device_lib_test.py +++ b/tensorflow/python/client/device_lib_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import device_lib from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest @@ -31,6 +32,10 @@ class DeviceLibTest(test_util.TensorFlowTestCase): self.assertGreater(len(devices), 0) self.assertEqual(devices[0].device_type, "CPU") + devices = device_lib.list_local_devices(config_pb2.ConfigProto()) + self.assertGreater(len(devices), 0) + self.assertEqual(devices[0].device_type, "CPU") + # GPU test if test.is_gpu_available(): self.assertGreater(len(devices), 1) |