aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-03-09 12:20:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-09 12:24:51 -0800
commit20dfc25c378c600fac683e62dc8a1ed2a522711c (patch)
treedf618d2f7874bcc08ea3de65297f7ee259f65a51
parent61a744fffbcc68e453aafc6eaa2c7ff2318a3584 (diff)
Allowing for FunctionLibraryRuntime::Run calls to not be provided with a runner to execute kernels with. In that case, it defaults to using the threadpool provided by the device.
Also makes sure each device has a default threadpool to fall back on. PiperOrigin-RevId: 188520648
-rw-r--r--tensorflow/c/eager/runtime_test.cc2
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc2
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc54
-rw-r--r--tensorflow/core/common_runtime/function.cc33
-rw-r--r--tensorflow/core/common_runtime/function.h4
-rw-r--r--tensorflow/core/common_runtime/function_test.cc116
-rw-r--r--tensorflow/core/common_runtime/function_testlib.cc53
-rw-r--r--tensorflow/core/common_runtime/function_testlib.h16
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc27
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h3
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime_test.cc4
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc3
12 files changed, 221 insertions, 96 deletions
diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc
index 643153058c..4f75d27887 100644
--- a/tensorflow/c/eager/runtime_test.cc
+++ b/tensorflow/c/eager/runtime_test.cc
@@ -41,7 +41,7 @@ class TestEnv {
device_mgr_.reset(new DeviceMgr({device}));
flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(),
device, TF_GRAPH_DEF_VERSION,
- &flib_def_, {}, nullptr);
+ &flib_def_, nullptr, {}, nullptr);
}
FunctionLibraryRuntime* function_library_runtime() const {
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index ecbffcbf6c..9def58cb9c 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1181,7 +1181,7 @@ Status DirectSession::GetOrCreateExecutors(
}
func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), options_.env, graph_def_version,
- func_info->flib_def.get(), optimizer_opts));
+ func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first));
GraphOptimizer optimizer(optimizer_opts);
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index b75a4f76d9..6fe0cba1e5 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -868,59 +869,14 @@ TEST(DirectSessionTest, TestTimeoutCleanShutdown) {
TF_ASSERT_OK(session->Close());
}
-class BlockingOpState {
- public:
- void AwaitState(int awaiting_state) {
- mutex_lock ml(mu_);
- while (state_ != awaiting_state) {
- cv_.wait(ml);
- }
- }
- void MoveToState(int expected_current, int next) {
- mutex_lock ml(mu_);
- CHECK_EQ(expected_current, state_);
- state_ = next;
- cv_.notify_all();
- }
-
- private:
- mutex mu_;
- condition_variable cv_;
- int state_ = 0;
-};
-static BlockingOpState* blocking_op_state = nullptr;
-
-// BlockingOp blocks on the global <blocking_op_state's> state,
-// and also updates it when it is unblocked and finishing computation.
-class BlockingOp : public OpKernel {
- public:
- explicit BlockingOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
- void Compute(OpKernelContext* ctx) override {
- blocking_op_state->MoveToState(0, 1);
- blocking_op_state->AwaitState(2);
- blocking_op_state->MoveToState(2, 3);
-
- Tensor* out = nullptr;
- const Tensor& in = ctx->input(0);
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
- out->flat<float>() = in.flat<float>();
- }
-};
-REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp);
-REGISTER_OP("BlockingOp").Input("x: float").Output("y: float").Doc("");
-
static void TestSessionInterOpThreadsImpl(bool use_function_lib,
bool use_global_pools) {
+ using test::function::blocking_op_state;
+ using test::function::BlockingOpState;
+
FunctionDefLibrary library_graph_def;
if (use_function_lib) {
- const string lib = R"proto(
- signature: {
- name: "BlockingOpFn" input_arg: { name: "x" type: DT_FLOAT }
- output_arg: { name: "y" type: DT_FLOAT }}
- node_def: { name: "y" op: "BlockingOp" input: "x" }
- ret: { key: "y" value: "y:y:0" } )proto";
- CHECK(protobuf::TextFormat::ParseFromString(
- lib, library_graph_def.add_function()));
+ *library_graph_def.add_function() = test::function::BlockingOpFn();
}
FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def);
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 3e937ceb64..effe53c961 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/graph/gradients.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/optimizer_cse.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/macros.h"
@@ -141,6 +142,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device,
int graph_def_version,
const FunctionLibraryDefinition* lib_def,
+ thread::ThreadPool* default_thread_pool,
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
ProcessFunctionLibraryRuntime* parent);
@@ -194,6 +196,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
const FunctionLibraryDefinition* const base_lib_def_;
GraphOptimizer optimizer_;
const CustomKernelCreator custom_kernel_creator_;
+ Executor::Args::Runner default_runner_;
const string device_name_;
std::function<Status(const string&, const OpDef**)> get_func_sig_;
@@ -243,6 +246,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
+ thread::ThreadPool* default_thread_pool,
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
ProcessFunctionLibraryRuntime* parent)
@@ -253,6 +257,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
base_lib_def_(lib_def),
optimizer_(optimizer_options),
custom_kernel_creator_(std::move(custom_kernel_creator)),
+ default_runner_(nullptr),
device_name_(device_ == nullptr
? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
: device_->name()),
@@ -264,6 +269,18 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
return CreateKernel(ndef, kernel);
};
+ thread::ThreadPool* pool = nullptr;
+ if (device_ != nullptr) {
+ pool = device_->tensorflow_device_thread_pool();
+ }
+ if (pool == nullptr) {
+ pool = default_thread_pool;
+ }
+ if (pool != nullptr) {
+ default_runner_ = [pool](Executor::Args::Closure c) {
+ pool->Schedule(std::move(c));
+ };
+ }
}
FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
@@ -768,6 +785,9 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
return;
}
+ if (run_opts.runner == nullptr) {
+ run_opts.runner = &default_runner_;
+ }
DCHECK(run_opts.runner != nullptr);
Executor::Args* exec_args = new Executor::Args;
@@ -854,6 +874,9 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
done(s);
return;
}
+ if (run_opts.runner == nullptr) {
+ run_opts.runner = &default_runner_;
+ }
DCHECK(run_opts.runner != nullptr);
Executor::Args* exec_args = new Executor::Args;
@@ -942,21 +965,21 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
- const OptimizerOptions& optimizer_options,
+ thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
ProcessFunctionLibraryRuntime* parent) {
return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
- device_mgr, env, device, graph_def_version, lib_def, optimizer_options,
- std::move(custom_kernel_creator), parent));
+ device_mgr, env, device, graph_def_version, lib_def, thread_pool,
+ optimizer_options, std::move(custom_kernel_creator), parent));
}
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
- const OptimizerOptions& optimizer_options,
+ thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
ProcessFunctionLibraryRuntime* parent) {
return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version,
- lib_def, optimizer_options,
+ lib_def, thread_pool, optimizer_options,
GetCustomCreatorSingleton()->Get(), parent);
}
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
index 477340d87a..a0f9fcae0a 100644
--- a/tensorflow/core/common_runtime/function.h
+++ b/tensorflow/core/common_runtime/function.h
@@ -55,7 +55,7 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb);
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
- const OptimizerOptions& optimizer_options,
+ thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
ProcessFunctionLibraryRuntime* parent);
@@ -65,7 +65,7 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
- const OptimizerOptions& optimizer_options,
+ thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
ProcessFunctionLibraryRuntime* parent);
// FunctionLibraryRuntime::GetFunctionBody returns a description of an
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 63ad0d231c..d7e5f0018e 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
@@ -135,7 +136,8 @@ TEST_F(FunctionTest, WXPlusB) {
class FunctionLibraryRuntimeTest : public ::testing::Test {
protected:
- void Init(const std::vector<FunctionDef>& flib) {
+ void Init(const std::vector<FunctionDef>& flib,
+ thread::ThreadPool* default_thread_pool = nullptr) {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 3});
@@ -149,7 +151,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
device_mgr_.reset(new DeviceMgr(devices_));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
- opts, nullptr /* cluster_flr */));
+ opts, default_thread_pool, nullptr /* cluster_flr */));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");
@@ -158,16 +160,20 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime::Options opts,
- const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
+ const std::vector<Tensor>& args, std::vector<Tensor*> rets,
+ bool add_runner = true) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
[&call_count](std::function<void()> fn) {
++call_count;
test::function::FunctionTestSchedClosure(fn);
};
-
+ if (add_runner) {
+ opts.runner = &runner;
+ } else {
+ opts.runner = nullptr;
+ }
Notification done;
- opts.runner = &runner;
std::vector<Tensor> out;
Status status;
flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
@@ -183,7 +189,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
*rets[i] = out[i];
}
- EXPECT_GE(call_count, 1); // Test runner is used.
+ if (add_runner) {
+ EXPECT_GE(call_count, 1); // Test runner is used.
+ }
return Status::OK();
}
@@ -204,24 +212,25 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
const std::vector<Tensor>& args,
- std::vector<Tensor*> rets) {
+ std::vector<Tensor*> rets, bool add_runner = true) {
return InstantiateAndRun(flr, name, attrs,
FunctionLibraryRuntime::InstantiateOptions(), args,
- std::move(rets));
+ std::move(rets), add_runner);
}
Status InstantiateAndRun(
FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
- const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
+ const std::vector<Tensor>& args, std::vector<Tensor*> rets,
+ bool add_runner = true) {
FunctionLibraryRuntime::Handle handle;
Status status = flr->Instantiate(name, attrs, options, &handle);
if (!status.ok()) {
return status;
}
FunctionLibraryRuntime::Options opts;
- status = Run(flr, handle, opts, args, rets);
+ status = Run(flr, handle, opts, args, rets, add_runner);
if (!status.ok()) return status;
// Release the handle and try running again. It should not succeed.
@@ -237,16 +246,20 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
- FunctionLibraryRuntime::Options opts, CallFrameInterface* frame) {
+ FunctionLibraryRuntime::Options opts, CallFrameInterface* frame,
+ bool add_runner = true) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
[&call_count](std::function<void()> fn) {
++call_count;
test::function::FunctionTestSchedClosure(fn);
};
-
+ if (add_runner) {
+ opts.runner = &runner;
+ } else {
+ opts.runner = nullptr;
+ }
Notification done;
- opts.runner = &runner;
std::vector<Tensor> out;
Status status;
flr->Run(opts, handle, frame, [&status, &done](const Status& s) {
@@ -258,7 +271,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return status;
}
- EXPECT_GE(call_count, 1); // Test runner is used.
+ if (add_runner) {
+ EXPECT_GE(call_count, 1); // Test runner is used.
+ }
return Status::OK();
}
@@ -447,7 +462,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
{
// Simple case: instantiating with no state_handle.
for (int32 expected : {6, 4}) {
- TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
@@ -460,7 +475,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
Instantiate(flr0_, "RandomUniformWrapper", {}, &handle_non_isolated));
EXPECT_EQ(handle, handle_non_isolated);
for (int32 expected : {0, 1}) {
- TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
@@ -475,7 +490,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
&handle_isolated));
EXPECT_NE(handle, handle_isolated);
for (int32 expected : {6, 4, 0, 1}) {
- TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
@@ -490,7 +505,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
&handle_isolated));
EXPECT_NE(handle, handle_isolated);
for (int32 expected : {6, 4, 0, 1}) {
- TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
@@ -507,7 +522,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
&handle_isolated));
EXPECT_NE(handle, handle_isolated);
for (int32 expected : {6, 4, 0, 1}) {
- TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
TF_CHECK_OK(flr0_->ReleaseHandle(handle_isolated));
@@ -515,6 +530,59 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
}
}
+TEST_F(FunctionLibraryRuntimeTest, DefaultThreadpool) {
+ using test::function::blocking_op_state;
+ using test::function::BlockingOpState;
+
+ thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "FLRTest", 1);
+ Init({test::function::BlockingOpFn(), test::function::XTimesTwo()}, tp);
+
+ auto x = test::AsScalar<float>(1.3);
+ Tensor y;
+ blocking_op_state = new BlockingOpState();
+
+ thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5);
+ bool finished_running = false;
+ tp1->Schedule([&x, &y, &finished_running, this]() {
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "BlockingOpFn", {}, {x}, {&y},
+ false /* add_runner */));
+ finished_running = true;
+ });
+
+ // InstantiateAndRun shouldn't finish because BlockingOpFn should be blocked.
+ EXPECT_FALSE(finished_running);
+
+ FunctionLibraryRuntime::Handle h;
+ TF_CHECK_OK(Instantiate(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, &h));
+
+ auto x1 = test::AsTensor<float>({1, 2, 3, 4});
+ Tensor y1;
+ std::atomic<int32> num_done(0);
+ FunctionLibraryRuntime::Options opts;
+ for (int i = 0; i < 4; ++i) {
+ tp1->Schedule([&h, &x1, &y1, &opts, &num_done, this]() {
+ TF_CHECK_OK(Run(flr0_, h, opts, {x1}, {&y1}, false /* add_runner */));
+ num_done.fetch_add(1);
+ });
+ }
+ // All the 4 Run() calls should be blocked because the runner is occupied.
+ EXPECT_EQ(0, num_done.load());
+
+ blocking_op_state->AwaitState(1);
+ blocking_op_state->MoveToState(1, 2);
+ // Now the runner should be unblocked and all the other Run() calls should
+ // proceed.
+ blocking_op_state->AwaitState(3);
+ blocking_op_state->MoveToState(3, 0);
+ delete tp1;
+ EXPECT_TRUE(finished_running);
+ EXPECT_EQ(4, num_done.load());
+
+ delete blocking_op_state;
+ blocking_op_state = nullptr;
+ delete tp;
+}
+
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
@@ -787,7 +855,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
- s.WithOpName("x4/x2/scale/_12__cf__6")
+ s.WithOpName("x4/x2/scale/_12__cf__10")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@@ -993,13 +1061,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
- s.WithOpName("scale/_6__cf__11")
+ s.WithOpName("scale/_6__cf__15")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
- s.WithOpName("Func/_1/sy/_5__cf__10")
+ s.WithOpName("Func/_1/sy/_5__cf__14")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(
@@ -1247,14 +1315,14 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
opts.source_device = "/device:CPU:1";
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
- TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true));
test::ExpectTensorEqual<string>(
y,
test::AsTensor<string>({"/job:localhost/replica:0/task:0/device:CPU:1"},
TensorShape({})));
opts.remote_execution = true;
opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
- TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}, true));
test::ExpectTensorEqual<string>(
y,
test::AsTensor<string>({"/job:localhost/replica:0/task:0/device:CPU:1"},
diff --git a/tensorflow/core/common_runtime/function_testlib.cc b/tensorflow/core/common_runtime/function_testlib.cc
index 87733ed2db..1720ee64c0 100644
--- a/tensorflow/core/common_runtime/function_testlib.cc
+++ b/tensorflow/core/common_runtime/function_testlib.cc
@@ -58,6 +58,59 @@ FunctionDef FindDevice() {
{{{"device_name"}, "FindDeviceOp", {}, {}}});
}
+void BlockingOpState::AwaitState(int awaiting_state) {
+ mutex_lock ml(mu_);
+ while (state_ != awaiting_state) {
+ cv_.wait(ml);
+ }
+}
+
+void BlockingOpState::MoveToState(int expected_current, int next) {
+ mutex_lock ml(mu_);
+ CHECK_EQ(expected_current, state_);
+ state_ = next;
+ cv_.notify_all();
+}
+
+BlockingOpState* blocking_op_state = nullptr;
+
+// BlockingOp blocks on the global <blocking_op_state's> state,
+// and also updates it when it is unblocked and finishing computation.
+class BlockingOp : public OpKernel {
+ public:
+ explicit BlockingOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) override {
+ blocking_op_state->MoveToState(0, 1);
+ blocking_op_state->AwaitState(2);
+ blocking_op_state->MoveToState(2, 3);
+
+ Tensor* out = nullptr;
+ const Tensor& in = ctx->input(0);
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
+ out->flat<float>() = in.flat<float>();
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp);
+REGISTER_OP("BlockingOp")
+ .Input("x: float")
+ .Output("y: float")
+ .Doc("")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+FunctionDef BlockingOpFn() {
+ return FDH::Define(
+ // Name
+ "BlockingOpFn",
+ // Args
+ {"x: float"},
+ // Return values
+ {"y: float"},
+ // Attr def
+ {},
+ // Nodes
+ {{{"y"}, "BlockingOp", {"x"}, {}}});
+}
+
// TODO(phawkins): replace with C++ API for calling functions, when that exists.
Output Call(Scope* scope, const string& op_name, const string& fn_name,
gtl::ArraySlice<Input> inputs) {
diff --git a/tensorflow/core/common_runtime/function_testlib.h b/tensorflow/core/common_runtime/function_testlib.h
index 3ddb26de92..fb967a6123 100644
--- a/tensorflow/core/common_runtime/function_testlib.h
+++ b/tensorflow/core/common_runtime/function_testlib.h
@@ -25,6 +25,22 @@ namespace function {
// {} -> y:DT_STRING (device where this op runs).
FunctionDef FindDevice();
+class BlockingOpState {
+ public:
+ void AwaitState(int awaiting_state);
+
+ void MoveToState(int expected_current, int next);
+
+ private:
+ mutex mu_;
+ condition_variable cv_;
+ int state_ = 0;
+};
+
+extern BlockingOpState* blocking_op_state;
+
+FunctionDef BlockingOpFn();
+
// Adds a function call to the given scope and returns the output for the node.
// TODO(phawkins): replace with C++ API for calling functions, when that exists.
Output Call(Scope* scope, const string& op_name, const string& fn_name,
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 44dc6f9459..07c657a741 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -42,21 +42,23 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
+ thread::ThreadPool* default_thread_pool,
DistributedFunctionLibraryRuntime* parent)
: device_mgr_(device_mgr),
lib_def_(lib_def),
+ default_thread_pool_(default_thread_pool),
next_handle_(0),
parent_(parent) {
if (device_mgr == nullptr) {
- flr_map_[nullptr] =
- NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version,
- lib_def, optimizer_options, this);
+ flr_map_[nullptr] = NewFunctionLibraryRuntime(
+ nullptr, env, nullptr, graph_def_version, lib_def, default_thread_pool,
+ optimizer_options, this);
return;
}
for (Device* d : device_mgr->ListDevices()) {
- flr_map_[d] =
- NewFunctionLibraryRuntime(device_mgr, env, d, graph_def_version,
- lib_def, optimizer_options, this);
+ flr_map_[d] = NewFunctionLibraryRuntime(
+ device_mgr, env, d, graph_def_version, lib_def, default_thread_pool,
+ optimizer_options, this);
}
}
@@ -65,21 +67,23 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
+ thread::ThreadPool* default_thread_pool,
DistributedFunctionLibraryRuntime* parent)
: device_mgr_(device_mgr),
lib_def_(lib_def),
+ default_thread_pool_(default_thread_pool),
next_handle_(0),
parent_(parent) {
if (device_mgr == nullptr) {
flr_map_[nullptr] = NewFunctionLibraryRuntime(
- nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options,
- std::move(custom_kernel_creator), this);
+ nullptr, env, nullptr, graph_def_version, lib_def, default_thread_pool,
+ optimizer_options, std::move(custom_kernel_creator), this);
return;
}
for (Device* d : device_mgr->ListDevices()) {
flr_map_[d] = NewFunctionLibraryRuntime(
- device_mgr, env, d, graph_def_version, lib_def, optimizer_options,
- custom_kernel_creator, this);
+ device_mgr, env, d, graph_def_version, lib_def, default_thread_pool,
+ optimizer_options, custom_kernel_creator, this);
}
}
@@ -370,7 +374,8 @@ Status ProcessFunctionLibraryRuntime::Clone(
out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_));
out_pflr->reset(new ProcessFunctionLibraryRuntime(
device_mgr_, env, graph_def_version, out_lib_def->get(),
- optimizer_options, std::move(custom_kernel_creator), parent_));
+ optimizer_options, std::move(custom_kernel_creator), default_thread_pool_,
+ parent_));
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 10619ba6ea..d69e8bc2a0 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -33,6 +33,7 @@ class ProcessFunctionLibraryRuntime {
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
+ thread::ThreadPool* thread_pool = nullptr,
DistributedFunctionLibraryRuntime* parent = nullptr);
// With `custom_kernel_creator`.
@@ -41,6 +42,7 @@ class ProcessFunctionLibraryRuntime {
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
+ thread::ThreadPool* thread_pool,
DistributedFunctionLibraryRuntime* parent);
// Sends `tensors_to_send` from `source_device` to `target_device` using
@@ -174,6 +176,7 @@ class ProcessFunctionLibraryRuntime {
const DeviceMgr* const device_mgr_;
const FunctionLibraryDefinition* lib_def_;
+ thread::ThreadPool* default_thread_pool_;
// Holds all the function invocations here.
std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
GUARDED_BY(mu_);
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index ab1f919852..2da67b084a 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -71,7 +71,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
cluster_flr_.reset(new TestClusterFLR());
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
- opts, cluster_flr_.get()));
+ opts, nullptr, cluster_flr_.get()));
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
}
@@ -153,7 +153,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr(
new ProcessFunctionLibraryRuntime(
nullptr /* device_mgr */, Env::Default(), TF_GRAPH_DEF_VERSION,
- lib_def.get(), opts, nullptr /* cluster_flr */));
+ lib_def.get(), opts, nullptr, nullptr /* cluster_flr */));
FunctionLibraryRuntime* flr =
proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
EXPECT_NE(flr, nullptr);
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 7878ebb5f0..9768a244f2 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -134,7 +134,8 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
device_mgr_, worker_env_->env, gdef.versions().producer(),
- item->lib_def.get(), graph_options.optimizer_options(), cluster_flr));
+ item->lib_def.get(), graph_options.optimizer_options(),
+ worker_env_->compute_pool, cluster_flr));
// Constructs the graph out of "gdef".
Graph graph(OpRegistry::Global());