diff options
author | 2017-10-02 11:10:44 -0700 | |
---|---|---|
committer | 2017-10-02 11:14:41 -0700 | |
commit | 9bfa43625061ec62bd9623ab014db4851307e92d (patch) | |
tree | cd59c3f044646b3bf9663b8dd8a4250b525e3d27 /tensorflow/core/common_runtime | |
parent | c0644791cfc064d5e4652271e51d826aeccad0c2 (diff) |
Allowing for functions to run across processes using RPC's. Currently this only works for processes running on CPU's only.
PiperOrigin-RevId: 170725482
Diffstat (limited to 'tensorflow/core/common_runtime')
4 files changed, 80 insertions, 24 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index a92b245705..23d2741913 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -148,7 +148,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { device_mgr_.reset(new DeviceMgr(devices_)); pflr_.reset(new ProcessFunctionLibraryRuntime( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts)); + opts, nullptr /* cluster_flr */)); flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1"); flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2"); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 26ae6907bc..ca7843ee67 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -27,7 +27,9 @@ const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null"; ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options) { + const OptimizerOptions& optimizer_options, + DistributedFunctionLibraryRuntime* parent) + : lib_def_(lib_def), parent_(parent) { if (device_mgr == nullptr) { flr_map_[kDefaultFLRDevice] = NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version, @@ -45,11 +47,14 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator) { + CustomKernelCreator custom_kernel_creator, + DistributedFunctionLibraryRuntime* parent) + : lib_def_(lib_def), parent_(parent) { if (device_mgr == nullptr) { flr_map_[kDefaultFLRDevice] = NewFunctionLibraryRuntime( nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options, - custom_kernel_creator, this); + std::move(custom_kernel_creator), this); + return; } for (Device* d : device_mgr->ListDevices()) { flr_map_[d->name()] = NewFunctionLibraryRuntime( @@ -58,6 +63,23 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( } } +ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( + const DeviceMgr* device_mgr, Env* env, int graph_def_version, + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options) + : ProcessFunctionLibraryRuntime(device_mgr, env, graph_def_version, lib_def, + optimizer_options, + nullptr /* cluster_flr */) {} + +ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( + const DeviceMgr* device_mgr, Env* env, int graph_def_version, + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options, + CustomKernelCreator custom_kernel_creator) + : ProcessFunctionLibraryRuntime( + device_mgr, env, graph_def_version, lib_def, optimizer_options, + std::move(custom_kernel_creator), nullptr /* cluster_flr */) {} + /* static */ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget( const AttrSlice& attrs) { @@ -176,33 +198,41 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice( const string& device_name, FunctionLibraryRuntime::Handle handle) { mutex_lock l(mu_); CHECK_LE(handle, function_data_.size()); - std::pair<string, FunctionLibraryRuntime::LocalHandle> p = - function_data_[handle]; - if (p.first != device_name) { + const FunctionData& function_data = function_data_[handle]; + if (function_data.target_device != device_name) { return kInvalidLocalHandle; } - return p.second; + return function_data.local_handle; } string ProcessFunctionLibraryRuntime::GetDeviceName( FunctionLibraryRuntime::Handle handle) { mutex_lock l(mu_); CHECK_LE(handle, function_data_.size()); - std::pair<string, FunctionLibraryRuntime::LocalHandle> p = - function_data_[handle]; - return p.first; + const FunctionData& function_data = function_data_[handle]; + return function_data.target_device; } Status ProcessFunctionLibraryRuntime::Instantiate( const string& function_name, AttrSlice attrs, FunctionLibraryRuntime::Handle* handle) { + *handle = kInvalidHandle; string target = ObtainFunctionTarget(attrs); FunctionLibraryRuntime* flr = GetFLR(target); if (flr != nullptr) { return flr->Instantiate(function_name, attrs, handle); } - return errors::InvalidArgument("Target: ", target, " is not supported"); + if (parent_ == nullptr) { + return errors::Internal( + "Currently don't support instantiating functions on device: ", target); + } + FunctionLibraryRuntime::Handle cluster_handle; + TF_RETURN_IF_ERROR( + parent_->Instantiate(function_name, *lib_def_, attrs, &cluster_handle)); + string function_key = Canonicalize(function_name, attrs); + *handle = AddHandle(function_key, target, cluster_handle); + return Status::OK(); } void ProcessFunctionLibraryRuntime::Run( @@ -218,14 +248,14 @@ void ProcessFunctionLibraryRuntime::Run( FunctionLibraryRuntime* flr = nullptr; string target_device; + FunctionLibraryRuntime::LocalHandle local_handle; { mutex_lock l(mu_); CHECK_LE(handle, function_data_.size()); - std::pair<string, FunctionLibraryRuntime::LocalHandle> p = - function_data_[handle]; - target_device = p.first; - flr = GetFLR(p.first); + target_device = function_data_[handle].target_device; + local_handle = function_data_[handle].local_handle; } + flr = GetFLR(target_device); if (flr != nullptr) { auto rendezvous = opts.rendezvous; string source_device = opts.source_device; @@ -266,10 +296,13 @@ void ProcessFunctionLibraryRuntime::Run( target_incarnation, num_returns, rendez_args, rendezvous, rets, done); }); - } else { - done(errors::Internal("Could not find device")); return; } + if (parent_ != nullptr) { + parent_->Run(opts, local_handle, args, rets, done); + return; + } + done(errors::Internal("Could not find device")); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 7ff1d5c7a7..9f03de0f76 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -27,8 +27,21 @@ namespace tensorflow { class ProcessFunctionLibraryRuntime { public: // Creates FunctionLibraryRuntime objects for each device in the provided - // DeviceMgr. Caller needs to make sure that device_mgr and lib_def outlive - // this object. + // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent + // (if provided) outlive this object. + ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, + int graph_def_version, + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options, + DistributedFunctionLibraryRuntime* parent); + + ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, + int graph_def_version, + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options, + CustomKernelCreator custom_kernel_creator, + DistributedFunctionLibraryRuntime* parent); + ProcessFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, @@ -77,7 +90,7 @@ class ProcessFunctionLibraryRuntime { // For a given canonicalized key signature of the function instantiated // on device `device_name` and a `local_handle`, creates a handle and returns - // that value. Use core/common_runtime/framework/function.h::Canonicalize + // that value. Uses core/common_runtime/framework/function.h::Canonicalize // to canonicalize the function signature. FunctionLibraryRuntime::Handle AddHandle( const string& function_key, const string& device_name, @@ -124,12 +137,22 @@ class ProcessFunctionLibraryRuntime { mutable mutex mu_; + struct FunctionData { + const string target_device; + const FunctionLibraryRuntime::LocalHandle local_handle; + + FunctionData(const string& target_device, + FunctionLibraryRuntime::LocalHandle local_handle) + : target_device(target_device), local_handle(local_handle) {} + }; + + const FunctionLibraryDefinition* lib_def_; // Holds all the function invocations here. std::unordered_map<string, FunctionLibraryRuntime::Handle> table_ GUARDED_BY(mu_); - std::vector<std::pair<string, FunctionLibraryRuntime::LocalHandle>> - function_data_ GUARDED_BY(mu_); + std::vector<FunctionData> function_data_ GUARDED_BY(mu_); std::unordered_map<string, std::unique_ptr<FunctionLibraryRuntime>> flr_map_; + DistributedFunctionLibraryRuntime* const parent_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 50379a52c4..b86a7f597e 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -44,7 +44,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { OptimizerOptions opts; proc_flr_.reset(new ProcessFunctionLibraryRuntime( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts)); + opts, nullptr /* cluster_flr */)); rendezvous_ = new IntraProcessRendezvous(device_mgr_.get()); } |