diff options
author | Rohan Jain <rohanj@google.com> | 2017-10-02 11:10:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-02 11:14:41 -0700 |
commit | 9bfa43625061ec62bd9623ab014db4851307e92d (patch) | |
tree | cd59c3f044646b3bf9663b8dd8a4250b525e3d27 /tensorflow/core/common_runtime/process_function_library_runtime.cc | |
parent | c0644791cfc064d5e4652271e51d826aeccad0c2 (diff) |
Allowing for functions to run across processes using RPC's. Currently this only works for processes running on CPU's only.
PiperOrigin-RevId: 170725482
Diffstat (limited to 'tensorflow/core/common_runtime/process_function_library_runtime.cc')
-rw-r--r-- | tensorflow/core/common_runtime/process_function_library_runtime.cc | 67 |
1 files changed, 50 insertions, 17 deletions
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 26ae6907bc..ca7843ee67 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -27,7 +27,9 @@ const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null"; ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options) { + const OptimizerOptions& optimizer_options, + DistributedFunctionLibraryRuntime* parent) + : lib_def_(lib_def), parent_(parent) { if (device_mgr == nullptr) { flr_map_[kDefaultFLRDevice] = NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version, @@ -45,11 +47,14 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator) { + CustomKernelCreator custom_kernel_creator, + DistributedFunctionLibraryRuntime* parent) + : lib_def_(lib_def), parent_(parent) { if (device_mgr == nullptr) { flr_map_[kDefaultFLRDevice] = NewFunctionLibraryRuntime( nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options, - custom_kernel_creator, this); + std::move(custom_kernel_creator), this); + return; } for (Device* d : device_mgr->ListDevices()) { flr_map_[d->name()] = NewFunctionLibraryRuntime( @@ -58,6 +63,23 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( } } +ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( + const DeviceMgr* device_mgr, Env* env, int graph_def_version, + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options) + : ProcessFunctionLibraryRuntime(device_mgr, env, graph_def_version, lib_def, + optimizer_options, + nullptr /* cluster_flr */) {} + +ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( + const DeviceMgr* device_mgr, Env* env, int graph_def_version, + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options, + CustomKernelCreator custom_kernel_creator) + : ProcessFunctionLibraryRuntime( + device_mgr, env, graph_def_version, lib_def, optimizer_options, + std::move(custom_kernel_creator), nullptr /* cluster_flr */) {} + /* static */ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget( const AttrSlice& attrs) { @@ -176,33 +198,41 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice( const string& device_name, FunctionLibraryRuntime::Handle handle) { mutex_lock l(mu_); CHECK_LE(handle, function_data_.size()); - std::pair<string, FunctionLibraryRuntime::LocalHandle> p = - function_data_[handle]; - if (p.first != device_name) { + const FunctionData& function_data = function_data_[handle]; + if (function_data.target_device != device_name) { return kInvalidLocalHandle; } - return p.second; + return function_data.local_handle; } string ProcessFunctionLibraryRuntime::GetDeviceName( FunctionLibraryRuntime::Handle handle) { mutex_lock l(mu_); CHECK_LE(handle, function_data_.size()); - std::pair<string, FunctionLibraryRuntime::LocalHandle> p = - function_data_[handle]; - return p.first; + const FunctionData& function_data = function_data_[handle]; + return function_data.target_device; } Status ProcessFunctionLibraryRuntime::Instantiate( const string& function_name, AttrSlice attrs, FunctionLibraryRuntime::Handle* handle) { + *handle = kInvalidHandle; string target = ObtainFunctionTarget(attrs); FunctionLibraryRuntime* flr = GetFLR(target); if (flr != nullptr) { return flr->Instantiate(function_name, attrs, handle); } - return errors::InvalidArgument("Target: ", target, " is not supported"); + if (parent_ == nullptr) { + return errors::Internal( + "Currently don't support instantiating functions on device: ", target); + } + FunctionLibraryRuntime::Handle cluster_handle; + TF_RETURN_IF_ERROR( + parent_->Instantiate(function_name, *lib_def_, attrs, &cluster_handle)); + string function_key = Canonicalize(function_name, attrs); + *handle = AddHandle(function_key, target, cluster_handle); + return Status::OK(); } void ProcessFunctionLibraryRuntime::Run( @@ -218,14 +248,14 @@ void ProcessFunctionLibraryRuntime::Run( FunctionLibraryRuntime* flr = nullptr; string target_device; + FunctionLibraryRuntime::LocalHandle local_handle; { mutex_lock l(mu_); CHECK_LE(handle, function_data_.size()); - std::pair<string, FunctionLibraryRuntime::LocalHandle> p = - function_data_[handle]; - target_device = p.first; - flr = GetFLR(p.first); + target_device = function_data_[handle].target_device; + local_handle = function_data_[handle].local_handle; } + flr = GetFLR(target_device); if (flr != nullptr) { auto rendezvous = opts.rendezvous; string source_device = opts.source_device; @@ -266,10 +296,13 @@ void ProcessFunctionLibraryRuntime::Run( target_incarnation, num_returns, rendez_args, rendezvous, rets, done); }); - } else { - done(errors::Internal("Could not find device")); return; } + if (parent_ != nullptr) { + parent_->Run(opts, local_handle, args, rets, done); + return; + } + done(errors::Internal("Could not find device")); } } // namespace tensorflow |