aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar Mingxing Tan <tanmingxing@google.com>2018-06-22 08:35:11 -0700
committerGravatar Mingxing Tan <tanmingxing@google.com>2018-06-22 08:35:11 -0700
commitd1e4e035619b249541bf6d9db10789d3af30b225 (patch)
tree533745d4e2893400937a5b5b3000614e913a9e3b /tensorflow/core/distributed_runtime
parent359f53686c87ee76e80353c32a3d22cfb1cf0989 (diff)
parentfcb519a4a3d3bce0fc14dc2e46761a22b2d665a3 (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/eager/BUILD3
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.cc17
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc56
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc10
4 files changed, 60 insertions, 26 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD
index 1a7187597d..5bcf295acd 100644
--- a/tensorflow/core/distributed_runtime/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/eager/BUILD
@@ -62,6 +62,7 @@ cc_library(
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:server_lib",
+ "//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_cache_wrapper",
"//tensorflow/core/distributed_runtime:worker_env",
@@ -79,10 +80,12 @@ tf_cc_test(
"//tensorflow/c:c_api_internal",
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:tensor_handle",
+ "//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
],
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 4bd74b81a7..2fa234c810 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
@@ -80,8 +81,8 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
CreateContextResponse* response) {
- tensorflow::RemoteRendezvous* r = env_->rendezvous_mgr->Find(0);
std::vector<tensorflow::Device*> devices;
+
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
// TODO(nareshmodi): Correctly set the SessionOptions.
SessionOptions(),
@@ -89,7 +90,6 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
request->server_def().job_name().data(),
request->server_def().task_index()),
&devices));
-
response->mutable_device_attributes()->Reserve(devices.size());
for (auto& d : devices) {
*response->add_device_attributes() = d->attributes();
@@ -97,6 +97,19 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(devices));
+
+ auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
+ auto session_name = strings::StrCat("eager_", request->rendezvous_id());
+ TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
+ session_name, request->server_def(), true));
+
+ std::shared_ptr<WorkerSession> worker_session;
+ TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
+ session_name, &worker_session));
+
+ // Initialize remote tensor communication based on worker session.
+ TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
+
std::unique_ptr<tensorflow::EagerContext> ctx(new tensorflow::EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index f865ebe1be..91b58698a4 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -20,15 +20,16 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/eager_service.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@@ -50,6 +51,39 @@ class TestEagerServiceImpl : public EagerServiceImpl {
}
};
+class EagerServiceImplTest : public ::testing::Test {
+ public:
+ EagerServiceImplTest()
+ : rendezvous_mgr_(&worker_env_),
+ session_mgr_(new SessionMgr(
+ &worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0",
+ std::unique_ptr<WorkerCacheInterface>(),
+ [](const ServerDef& server_def,
+ WorkerCacheInterface** worker_cache) {
+ *worker_cache = nullptr;
+ return Status::OK();
+ })) {
+ worker_env_.env = Env::Default();
+
+ worker_env_.rendezvous_mgr = &rendezvous_mgr_;
+ worker_env_.session_mgr = session_mgr_.get();
+
+ Device* device = DeviceFactory::NewDevice(
+ "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0");
+
+ worker_env_.local_devices = {device};
+
+ device_mgr_.reset(new DeviceMgr(worker_env_.local_devices));
+ worker_env_.device_mgr = device_mgr_.get();
+ }
+
+ protected:
+ WorkerEnv worker_env_;
+ tensorflow::RpcRendezvousMgr rendezvous_mgr_;
+ std::unique_ptr<SessionMgr> session_mgr_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+};
+
void SetTensorProto(AttrValue* val) {
int64_t dims[] = {2, 2};
float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
@@ -119,17 +153,13 @@ tensorflow::FunctionDef MatMulFunction() {
}
// Test creates a context and attempts to execute some ops.
-TEST(EagerServiceImplTest, BasicTest) {
- WorkerEnv worker_env;
- worker_env.env = Env::Default();
- tensorflow::RpcRendezvousMgr rm(&worker_env);
- worker_env.rendezvous_mgr = &rm;
-
- TestEagerServiceImpl eager_service_impl(&worker_env);
+TEST_F(EagerServiceImplTest, BasicTest) {
+ TestEagerServiceImpl eager_service_impl(&worker_env_);
CreateContextRequest request;
request.mutable_server_def()->set_job_name("localhost");
request.mutable_server_def()->set_task_index(0);
+ request.set_rendezvous_id(random::New64());
CreateContextResponse response;
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
@@ -194,17 +224,13 @@ TEST(EagerServiceImplTest, BasicTest) {
}
// Test creates a context and attempts to execute a function.
-TEST(EagerServiceImplTest, BasicFunctionTest) {
- WorkerEnv worker_env;
- worker_env.env = Env::Default();
- tensorflow::RpcRendezvousMgr rm(&worker_env);
- worker_env.rendezvous_mgr = &rm;
-
- TestEagerServiceImpl eager_service_impl(&worker_env);
+TEST_F(EagerServiceImplTest, BasicFunctionTest) {
+ TestEagerServiceImpl eager_service_impl(&worker_env_);
CreateContextRequest request;
request.mutable_server_def()->set_job_name("localhost");
request.mutable_server_def()->set_task_index(0);
+ request.set_rendezvous_id(random::New64());
CreateContextResponse response;
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index aa334f9424..ff64d78b79 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -269,15 +269,7 @@ Status GrpcServer::Init(
LocalMaster::Register(target(), master_impl_.get(),
config.operation_timeout_in_ms());
- // Generate a dummy worker session that is used to register the
- // Rendezvous for eager (we use Step 0 for eager).
- worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr(
- "", name_prefix,
- std::unique_ptr<WorkerCacheInterface>(
- new WorkerCacheWrapper(master_env_.worker_cache)),
- worker_env_.device_mgr, {});
- auto* r = worker_env()->rendezvous_mgr->Find(0);
- return r->Initialize(worker_session_.get());
+ return Status::OK();
}
Status GrpcServer::Init(