aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2017-05-26 11:04:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-26 11:07:54 -0700
commit9c495f9499199ea46fff9028774374fa0c52e018 (patch)
tree9c0009cdc0cbe1eebfec9bdc7e24e271dbf0584b
parent7b401106a488eba759b6cce370393dce05d1d173 (diff)
Add session.list_devices() API
In order to debug a TensorFlow cluster or check whether devices are available in a local session (e.g. GPU drivers are loaded), this change adds a `sess.list_devices` API to list all devices within the cluster. This CL implements the list_devices() feature via extensions to the TensorFlow C API, and the corresponding additions to the session.h session class and corresponding subclasses for both direct sessions, grpc_sessions, tensorflow_serving, and others. Additionally, in order to accomidate ClusterSpec propagation clusters, Master::ListDevices now also includes a session_handle in order to identify the appropriate master_session on which it should list the available devices. (With ClusterSpec propagation, different sessions can have different servers with different device capabilities.) This CL adds a ListDevices() API to MasterSession. It is most efficient to implement this API call there, because the MasterSession already has a list of devices. Additionally, this change upgrades the implementation of Master::ListDevices() to delegate to the MasterSession if a session handle is specified, and to return an error if no corresponding session is found. PiperOrigin-RevId: 157239656
-rw-r--r--tensorflow/c/c_api.cc43
-rw-r--r--tensorflow/c/c_api.h49
-rw-r--r--tensorflow/c/c_api_internal.h5
-rw-r--r--tensorflow/contrib/session_bundle/signature_test.cc4
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc11
-rw-r--r--tensorflow/core/common_runtime/direct_session.h2
-rw-r--r--tensorflow/core/distributed_runtime/master.cc20
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc24
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.cc33
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpcbench_test.cc2
-rw-r--r--tensorflow/core/protobuf/master.proto9
-rw-r--r--tensorflow/core/public/session.h7
-rw-r--r--tensorflow/python/client/session.py70
-rw-r--r--tensorflow/python/client/session_test.py52
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-session.pbtxt4
19 files changed, 332 insertions, 15 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index f4775783f9..61089e2a50 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -718,6 +718,49 @@ TF_Buffer* TF_GetAllOpList() {
return ret;
}
+// --------------------------------------------------------------------------
+// ListDevices & SessionListDevices API
+
+void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; }
+
+TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
+ TF_DeviceList* response = new TF_DeviceList;
+ status->status = session->session->ListDevices(&response->response);
+ return response;
+}
+
+TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
+ TF_Status* status) {
+ TF_DeviceList* response = new TF_DeviceList;
+ status->status = session->session->ListDevices(&response->response);
+ return response;
+}
+
+int TF_DeviceListCount(const TF_DeviceList* list) {
+ return list->response.size();
+}
+
+#define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
+ return_type method_name(const TF_DeviceList* list, const int index, \
+ TF_Status* status) { \
+ if (list == nullptr) { \
+ status->status = InvalidArgument("list is null!"); \
+ return err_val; \
+ } \
+ if (index < 0 || index >= list->response.size()) { \
+ status->status = InvalidArgument("index out of bounds"); \
+ return err_val; \
+ } \
+ return list->response[index].accessor; \
+ }
+
+TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
+TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
+ nullptr);
+TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
+
+#undef TF_DEVICELIST_METHOD
+
} // end extern "C"
// --------------------------------------------------------------------------
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index c55bf70b9f..15139a47ac 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -1183,6 +1183,55 @@ TF_CAPI_EXPORT extern void TF_PRun(TF_DeprecatedSession*, const char* handle,
const char** target_oper_names, int ntargets,
TF_Status*);
+typedef struct TF_DeviceList TF_DeviceList;
+
+// Lists all devices in a TF_Session.
+//
+// Caller takes ownership of the returned TF_DeviceList* which must eventually
+// be freed with a call to TF_DeleteDeviceList.
+TF_CAPI_EXPORT extern TF_DeviceList* TF_SessionListDevices(TF_Session* session,
+ TF_Status* status);
+
+// Lists all devices in a TF_Session.
+//
+// Caller takes ownership of the returned TF_DeviceList* which must eventually
+// be freed with a call to TF_DeleteDeviceList.
+TF_CAPI_EXPORT extern TF_DeviceList* TF_DeprecatedSessionListDevices(
+ TF_DeprecatedSession* session, TF_Status* status);
+
+// Deallocates the device list.
+TF_CAPI_EXPORT extern void TF_DeleteDeviceList(TF_DeviceList* list);
+
+// Counts the number of elements in the device list.
+TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list);
+
+// Retrieves the full name of the device (e.g. /job:worker/replica:0/...)
+// The return value will be a pointer to a null terminated string. The caller
+// must not modify or delete the string. It will be deallocated upon a call to
+// TF_DeleteDeviceList.
+//
+// If index is out of bounds, an error code will be set in the status object,
+// and a null pointer will be returned.
+TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list,
+ int index, TF_Status*);
+
+// Retrieves the type of the device at the given index.
+//
+// The caller must not modify or delete the string. It will be deallocated upon
+// a call to TF_DeleteDeviceList.
+//
+// If index is out of bounds, an error code will be set in the status object,
+// and a null pointer will be returned.
+TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list,
+ int index, TF_Status*);
+
+// Retrieve the amount of memory associated with a given device.
+//
+// If index is out of bounds, an error code will be set in the status object,
+// and -1 will be returned.
+TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes(
+ const TF_DeviceList* list, int index, TF_Status*);
+
// --------------------------------------------------------------------------
// Load plugins containing custom ops and kernels
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index b5320d20da..f17ac26ad9 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -29,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
-
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
@@ -114,3 +113,7 @@ struct TF_Session {
struct TF_ImportGraphDefOptions {
tensorflow::ImportGraphDefOptions opts;
};
+
+struct TF_DeviceList {
+ std::vector<tensorflow::DeviceAttributes> response;
+};
diff --git a/tensorflow/contrib/session_bundle/signature_test.cc b/tensorflow/contrib/session_bundle/signature_test.cc
index 582f945538..741b7fde9b 100644
--- a/tensorflow/contrib/session_bundle/signature_test.cc
+++ b/tensorflow/contrib/session_bundle/signature_test.cc
@@ -261,6 +261,10 @@ struct MockSession : public tensorflow::Session {
return errors::Unimplemented("Not implemented for mock.");
}
+ Status ListDevices(std::vector<DeviceAttributes>* response) override {
+ return errors::Unimplemented("Not implemented for mock.");
+ }
+
// Arguments stored on a Run call.
std::vector<std::pair<string, Tensor>> inputs;
std::vector<string> output_tensor_names;
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index df99ce6f7a..d0748f98f9 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1354,6 +1354,17 @@ Status DirectSession::CreateGraphs(
return s;
}
+::tensorflow::Status DirectSession::ListDevices(
+ std::vector<DeviceAttributes>* response) {
+ response->clear();
+ response->reserve(devices_.size());
+ for (Device* d : devices_) {
+ const DeviceAttributes& attrs = d->attributes();
+ response->emplace_back(attrs);
+ }
+ return ::tensorflow::Status::OK();
+}
+
::tensorflow::Status DirectSession::Reset(
const std::vector<string>& containers) {
device_mgr_->ClearContainers(containers);
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index d4450544c3..b14a517188 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -95,6 +95,8 @@ class DirectSession : public Session {
// If 'containers' is empty, then Reset clears the default container.
::tensorflow::Status Reset(const std::vector<string>& containers);
+ ::tensorflow::Status ListDevices(
+ std::vector<DeviceAttributes>* response) override;
::tensorflow::Status Close() override;
void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index e860c99d95..1cbf30fe4b 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -524,6 +524,26 @@ void Master::CloseSession(const CloseSessionRequest* req,
void Master::ListDevices(const ListDevicesRequest* req,
ListDevicesResponse* resp, MyClosure done) {
SchedClosure([this, req, resp, done]() {
+ if (!req->session_handle().empty()) {
+ MasterSession* session = nullptr;
+ {
+ mutex_lock l(mu_);
+ session = gtl::FindPtrOrNull(sessions_, req->session_handle());
+ if (session != nullptr) {
+ session->Ref();
+ }
+ }
+ if (session == nullptr) {
+ done(errors::InvalidArgument(
+ "Session ", req->session_handle(),
+ " is not found. Possibly, this master has restarted."));
+ return;
+ }
+ core::ScopedUnref ref(session);
+ Status s = session->ListDevices(resp);
+ done(s);
+ return;
+ }
std::vector<std::unique_ptr<Device>> remote_devices;
Status s = DeviceFinder::GetRemoteDevices({}, env_, env_->worker_cache,
&remote_devices);
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index ea3dc8f2fe..a2160816fe 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -1102,6 +1102,30 @@ Status MasterSession::CreateWorkerSessions(
return status;
}
+Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
+ if (worker_cache_) {
+ // This is a ClusterSpec-propagated session, and thus env_->local_devices
+ // are invalid.
+
+ // Mark the "client_device" as the sole local device.
+ const Device* client_device = devices_->client_device();
+ for (const Device* dev : devices_->devices()) {
+ if (dev != client_device) {
+ *(resp->add_remote_device()) = dev->attributes();
+ }
+ }
+ *(resp->add_local_device()) = client_device->attributes();
+ } else {
+ for (Device* dev : env_->local_devices) {
+ *(resp->add_local_device()) = dev->attributes();
+ }
+ for (auto&& dev : *remote_devs_) {
+ *(resp->add_local_device()) = dev->attributes();
+ }
+ }
+ return Status::OK();
+}
+
Status MasterSession::Extend(const ExtendSessionRequest* req,
ExtendSessionResponse* resp) {
UpdateLastAccessTime();
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index 3acc5bc5f0..10fc4868ca 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -87,6 +87,8 @@ class MasterSession : public core::RefCounted {
Status Run(CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp);
+ Status ListDevices(ListDevicesResponse* resp) const;
+
// Close this session and delete "*this". Returns OK if all known
// states are cleanup successfully.
//
@@ -112,7 +114,7 @@ class MasterSession : public core::RefCounted {
// The optional session-specific worker cluster.
// TODO(saeta): Convert to std::optional when available.
- std::unique_ptr<WorkerCacheInterface> worker_cache_;
+ const std::unique_ptr<WorkerCacheInterface> worker_cache_;
// Retrieves either worker_cache_ or the env_->worker_cache as appropriate.
WorkerCacheInterface* get_worker_cache() const;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
index 38d59d5bb5..6d5b8e3934 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -30,8 +30,7 @@ limitations under the License.
namespace tensorflow {
GrpcSession::GrpcSession(const SessionOptions& options)
- : options_(options),
- current_graph_version_(-1) {}
+ : options_(options), current_graph_version_(-1) {}
GrpcSession::~GrpcSession() {}
@@ -321,27 +320,41 @@ Status GrpcSession::Close() {
return master_->CloseSession(&call_options, &req, &resp);
}
-std::vector<DeviceAttributes> GrpcSession::ListDevices() {
- std::vector<DeviceAttributes> devices;
-
+Status GrpcSession::ListDevices(std::vector<DeviceAttributes>* response) {
ListDevicesRequest req;
+ {
+ mutex_lock l(mu_);
+ req.set_session_handle(handle_);
+ }
+ if (req.session_handle().empty()) {
+ LOG(WARNING) << "GrpcSession::ListDevices will initialize the session with "
+ "an empty graph and other defaults because the session has "
+ "not yet been created.";
+ GraphDef graph_def;
+ TF_RETURN_IF_ERROR(Create(graph_def));
+ {
+ mutex_lock l(mu_);
+ req.set_session_handle(handle_);
+ }
+ }
ListDevicesResponse resp;
CallOptions call_options;
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
Status s = master_->ListDevices(&call_options, &req, &resp);
if (!s.ok()) {
LOG(ERROR) << "Could not list devices: " << s;
- return devices;
+ return s;
}
+ response->clear();
+ response->reserve(resp.local_device_size() + resp.remote_device_size());
for (const auto& device_attr : resp.local_device()) {
- devices.push_back(device_attr);
+ response->emplace_back(device_attr);
}
for (const auto& device_attr : resp.remote_device()) {
- devices.push_back(device_attr);
+ response->emplace_back(device_attr);
}
-
- return devices;
+ return Status::OK();
}
void GrpcSession::SetRemoteMaster(std::unique_ptr<MasterInterface> master) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
index 8fd17a053b..300f727124 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
@@ -94,7 +94,7 @@ class GrpcSession : public Session {
const std::vector<string>& output_names,
std::vector<Tensor>* outputs) override;
- std::vector<DeviceAttributes> ListDevices();
+ Status ListDevices(std::vector<DeviceAttributes>* response) override;
protected:
// Takes ownership of `*master`.
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
index bcf5181e15..c237f2dce4 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
@@ -70,7 +70,7 @@ Status TestCluster::MakeTestCluster(const SessionOptions& options, int n,
std::unique_ptr<GrpcSession> session;
TF_RETURN_IF_ERROR(GrpcSession::Create(options_copy, &session));
std::vector<DeviceAttributes> device_attributes;
- ret->devices_ = session->ListDevices();
+ TF_RETURN_IF_ERROR(session->ListDevices(&ret->devices_));
*out_cluster = std::move(ret);
return Status::OK();
diff --git a/tensorflow/core/distributed_runtime/rpcbench_test.cc b/tensorflow/core/distributed_runtime/rpcbench_test.cc
index 3d2b1894b8..8d3402e99c 100644
--- a/tensorflow/core/distributed_runtime/rpcbench_test.cc
+++ b/tensorflow/core/distributed_runtime/rpcbench_test.cc
@@ -95,7 +95,7 @@ void MakeGRPCCluster(const SessionOptions& options, int n,
options_copy.target = (*workers)[0];
std::unique_ptr<GrpcSession> session;
TF_CHECK_OK(GrpcSession::Create(options_copy, &session));
- *devices = session->ListDevices();
+ TF_CHECK_OK(session->ListDevices(devices));
}
struct Cluster {
diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto
index e607b1c42a..22bcdf0f0c 100644
--- a/tensorflow/core/protobuf/master.proto
+++ b/tensorflow/core/protobuf/master.proto
@@ -234,6 +234,15 @@ message ResetResponse {
////////////////////////////////////////////////////////////////////////////////
message ListDevicesRequest {
+ // Optional: session_handle must be returned by a CreateSession call to the
+ // same master service.
+ //
+ // When session_handle is empty, the ClusterSpec provided when the master was
+ // started is used to compute the available devices. If the session_handle is
+ // provided but not recognized, an error is returned. Finally, if a valid
+ // session_handle is provided, the cluster configuration for that session is
+ // used when computing the response.
+ string session_handle = 1;
}
message ListDevicesResponse {
diff --git a/tensorflow/core/public/session.h b/tensorflow/core/public/session.h
index eaa076ffb9..acd1482418 100644
--- a/tensorflow/core/public/session.h
+++ b/tensorflow/core/public/session.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@@ -170,6 +171,12 @@ class Session {
const std::vector<string>& output_names,
std::vector<Tensor>* outputs);
+ /// \brief List devices in the session.
+ ///
+ /// Retrieves the list of available devices within the session, and populates
+ /// *response. This API is optional. If it is unimplemented, Status will
+ /// return a corresponding error message, and *response will be unmodified.
+ virtual Status ListDevices(std::vector<DeviceAttributes>* response) = 0;
/// \brief Closes this session.
///
/// Closing a session releases the resources used by this session
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index af7e6bd40c..55f2179788 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -26,6 +26,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session
+from tensorflow.python.framework import device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -512,6 +513,39 @@ def _name_list(tensor_list):
return [compat.as_bytes(t.name) for t in tensor_list]
+class _DeviceAttributes(object):
+ """Struct-like object describing a device's attributes.
+
+ Each device has 3 key properties:
+ - name: the fully-qualified TensorFlow path to the device. For
+ example: /job:worker/replica:0/task:3/device:CPU:0
+ - device_type: the type of the device (e.g. CPU, GPU, TPU, etc.)
+ - memory_limit_bytes: the maximum amount of memory available on the device
+ (in bytes).
+ """
+
+ def __init__(self, name, device_type, memory_limit_bytes):
+ self._name = device.canonical_name(name)
+ self._device_type = device_type
+ self._memory_limit_bytes = memory_limit_bytes
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def device_type(self):
+ return self._device_type
+
+ @property
+ def memory_limit_bytes(self):
+ return self._memory_limit_bytes
+
+ def __repr__(self):
+ return '_DeviceAttributes(%s, %s, %d)' % (self.name, self.device_type,
+ self.memory_limit_bytes,)
+
+
class BaseSession(SessionInterface):
"""A class for interacting with a TensorFlow computation.
@@ -574,6 +608,42 @@ class BaseSession(SessionInterface):
finally:
tf_session.TF_DeleteSessionOptions(opts)
+ def list_devices(self):
+ """Lists available devices in this session.
+
+ ```python
+ devices = sess.list_devices()
+ for d in devices:
+ print(d.name)
+ ```
+
+ Each element in the list has the following properties:
+ - `name`: A string with the full name of the device. ex:
+ `/job:worker/replica:0/task:3/device:CPU:0`
+ - `device_type`: The type of the device (e.g. `CPU`, `GPU`, `TPU`.)
+ - `memory_limit`: The maximum amount of memory available on the device.
+ Note: depending on the device, it is possible the usable memory could
+ be substantially less.
+ Raises:
+ tf.errors.OpError: If it encounters an error (e.g. session is in an
+ invalid state, or network errors occur).
+
+ Returns:
+ A list of devices in the session.
+ """
+ with errors.raise_exception_on_not_ok_status() as status:
+ raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
+ self._session, status)
+ device_list = []
+ size = tf_session.TF_DeviceListCount(raw_device_list)
+ for i in range(size):
+ name = tf_session.TF_DeviceListName(raw_device_list, i, status)
+ device_type = tf_session.TF_DeviceListType(raw_device_list, i, status)
+ memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i, status)
+ device_list.append(_DeviceAttributes(name, device_type, memory))
+ tf_session.TF_DeleteDeviceList(raw_device_list)
+ return device_list
+
def close(self):
"""Closes this session.
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index f1bfbd5edf..c2f5763e86 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -1923,6 +1923,58 @@ class SessionTest(test_util.TensorFlowTestCase):
found_valid_nodes += 1
self.assertEqual(3, found_valid_nodes)
+ def testDeviceAttributes(self):
+ attrs = session._DeviceAttributes(
+ '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337)
+ self.assertEqual(1337, attrs.memory_limit_bytes)
+ self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name)
+ self.assertEqual('TYPE', attrs.device_type)
+ str_repr = '%s' % attrs
+ self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
+
+ def testDeviceAttributesCanonicalization(self):
+ attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1',
+ 'TYPE', 1337)
+ self.assertEqual(1337, attrs.memory_limit_bytes)
+ self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name)
+ self.assertEqual('TYPE', attrs.device_type)
+ str_repr = '%s' % attrs
+ self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
+
+ def testListDevices(self):
+ with session.Session() as sess:
+ devices = sess.list_devices()
+ self.assertTrue('/job:localhost/replica:0/task:0/device:CPU:0' in set(
+ [d.name for d in devices]), devices)
+ self.assertGreaterEqual(1, len(devices), devices)
+
+ def testListDevicesGrpcSession(self):
+ server = server_lib.Server.create_local_server()
+ with session.Session(server.target) as sess:
+ devices = sess.list_devices()
+ self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in set(
+ [d.name for d in devices]), devices)
+ self.assertGreaterEqual(1, len(devices), devices)
+
+ def testListDevicesClusterSpecPropagation(self):
+ server1 = server_lib.Server.create_local_server()
+ server2 = server_lib.Server.create_local_server()
+
+ cluster_def = cluster_pb2.ClusterDef()
+ job = cluster_def.job.add()
+ job.name = 'worker'
+ job.tasks[0] = server1.target[len('grpc://'):]
+ job.tasks[1] = server2.target[len('grpc://'):]
+ config = config_pb2.ConfigProto(cluster_def=cluster_def)
+ with session.Session(server1.target, config=config) as sess:
+ devices = sess.list_devices()
+ device_names = set([d.name for d in devices])
+ self.assertTrue(
+ '/job:worker/replica:0/task:0/device:CPU:0' in device_names)
+ self.assertTrue(
+ '/job:worker/replica:0/task:1/device:CPU:0' in device_names)
+ self.assertGreaterEqual(2, len(devices), devices)
+
class PartialRunTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt b/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt
index 9503ec440f..f5b0bae58d 100644
--- a/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt
@@ -29,6 +29,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "list_devices"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "make_callable"
argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.-session.pbtxt b/tensorflow/tools/api/golden/tensorflow.-session.pbtxt
index 5eec14f365..173cd1963e 100644
--- a/tensorflow/tools/api/golden/tensorflow.-session.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-session.pbtxt
@@ -29,6 +29,10 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "list_devices"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "make_callable"
argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], "
}