diff options
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 55 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.h | 21 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_internal.h | 18 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_test.cc | 165 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.cc | 100 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.h | 68 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/execute.cc | 17 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop.py | 7 | ||||
-rw-r--r-- | tensorflow/python/eager/context.py | 82 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util.py | 5 | ||||
-rw-r--r-- | tensorflow/python/pywrap_tfe.i | 2 |
11 files changed, 377 insertions, 163 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 555dab3e89..a0a44440c8 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -150,8 +150,8 @@ tensorflow::Status CreateRemoteContexts( return tensorflow::Status::OK(); } -tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, - TFE_Context** ctx) { +tensorflow::Status UpdateTFE_ContextWithServerDef( + const tensorflow::ServerDef& server_def, TFE_Context* ctx) { // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the // server object (which currently CHECK-fails) and we miss the error, instead, // we log the error, and then return to allow the user to see the error @@ -165,12 +165,12 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, } \ } while (0); - string worker_name = tensorflow::strings::StrCat( - "/job:", opts->server_def.job_name(), - "/replica:0/task:", opts->server_def.task_index()); + string worker_name = + tensorflow::strings::StrCat("/job:", server_def.job_name(), + "/replica:0/task:", server_def.task_index()); std::unique_ptr<tensorflow::ServerInterface> server; - LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server)); + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server)); tensorflow::GrpcServer* grpc_server = dynamic_cast<tensorflow::GrpcServer*>(server.get()); @@ -202,15 +202,15 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, // Initialize remote eager workers. tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - remote_workers, rendezvous_id, opts->server_def, - remote_eager_workers.get(), opts->async, &remote_contexts)); + remote_workers, rendezvous_id, server_def, remote_eager_workers.get(), + ctx->context.Async(), &remote_contexts)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( - session_name, opts->server_def, true)); + session_name, server_def, true)); std::shared_ptr<tensorflow::WorkerSession> worker_session; TF_RETURN_IF_ERROR( @@ -221,10 +221,10 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); auto* device_mgr = grpc_server->worker_env()->device_mgr; - *ctx = new TFE_Context(opts->session_options.options, opts->policy, - opts->async, device_mgr, r, std::move(server), - std::move(remote_eager_workers), - std::move(remote_device_mgr), remote_contexts); + + ctx->context.InitializeRemote( + std::move(server), std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts, r, device_mgr); return tensorflow::Status::OK(); #undef LOG_AND_RETURN_IF_ERROR @@ -249,15 +249,6 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( options->policy = policy; } -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( - TFE_ContextOptions* options, const void* proto, size_t proto_len, - TF_Status* status) { - if (!options->server_def.ParseFromArray(proto, proto_len)) { - status->status = tensorflow::errors::InvalidArgument( - "Invalid tensorflow.ServerDef protocol buffer"); - } -} - TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char async, TF_Status* status) { @@ -267,12 +258,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { - if (!opts->server_def.job_name().empty()) { - TFE_Context* ctx = nullptr; - status->status = NewRemoteAwareTFE_Context(opts, &ctx); - return ctx; - } - std::vector<tensorflow::Device*> devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", @@ -301,6 +286,20 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } +// Set server_def on the context, possibly updating it. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + const void* proto, + size_t proto_len, + TF_Status* status) { + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + return; + } + status->status = UpdateTFE_ContextWithServerDef(server_def, ctx); +} + void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { ctx->context.SetThreadLocalDevicePlacementPolicy( diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index ea019a5711..25cf7adbc7 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -81,16 +81,6 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); -// A tensorflow.ServerDef specifies remote workers (in addition to the current -// workers name). Operations created on this context can then be executed on -// any of these remote workers by setting an appropriate device. -// -// If the following is set, all servers identified by the -// ServerDef must be up when the context is created. -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( - TFE_ContextOptions* options, const void* proto, size_t proto_len, - TF_Status* status); - // Destroy an options object. TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); @@ -127,6 +117,17 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, unsigned char async, TF_Status* status); +// A tensorflow.ServerDef specifies remote workers (in addition to the current +// workers name). Operations created on this context can then be executed on +// any of these remote workers by setting an appropriate device. +// +// If the following is set, all servers identified by the +// ServerDef must be up when the context is created. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + const void* proto, + size_t proto_len, + TF_Status* status); + // Causes the calling thread to block till all ops dispatched in async mode // have been executed. Note that "execution" here refers to kernel execution / // scheduling of copies, etc. Similar to sync execution, it doesn't guarantee diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 4c5077023d..a5c0681e2e 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -59,7 +59,6 @@ struct TFE_ContextOptions { // true if async execution is enabled. bool async = false; TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; - tensorflow::ServerDef server_def; }; struct TFE_Context { @@ -73,23 +72,6 @@ struct TFE_Context { default_policy), async, std::move(device_mgr), rendezvous) {} - explicit TFE_Context( - const tensorflow::SessionOptions& opts, - TFE_ContextDevicePlacementPolicy default_policy, bool async, - tensorflow::DeviceMgr* local_device_mgr, - tensorflow::Rendezvous* rendezvous, - std::unique_ptr<tensorflow::ServerInterface> server, - std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers, - std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr, - const tensorflow::gtl::FlatMap<tensorflow::string, tensorflow::uint64>& - remote_contexts) - : context(opts, - static_cast<tensorflow::ContextDevicePlacementPolicy>( - default_policy), - async, local_device_mgr, rendezvous, std::move(server), - std::move(remote_eager_workers), std::move(remote_device_mgr), - remote_contexts) {} - tensorflow::EagerContext context; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 6f2fbee884..00a0a71fca 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -108,14 +108,14 @@ TEST(CAPI, Context) { TF_DeleteStatus(status); } -tensorflow::ServerDef GetServerDef(int num_tasks) { +tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { tensorflow::ServerDef server_def; server_def.set_protocol("grpc"); - server_def.set_job_name("localhost"); + server_def.set_job_name(job_name); server_def.set_task_index(0); tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); tensorflow::JobDef* job_def = cluster_def->add_job(); - job_def->set_name("localhost"); + job_def->set_name(job_name); for (int i = 0; i < num_tasks; i++) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( @@ -124,6 +124,10 @@ tensorflow::ServerDef GetServerDef(int num_tasks) { return server_def; } +tensorflow::ServerDef GetServerDef(int num_tasks) { + return GetServerDef("localhost", num_tasks); +} + void TestRemoteExecute(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); @@ -140,9 +144,6 @@ void TestRemoteExecute(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_EXPLICIT); @@ -150,6 +151,9 @@ void TestRemoteExecute(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); const char remote_device_name[] = @@ -195,8 +199,8 @@ void TestRemoteExecute(bool async) { TFE_DeleteOp(matmul); TFE_ContextAsyncWait(ctx, status); - TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); @@ -229,15 +233,15 @@ void TestRemoteExecuteSilentCopies(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); TFE_Context* ctx = TFE_NewContext(opts, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; @@ -296,6 +300,147 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) { TestRemoteExecuteSilentCopies(true); } +void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, + const std::vector<float>& expected_values) { + std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + std::unique_ptr<float[]> actual_values(new float[expected_values.size()]); + EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t)); + memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + + for (int i = 0; i < expected_values.size(); i++) { + EXPECT_EQ(expected_values[i], actual_values[i]) + << "Mismatch in expected values at (zero-based) index " << i; + } +} + +void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, + const char* remote_device_name, + const char* local_device_name) { + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = + TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22}); + + TFE_DeleteTensorHandle(retval_task0); + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); +} + +void TestRemoteExecuteChangeServerDef(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr<tensorflow::GrpcServer> worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); + + // Update the server def with a new set of names (worker instead of + // localhost). + tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2); + serialized = updated_server_def.SerializeAsString(); + + updated_server_def.set_task_index(1); + tensorflow::Status s = tensorflow::GrpcServer::Create( + updated_server_def, tensorflow::Env::Default(), &worker_server); + ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(worker_server->Start().ok()); + + TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Create a new tensor_handle. + TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(); + + // Check that copying it to the old remote device (named localhost) fails. + TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Copying and executing on the new remote device works. + const char new_remote_device_name[] = + "/job:worker/replica:0/task:1/device:CPU:0"; + const char new_local_device_name[] = + "/job:worker/replica:0/task:0/device:CPU:0"; + + auto* h0_task1_new = TFE_TensorHandleCopyToDevice( + h0_task0_new, ctx, new_remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(h0_task0_new); + TFE_DeleteTensorHandle(h0_task1_new); + + CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name, + new_local_device_name); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + TFE_DeleteContext(ctx); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteChangeServerDef) { + TestRemoteExecuteChangeServerDef(false); +} +TEST(CAPI, RemoteExecuteChangeServerDefAsync) { + TestRemoteExecuteChangeServerDef(true); +} + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index b623ed4421..6ab2d1ebf1 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -47,6 +47,7 @@ EagerContext::EagerContext(const SessionOptions& opts, &func_lib_def_, {}, thread_pool_.get())), log_device_placement_(opts.config.log_device_placement()), async_default_(async), + env_(opts.env), use_send_tensor_rpc_(false) { InitDeviceMapAndAsync(); if (opts.config.inter_op_parallelism_threads() > 0) { @@ -58,34 +59,6 @@ EagerContext::EagerContext(const SessionOptions& opts, } } -#ifndef __ANDROID__ -EagerContext::EagerContext( - const SessionOptions& opts, ContextDevicePlacementPolicy default_policy, - bool async, DeviceMgr* local_device_mgr, Rendezvous* rendezvous, - std::unique_ptr<ServerInterface> server, - std::unique_ptr<eager::EagerClientCache> remote_eager_workers, - std::unique_ptr<DeviceMgr> remote_device_manager, - const gtl::FlatMap<string, uint64>& remote_contexts) - : policy_(default_policy), - local_unowned_device_manager_(local_device_mgr), - devices_(local_unowned_device_manager_->ListDevices()), - rendezvous_(rendezvous), - thread_pool_(NewThreadPoolFromSessionOptions(opts)), - pflr_(new ProcessFunctionLibraryRuntime( - local_unowned_device_manager_, opts.env, TF_GRAPH_DEF_VERSION, - &func_lib_def_, {}, thread_pool_.get())), - log_device_placement_(opts.config.log_device_placement()), - async_default_(async), - remote_device_manager_(std::move(remote_device_manager)), - server_(std::move(server)), - remote_eager_workers_(std::move(remote_eager_workers)), - remote_contexts_(remote_contexts), - use_send_tensor_rpc_( - ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false)) { - InitDeviceMapAndAsync(); -} -#endif - void EagerContext::InitDeviceMapAndAsync() { if (async_default_) { executor_.EnableAsync(); @@ -148,15 +121,8 @@ ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() { return policy_; } -EagerContext::~EagerContext() { #ifndef __ANDROID__ - if (server_) { - // TODO(nareshmodi): Fix this. - LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. " - "Servers don't support clean shutdown."; - server_.release(); - } - +void EagerContext::CloseRemoteContexts() { // Close all remote contexts. std::vector<eager::CloseContextRequest> requests(remote_contexts_.size()); std::vector<eager::CloseContextResponse> responses(remote_contexts_.size()); @@ -183,6 +149,19 @@ EagerContext::~EagerContext() { } counter.Wait(); +} +#endif + +EagerContext::~EagerContext() { +#ifndef __ANDROID__ + if (server_) { + // TODO(nareshmodi): Fix this. + LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. " + "Servers don't support clean shutdown."; + server_.release(); + } + + CloseRemoteContexts(); #endif executor_.WaitForAllPendingNodes().IgnoreError(); @@ -318,6 +297,55 @@ Status EagerContext::GetClientAndContextID(Device* device, return Status::OK(); } + +void EagerContext::InitializeRemote( + std::unique_ptr<ServerInterface> server, + std::unique_ptr<eager::EagerClientCache> remote_eager_workers, + std::unique_ptr<DeviceMgr> remote_device_manager, + const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r, + DeviceMgr* local_device_mgr) { + if (!remote_contexts_.empty()) { + CloseRemoteContexts(); + } + remote_contexts_ = remote_contexts; + + use_send_tensor_rpc_ = + ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false); + + local_unowned_device_manager_ = local_device_mgr; + local_device_manager_ = nullptr; + pflr_.reset(new ProcessFunctionLibraryRuntime( + local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_, + {}, thread_pool_.get())); + + devices_ = local_unowned_device_manager_->ListDevices(); + devices_map_.clear(); + + if (rendezvous_ != nullptr) rendezvous_->Unref(); + rendezvous_ = r; + + // Memory leak! + if (server_ != nullptr) { + LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. " + "Servers don't support clean shutdown."; + server_.release(); + } + + server_ = std::move(server); + remote_eager_workers_ = std::move(remote_eager_workers); + + active_remote_contexts_.clear(); + for (const auto& remote_context : remote_contexts_) { + active_remote_contexts_.insert(remote_context.second); + } + + device_to_client_cache_.clear(); + remote_device_manager_ = std::move(remote_device_manager); + + InitDeviceMapAndAsync(); + + ClearCaches(); +} #endif } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 9c8c599452..a0b612e6e5 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" @@ -68,31 +69,6 @@ class EagerContext { ContextDevicePlacementPolicy default_policy, bool async, std::unique_ptr<DeviceMgr> device_mgr, Rendezvous* rendezvous); - - // TODO(nareshmodi): Split this into 2 classes and hide functionality behind - // an interface. Alternatively, encapsulate remote state into a separate - // class/struct. - // - // Constructs an eager context that is able to communicate with remote - // workers. - // - // Additional remote-specific args are: - // - server: A ServerInterface that exports the tensorflow.WorkerService. - // Note that this class expects the server to already have been started. - // - remote_eager_workers: A cache from which we can get "EagerClient"s to - // communicate with remote eager services. - // - remote_device_mgr: A DeviceMgr* which contains all remote devices - // (should contain no local devices). - // - remote_contexts: A map containing task name to remote context ID. -#ifndef __ANDROID__ - explicit EagerContext( - const SessionOptions& opts, ContextDevicePlacementPolicy default_policy, - bool async, DeviceMgr* local_device_mgr, Rendezvous* rendezvous, - std::unique_ptr<ServerInterface> server, - std::unique_ptr<eager::EagerClientCache> remote_eager_workers, - std::unique_ptr<DeviceMgr> remote_device_manager, - const gtl::FlatMap<string, uint64>& remote_contexts); -#endif ~EagerContext(); // Returns the function library runtime for the given device. @@ -183,7 +159,31 @@ class EagerContext { Status GetClientAndContextID(Device* device, eager::EagerClient** client, uint64* context_id); + // TODO(nareshmodi): Encapsulate remote state into a separate + // class/struct. + // + // Enables the eager context to communicate with remote devices. + // + // - server: A ServerInterface that exports the tensorflow.WorkerService. + // Note that this class expects the server to already have been started. + // - remote_eager_workers: A cache from which we can get "EagerClient"s to + // communicate with remote eager services. + // - remote_device_mgr: A DeviceMgr* which contains all remote devices + // (should contain no local devices). + // - remote_contexts: A map containing task name to remote context ID. + void InitializeRemote( + std::unique_ptr<ServerInterface> server, + std::unique_ptr<eager::EagerClientCache> remote_eager_workers, + std::unique_ptr<DeviceMgr> remote_device_manager, + const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r, + DeviceMgr* local_device_mgr); + + bool HasActiveRemoteContext(uint64 context_id) { + return active_remote_contexts_.find(context_id) != + active_remote_contexts_.end(); + } #endif + // If true, then tensors should be shipped across processes via the // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used // instead (which in-turn use WorkerService.RecvTensor RPCs. @@ -203,13 +203,13 @@ class EagerContext { // Only one of the below is set. std::unique_ptr<DeviceMgr> local_device_manager_; - const DeviceMgr* local_unowned_device_manager_; + DeviceMgr* local_unowned_device_manager_; // Devices owned by device_manager std::vector<Device*> devices_; // All devices are not owned. gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_; - Rendezvous* const rendezvous_; + Rendezvous* rendezvous_; mutex functions_mu_; FunctionLibraryDefinition func_lib_def_ GUARDED_BY(functions_mu_){ @@ -220,7 +220,7 @@ class EagerContext { // One FunctionLibraryRuntime per device. // func_libs[i] is the FunctionLibraryRuntime corresponding to // session->devices[i]. - const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; std::function<void(std::function<void()>)> runner_; @@ -243,21 +243,25 @@ class EagerContext { std::unordered_map<std::thread::id, bool> thread_local_async_ GUARDED_BY(async_map_mu_); - const std::unique_ptr<DeviceMgr> remote_device_manager_; + Env* const env_; #ifndef __ANDROID__ + void CloseRemoteContexts(); + std::unique_ptr<DeviceMgr> remote_device_manager_; + // The server_ is not const since we release it when the context is destroyed. // Therefore the server_ object is not marked as const (even though it should // be). std::unique_ptr<ServerInterface> server_; - const std::unique_ptr<eager::EagerClientCache> remote_eager_workers_; + std::unique_ptr<eager::EagerClientCache> remote_eager_workers_; - const gtl::FlatMap<string, uint64> remote_contexts_; + gtl::FlatMap<string, uint64> remote_contexts_; + gtl::FlatSet<uint64> active_remote_contexts_; gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>> device_to_client_cache_; #endif - const bool use_send_tensor_rpc_; + bool use_send_tensor_rpc_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 181b222b4c..3837405e7f 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -303,8 +303,14 @@ Status EagerLocalExecute(EagerOperation* op, // See WARNING comment in Execute (before kernel->Run) - would be nice to // rework to avoid this subtlety. tf_shared_lock l(*ctx->FunctionsMu()); - status = KernelAndDevice::Init(ndef, ctx->func_lib(device), ctx->runner(), - kernel); + auto* flr = ctx->func_lib(device); + + if (flr == nullptr) { + return errors::Unavailable( + "Unable to find a FunctionLibraryRuntime corresponding to device ", + device->name()); + } + status = KernelAndDevice::Init(ndef, flr, ctx->runner(), kernel); if (!status.ok()) { delete kernel; return status; @@ -383,6 +389,13 @@ std::function<void()> GetRemoteTensorDestructor( EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id, uint64 op_id, int output_num) { return [ctx, eager_client, context_id, op_id, output_num]() { + if (!ctx->HasActiveRemoteContext(context_id)) { + // This means that this tensor was pointing to a remote device, which has + // been changed out from under us. Simply return since there is nothing we + // can do. + return tensorflow::Status::OK(); + } + std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest); request->set_context_id(context_id); diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 9cd39d02da..5f60f62874 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -591,9 +591,6 @@ def _num_elements(grad): raise ValueError("`grad` not a Tensor or IndexedSlices.") -_zeros_cache = context._TensorCache() # pylint: disable=protected-access - - def _fast_fill(value, shape, dtype): return array_ops.fill(shape, constant_op.constant(value, dtype=dtype)) @@ -611,10 +608,10 @@ def _zeros(shape, dtype): device = ctx.device_name cache_key = shape, dtype, device - cached = _zeros_cache.get(cache_key) + cached = ctx.zeros_cache().get(cache_key) if cached is None: cached = _fast_fill(0, shape, dtype) - _zeros_cache.put(cache_key, cached) + ctx.zeros_cache().put(cache_key, cached) return cached diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 495a674526..c79294895b 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -91,6 +91,7 @@ class _EagerContext(threading.local): self.summary_writer_resource = None self.scalar_cache = {} self.ones_rank_cache = _TensorCache() + self.zeros_cache = _TensorCache() self.execution_mode = None @@ -225,6 +226,24 @@ class Context(object): """ return self._rng.randint(0, _MAXINT32) + def _initialize_devices(self): + """Helper to initialize devices.""" + # Store list of devices + self._context_devices = [] + device_list = pywrap_tensorflow.TFE_ContextListDevices( + self._context_handle) + try: + self._num_gpus = 0 + for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): + dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i) + self._context_devices.append(pydev.canonical_name(dev_name)) + dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i) + if dev_type == "GPU": + self._num_gpus += 1 + + finally: + pywrap_tensorflow.TF_DeleteDeviceList(device_list) + def _initialize_handle_and_devices(self): """Initialize handle and devices.""" with self._initialize_lock: @@ -241,27 +260,48 @@ class Context(object): opts, self._device_policy) if self._execution_mode == ASYNC: pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True) - if self._server_def is not None: - server_def_str = self._server_def.SerializeToString() - pywrap_tensorflow.TFE_ContextOptionsSetServerDef(opts, server_def_str) self._context_handle = pywrap_tensorflow.TFE_NewContext(opts) finally: pywrap_tensorflow.TFE_DeleteContextOptions(opts) - # Store list of devices - self._context_devices = [] - device_list = pywrap_tensorflow.TFE_ContextListDevices( - self._context_handle) - try: - self._num_gpus = 0 - for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): - dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i) - self._context_devices.append(pydev.canonical_name(dev_name)) - dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i) - if dev_type == "GPU": - self._num_gpus += 1 + if self._server_def is not None: + server_def_str = self._server_def.SerializeToString() + pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, + server_def_str) - finally: - pywrap_tensorflow.TF_DeleteDeviceList(device_list) + self._initialize_devices() + + def _clear_caches(self): + self.scalar_cache().clear() + self.ones_rank_cache().flush() + self.zeros_cache().flush() + + def set_server_def(self, server_def): + """Allow setting a server_def on the context. + + When a server def is replaced, it effectively clears a bunch of caches + within the context. If you attempt to use a tensor object that was pointing + to a tensor on the remote device, it will raise an error. + + Args: + server_def: A tensorflow::ServerDef proto. + Enables execution on remote devices. + + Raises: + ValueError: if server_def is None. + """ + if not server_def: + raise ValueError("server_def is None.") + if not self._context_handle: + self._server_def = server_def + else: + server_def_str = server_def.SerializeToString() + pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, + server_def_str) + + # Clear all the caches in case there are remote tensors in them. + self._clear_caches() + + self._initialize_devices() @property def _handle(self): @@ -324,6 +364,10 @@ class Context(object): """Per-device cache for scalars.""" return self._eager_context.ones_rank_cache + def zeros_cache(self): + """Per-device cache for scalars.""" + return self._eager_context.zeros_cache + @property def scope_name(self): """Returns scope name for the current thread.""" @@ -735,6 +779,10 @@ def export_run_metadata(): return context().export_run_metadata() +def set_server_def(server_def): + context().set_server_def(server_def) + + # Not every user creates a Context via context.context() # (for example, enable_eager_execution in python/framework/ops.py), # but they do all import this file. Note that IS_IN_GRAPH_MODE and diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index fc47b1cca5..764e8bfacb 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -51,7 +51,6 @@ from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import device_lib from tensorflow.python.client import session -from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import tape # pylint: disable=unused-import from tensorflow.python.framework import device as pydev @@ -498,9 +497,7 @@ def assert_no_new_tensors(f): f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. - backprop._zeros_cache.flush() - context.get_default_context().ones_rank_cache().flush() - context.get_default_context().scalar_cache().clear() + context.get_default_context()._clear_caches() # pylint: disable=protected-access gc.collect() tensors_after = [ obj for obj in gc.get_objects() diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 5d7535cf34..1b69e0d06c 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -29,6 +29,7 @@ limitations under the License. %rename("%s") TFE_ContextGetDevicePlacementPolicy; %rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy; %rename("%s") TFE_ContextSetAsyncForThread; +%rename("%s") TFE_ContextSetServerDef; %rename("%s") TFE_ContextAsyncWait; %rename("%s") TFE_ContextAsyncClearError; %rename("%s") TFE_OpNameGetAttrType; @@ -59,7 +60,6 @@ limitations under the License. %rename("%s") TFE_ContextOptionsSetConfig; %rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy; %rename("%s") TFE_ContextOptionsSetAsync; -%rename("%s") TFE_ContextOptionsSetServerDef; %rename("%s") TFE_DeleteContextOptions; %rename("%s") TFE_Py_TensorShapeSlice; %rename("%s") TFE_Py_TensorShapeOnDevice; |