From d56eface20da6adf5a12507053c16ef22594739b Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Thu, 8 Mar 2018 16:45:45 -0800 Subject: Fixes a bug where the ProcFLR doesn't lookup existing instantiations in the distributed (ClusterFLR) case. As a result multiple instantiations for the same function were happening. PiperOrigin-RevId: 188411978 --- tensorflow/core/BUILD | 1 + .../process_function_library_runtime.cc | 55 ++++++++++--- .../process_function_library_runtime.h | 32 ++++++-- .../process_function_library_runtime_test.cc | 94 +++++++++++++++++++++- 4 files changed, 160 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 0fbe4eba6e..f2b0d542dd 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -3156,6 +3156,7 @@ tf_cc_test( ":core_cpu", ":core_cpu_internal", ":framework", + ":lib", ":test", ":test_main", ":testlib", diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 929f5c67bc..44dc6f9459 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -25,6 +25,19 @@ namespace tensorflow { const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null"; +Status ProcessFunctionLibraryRuntime::FunctionData::DistributedInit( + DistributedFunctionLibraryRuntime* parent, const string& function_name, + const FunctionLibraryDefinition& lib_def, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options) { + mutex_lock l(mu_); + if (!init_started_) { + init_started_ = true; + init_result_ = parent->Instantiate(function_name, lib_def, attrs, options, + &local_handle_); + } + return init_result_; +} + ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, @@ -167,7 +180,8 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle( if (function_data_.count(h) != 0) return h; } h = next_handle_; - function_data_.insert({h, FunctionData(device_name, local_handle)}); + FunctionData* fd = new FunctionData(device_name, local_handle); + function_data_[h] = std::unique_ptr(fd); table_[function_key] = h; next_handle_++; return h; @@ -196,19 +210,19 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice( if (function_data_.count(handle) == 0) { return kInvalidLocalHandle; } - const FunctionData& function_data = function_data_[handle]; - if (function_data.target_device != device_name) { + FunctionData* function_data = function_data_[handle].get(); + if (function_data->target_device() != device_name) { return kInvalidLocalHandle; } - return function_data.local_handle; + return function_data->local_handle(); } string ProcessFunctionLibraryRuntime::GetDeviceName( FunctionLibraryRuntime::Handle handle) { mutex_lock l(mu_); CHECK_EQ(1, function_data_.count(handle)); - const FunctionData& function_data = function_data_[handle]; - return function_data.target_device; + FunctionData* function_data = function_data_[handle].get(); + return function_data->target_device(); } Status ProcessFunctionLibraryRuntime::Instantiate( @@ -225,11 +239,26 @@ Status ProcessFunctionLibraryRuntime::Instantiate( "Currently don't support instantiating functions on device: ", options.target); } - FunctionLibraryRuntime::Handle cluster_handle; - TF_RETURN_IF_ERROR(parent_->Instantiate(function_name, *lib_def_, attrs, - options, &cluster_handle)); + string function_key = Canonicalize(function_name, attrs); - *handle = AddHandle(function_key, options.target, cluster_handle); + FunctionData* f; + { + mutex_lock l(mu_); + FunctionLibraryRuntime::Handle h = + gtl::FindWithDefault(table_, function_key, kInvalidHandle); + if (h == kInvalidHandle || function_data_.count(h) == 0) { + h = next_handle_; + FunctionData* fd = new FunctionData(options.target, kInvalidHandle); + function_data_[h] = std::unique_ptr(fd); + table_[function_key] = h; + next_handle_++; + } + f = function_data_[h].get(); + *handle = h; + } + TF_RETURN_IF_ERROR( + f->DistributedInit(parent_, function_name, *lib_def_, attrs, options)); + return Status::OK(); } @@ -247,7 +276,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseHandle( { mutex_lock l(mu_); CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle; - target_device = function_data_[handle].target_device; + target_device = function_data_[handle]->target_device(); } flr = GetFLR(target_device); if (flr != nullptr) { @@ -276,8 +305,8 @@ void ProcessFunctionLibraryRuntime::Run( done(errors::NotFound("Handle: ", handle, " not found.")); return; } - target_device = function_data_[handle].target_device; - local_handle = function_data_[handle].local_handle; + target_device = function_data_[handle]->target_device(); + local_handle = function_data_[handle]->local_handle(); } flr = GetFLR(target_device); if (flr != nullptr) { diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 0473e16d24..10619ba6ea 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -145,14 +145,31 @@ class ProcessFunctionLibraryRuntime { mutable mutex mu_; - struct FunctionData { - const string target_device; - const FunctionLibraryRuntime::LocalHandle local_handle; - + class FunctionData { + public: FunctionData(const string& target_device, FunctionLibraryRuntime::LocalHandle local_handle) - : target_device(target_device), local_handle(local_handle) {} - FunctionData() : FunctionData("", -1) {} + : target_device_(target_device), local_handle_(local_handle) {} + + string target_device() { return target_device_; } + + FunctionLibraryRuntime::LocalHandle local_handle() { return local_handle_; } + + // Initializes the FunctionData object by potentially making an Initialize + // call to the DistributedFunctionLibraryRuntime. + Status DistributedInit( + DistributedFunctionLibraryRuntime* parent, const string& function_name, + const FunctionLibraryDefinition& lib_def, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options); + + private: + mutex mu_; + + const string target_device_; + FunctionLibraryRuntime::LocalHandle local_handle_ GUARDED_BY(mu_); + bool init_started_ GUARDED_BY(mu_) = false; + Status init_result_ GUARDED_BY(mu_); + Notification init_done_; }; const DeviceMgr* const device_mgr_; @@ -160,7 +177,8 @@ class ProcessFunctionLibraryRuntime { // Holds all the function invocations here. std::unordered_map table_ GUARDED_BY(mu_); - std::unordered_map + std::unordered_map> function_data_ GUARDED_BY(mu_); std::unordered_map> flr_map_; int next_handle_ GUARDED_BY(mu_); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 439ba1ce96..ab1f919852 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -19,9 +19,11 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -29,8 +31,32 @@ limitations under the License. namespace tensorflow { namespace { +class TestClusterFLR : public DistributedFunctionLibraryRuntime { + public: + TestClusterFLR() {} + + Status Instantiate(const string& function_name, + const FunctionLibraryDefinition& lib_def, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::LocalHandle* handle) { + mutex_lock l(mu_); + *handle = next_handle_; + next_handle_++; + return Status::OK(); + } + + void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, + gtl::ArraySlice args, std::vector* rets, + FunctionLibraryRuntime::DoneCallback done) {} + + private: + mutex mu_; + int next_handle_ GUARDED_BY(mu_) = 0; +}; + class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { - protected: + public: void Init(const std::vector& flib) { SessionOptions options; auto* device_count = options.config.mutable_device_count(); @@ -42,12 +68,20 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { for (const auto& fdef : flib) *(proto.add_function()) = fdef; lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); OptimizerOptions opts; + cluster_flr_.reset(new TestClusterFLR()); proc_flr_.reset(new ProcessFunctionLibraryRuntime( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts, nullptr /* cluster_flr */)); + opts, cluster_flr_.get())); rendezvous_ = new IntraProcessRendezvous(device_mgr_.get()); } + Status Instantiate( + const string& name, test::function::Attrs attrs, + const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, + FunctionLibraryRuntime::Handle* handle) { + return proc_flr_->Instantiate(name, attrs, instantiate_opts, handle); + } + Status Run(const string& name, FunctionLibraryRuntime::Options opts, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, @@ -106,6 +140,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { std::vector devices_; std::unique_ptr device_mgr_; std::unique_ptr lib_def_; + std::unique_ptr cluster_flr_; std::unique_ptr proc_flr_; IntraProcessRendezvous* rendezvous_; }; @@ -250,5 +285,60 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) { rendezvous_->Unref(); } +TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) { + Init({test::function::FindDevice()}); + FunctionLibraryRuntime::Options opts; + opts.source_device = "/job:a/replica:0/task:0/cpu:0"; + opts.rendezvous = rendezvous_; + opts.remote_execution = true; + FunctionLibraryRuntime::InstantiateOptions instantiate_opts; + instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0"; + FunctionLibraryRuntime::Handle h; + TF_CHECK_OK(Instantiate("FindDevice", + {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}}, + instantiate_opts, &h)); + EXPECT_EQ(0, proc_flr_->GetHandleOnDevice( + "/job:b/replica:0/task:0/device:CPU:0", h)); + TF_CHECK_OK(Instantiate("FindDevice", + {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}}, + instantiate_opts, &h)); + EXPECT_EQ(0, proc_flr_->GetHandleOnDevice( + "/job:b/replica:0/task:0/device:CPU:0", h)); + instantiate_opts.target = "/job:c/replica:0/task:0/device:CPU:0"; + TF_CHECK_OK(Instantiate("FindDevice", + {{"_target", "/job:c/replica:0/task:0/device:CPU:0"}}, + instantiate_opts, &h)); + EXPECT_EQ(1, proc_flr_->GetHandleOnDevice( + "/job:c/replica:0/task:0/device:CPU:0", h)); + rendezvous_->Unref(); +} + +TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRParallelTest) { + Init({test::function::FindDevice()}); + FunctionLibraryRuntime::Options opts; + opts.source_device = "/job:a/replica:0/task:0/cpu:0"; + opts.rendezvous = rendezvous_; + opts.remote_execution = true; + FunctionLibraryRuntime::InstantiateOptions instantiate_opts; + instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0"; + + thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4); + auto fn = [this, &instantiate_opts]() { + FunctionLibraryRuntime::Handle h; + TF_CHECK_OK(Instantiate( + "FindDevice", {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}}, + instantiate_opts, &h)); + EXPECT_EQ(0, proc_flr_->GetHandleOnDevice( + "/job:b/replica:0/task:0/device:CPU:0", h)); + }; + + for (int i = 0; i < 100; ++i) { + tp->Schedule(fn); + } + delete tp; + + rendezvous_->Unref(); +} + } // anonymous namespace } // namespace tensorflow -- cgit v1.2.3