diff options
author | Mingxing Tan <tanmingxing@google.com> | 2018-06-22 08:35:11 -0700 |
---|---|---|
committer | Mingxing Tan <tanmingxing@google.com> | 2018-06-22 08:35:11 -0700 |
commit | d1e4e035619b249541bf6d9db10789d3af30b225 (patch) | |
tree | 533745d4e2893400937a5b5b3000614e913a9e3b /tensorflow/core/distributed_runtime | |
parent | 359f53686c87ee76e80353c32a3d22cfb1cf0989 (diff) | |
parent | fcb519a4a3d3bce0fc14dc2e46761a22b2d665a3 (diff) |
Merge commit for internal changes
Diffstat (limited to 'tensorflow/core/distributed_runtime')
4 files changed, 60 insertions, 26 deletions
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index 1a7187597d..5bcf295acd 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -62,6 +62,7 @@ cc_library( "//tensorflow/core/common_runtime/eager:execute", "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime:worker_cache", "//tensorflow/core/distributed_runtime:worker_cache_wrapper", "//tensorflow/core/distributed_runtime:worker_env", @@ -79,10 +80,12 @@ tf_cc_test( "//tensorflow/c:c_api_internal", "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", ], diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 4bd74b81a7..2fa234c810 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h" #include "tensorflow/core/distributed_runtime/worker_env.h" @@ -80,8 +81,8 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, CreateContextResponse* response) { - tensorflow::RemoteRendezvous* r = env_->rendezvous_mgr->Find(0); std::vector<tensorflow::Device*> devices; + TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( // TODO(nareshmodi): Correctly set the SessionOptions. SessionOptions(), @@ -89,7 +90,6 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, request->server_def().job_name().data(), request->server_def().task_index()), &devices)); - response->mutable_device_attributes()->Reserve(devices.size()); for (auto& d : devices) { *response->add_device_attributes() = d->attributes(); @@ -97,6 +97,19 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, std::unique_ptr<tensorflow::DeviceMgr> device_mgr( new tensorflow::DeviceMgr(devices)); + + auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id()); + auto session_name = strings::StrCat("eager_", request->rendezvous_id()); + TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession( + session_name, request->server_def(), true)); + + std::shared_ptr<WorkerSession> worker_session; + TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession( + session_name, &worker_session)); + + // Initialize remote tensor communication based on worker session. + TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); + std::unique_ptr<tensorflow::EagerContext> ctx(new tensorflow::EagerContext( SessionOptions(), tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index f865ebe1be..91b58698a4 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -20,15 +20,16 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/protobuf/eager_service.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" @@ -50,6 +51,39 @@ class TestEagerServiceImpl : public EagerServiceImpl { } }; +class EagerServiceImplTest : public ::testing::Test { + public: + EagerServiceImplTest() + : rendezvous_mgr_(&worker_env_), + session_mgr_(new SessionMgr( + &worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0", + std::unique_ptr<WorkerCacheInterface>(), + [](const ServerDef& server_def, + WorkerCacheInterface** worker_cache) { + *worker_cache = nullptr; + return Status::OK(); + })) { + worker_env_.env = Env::Default(); + + worker_env_.rendezvous_mgr = &rendezvous_mgr_; + worker_env_.session_mgr = session_mgr_.get(); + + Device* device = DeviceFactory::NewDevice( + "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"); + + worker_env_.local_devices = {device}; + + device_mgr_.reset(new DeviceMgr(worker_env_.local_devices)); + worker_env_.device_mgr = device_mgr_.get(); + } + + protected: + WorkerEnv worker_env_; + tensorflow::RpcRendezvousMgr rendezvous_mgr_; + std::unique_ptr<SessionMgr> session_mgr_; + std::unique_ptr<DeviceMgr> device_mgr_; +}; + void SetTensorProto(AttrValue* val) { int64_t dims[] = {2, 2}; float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -119,17 +153,13 @@ tensorflow::FunctionDef MatMulFunction() { } // Test creates a context and attempts to execute some ops. -TEST(EagerServiceImplTest, BasicTest) { - WorkerEnv worker_env; - worker_env.env = Env::Default(); - tensorflow::RpcRendezvousMgr rm(&worker_env); - worker_env.rendezvous_mgr = &rm; - - TestEagerServiceImpl eager_service_impl(&worker_env); +TEST_F(EagerServiceImplTest, BasicTest) { + TestEagerServiceImpl eager_service_impl(&worker_env_); CreateContextRequest request; request.mutable_server_def()->set_job_name("localhost"); request.mutable_server_def()->set_task_index(0); + request.set_rendezvous_id(random::New64()); CreateContextResponse response; TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); @@ -194,17 +224,13 @@ TEST(EagerServiceImplTest, BasicTest) { } // Test creates a context and attempts to execute a function. -TEST(EagerServiceImplTest, BasicFunctionTest) { - WorkerEnv worker_env; - worker_env.env = Env::Default(); - tensorflow::RpcRendezvousMgr rm(&worker_env); - worker_env.rendezvous_mgr = &rm; - - TestEagerServiceImpl eager_service_impl(&worker_env); +TEST_F(EagerServiceImplTest, BasicFunctionTest) { + TestEagerServiceImpl eager_service_impl(&worker_env_); CreateContextRequest request; request.mutable_server_def()->set_job_name("localhost"); request.mutable_server_def()->set_task_index(0); + request.set_rendezvous_id(random::New64()); CreateContextResponse response; TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index aa334f9424..ff64d78b79 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -269,15 +269,7 @@ Status GrpcServer::Init( LocalMaster::Register(target(), master_impl_.get(), config.operation_timeout_in_ms()); - // Generate a dummy worker session that is used to register the - // Rendezvous for eager (we use Step 0 for eager). - worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr( - "", name_prefix, - std::unique_ptr<WorkerCacheInterface>( - new WorkerCacheWrapper(master_env_.worker_cache)), - worker_env_.device_mgr, {}); - auto* r = worker_env()->rendezvous_mgr->Find(0); - return r->Initialize(worker_session_.get()); + return Status::OK(); } Status GrpcServer::Init( |