aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2017-10-02 11:10:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-02 11:14:41 -0700
commit9bfa43625061ec62bd9623ab014db4851307e92d (patch)
treecd59c3f044646b3bf9663b8dd8a4250b525e3d27 /tensorflow/core/common_runtime
parentc0644791cfc064d5e4652271e51d826aeccad0c2 (diff)
Allowing for functions to run across processes using RPC's. Currently this only works for processes running on CPU's only.
PiperOrigin-RevId: 170725482
Diffstat (limited to 'tensorflow/core/common_runtime')
-rw-r--r--tensorflow/core/common_runtime/function_test.cc2
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc67
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h33
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime_test.cc2
4 files changed, 80 insertions, 24 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index a92b245705..23d2741913 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -148,7 +148,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
device_mgr_.reset(new DeviceMgr(devices_));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
- opts));
+ opts, nullptr /* cluster_flr */));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 26ae6907bc..ca7843ee67 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -27,7 +27,9 @@ const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
- const OptimizerOptions& optimizer_options) {
+ const OptimizerOptions& optimizer_options,
+ DistributedFunctionLibraryRuntime* parent)
+ : lib_def_(lib_def), parent_(parent) {
if (device_mgr == nullptr) {
flr_map_[kDefaultFLRDevice] =
NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version,
@@ -45,11 +47,14 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
- CustomKernelCreator custom_kernel_creator) {
+ CustomKernelCreator custom_kernel_creator,
+ DistributedFunctionLibraryRuntime* parent)
+ : lib_def_(lib_def), parent_(parent) {
if (device_mgr == nullptr) {
flr_map_[kDefaultFLRDevice] = NewFunctionLibraryRuntime(
nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options,
- custom_kernel_creator, this);
+ std::move(custom_kernel_creator), this);
+ return;
}
for (Device* d : device_mgr->ListDevices()) {
flr_map_[d->name()] = NewFunctionLibraryRuntime(
@@ -58,6 +63,23 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
}
}
+ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
+ const DeviceMgr* device_mgr, Env* env, int graph_def_version,
+ const FunctionLibraryDefinition* lib_def,
+ const OptimizerOptions& optimizer_options)
+ : ProcessFunctionLibraryRuntime(device_mgr, env, graph_def_version, lib_def,
+ optimizer_options,
+ nullptr /* cluster_flr */) {}
+
+ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
+ const DeviceMgr* device_mgr, Env* env, int graph_def_version,
+ const FunctionLibraryDefinition* lib_def,
+ const OptimizerOptions& optimizer_options,
+ CustomKernelCreator custom_kernel_creator)
+ : ProcessFunctionLibraryRuntime(
+ device_mgr, env, graph_def_version, lib_def, optimizer_options,
+ std::move(custom_kernel_creator), nullptr /* cluster_flr */) {}
+
/* static */
string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
const AttrSlice& attrs) {
@@ -176,33 +198,41 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
- std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
- function_data_[handle];
- if (p.first != device_name) {
+ const FunctionData& function_data = function_data_[handle];
+ if (function_data.target_device != device_name) {
return kInvalidLocalHandle;
}
- return p.second;
+ return function_data.local_handle;
}
string ProcessFunctionLibraryRuntime::GetDeviceName(
FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
- std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
- function_data_[handle];
- return p.first;
+ const FunctionData& function_data = function_data_[handle];
+ return function_data.target_device;
}
Status ProcessFunctionLibraryRuntime::Instantiate(
const string& function_name, AttrSlice attrs,
FunctionLibraryRuntime::Handle* handle) {
+ *handle = kInvalidHandle;
string target = ObtainFunctionTarget(attrs);
FunctionLibraryRuntime* flr = GetFLR(target);
if (flr != nullptr) {
return flr->Instantiate(function_name, attrs, handle);
}
- return errors::InvalidArgument("Target: ", target, " is not supported");
+ if (parent_ == nullptr) {
+ return errors::Internal(
+ "Currently don't support instantiating functions on device: ", target);
+ }
+ FunctionLibraryRuntime::Handle cluster_handle;
+ TF_RETURN_IF_ERROR(
+ parent_->Instantiate(function_name, *lib_def_, attrs, &cluster_handle));
+ string function_key = Canonicalize(function_name, attrs);
+ *handle = AddHandle(function_key, target, cluster_handle);
+ return Status::OK();
}
void ProcessFunctionLibraryRuntime::Run(
@@ -218,14 +248,14 @@ void ProcessFunctionLibraryRuntime::Run(
FunctionLibraryRuntime* flr = nullptr;
string target_device;
+ FunctionLibraryRuntime::LocalHandle local_handle;
{
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
- std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
- function_data_[handle];
- target_device = p.first;
- flr = GetFLR(p.first);
+ target_device = function_data_[handle].target_device;
+ local_handle = function_data_[handle].local_handle;
}
+ flr = GetFLR(target_device);
if (flr != nullptr) {
auto rendezvous = opts.rendezvous;
string source_device = opts.source_device;
@@ -266,10 +296,13 @@ void ProcessFunctionLibraryRuntime::Run(
target_incarnation, num_returns, rendez_args,
rendezvous, rets, done);
});
- } else {
- done(errors::Internal("Could not find device"));
return;
}
+ if (parent_ != nullptr) {
+ parent_->Run(opts, local_handle, args, rets, done);
+ return;
+ }
+ done(errors::Internal("Could not find device"));
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 7ff1d5c7a7..9f03de0f76 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -27,8 +27,21 @@ namespace tensorflow {
class ProcessFunctionLibraryRuntime {
public:
// Creates FunctionLibraryRuntime objects for each device in the provided
- // DeviceMgr. Caller needs to make sure that device_mgr and lib_def outlive
- // this object.
+ // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent
+ // (if provided) outlive this object.
+ ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env,
+ int graph_def_version,
+ const FunctionLibraryDefinition* lib_def,
+ const OptimizerOptions& optimizer_options,
+ DistributedFunctionLibraryRuntime* parent);
+
+ ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env,
+ int graph_def_version,
+ const FunctionLibraryDefinition* lib_def,
+ const OptimizerOptions& optimizer_options,
+ CustomKernelCreator custom_kernel_creator,
+ DistributedFunctionLibraryRuntime* parent);
+
ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env,
int graph_def_version,
const FunctionLibraryDefinition* lib_def,
@@ -77,7 +90,7 @@ class ProcessFunctionLibraryRuntime {
// For a given canonicalized key signature of the function instantiated
// on device `device_name` and a `local_handle`, creates a handle and returns
- // that value. Use core/common_runtime/framework/function.h::Canonicalize
+ // that value. Uses core/common_runtime/framework/function.h::Canonicalize
// to canonicalize the function signature.
FunctionLibraryRuntime::Handle AddHandle(
const string& function_key, const string& device_name,
@@ -124,12 +137,22 @@ class ProcessFunctionLibraryRuntime {
mutable mutex mu_;
+ struct FunctionData {
+ const string target_device;
+ const FunctionLibraryRuntime::LocalHandle local_handle;
+
+ FunctionData(const string& target_device,
+ FunctionLibraryRuntime::LocalHandle local_handle)
+ : target_device(target_device), local_handle(local_handle) {}
+ };
+
+ const FunctionLibraryDefinition* lib_def_;
// Holds all the function invocations here.
std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
GUARDED_BY(mu_);
- std::vector<std::pair<string, FunctionLibraryRuntime::LocalHandle>>
- function_data_ GUARDED_BY(mu_);
+ std::vector<FunctionData> function_data_ GUARDED_BY(mu_);
std::unordered_map<string, std::unique_ptr<FunctionLibraryRuntime>> flr_map_;
+ DistributedFunctionLibraryRuntime* const parent_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index 50379a52c4..b86a7f597e 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -44,7 +44,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
OptimizerOptions opts;
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
- opts));
+ opts, nullptr /* cluster_flr */));
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
}