aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-03-08 16:45:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-08 16:52:20 -0800
commitd56eface20da6adf5a12507053c16ef22594739b (patch)
tree8df07708a2865fe93c3d1a93c61c862a76a9ef9c
parent44bcb41f7edae78b69ab52acbc58934242cf13b8 (diff)
Fixes a bug where the ProcFLR doesn't lookup existing instantiations in the
distributed (ClusterFLR) case. As a result multiple instantiations for the same function were happening. PiperOrigin-RevId: 188411978
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc55
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h32
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime_test.cc94
4 files changed, 160 insertions, 22 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 0fbe4eba6e..f2b0d542dd 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -3156,6 +3156,7 @@ tf_cc_test(
":core_cpu",
":core_cpu_internal",
":framework",
+ ":lib",
":test",
":test_main",
":testlib",
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 929f5c67bc..44dc6f9459 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -25,6 +25,19 @@ namespace tensorflow {
const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
+Status ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
+ DistributedFunctionLibraryRuntime* parent, const string& function_name,
+ const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options) {
+ mutex_lock l(mu_);
+ if (!init_started_) {
+ init_started_ = true;
+ init_result_ = parent->Instantiate(function_name, lib_def, attrs, options,
+ &local_handle_);
+ }
+ return init_result_;
+}
+
ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
@@ -167,7 +180,8 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
if (function_data_.count(h) != 0) return h;
}
h = next_handle_;
- function_data_.insert({h, FunctionData(device_name, local_handle)});
+ FunctionData* fd = new FunctionData(device_name, local_handle);
+ function_data_[h] = std::unique_ptr<FunctionData>(fd);
table_[function_key] = h;
next_handle_++;
return h;
@@ -196,19 +210,19 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
if (function_data_.count(handle) == 0) {
return kInvalidLocalHandle;
}
- const FunctionData& function_data = function_data_[handle];
- if (function_data.target_device != device_name) {
+ FunctionData* function_data = function_data_[handle].get();
+ if (function_data->target_device() != device_name) {
return kInvalidLocalHandle;
}
- return function_data.local_handle;
+ return function_data->local_handle();
}
string ProcessFunctionLibraryRuntime::GetDeviceName(
FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
CHECK_EQ(1, function_data_.count(handle));
- const FunctionData& function_data = function_data_[handle];
- return function_data.target_device;
+ FunctionData* function_data = function_data_[handle].get();
+ return function_data->target_device();
}
Status ProcessFunctionLibraryRuntime::Instantiate(
@@ -225,11 +239,26 @@ Status ProcessFunctionLibraryRuntime::Instantiate(
"Currently don't support instantiating functions on device: ",
options.target);
}
- FunctionLibraryRuntime::Handle cluster_handle;
- TF_RETURN_IF_ERROR(parent_->Instantiate(function_name, *lib_def_, attrs,
- options, &cluster_handle));
+
string function_key = Canonicalize(function_name, attrs);
- *handle = AddHandle(function_key, options.target, cluster_handle);
+ FunctionData* f;
+ {
+ mutex_lock l(mu_);
+ FunctionLibraryRuntime::Handle h =
+ gtl::FindWithDefault(table_, function_key, kInvalidHandle);
+ if (h == kInvalidHandle || function_data_.count(h) == 0) {
+ h = next_handle_;
+ FunctionData* fd = new FunctionData(options.target, kInvalidHandle);
+ function_data_[h] = std::unique_ptr<FunctionData>(fd);
+ table_[function_key] = h;
+ next_handle_++;
+ }
+ f = function_data_[h].get();
+ *handle = h;
+ }
+ TF_RETURN_IF_ERROR(
+ f->DistributedInit(parent_, function_name, *lib_def_, attrs, options));
+
return Status::OK();
}
@@ -247,7 +276,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseHandle(
{
mutex_lock l(mu_);
CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
- target_device = function_data_[handle].target_device;
+ target_device = function_data_[handle]->target_device();
}
flr = GetFLR(target_device);
if (flr != nullptr) {
@@ -276,8 +305,8 @@ void ProcessFunctionLibraryRuntime::Run(
done(errors::NotFound("Handle: ", handle, " not found."));
return;
}
- target_device = function_data_[handle].target_device;
- local_handle = function_data_[handle].local_handle;
+ target_device = function_data_[handle]->target_device();
+ local_handle = function_data_[handle]->local_handle();
}
flr = GetFLR(target_device);
if (flr != nullptr) {
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 0473e16d24..10619ba6ea 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -145,14 +145,31 @@ class ProcessFunctionLibraryRuntime {
mutable mutex mu_;
- struct FunctionData {
- const string target_device;
- const FunctionLibraryRuntime::LocalHandle local_handle;
-
+ class FunctionData {
+ public:
FunctionData(const string& target_device,
FunctionLibraryRuntime::LocalHandle local_handle)
- : target_device(target_device), local_handle(local_handle) {}
- FunctionData() : FunctionData("", -1) {}
+ : target_device_(target_device), local_handle_(local_handle) {}
+
+ string target_device() { return target_device_; }
+
+ FunctionLibraryRuntime::LocalHandle local_handle() { return local_handle_; }
+
+ // Initializes the FunctionData object by potentially making an Initialize
+ // call to the DistributedFunctionLibraryRuntime.
+ Status DistributedInit(
+ DistributedFunctionLibraryRuntime* parent, const string& function_name,
+ const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options);
+
+ private:
+ mutex mu_;
+
+ const string target_device_;
+ FunctionLibraryRuntime::LocalHandle local_handle_ GUARDED_BY(mu_);
+ bool init_started_ GUARDED_BY(mu_) = false;
+ Status init_result_ GUARDED_BY(mu_);
+ Notification init_done_;
};
const DeviceMgr* const device_mgr_;
@@ -160,7 +177,8 @@ class ProcessFunctionLibraryRuntime {
// Holds all the function invocations here.
std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
GUARDED_BY(mu_);
- std::unordered_map<FunctionLibraryRuntime::Handle, FunctionData>
+ std::unordered_map<FunctionLibraryRuntime::Handle,
+ std::unique_ptr<FunctionData>>
function_data_ GUARDED_BY(mu_);
std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>> flr_map_;
int next_handle_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index 439ba1ce96..ab1f919852 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -19,9 +19,11 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
@@ -29,8 +31,32 @@ limitations under the License.
namespace tensorflow {
namespace {
+class TestClusterFLR : public DistributedFunctionLibraryRuntime {
+ public:
+ TestClusterFLR() {}
+
+ Status Instantiate(const string& function_name,
+ const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& options,
+ FunctionLibraryRuntime::LocalHandle* handle) {
+ mutex_lock l(mu_);
+ *handle = next_handle_;
+ next_handle_++;
+ return Status::OK();
+ }
+
+ void Run(const FunctionLibraryRuntime::Options& opts,
+ FunctionLibraryRuntime::LocalHandle handle,
+ gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
+ FunctionLibraryRuntime::DoneCallback done) {}
+
+ private:
+ mutex mu_;
+ int next_handle_ GUARDED_BY(mu_) = 0;
+};
+
class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
- protected:
+ public:
void Init(const std::vector<FunctionDef>& flib) {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
@@ -42,12 +68,20 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts;
+ cluster_flr_.reset(new TestClusterFLR());
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
- opts, nullptr /* cluster_flr */));
+ opts, cluster_flr_.get()));
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
}
+ Status Instantiate(
+ const string& name, test::function::Attrs attrs,
+ const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
+ FunctionLibraryRuntime::Handle* handle) {
+ return proc_flr_->Instantiate(name, attrs, instantiate_opts, handle);
+ }
+
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
@@ -106,6 +140,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
+ std::unique_ptr<TestClusterFLR> cluster_flr_;
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
IntraProcessRendezvous* rendezvous_;
};
@@ -250,5 +285,60 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
rendezvous_->Unref();
}
+TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRSerialTest) {
+ Init({test::function::FindDevice()});
+ FunctionLibraryRuntime::Options opts;
+ opts.source_device = "/job:a/replica:0/task:0/cpu:0";
+ opts.rendezvous = rendezvous_;
+ opts.remote_execution = true;
+ FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
+ instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
+ FunctionLibraryRuntime::Handle h;
+ TF_CHECK_OK(Instantiate("FindDevice",
+ {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}},
+ instantiate_opts, &h));
+ EXPECT_EQ(0, proc_flr_->GetHandleOnDevice(
+ "/job:b/replica:0/task:0/device:CPU:0", h));
+ TF_CHECK_OK(Instantiate("FindDevice",
+ {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}},
+ instantiate_opts, &h));
+ EXPECT_EQ(0, proc_flr_->GetHandleOnDevice(
+ "/job:b/replica:0/task:0/device:CPU:0", h));
+ instantiate_opts.target = "/job:c/replica:0/task:0/device:CPU:0";
+ TF_CHECK_OK(Instantiate("FindDevice",
+ {{"_target", "/job:c/replica:0/task:0/device:CPU:0"}},
+ instantiate_opts, &h));
+ EXPECT_EQ(1, proc_flr_->GetHandleOnDevice(
+ "/job:c/replica:0/task:0/device:CPU:0", h));
+ rendezvous_->Unref();
+}
+
+TEST_F(ProcessFunctionLibraryRuntimeTest, ClusterFLRParallelTest) {
+ Init({test::function::FindDevice()});
+ FunctionLibraryRuntime::Options opts;
+ opts.source_device = "/job:a/replica:0/task:0/cpu:0";
+ opts.rendezvous = rendezvous_;
+ opts.remote_execution = true;
+ FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
+ instantiate_opts.target = "/job:b/replica:0/task:0/device:CPU:0";
+
+ thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
+ auto fn = [this, &instantiate_opts]() {
+ FunctionLibraryRuntime::Handle h;
+ TF_CHECK_OK(Instantiate(
+ "FindDevice", {{"_target", "/job:b/replica:0/task:0/device:CPU:0"}},
+ instantiate_opts, &h));
+ EXPECT_EQ(0, proc_flr_->GetHandleOnDevice(
+ "/job:b/replica:0/task:0/device:CPU:0", h));
+ };
+
+ for (int i = 0; i < 100; ++i) {
+ tp->Schedule(fn);
+ }
+ delete tp;
+
+ rendezvous_->Unref();
+}
+
} // anonymous namespace
} // namespace tensorflow