aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-08-03 12:39:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 12:43:59 -0700
commit249e639d634b26d1684d0df14ab6933feeba83a3 (patch)
treef3c900664ba8a3ee8150f2bffaf49c2133d66fa2 /tensorflow/c
parent401a8d4c6562b9ab591068e4e5418d8967543ef6 (diff)
Allow setting server_def directly on TFE_Context.
Any time that the server def is updated, the context is effectively "reset" by clearing all the caches. - Check that the FLR returned is not a nullptr instead of seg faulting. - Consolidate caches within the context object. PiperOrigin-RevId: 207308086
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/eager/c_api.cc55
-rw-r--r--tensorflow/c/eager/c_api.h21
-rw-r--r--tensorflow/c/eager/c_api_internal.h18
-rw-r--r--tensorflow/c/eager/c_api_test.cc165
4 files changed, 193 insertions, 66 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 555dab3e89..a0a44440c8 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -150,8 +150,8 @@ tensorflow::Status CreateRemoteContexts(
return tensorflow::Status::OK();
}
-tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
- TFE_Context** ctx) {
+tensorflow::Status UpdateTFE_ContextWithServerDef(
+ const tensorflow::ServerDef& server_def, TFE_Context* ctx) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
@@ -165,12 +165,12 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
} \
} while (0);
- string worker_name = tensorflow::strings::StrCat(
- "/job:", opts->server_def.job_name(),
- "/replica:0/task:", opts->server_def.task_index());
+ string worker_name =
+ tensorflow::strings::StrCat("/job:", server_def.job_name(),
+ "/replica:0/task:", server_def.task_index());
std::unique_ptr<tensorflow::ServerInterface> server;
- LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server));
+ LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server));
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(server.get());
@@ -202,15 +202,15 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
// Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
- remote_workers, rendezvous_id, opts->server_def,
- remote_eager_workers.get(), opts->async, &remote_contexts));
+ remote_workers, rendezvous_id, server_def, remote_eager_workers.get(),
+ ctx->context.Async(), &remote_contexts));
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
- session_name, opts->server_def, true));
+ session_name, server_def, true));
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_RETURN_IF_ERROR(
@@ -221,10 +221,10 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
auto* device_mgr = grpc_server->worker_env()->device_mgr;
- *ctx = new TFE_Context(opts->session_options.options, opts->policy,
- opts->async, device_mgr, r, std::move(server),
- std::move(remote_eager_workers),
- std::move(remote_device_mgr), remote_contexts);
+
+ ctx->context.InitializeRemote(
+ std::move(server), std::move(remote_eager_workers),
+ std::move(remote_device_mgr), remote_contexts, r, device_mgr);
return tensorflow::Status::OK();
#undef LOG_AND_RETURN_IF_ERROR
@@ -249,15 +249,6 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
options->policy = policy;
}
-TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
- TFE_ContextOptions* options, const void* proto, size_t proto_len,
- TF_Status* status) {
- if (!options->server_def.ParseFromArray(proto, proto_len)) {
- status->status = tensorflow::errors::InvalidArgument(
- "Invalid tensorflow.ServerDef protocol buffer");
- }
-}
-
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char async,
TF_Status* status) {
@@ -267,12 +258,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
- if (!opts->server_def.job_name().empty()) {
- TFE_Context* ctx = nullptr;
- status->status = NewRemoteAwareTFE_Context(opts, &ctx);
- return ctx;
- }
-
std::vector<tensorflow::Device*> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
@@ -301,6 +286,20 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); }
+// Set server_def on the context, possibly updating it.
+TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
+ const void* proto,
+ size_t proto_len,
+ TF_Status* status) {
+ tensorflow::ServerDef server_def;
+ if (!server_def.ParseFromArray(proto, proto_len)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Invalid tensorflow.ServerDef protocol buffer");
+ return;
+ }
+ status->status = UpdateTFE_ContextWithServerDef(server_def, ctx);
+}
+
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context.SetThreadLocalDevicePlacementPolicy(
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index ea019a5711..25cf7adbc7 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -81,16 +81,6 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
-// A tensorflow.ServerDef specifies remote workers (in addition to the current
-// workers name). Operations created on this context can then be executed on
-// any of these remote workers by setting an appropriate device.
-//
-// If the following is set, all servers identified by the
-// ServerDef must be up when the context is created.
-TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
- TFE_ContextOptions* options, const void* proto, size_t proto_len,
- TF_Status* status);
-
// Destroy an options object.
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
@@ -127,6 +117,17 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*,
unsigned char async,
TF_Status* status);
+// A tensorflow.ServerDef specifies remote workers (in addition to the current
+// workers name). Operations created on this context can then be executed on
+// any of these remote workers by setting an appropriate device.
+//
+// If the following is set, all servers identified by the
+// ServerDef must be up when the context is created.
+TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
+ const void* proto,
+ size_t proto_len,
+ TF_Status* status);
+
// Causes the calling thread to block till all ops dispatched in async mode
// have been executed. Note that "execution" here refers to kernel execution /
// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 4c5077023d..a5c0681e2e 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -59,7 +59,6 @@ struct TFE_ContextOptions {
// true if async execution is enabled.
bool async = false;
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT};
- tensorflow::ServerDef server_def;
};
struct TFE_Context {
@@ -73,23 +72,6 @@ struct TFE_Context {
default_policy),
async, std::move(device_mgr), rendezvous) {}
- explicit TFE_Context(
- const tensorflow::SessionOptions& opts,
- TFE_ContextDevicePlacementPolicy default_policy, bool async,
- tensorflow::DeviceMgr* local_device_mgr,
- tensorflow::Rendezvous* rendezvous,
- std::unique_ptr<tensorflow::ServerInterface> server,
- std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers,
- std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr,
- const tensorflow::gtl::FlatMap<tensorflow::string, tensorflow::uint64>&
- remote_contexts)
- : context(opts,
- static_cast<tensorflow::ContextDevicePlacementPolicy>(
- default_policy),
- async, local_device_mgr, rendezvous, std::move(server),
- std::move(remote_eager_workers), std::move(remote_device_mgr),
- remote_contexts) {}
-
tensorflow::EagerContext context;
};
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 6f2fbee884..00a0a71fca 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -108,14 +108,14 @@ TEST(CAPI, Context) {
TF_DeleteStatus(status);
}
-tensorflow::ServerDef GetServerDef(int num_tasks) {
+tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
- server_def.set_job_name("localhost");
+ server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
- job_def->set_name("localhost");
+ job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
@@ -124,6 +124,10 @@ tensorflow::ServerDef GetServerDef(int num_tasks) {
return server_def;
}
+tensorflow::ServerDef GetServerDef(int num_tasks) {
+ return GetServerDef("localhost", num_tasks);
+}
+
void TestRemoteExecute(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
@@ -140,9 +144,6 @@ void TestRemoteExecute(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
- TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
- status);
- EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
@@ -150,6 +151,9 @@ void TestRemoteExecute(bool async) {
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
+ TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
const char remote_device_name[] =
@@ -195,8 +199,8 @@ void TestRemoteExecute(bool async) {
TFE_DeleteOp(matmul);
TFE_ContextAsyncWait(ctx, status);
- TFE_DeleteContext(ctx);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
@@ -229,15 +233,15 @@ void TestRemoteExecuteSilentCopies(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
- TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
- status);
- EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
+ TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
@@ -296,6 +300,147 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
}
+void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
+ const std::vector<float>& expected_values) {
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
+ EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
+ memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+
+ for (int i = 0; i < expected_values.size(); i++) {
+ EXPECT_EQ(expected_values[i], actual_values[i])
+ << "Mismatch in expected values at (zero-based) index " << i;
+ }
+}
+
+void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
+ const char* remote_device_name,
+ const char* local_device_name) {
+ TF_Status* status = TF_NewStatus();
+ TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
+
+ TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
+ TFE_OpSetDevice(matmul, remote_device_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ auto* retval_task0 =
+ TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
+
+ TFE_DeleteTensorHandle(retval_task0);
+ TFE_DeleteTensorHandle(h0_task0);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(matmul);
+
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+}
+
+void TestRemoteExecuteChangeServerDef(bool async) {
+ tensorflow::ServerDef server_def = GetServerDef(2);
+
+ // This server def has the task index set to 0.
+ string serialized = server_def.SerializeAsString();
+
+ server_def.set_task_index(1);
+
+ std::unique_ptr<tensorflow::GrpcServer> worker_server;
+ ASSERT_TRUE(tensorflow::GrpcServer::Create(
+ server_def, tensorflow::Env::Default(), &worker_server)
+ .ok());
+ ASSERT_TRUE(worker_server->Start().ok());
+
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+ TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ const char remote_device_name[] =
+ "/job:localhost/replica:0/task:1/device:CPU:0";
+ const char local_device_name[] =
+ "/job:localhost/replica:0/task:0/device:CPU:0";
+ CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
+
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // TODO(nareshmodi): Figure out how to correctly shut the server down.
+ worker_server.release();
+
+ // Update the server def with a new set of names (worker instead of
+ // localhost).
+ tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
+ serialized = updated_server_def.SerializeAsString();
+
+ updated_server_def.set_task_index(1);
+ tensorflow::Status s = tensorflow::GrpcServer::Create(
+ updated_server_def, tensorflow::Env::Default(), &worker_server);
+ ASSERT_TRUE(s.ok()) << s.error_message();
+ ASSERT_TRUE(worker_server->Start().ok());
+
+ TFE_ContextSetServerDef(ctx, serialized.data(), serialized.size(), status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // Create a new tensor_handle.
+ TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle();
+
+ // Check that copying it to the old remote device (named localhost) fails.
+ TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
+ EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // Copying and executing on the new remote device works.
+ const char new_remote_device_name[] =
+ "/job:worker/replica:0/task:1/device:CPU:0";
+ const char new_local_device_name[] =
+ "/job:worker/replica:0/task:0/device:CPU:0";
+
+ auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
+ h0_task0_new, ctx, new_remote_device_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_DeleteTensorHandle(h0_task0_new);
+ TFE_DeleteTensorHandle(h0_task1_new);
+
+ CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
+ new_local_device_name);
+
+ TFE_ContextAsyncWait(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_DeleteStatus(status);
+
+ TFE_DeleteContext(ctx);
+
+ // TODO(nareshmodi): Figure out how to correctly shut the server down.
+ worker_server.release();
+}
+
+TEST(CAPI, RemoteExecuteChangeServerDef) {
+ TestRemoteExecuteChangeServerDef(false);
+}
+TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
+ TestRemoteExecuteChangeServerDef(true);
+}
+
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));