aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-03-29 19:56:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-29 19:59:30 -0700
commit9451c12d62b272789947f475554601295ada4083 (patch)
treeaae01bbe8c523fdeba6de2059e3f621583f7bdc1 /tensorflow/python/client
parent566f9041e19831a4eb8904654ddd365fd8f234c0 (diff)
Internal change
PiperOrigin-RevId: 191024677
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r--tensorflow/python/client/device_lib.i41
-rw-r--r--tensorflow/python/client/device_lib.py10
-rw-r--r--tensorflow/python/client/device_lib_test.py5
3 files changed, 50 insertions, 6 deletions
diff --git a/tensorflow/python/client/device_lib.i b/tensorflow/python/client/device_lib.i
index 51c04584a5..944e855cee 100644
--- a/tensorflow/python/client/device_lib.i
+++ b/tensorflow/python/client/device_lib.i
@@ -15,19 +15,39 @@ limitations under the License.
%include "tensorflow/python/platform/base.i"
+%typemap(in) const tensorflow::ConfigProto& (tensorflow::ConfigProto temp) {
+ char* c_string;
+ Py_ssize_t py_size;
+ if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
+ // Python has raised an error (likely TypeError or UnicodeEncodeError).
+ SWIG_fail;
+ }
+
+ if (!temp.ParseFromString(string(c_string, py_size))) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "The ConfigProto could not be parsed as a valid protocol buffer");
+ SWIG_fail;
+ }
+ $1 = &temp;
+}
+
%{
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace swig {
-static std::vector<string> ListDevices(TF_Status* out_status) {
+static std::vector<string> ListDevicesWithSessionConfig(
+ const tensorflow::ConfigProto& config, TF_Status* out_status) {
std::vector<string> output;
SessionOptions options;
+ options.config = config;
std::vector<Device*> devices;
Status status = DeviceFactory::AddDevices(
options, "" /* name_prefix */, &devices);
@@ -35,7 +55,8 @@ static std::vector<string> ListDevices(TF_Status* out_status) {
Set_TF_Status_from_Status(out_status, status);
}
- std::vector<std::unique_ptr<Device>> device_holder(devices.begin(), devices.end());
+ std::vector<std::unique_ptr<Device>> device_holder(devices.begin(),
+ devices.end());
for (const Device* device : devices) {
const DeviceAttributes& attr = device->attributes();
@@ -53,6 +74,11 @@ static std::vector<string> ListDevices(TF_Status* out_status) {
return output;
}
+std::vector<string> ListDevices(TF_Status* out_status) {
+ tensorflow::ConfigProto session_config;
+ return ListDevicesWithSessionConfig(session_config, out_status);
+}
+
} // namespace swig
} // namespace tensorflow
@@ -62,21 +88,28 @@ static std::vector<string> ListDevices(TF_Status* out_status) {
%unignore tensorflow;
%unignore tensorflow::swig;
+%unignore tensorflow::swig::ListDevicesWithSessionConfig;
%unignore tensorflow::swig::ListDevices;
// Wrap this function
namespace tensorflow {
namespace swig {
std::vector<string> ListDevices(TF_Status* out_status);
+static std::vector<string> ListDevicesWithSessionConfig(
+ const tensorflow::ConfigProto& config, TF_Status* out_status);
} // namespace swig
} // namespace tensorflow
%insert("python") %{
-def list_devices():
+def list_devices(session_config=None):
from tensorflow.python.framework import errors
with errors.raise_exception_on_not_ok_status() as status:
- return ListDevices(status)
+ if session_config:
+ return ListDevicesWithSessionConfig(session_config.SerializeToString(),
+ status)
+ else:
+ return ListDevices(status)
%}
%unignoreall
diff --git a/tensorflow/python/client/device_lib.py b/tensorflow/python/client/device_lib.py
index ad430cbae5..9d90d5395e 100644
--- a/tensorflow/python/client/device_lib.py
+++ b/tensorflow/python/client/device_lib.py
@@ -22,9 +22,12 @@ from tensorflow.core.framework import device_attributes_pb2
from tensorflow.python import pywrap_tensorflow
-def list_local_devices():
+def list_local_devices(session_config=None):
"""List the available devices available in the local process.
+ Args:
+ session_config: a session config proto or None to use the default config.
+
Returns:
A list of `DeviceAttribute` protocol buffers.
"""
@@ -33,4 +36,7 @@ def list_local_devices():
m.ParseFromString(pb_str)
return m
- return [_convert(s) for s in pywrap_tensorflow.list_devices()]
+ return [
+ _convert(s)
+ for s in pywrap_tensorflow.list_devices(session_config=session_config)
+ ]
diff --git a/tensorflow/python/client/device_lib_test.py b/tensorflow/python/client/device_lib_test.py
index aaf41626ab..fec41f50b6 100644
--- a/tensorflow/python/client/device_lib_test.py
+++ b/tensorflow/python/client/device_lib_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import device_lib
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
@@ -31,6 +32,10 @@ class DeviceLibTest(test_util.TensorFlowTestCase):
self.assertGreater(len(devices), 0)
self.assertEqual(devices[0].device_type, "CPU")
+ devices = device_lib.list_local_devices(config_pb2.ConfigProto())
+ self.assertGreater(len(devices), 0)
+ self.assertEqual(devices[0].device_type, "CPU")
+
# GPU test
if test.is_gpu_available():
self.assertGreater(len(devices), 1)