aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-05-16 16:43:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-16 16:46:24 -0700
commit76728dbee8732054902cda929fb8821576b63509 (patch)
tree55f8594260aa841c0dbf910abfff54d93c137147 /tensorflow/c/eager
parent9c1a186f66a50345731ce6e78ac561560e349866 (diff)
Allow for remote eager execution.
PiperOrigin-RevId: 196910675
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/BUILD23
-rw-r--r--tensorflow/c/eager/c_api.cc147
-rw-r--r--tensorflow/c/eager/c_api.h10
-rw-r--r--tensorflow/c/eager/c_api_internal.h26
-rw-r--r--tensorflow/c/eager/c_api_test.cc100
5 files changed, 302 insertions, 4 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 1432119162..28f974c5d4 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -49,6 +49,17 @@ tf_cuda_library(
"//conditions:default": [],
}) + [
"//tensorflow/core/common_runtime/eager:eager_operation",
+ "//tensorflow/core/distributed_runtime/eager:eager_client",
+ "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_channel",
+ "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
+ "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime:remote_device",
+ "//tensorflow/core/distributed_runtime:server_lib",
+ "//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core:gpu_runtime",
],
)
@@ -74,6 +85,17 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
+ "//tensorflow/core/distributed_runtime:remote_device",
+ "//tensorflow/core/distributed_runtime:server_lib",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ "//tensorflow/core/distributed_runtime/eager:eager_client",
+ "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_channel",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
+ "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
+ "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
],
)
@@ -92,6 +114,7 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
],
)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 3bf071f3ab..1c1020f812 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -36,11 +36,17 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
@@ -71,6 +77,121 @@ string DeviceName(const tensorflow::Device* d) {
std::atomic_int_fast64_t func_id_generator(0);
#endif // TENSORFLOW_EAGER_USE_XLA
+tensorflow::Status GetAllRemoteDevices(
+ const std::vector<string>& remote_workers,
+ tensorflow::WorkerCacheInterface* worker_cache,
+ std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
+ std::vector<tensorflow::Device*> remote_devices;
+ tensorflow::Status status;
+ // TODO(nareshmodi) do this in parallel instead of serially.
+ for (const string& remote_worker : remote_workers) {
+ tensorflow::Notification n;
+ tensorflow::NewRemoteDevices(
+ tensorflow::Env::Default(), worker_cache, remote_worker,
+ [&status, &n, &remote_devices](
+ const tensorflow::Status& s,
+ std::vector<tensorflow::Device*>* devices) {
+ status = s;
+ if (s.ok()) {
+ for (tensorflow::Device* d : *devices) {
+ remote_devices.push_back(d);
+ }
+ }
+ n.Notify();
+ });
+ n.WaitForNotification();
+ }
+ std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
+ new tensorflow::DeviceMgr(remote_devices));
+
+ TF_RETURN_IF_ERROR(status);
+
+ *device_mgr = std::move(remote_device_mgr);
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status CreateRemoteContexts(
+ const std::vector<string>& remote_workers,
+ tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
+ tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
+ for (int i = 0; i < remote_workers.size(); i++) {
+ const string& remote_worker = remote_workers[i];
+
+ tensorflow::eager::CreateContextRequest request;
+ tensorflow::eager::CreateContextResponse response;
+ tensorflow::DeviceNameUtils::ParsedName parsed_name;
+ if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
+ &parsed_name)) {
+ return tensorflow::errors::InvalidArgument(
+ "Unable to parse ", remote_worker, " as a device name");
+ }
+ request.mutable_server_def()->set_job_name(parsed_name.job);
+ request.mutable_server_def()->set_task_index(parsed_name.task);
+ request.set_async(async);
+ auto* eager_client = remote_eager_workers->GetClient(remote_worker);
+ if (eager_client == nullptr) {
+ return tensorflow::errors::Internal(
+ "Cannot find a client for the given target:", remote_worker);
+ }
+ tensorflow::Notification n;
+ tensorflow::Status status;
+ // TODO(nareshmodi) do this in parallel instead of serially.
+ eager_client->CreateContextAsync(
+ &request, &response, [&status, &n](const tensorflow::Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ TF_RETURN_IF_ERROR(status);
+
+ remote_contexts->emplace(remote_worker, response.context_id());
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
+ TFE_Context** ctx) {
+ string worker_name = tensorflow::strings::StrCat(
+ "/job:", opts->server_def.job_name(),
+ "/replica:0/task:", opts->server_def.task_index());
+ std::unique_ptr<tensorflow::eager::EagerGrpcServer> server;
+ TF_RETURN_IF_ERROR(
+ tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server));
+
+ TF_RETURN_IF_ERROR(server->Start());
+
+ std::vector<string> remote_workers;
+ server->master_env()->worker_cache->ListWorkers(&remote_workers);
+ remote_workers.erase(
+ std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
+ remote_workers.end());
+
+ std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr;
+ TF_RETURN_IF_ERROR(GetAllRemoteDevices(
+ remote_workers, server->master_env()->worker_cache, &remote_device_mgr));
+
+ std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
+ server->channel_cache();
+ std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
+ tensorflow::eager::NewGrpcEagerClientCache(channel_cache));
+
+ // Initialize remote eager workers.
+ tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
+ TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers,
+ remote_eager_workers.get(),
+ opts->async, &remote_contexts));
+
+ tensorflow::RemoteRendezvous* r =
+ server->worker_env()->rendezvous_mgr->Find(0);
+
+ auto* device_mgr = server->worker_env()->device_mgr;
+ *ctx = new TFE_Context(opts->session_options.options, opts->policy,
+ opts->async, device_mgr, r, std::move(server),
+ std::move(remote_eager_workers),
+ std::move(remote_device_mgr), remote_contexts);
+
+ return tensorflow::Status::OK();
+}
} // namespace
extern "C" {
@@ -91,6 +212,15 @@ void TFE_ContextOptionsSetDevicePlacementPolicy(
options->policy = policy;
}
+TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
+ TFE_ContextOptions* options, const void* proto, size_t proto_len,
+ TF_Status* status) {
+ if (!options->server_def.ParseFromArray(proto, proto_len)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Invalid tensorflow.ServerDef protocol buffer");
+ }
+}
+
TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
unsigned char async,
TF_Status* status) {
@@ -100,17 +230,23 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
+ if (!opts->server_def.job_name().empty()) {
+ TFE_Context* ctx = nullptr;
+ status->status = NewRemoteAwareTFE_Context(opts, &ctx);
+ return ctx;
+ }
+
std::vector<tensorflow::Device*> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
&devices);
- if (!status->status.ok()) {
- return nullptr;
- }
+ if (!status->status.ok()) return nullptr;
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(devices));
+
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
+
return new TFE_Context(opts->session_options.options, opts->policy,
opts->async, std::move(device_mgr), r);
}
@@ -119,7 +255,10 @@ void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; }
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
- ctx->context.device_mgr()->ListDeviceAttributes(&list->response);
+ ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response);
+ if (ctx->context.remote_device_mgr()) {
+ ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response);
+ }
return list;
}
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index c06ce84a8c..574a097e0d 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -81,6 +81,16 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
+// A tensorflow.ServerDef specifies remote workers (in addition to the current
+// workers name). Operations created on this context can then be executed on
+// any of these remote workers by setting an appropriate device.
+//
+// If the following is set, all servers identified by the
+// ServerDef must be up when the context is created.
+TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef(
+ TFE_ContextOptions* options, const void* proto, size_t proto_len,
+ TF_Status* status);
+
// Destroy an options object.
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 49e1aab1ce..f506ede087 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -37,6 +37,14 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
+#include "tensorflow/core/distributed_runtime/remote_device.h"
+#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
+#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -51,6 +59,7 @@ struct TFE_ContextOptions {
// true if async execution is enabled.
bool async = false;
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT};
+ tensorflow::ServerDef server_def;
};
struct TFE_Context {
@@ -64,6 +73,23 @@ struct TFE_Context {
default_policy),
async, std::move(device_mgr), rendezvous) {}
+ explicit TFE_Context(
+ const tensorflow::SessionOptions& opts,
+ TFE_ContextDevicePlacementPolicy default_policy, bool async,
+ tensorflow::DeviceMgr* local_device_mgr,
+ tensorflow::Rendezvous* rendezvous,
+ std::unique_ptr<tensorflow::GrpcServer> server,
+ std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers,
+ std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr,
+ const tensorflow::gtl::FlatMap<tensorflow::string, tensorflow::uint64>&
+ remote_contexts)
+ : context(opts,
+ static_cast<tensorflow::ContextDevicePlacementPolicy>(
+ default_policy),
+ async, local_device_mgr, rendezvous, std::move(server),
+ std::move(remote_eager_workers), std::move(remote_device_mgr),
+ remote_contexts) {}
+
tensorflow::EagerContext context;
};
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 701175e494..49646bb735 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include <string.h>
+#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -23,7 +24,9 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
using tensorflow::string;
@@ -220,6 +223,103 @@ TEST(CAPI, Context) {
TF_DeleteStatus(status);
}
+tensorflow::ServerDef GetServerDef(int num_tasks) {
+ tensorflow::ServerDef server_def;
+ server_def.set_protocol("grpc");
+ server_def.set_job_name("localhost");
+ server_def.set_task_index(0);
+ tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
+ tensorflow::JobDef* job_def = cluster_def->add_job();
+ job_def->set_name("localhost");
+ for (int i = 0; i < num_tasks; i++) {
+ int port = tensorflow::testing::PickUnusedPortOrDie();
+ job_def->mutable_tasks()->insert(
+ {i, tensorflow::strings::StrCat("localhost:", port)});
+ }
+ return server_def;
+}
+
+void TestRemoteExecute(bool async) {
+ tensorflow::ServerDef server_def = GetServerDef(2);
+
+ // This server def has the task index set to 0.
+ string serialized = server_def.SerializeAsString();
+
+ server_def.set_task_index(1);
+
+ std::unique_ptr<tensorflow::eager::EagerGrpcServer> worker_server;
+ ASSERT_TRUE(
+ tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server)
+ .ok());
+ ASSERT_TRUE(worker_server->Start().ok());
+
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
+ status);
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
+ TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
+ const char remote_device_name[] =
+ "/job:localhost/replica:0/task:1/device:CPU:0";
+ auto* h0_task1 =
+ TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ auto* h1_task1 =
+ TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
+ TFE_OpSetDevice(matmul, remote_device_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ auto* retval_task0 = TFE_TensorHandleCopyToDevice(
+ retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteTensorHandle(retval_task0);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+
+ TFE_DeleteTensorHandle(h0_task0);
+ TFE_DeleteTensorHandle(h1_task0);
+ TFE_DeleteTensorHandle(h0_task1);
+ TFE_DeleteTensorHandle(h1_task1);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(matmul);
+
+ TFE_ContextAsyncWait(ctx, status);
+ TFE_DeleteContext(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_DeleteStatus(status);
+
+ // TODO(nareshmodi): Figure out how to correctly shut the server down.
+ worker_server.release();
+}
+
+TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
+TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
+
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));