aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-08-03 12:39:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 12:43:59 -0700
commit249e639d634b26d1684d0df14ab6933feeba83a3 (patch)
treef3c900664ba8a3ee8150f2bffaf49c2133d66fa2
parent401a8d4c6562b9ab591068e4e5418d8967543ef6 (diff)
Allow setting server_def directly on TFE_Context.
Any time that the server def is updated, the context is effectively "reset" by clearing all the caches. - Check that the FLR returned is not a nullptr instead of seg faulting. - Consolidate caches within the context object. PiperOrigin-RevId: 207308086
-rw-r--r--tensorflow/c/eager/c_api.cc55
-rw-r--r--tensorflow/c/eager/c_api.h21
-rw-r--r--tensorflow/c/eager/c_api_internal.h18
-rw-r--r--tensorflow/c/eager/c_api_test.cc165
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc100
-rw-r--r--tensorflow/core/common_runtime/eager/context.h68
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc17
-rw-r--r--tensorflow/python/eager/backprop.py7
-rw-r--r--tensorflow/python/eager/context.py82
-rw-r--r--tensorflow/python/framework/test_util.py5
-rw-r--r--tensorflow/python/pywrap_tfe.i2
11 files changed, 377 insertions, 163 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 555dab3e89..a0a44440c8 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -150,8 +150,8 @@ tensorflow::Status CreateRemoteContexts(
return tensorflow::Status::OK();
}
-tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
- TFE_Context** ctx) {
+tensorflow::Status UpdateTFE_ContextWithServerDef(
+ const tensorflow::ServerDef& server_def, TFE_Context* ctx) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
@@ -165,12 +165,12 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
} \
} while (0);
- string worker_name = tensorflow::strings::StrCat(
- "/job:", opts->server_def.job_name(),
- "/replica:0/task:", opts->server_def.task_index());
+ string worker_name =
+ tensorflow::strings::StrCat("/job:", server_def.job_name(),
+ "/replica:0/task:", server_def.task_index());
std::unique_ptr<tensorflow::ServerInterface> server;
- LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server));
+ LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server));
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(server.get());
@@ -202,15 +202,15 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
// Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
- remote_workers, rendezvous_id, opts->server_def,
- remote_eager_workers.get(), opts->async, &remote_contexts));
+ remote_workers, rendezvous_id, server_def, remote_eager_workers.get(),
+ ctx->context.Async(), &remote_contexts));
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
- session_name, opts->server_def, true));
+ session_name, server_def, true));
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_RETURN_IF_ERROR(
@@ -221,10 +221,10 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
auto* device_mgr = grpc_server->worker_env()->device_mgr;
- *ctx = new TFE_Context(opts->session_options.options, opts->policy,
- opts->async, device_mgr, r, std::move(server),
- std::move(remote_eager_workers),
- std::move(remote_device_mgr), remote_contexts);
+
+ ctx->context.InitializeRemote(
+ std::move(server), std::move(remote_eager_workers),
+ std::move(remote_device_mgr), remote_contexts, r, device_mgr);
return tensorflow::Status::OK();
#undef LOG_AND_RETURN_IF_ERROR
@@ -249,15 +249,6 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
options->policy = policy;
}
-TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
- TFE_ContextOptions* options, const void* proto, size_t proto_len,
- TF_Status* status) {
- if (!options->server_def.ParseFromArray(proto, proto_len)) {
- status->status = tensorflow::errors::InvalidArgument(
- "Invalid tensorflow.ServerDef protocol buffer");
- }
-}
-
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char async,
TF_Status* status) {
@@ -267,12 +258,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
- if (!opts->server_def.job_name().empty()) {
- TFE_Context* ctx = nullptr;
- status->status = NewRemoteAwareTFE_Context(opts, &ctx);
- return ctx;
- }
-
std::vector<tensorflow::Device*> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
@@ -301,6 +286,20 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
+// Set server_def on the context, possibly updating it.
+TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
+ const void* proto,
+ size_t proto_len,
+ TF_Status* status) {
+ tensorflow::ServerDef server_def;
+ if (!server_def.ParseFromArray(proto, proto_len)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Invalid tensorflow.ServerDef protocol buffer");
+ return;
+ }
+ status->status = UpdateTFE_ContextWithServerDef(server_def, ctx);
+}
+
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context.SetThreadLocalDevicePlacementPolicy(
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index ea019a5711..25cf7adbc7 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -81,16 +81,6 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
-// A tensorflow.ServerDef specifies remote workers (in addition to the current
-// workers name). Operations created on this context can then be executed on
-// any of these remote workers by setting an appropriate device.
-//
-// If the following is set, all servers identified by the
-// ServerDef must be up when the context is created.
-TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
- TFE_ContextOptions* options, const void* proto, size_t proto_len,
- TF_Status* status);
-
// Destroy an options object.
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
@@ -127,6 +117,17 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
unsigned char async,
TF_Status* status);
+// A tensorflow.ServerDef specifies remote workers (in addition to the current
+// workers name). Operations created on this context can then be executed on
+// any of these remote workers by setting an appropriate device.
+//
+// If the following is set, all servers identified by the
+// ServerDef must be up when the context is created.
+TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
+ const void* proto,
+ size_t proto_len,
+ TF_Status* status);
+
// Causes the calling thread to block till all ops dispatched in async mode
// have been executed. Note that "execution" here refers to kernel execution /
// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 4c5077023d..a5c0681e2e 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -59,7 +59,6 @@ struct TFE_ContextOptions {
// true if async execution is enabled.
bool async = false;
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT};
- tensorflow::ServerDef server_def;
};
struct TFE_Context {
@@ -73,23 +72,6 @@ struct TFE_Context {
default_policy),
async, std::move(device_mgr), rendezvous) {}
- explicit TFE_Context(
- const tensorflow::SessionOptions& opts,
- TFE_ContextDevicePlacementPolicy default_policy, bool async,
- tensorflow::DeviceMgr* local_device_mgr,
- tensorflow::Rendezvous* rendezvous,
- std::unique_ptr<tensorflow::ServerInterface> server,
- std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers,
- std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr,
- const tensorflow::gtl::FlatMap<tensorflow::string, tensorflow::uint64>&
- remote_contexts)
- : context(opts,
- static_cast<tensorflow::ContextDevicePlacementPolicy>(
- default_policy),
- async, local_device_mgr, rendezvous, std::move(server),
- std::move(remote_eager_workers), std::move(remote_device_mgr),
- remote_contexts) {}
-
tensorflow::EagerContext context;
};
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 6f2fbee884..00a0a71fca 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -108,14 +108,14 @@ TEST(CAPI, Context) {
TF_DeleteStatus(status);
}
-tensorflow::ServerDef GetServerDef(int num_tasks) {
+tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
- server_def.set_job_name("localhost");
+ server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
- job_def->set_name("localhost");
+ job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
@@ -124,6 +124,10 @@ tensorflow::ServerDef GetServerDef(int num_tasks) {
return server_def;
}
+tensorflow::ServerDef GetServerDef(int num_tasks) {
+ return GetServerDef("localhost", num_tasks);
+}
+
void TestRemoteExecute(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
@@ -140,9 +144,6 @@ void TestRemoteExecute(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
- TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
- status);
- EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
@@ -150,6 +151,9 @@ void TestRemoteExecute(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
+ TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
const char remote_device_name[] =
@@ -195,8 +199,8 @@ void TestRemoteExecute(bool async) {
TFE_DeleteOp(matmul);
TFE_ContextAsyncWait(ctx, status);
- TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
@@ -229,15 +233,15 @@ void TestRemoteExecuteSilentCopies(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
- TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
- status);
- EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
+ TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
@@ -296,6 +300,147 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
}
+void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
+ const std::vector<float>& expected_values) {
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
+ EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
+ memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+
+ for (int i = 0; i < expected_values.size(); i++) {
+ EXPECT_EQ(expected_values[i], actual_values[i])
+ << "Mismatch in expected values at (zero-based) index " << i;
+ }
+}
+
+void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
+ const char* remote_device_name,
+ const char* local_device_name) {
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
+
+ TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
+ TFE_OpSetDevice(matmul, remote_device_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ auto* retval_task0 =
+ TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
+
+ TFE_DeleteTensorHandle(retval_task0);
+ TFE_DeleteTensorHandle(h0_task0);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(matmul);
+
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+}
+
+void TestRemoteExecuteChangeServerDef(bool async) {
+ tensorflow::ServerDef server_def = GetServerDef(2);
+
+ // This server def has the task index set to 0.
+ string serialized = server_def.SerializeAsString();
+
+ server_def.set_task_index(1);
+
+ std::unique_ptr<tensorflow::GrpcServer> worker_server;
+ ASSERT_TRUE(tensorflow::GrpcServer::Create(
+ server_def, tensorflow::Env::Default(), &worker_server)
+ .ok());
+ ASSERT_TRUE(worker_server->Start().ok());
+
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+ TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ const char remote_device_name[] =
+ "/job:localhost/replica:0/task:1/device:CPU:0";
+ const char local_device_name[] =
+ "/job:localhost/replica:0/task:0/device:CPU:0";
+ CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
+
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // TODO(nareshmodi): Figure out how to correctly shut the server down.
+ worker_server.release();
+
+ // Update the server def with a new set of names (worker instead of
+ // localhost).
+ tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
+ serialized = updated_server_def.SerializeAsString();
+
+ updated_server_def.set_task_index(1);
+ tensorflow::Status s = tensorflow::GrpcServer::Create(
+ updated_server_def, tensorflow::Env::Default(), &worker_server);
+ ASSERT_TRUE(s.ok()) << s.error_message();
+ ASSERT_TRUE(worker_server->Start().ok());
+
+ TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // Create a new tensor_handle.
+ TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle();
+
+ // Check that copying it to the old remote device (named localhost) fails.
+ TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
+ EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // Copying and executing on the new remote device works.
+ const char new_remote_device_name[] =
+ "/job:worker/replica:0/task:1/device:CPU:0";
+ const char new_local_device_name[] =
+ "/job:worker/replica:0/task:0/device:CPU:0";
+
+ auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
+ h0_task0_new, ctx, new_remote_device_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_DeleteTensorHandle(h0_task0_new);
+ TFE_DeleteTensorHandle(h0_task1_new);
+
+ CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
+ new_local_device_name);
+
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_DeleteStatus(status);
+
+ TFE_DeleteContext(ctx);
+
+ // TODO(nareshmodi): Figure out how to correctly shut the server down.
+ worker_server.release();
+}
+
+TEST(CAPI, RemoteExecuteChangeServerDef) {
+ TestRemoteExecuteChangeServerDef(false);
+}
+TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
+ TestRemoteExecuteChangeServerDef(true);
+}
+
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index b623ed4421..6ab2d1ebf1 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -47,6 +47,7 @@ EagerContext::EagerContext(const SessionOptions& opts,
&func_lib_def_, {}, thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
async_default_(async),
+ env_(opts.env),
use_send_tensor_rpc_(false) {
InitDeviceMapAndAsync();
if (opts.config.inter_op_parallelism_threads() > 0) {
@@ -58,34 +59,6 @@ EagerContext::EagerContext(const SessionOptions& opts,
}
}
-#ifndef __ANDROID__
-EagerContext::EagerContext(
- const SessionOptions& opts, ContextDevicePlacementPolicy default_policy,
- bool async, DeviceMgr* local_device_mgr, Rendezvous* rendezvous,
- std::unique_ptr<ServerInterface> server,
- std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
- std::unique_ptr<DeviceMgr> remote_device_manager,
- const gtl::FlatMap<string, uint64>& remote_contexts)
- : policy_(default_policy),
- local_unowned_device_manager_(local_device_mgr),
- devices_(local_unowned_device_manager_->ListDevices()),
- rendezvous_(rendezvous),
- thread_pool_(NewThreadPoolFromSessionOptions(opts)),
- pflr_(new ProcessFunctionLibraryRuntime(
- local_unowned_device_manager_, opts.env, TF_GRAPH_DEF_VERSION,
- &func_lib_def_, {}, thread_pool_.get())),
- log_device_placement_(opts.config.log_device_placement()),
- async_default_(async),
- remote_device_manager_(std::move(remote_device_manager)),
- server_(std::move(server)),
- remote_eager_workers_(std::move(remote_eager_workers)),
- remote_contexts_(remote_contexts),
- use_send_tensor_rpc_(
- ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false)) {
- InitDeviceMapAndAsync();
-}
-#endif
-
void EagerContext::InitDeviceMapAndAsync() {
if (async_default_) {
executor_.EnableAsync();
@@ -148,15 +121,8 @@ ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() {
return policy_;
}
-EagerContext::~EagerContext() {
#ifndef __ANDROID__
- if (server_) {
- // TODO(nareshmodi): Fix this.
- LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
- "Servers don't support clean shutdown.";
- server_.release();
- }
-
+void EagerContext::CloseRemoteContexts() {
// Close all remote contexts.
std::vector<eager::CloseContextRequest> requests(remote_contexts_.size());
std::vector<eager::CloseContextResponse> responses(remote_contexts_.size());
@@ -183,6 +149,19 @@ EagerContext::~EagerContext() {
}
counter.Wait();
+}
+#endif
+
+EagerContext::~EagerContext() {
+#ifndef __ANDROID__
+ if (server_) {
+ // TODO(nareshmodi): Fix this.
+ LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
+ "Servers don't support clean shutdown.";
+ server_.release();
+ }
+
+ CloseRemoteContexts();
#endif
executor_.WaitForAllPendingNodes().IgnoreError();
@@ -318,6 +297,55 @@ Status EagerContext::GetClientAndContextID(Device* device,
return Status::OK();
}
+
+void EagerContext::InitializeRemote(
+ std::unique_ptr<ServerInterface> server,
+ std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
+ std::unique_ptr<DeviceMgr> remote_device_manager,
+ const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
+ DeviceMgr* local_device_mgr) {
+ if (!remote_contexts_.empty()) {
+ CloseRemoteContexts();
+ }
+ remote_contexts_ = remote_contexts;
+
+ use_send_tensor_rpc_ =
+ ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false);
+
+ local_unowned_device_manager_ = local_device_mgr;
+ local_device_manager_ = nullptr;
+ pflr_.reset(new ProcessFunctionLibraryRuntime(
+ local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
+ {}, thread_pool_.get()));
+
+ devices_ = local_unowned_device_manager_->ListDevices();
+ devices_map_.clear();
+
+ if (rendezvous_ != nullptr) rendezvous_->Unref();
+ rendezvous_ = r;
+
+ // Memory leak!
+ if (server_ != nullptr) {
+ LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
+ "Servers don't support clean shutdown.";
+ server_.release();
+ }
+
+ server_ = std::move(server);
+ remote_eager_workers_ = std::move(remote_eager_workers);
+
+ active_remote_contexts_.clear();
+ for (const auto& remote_context : remote_contexts_) {
+ active_remote_contexts_.insert(remote_context.second);
+ }
+
+ device_to_client_cache_.clear();
+ remote_device_manager_ = std::move(remote_device_manager);
+
+ InitDeviceMapAndAsync();
+
+ ClearCaches();
+}
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 9c8c599452..a0b612e6e5 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
@@ -68,31 +69,6 @@ class EagerContext {
ContextDevicePlacementPolicy default_policy, bool async,
std::unique_ptr<DeviceMgr> device_mgr,
Rendezvous* rendezvous);
-
- // TODO(nareshmodi): Split this into 2 classes and hide functionality behind
- // an interface. Alternatively, encapsulate remote state into a separate
- // class/struct.
- //
- // Constructs an eager context that is able to communicate with remote
- // workers.
- //
- // Additional remote-specific args are:
- // - server: A ServerInterface that exports the tensorflow.WorkerService.
- // Note that this class expects the server to already have been started.
- // - remote_eager_workers: A cache from which we can get "EagerClient"s to
- // communicate with remote eager services.
- // - remote_device_mgr: A DeviceMgr* which contains all remote devices
- // (should contain no local devices).
- // - remote_contexts: A map containing task name to remote context ID.
-#ifndef __ANDROID__
- explicit EagerContext(
- const SessionOptions& opts, ContextDevicePlacementPolicy default_policy,
- bool async, DeviceMgr* local_device_mgr, Rendezvous* rendezvous,
- std::unique_ptr<ServerInterface> server,
- std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
- std::unique_ptr<DeviceMgr> remote_device_manager,
- const gtl::FlatMap<string, uint64>& remote_contexts);
-#endif
~EagerContext();
// Returns the function library runtime for the given device.
@@ -183,7 +159,31 @@ class EagerContext {
Status GetClientAndContextID(Device* device, eager::EagerClient** client,
uint64* context_id);
+ // TODO(nareshmodi): Encapsulate remote state into a separate
+ // class/struct.
+ //
+ // Enables the eager context to communicate with remote devices.
+ //
+ // - server: A ServerInterface that exports the tensorflow.WorkerService.
+ // Note that this class expects the server to already have been started.
+ // - remote_eager_workers: A cache from which we can get "EagerClient"s to
+ // communicate with remote eager services.
+ // - remote_device_mgr: A DeviceMgr* which contains all remote devices
+ // (should contain no local devices).
+ // - remote_contexts: A map containing task name to remote context ID.
+ void InitializeRemote(
+ std::unique_ptr<ServerInterface> server,
+ std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
+ std::unique_ptr<DeviceMgr> remote_device_manager,
+ const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
+ DeviceMgr* local_device_mgr);
+
+ bool HasActiveRemoteContext(uint64 context_id) {
+ return active_remote_contexts_.find(context_id) !=
+ active_remote_contexts_.end();
+ }
#endif
+
// If true, then tensors should be shipped across processes via the
// EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
// instead (which in-turn use WorkerService.RecvTensor RPCs.
@@ -203,13 +203,13 @@ class EagerContext {
// Only one of the below is set.
std::unique_ptr<DeviceMgr> local_device_manager_;
- const DeviceMgr* local_unowned_device_manager_;
+ DeviceMgr* local_unowned_device_manager_;
// Devices owned by device_manager
std::vector<Device*> devices_;
// All devices are not owned.
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
- Rendezvous* const rendezvous_;
+ Rendezvous* rendezvous_;
mutex functions_mu_;
FunctionLibraryDefinition func_lib_def_ GUARDED_BY(functions_mu_){
@@ -220,7 +220,7 @@ class EagerContext {
// One FunctionLibraryRuntime per device.
// func_libs[i] is the FunctionLibraryRuntime corresponding to
// session->devices[i].
- const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
std::function<void(std::function<void()>)> runner_;
@@ -243,21 +243,25 @@ class EagerContext {
std::unordered_map<std::thread::id, bool> thread_local_async_
GUARDED_BY(async_map_mu_);
- const std::unique_ptr<DeviceMgr> remote_device_manager_;
+ Env* const env_;
#ifndef __ANDROID__
+ void CloseRemoteContexts();
+ std::unique_ptr<DeviceMgr> remote_device_manager_;
+
// The server_ is not const since we release it when the context is destroyed.
// Therefore the server_ object is not marked as const (even though it should
// be).
std::unique_ptr<ServerInterface> server_;
- const std::unique_ptr<eager::EagerClientCache> remote_eager_workers_;
+ std::unique_ptr<eager::EagerClientCache> remote_eager_workers_;
- const gtl::FlatMap<string, uint64> remote_contexts_;
+ gtl::FlatMap<string, uint64> remote_contexts_;
+ gtl::FlatSet<uint64> active_remote_contexts_;
gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>>
device_to_client_cache_;
#endif
- const bool use_send_tensor_rpc_;
+ bool use_send_tensor_rpc_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 181b222b4c..3837405e7f 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -303,8 +303,14 @@ Status EagerLocalExecute(EagerOperation* op,
// See WARNING comment in Execute (before kernel->Run) - would be nice to
// rework to avoid this subtlety.
tf_shared_lock l(*ctx->FunctionsMu());
- status = KernelAndDevice::Init(ndef, ctx->func_lib(device), ctx->runner(),
- kernel);
+ auto* flr = ctx->func_lib(device);
+
+ if (flr == nullptr) {
+ return errors::Unavailable(
+ "Unable to find a FunctionLibraryRuntime corresponding to device ",
+ device->name());
+ }
+ status = KernelAndDevice::Init(ndef, flr, ctx->runner(), kernel);
if (!status.ok()) {
delete kernel;
return status;
@@ -383,6 +389,13 @@ std::function<void()> GetRemoteTensorDestructor(
EagerContext* ctx, eager::EagerClient* eager_client, uint64 context_id,
uint64 op_id, int output_num) {
return [ctx, eager_client, context_id, op_id, output_num]() {
+ if (!ctx->HasActiveRemoteContext(context_id)) {
+ // This means that this tensor was pointing to a remote device, which has
+ // been changed out from under us. Simply return since there is nothing we
+ // can do.
+ return tensorflow::Status::OK();
+ }
+
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
request->set_context_id(context_id);
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 9cd39d02da..5f60f62874 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -591,9 +591,6 @@ def _num_elements(grad):
raise ValueError("`grad` not a Tensor or IndexedSlices.")
-_zeros_cache = context._TensorCache() # pylint: disable=protected-access
-
-
def _fast_fill(value, shape, dtype):
return array_ops.fill(shape, constant_op.constant(value, dtype=dtype))
@@ -611,10 +608,10 @@ def _zeros(shape, dtype):
device = ctx.device_name
cache_key = shape, dtype, device
- cached = _zeros_cache.get(cache_key)
+ cached = ctx.zeros_cache().get(cache_key)
if cached is None:
cached = _fast_fill(0, shape, dtype)
- _zeros_cache.put(cache_key, cached)
+ ctx.zeros_cache().put(cache_key, cached)
return cached
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 495a674526..c79294895b 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -91,6 +91,7 @@ class _EagerContext(threading.local):
self.summary_writer_resource = None
self.scalar_cache = {}
self.ones_rank_cache = _TensorCache()
+ self.zeros_cache = _TensorCache()
self.execution_mode = None
@@ -225,6 +226,24 @@ class Context(object):
"""
return self._rng.randint(0, _MAXINT32)
+ def _initialize_devices(self):
+ """Helper to initialize devices."""
+ # Store list of devices
+ self._context_devices = []
+ device_list = pywrap_tensorflow.TFE_ContextListDevices(
+ self._context_handle)
+ try:
+ self._num_gpus = 0
+ for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
+ dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
+ self._context_devices.append(pydev.canonical_name(dev_name))
+ dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
+ if dev_type == "GPU":
+ self._num_gpus += 1
+
+ finally:
+ pywrap_tensorflow.TF_DeleteDeviceList(device_list)
+
def _initialize_handle_and_devices(self):
"""Initialize handle and devices."""
with self._initialize_lock:
@@ -241,27 +260,48 @@ class Context(object):
opts, self._device_policy)
if self._execution_mode == ASYNC:
pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
- if self._server_def is not None:
- server_def_str = self._server_def.SerializeToString()
- pywrap_tensorflow.TFE_ContextOptionsSetServerDef(opts, server_def_str)
self._context_handle = pywrap_tensorflow.TFE_NewContext(opts)
finally:
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
- # Store list of devices
- self._context_devices = []
- device_list = pywrap_tensorflow.TFE_ContextListDevices(
- self._context_handle)
- try:
- self._num_gpus = 0
- for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
- dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
- self._context_devices.append(pydev.canonical_name(dev_name))
- dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
- if dev_type == "GPU":
- self._num_gpus += 1
+ if self._server_def is not None:
+ server_def_str = self._server_def.SerializeToString()
+ pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle,
+ server_def_str)
- finally:
- pywrap_tensorflow.TF_DeleteDeviceList(device_list)
+ self._initialize_devices()
+
+ def _clear_caches(self):
+ self.scalar_cache().clear()
+ self.ones_rank_cache().flush()
+ self.zeros_cache().flush()
+
+ def set_server_def(self, server_def):
+ """Allow setting a server_def on the context.
+
+ When a server def is replaced, it effectively clears a bunch of caches
+ within the context. If you attempt to use a tensor object that was pointing
+ to a tensor on the remote device, it will raise an error.
+
+ Args:
+ server_def: A tensorflow::ServerDef proto.
+ Enables execution on remote devices.
+
+ Raises:
+ ValueError: if server_def is None.
+ """
+ if not server_def:
+ raise ValueError("server_def is None.")
+ if not self._context_handle:
+ self._server_def = server_def
+ else:
+ server_def_str = server_def.SerializeToString()
+ pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle,
+ server_def_str)
+
+ # Clear all the caches in case there are remote tensors in them.
+ self._clear_caches()
+
+ self._initialize_devices()
@property
def _handle(self):
@@ -324,6 +364,10 @@ class Context(object):
"""Per-device cache for scalars."""
return self._eager_context.ones_rank_cache
+ def zeros_cache(self):
+ """Per-device cache for scalars."""
+ return self._eager_context.zeros_cache
+
@property
def scope_name(self):
"""Returns scope name for the current thread."""
@@ -735,6 +779,10 @@ def export_run_metadata():
return context().export_run_metadata()
+def set_server_def(server_def):
+ context().set_server_def(server_def)
+
+
# Not every user creates a Context via context.context()
# (for example, enable_eager_execution in python/framework/ops.py),
# but they do all import this file. Note that IS_IN_GRAPH_MODE and
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index fc47b1cca5..764e8bfacb 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -51,7 +51,6 @@ from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
-from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape # pylint: disable=unused-import
from tensorflow.python.framework import device as pydev
@@ -498,9 +497,7 @@ def assert_no_new_tensors(f):
f(self, **kwargs)
# Make an effort to clear caches, which would otherwise look like leaked
# Tensors.
- backprop._zeros_cache.flush()
- context.get_default_context().ones_rank_cache().flush()
- context.get_default_context().scalar_cache().clear()
+ context.get_default_context()._clear_caches() # pylint: disable=protected-access
gc.collect()
tensors_after = [
obj for obj in gc.get_objects()
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 5d7535cf34..1b69e0d06c 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -29,6 +29,7 @@ limitations under the License.
%rename("%s") TFE_ContextGetDevicePlacementPolicy;
%rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy;
%rename("%s") TFE_ContextSetAsyncForThread;
+%rename("%s") TFE_ContextSetServerDef;
%rename("%s") TFE_ContextAsyncWait;
%rename("%s") TFE_ContextAsyncClearError;
%rename("%s") TFE_OpNameGetAttrType;
@@ -59,7 +60,6 @@ limitations under the License.
%rename("%s") TFE_ContextOptionsSetConfig;
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
%rename("%s") TFE_ContextOptionsSetAsync;
-%rename("%s") TFE_ContextOptionsSetServerDef;
%rename("%s") TFE_DeleteContextOptions;
%rename("%s") TFE_Py_TensorShapeSlice;
%rename("%s") TFE_Py_TensorShapeOnDevice;