aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/process_function_library_runtime.cc
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2017-10-02 11:10:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-02 11:14:41 -0700
commit9bfa43625061ec62bd9623ab014db4851307e92d (patch)
treecd59c3f044646b3bf9663b8dd8a4250b525e3d27 /tensorflow/core/common_runtime/process_function_library_runtime.cc
parentc0644791cfc064d5e4652271e51d826aeccad0c2 (diff)
Allowing for functions to run across processes using RPC's. Currently this only works for processes running on CPU's only.
PiperOrigin-RevId: 170725482
Diffstat (limited to 'tensorflow/core/common_runtime/process_function_library_runtime.cc')
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc67
1 files changed, 50 insertions, 17 deletions
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 26ae6907bc..ca7843ee67 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -27,7 +27,9 @@ const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
- const OptimizerOptions& optimizer_options) {
+ const OptimizerOptions& optimizer_options,
+ DistributedFunctionLibraryRuntime* parent)
+ : lib_def_(lib_def), parent_(parent) {
if (device_mgr == nullptr) {
flr_map_[kDefaultFLRDevice] =
NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version,
@@ -45,11 +47,14 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
- CustomKernelCreator custom_kernel_creator) {
+ CustomKernelCreator custom_kernel_creator,
+ DistributedFunctionLibraryRuntime* parent)
+ : lib_def_(lib_def), parent_(parent) {
if (device_mgr == nullptr) {
flr_map_[kDefaultFLRDevice] = NewFunctionLibraryRuntime(
nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options,
- custom_kernel_creator, this);
+ std::move(custom_kernel_creator), this);
+ return;
}
for (Device* d : device_mgr->ListDevices()) {
flr_map_[d->name()] = NewFunctionLibraryRuntime(
@@ -58,6 +63,23 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
}
}
+ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
+ const DeviceMgr* device_mgr, Env* env, int graph_def_version,
+ const FunctionLibraryDefinition* lib_def,
+ const OptimizerOptions& optimizer_options)
+ : ProcessFunctionLibraryRuntime(device_mgr, env, graph_def_version, lib_def,
+ optimizer_options,
+ nullptr /* cluster_flr */) {}
+
+ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
+ const DeviceMgr* device_mgr, Env* env, int graph_def_version,
+ const FunctionLibraryDefinition* lib_def,
+ const OptimizerOptions& optimizer_options,
+ CustomKernelCreator custom_kernel_creator)
+ : ProcessFunctionLibraryRuntime(
+ device_mgr, env, graph_def_version, lib_def, optimizer_options,
+ std::move(custom_kernel_creator), nullptr /* cluster_flr */) {}
+
/* static */
string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
const AttrSlice& attrs) {
@@ -176,33 +198,41 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
- std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
- function_data_[handle];
- if (p.first != device_name) {
+ const FunctionData& function_data = function_data_[handle];
+ if (function_data.target_device != device_name) {
return kInvalidLocalHandle;
}
- return p.second;
+ return function_data.local_handle;
}
string ProcessFunctionLibraryRuntime::GetDeviceName(
FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
- std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
- function_data_[handle];
- return p.first;
+ const FunctionData& function_data = function_data_[handle];
+ return function_data.target_device;
}
Status ProcessFunctionLibraryRuntime::Instantiate(
const string& function_name, AttrSlice attrs,
FunctionLibraryRuntime::Handle* handle) {
+ *handle = kInvalidHandle;
string target = ObtainFunctionTarget(attrs);
FunctionLibraryRuntime* flr = GetFLR(target);
if (flr != nullptr) {
return flr->Instantiate(function_name, attrs, handle);
}
- return errors::InvalidArgument("Target: ", target, " is not supported");
+ if (parent_ == nullptr) {
+ return errors::Internal(
+ "Currently don't support instantiating functions on device: ", target);
+ }
+ FunctionLibraryRuntime::Handle cluster_handle;
+ TF_RETURN_IF_ERROR(
+ parent_->Instantiate(function_name, *lib_def_, attrs, &cluster_handle));
+ string function_key = Canonicalize(function_name, attrs);
+ *handle = AddHandle(function_key, target, cluster_handle);
+ return Status::OK();
}
void ProcessFunctionLibraryRuntime::Run(
@@ -218,14 +248,14 @@ void ProcessFunctionLibraryRuntime::Run(
FunctionLibraryRuntime* flr = nullptr;
string target_device;
+ FunctionLibraryRuntime::LocalHandle local_handle;
{
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
- std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
- function_data_[handle];
- target_device = p.first;
- flr = GetFLR(p.first);
+ target_device = function_data_[handle].target_device;
+ local_handle = function_data_[handle].local_handle;
}
+ flr = GetFLR(target_device);
if (flr != nullptr) {
auto rendezvous = opts.rendezvous;
string source_device = opts.source_device;
@@ -266,10 +296,13 @@ void ProcessFunctionLibraryRuntime::Run(
target_incarnation, num_returns, rendez_args,
rendezvous, rets, done);
});
- } else {
- done(errors::Internal("Could not find device"));
return;
}
+ if (parent_ != nullptr) {
+ parent_->Run(opts, local_handle, args, rets, done);
+ return;
+ }
+ done(errors::Internal("Could not find device"));
}
} // namespace tensorflow