diff options
author | Rohan Jain <rohanj@google.com> | 2018-03-09 12:20:32 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-09 12:24:51 -0800 |
commit | 20dfc25c378c600fac683e62dc8a1ed2a522711c (patch) | |
tree | df618d2f7874bcc08ea3de65297f7ee259f65a51 | |
parent | 61a744fffbcc68e453aafc6eaa2c7ff2318a3584 (diff) |
Allowing for FunctionLibraryRuntime::Run calls to not be provided with a runner to execute kernels with. In that case, it defaults to using the threadpool provided by the device.
Also makes sure each device has a default threadpool to fall back on.
PiperOrigin-RevId: 188520648
12 files changed, 221 insertions, 96 deletions
diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index 643153058c..4f75d27887 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -41,7 +41,7 @@ class TestEnv { device_mgr_.reset(new DeviceMgr({device})); flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(), device, TF_GRAPH_DEF_VERSION, - &flib_def_, {}, nullptr); + &flib_def_, nullptr, {}, nullptr); } FunctionLibraryRuntime* function_library_runtime() const { diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index ecbffcbf6c..9def58cb9c 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1181,7 +1181,7 @@ Status DirectSession::GetOrCreateExecutors( } func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime( device_mgr_.get(), options_.env, graph_def_version, - func_info->flib_def.get(), optimizer_opts)); + func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first)); GraphOptimizer optimizer(optimizer_opts); for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index b75a4f76d9..6fe0cba1e5 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" @@ -868,59 +869,14 @@ TEST(DirectSessionTest, TestTimeoutCleanShutdown) { TF_ASSERT_OK(session->Close()); } -class BlockingOpState { - public: - void AwaitState(int awaiting_state) { - mutex_lock ml(mu_); - while (state_ != awaiting_state) { - cv_.wait(ml); - } - } - void MoveToState(int expected_current, int next) { - mutex_lock ml(mu_); - CHECK_EQ(expected_current, state_); - state_ = next; - cv_.notify_all(); - } - - private: - mutex mu_; - condition_variable cv_; - int state_ = 0; -}; -static BlockingOpState* blocking_op_state = nullptr; - -// BlockingOp blocks on the global <blocking_op_state's> state, -// and also updates it when it is unblocked and finishing computation. -class BlockingOp : public OpKernel { - public: - explicit BlockingOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - void Compute(OpKernelContext* ctx) override { - blocking_op_state->MoveToState(0, 1); - blocking_op_state->AwaitState(2); - blocking_op_state->MoveToState(2, 3); - - Tensor* out = nullptr; - const Tensor& in = ctx->input(0); - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out)); - out->flat<float>() = in.flat<float>(); - } -}; -REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp); -REGISTER_OP("BlockingOp").Input("x: float").Output("y: float").Doc(""); - static void TestSessionInterOpThreadsImpl(bool use_function_lib, bool use_global_pools) { + using test::function::blocking_op_state; + using test::function::BlockingOpState; + FunctionDefLibrary library_graph_def; if (use_function_lib) { - const string lib = R"proto( - signature: { - name: "BlockingOpFn" input_arg: { name: "x" type: DT_FLOAT } - output_arg: { name: "y" type: DT_FLOAT }} - node_def: { name: "y" op: "BlockingOp" input: "x" } - ret: { key: "y" value: "y:y:0" } )proto"; - CHECK(protobuf::TextFormat::ParseFromString( - lib, library_graph_def.add_function())); + *library_graph_def.add_function() = test::function::BlockingOpFn(); } FunctionLibraryDefinition flib(OpRegistry::Global(), library_graph_def); diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 3e937ceb64..effe53c961 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/graph/gradients.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/optimizer_cse.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/macros.h" @@ -141,6 +142,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, + thread::ThreadPool* default_thread_pool, const OptimizerOptions& optimizer_options, CustomKernelCreator custom_kernel_creator, ProcessFunctionLibraryRuntime* parent); @@ -194,6 +196,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { const FunctionLibraryDefinition* const base_lib_def_; GraphOptimizer optimizer_; const CustomKernelCreator custom_kernel_creator_; + Executor::Args::Runner default_runner_; const string device_name_; std::function<Status(const string&, const OpDef**)> get_func_sig_; @@ -243,6 +246,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, + thread::ThreadPool* default_thread_pool, const OptimizerOptions& optimizer_options, CustomKernelCreator custom_kernel_creator, ProcessFunctionLibraryRuntime* parent) @@ -253,6 +257,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( base_lib_def_(lib_def), optimizer_(optimizer_options), custom_kernel_creator_(std::move(custom_kernel_creator)), + default_runner_(nullptr), device_name_(device_ == nullptr ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice : device_->name()), @@ -264,6 +269,18 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) { return CreateKernel(ndef, kernel); }; + thread::ThreadPool* pool = nullptr; + if (device_ != nullptr) { + pool = device_->tensorflow_device_thread_pool(); + } + if (pool == nullptr) { + pool = default_thread_pool; + } + if (pool != nullptr) { + default_runner_ = [pool](Executor::Args::Closure c) { + pool->Schedule(std::move(c)); + }; + } } FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() { @@ -768,6 +785,9 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, return; } + if (run_opts.runner == nullptr) { + run_opts.runner = &default_runner_; + } DCHECK(run_opts.runner != nullptr); Executor::Args* exec_args = new Executor::Args; @@ -854,6 +874,9 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, done(s); return; } + if (run_opts.runner == nullptr) { + run_opts.runner = &default_runner_; + } DCHECK(run_opts.runner != nullptr); Executor::Args* exec_args = new Executor::Args; @@ -942,21 +965,21 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) { std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, + thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options, CustomKernelCreator custom_kernel_creator, ProcessFunctionLibraryRuntime* parent) { return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl( - device_mgr, env, device, graph_def_version, lib_def, optimizer_options, - std::move(custom_kernel_creator), parent)); + device_mgr, env, device, graph_def_version, lib_def, thread_pool, + optimizer_options, std::move(custom_kernel_creator), parent)); } std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, + thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options, ProcessFunctionLibraryRuntime* parent) { return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version, - lib_def, optimizer_options, + lib_def, thread_pool, optimizer_options, GetCustomCreatorSingleton()->Get(), parent); } diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h index 477340d87a..a0f9fcae0a 100644 --- a/tensorflow/core/common_runtime/function.h +++ b/tensorflow/core/common_runtime/function.h @@ -55,7 +55,7 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb); std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, + thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options, CustomKernelCreator custom_kernel_creator, ProcessFunctionLibraryRuntime* parent); @@ -65,7 +65,7 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, + thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options, ProcessFunctionLibraryRuntime* parent); // FunctionLibraryRuntime::GetFunctionBody returns a description of an diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 63ad0d231c..d7e5f0018e 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -135,7 +136,8 @@ TEST_F(FunctionTest, WXPlusB) { class FunctionLibraryRuntimeTest : public ::testing::Test { protected: - void Init(const std::vector<FunctionDef>& flib) { + void Init(const std::vector<FunctionDef>& flib, + thread::ThreadPool* default_thread_pool = nullptr) { SessionOptions options; auto* device_count = options.config.mutable_device_count(); device_count->insert({"CPU", 3}); @@ -149,7 +151,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { device_mgr_.reset(new DeviceMgr(devices_)); pflr_.reset(new ProcessFunctionLibraryRuntime( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts, nullptr /* cluster_flr */)); + opts, default_thread_pool, nullptr /* cluster_flr */)); flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1"); flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2"); @@ -158,16 +160,20 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle, FunctionLibraryRuntime::Options opts, - const std::vector<Tensor>& args, std::vector<Tensor*> rets) { + const std::vector<Tensor>& args, std::vector<Tensor*> rets, + bool add_runner = true) { std::atomic<int32> call_count(0); std::function<void(std::function<void()>)> runner = [&call_count](std::function<void()> fn) { ++call_count; test::function::FunctionTestSchedClosure(fn); }; - + if (add_runner) { + opts.runner = &runner; + } else { + opts.runner = nullptr; + } Notification done; - opts.runner = &runner; std::vector<Tensor> out; Status status; flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) { @@ -183,7 +189,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { *rets[i] = out[i]; } - EXPECT_GE(call_count, 1); // Test runner is used. + if (add_runner) { + EXPECT_GE(call_count, 1); // Test runner is used. + } return Status::OK(); } @@ -204,24 +212,25 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name, test::function::Attrs attrs, const std::vector<Tensor>& args, - std::vector<Tensor*> rets) { + std::vector<Tensor*> rets, bool add_runner = true) { return InstantiateAndRun(flr, name, attrs, FunctionLibraryRuntime::InstantiateOptions(), args, - std::move(rets)); + std::move(rets), add_runner); } Status InstantiateAndRun( FunctionLibraryRuntime* flr, const string& name, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& options, - const std::vector<Tensor>& args, std::vector<Tensor*> rets) { + const std::vector<Tensor>& args, std::vector<Tensor*> rets, + bool add_runner = true) { FunctionLibraryRuntime::Handle handle; Status status = flr->Instantiate(name, attrs, options, &handle); if (!status.ok()) { return status; } FunctionLibraryRuntime::Options opts; - status = Run(flr, handle, opts, args, rets); + status = Run(flr, handle, opts, args, rets, add_runner); if (!status.ok()) return status; // Release the handle and try running again. It should not succeed. @@ -237,16 +246,20 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle, - FunctionLibraryRuntime::Options opts, CallFrameInterface* frame) { + FunctionLibraryRuntime::Options opts, CallFrameInterface* frame, + bool add_runner = true) { std::atomic<int32> call_count(0); std::function<void(std::function<void()>)> runner = [&call_count](std::function<void()> fn) { ++call_count; test::function::FunctionTestSchedClosure(fn); }; - + if (add_runner) { + opts.runner = &runner; + } else { + opts.runner = nullptr; + } Notification done; - opts.runner = &runner; std::vector<Tensor> out; Status status; flr->Run(opts, handle, frame, [&status, &done](const Status& s) { @@ -258,7 +271,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return status; } - EXPECT_GE(call_count, 1); // Test runner is used. + if (add_runner) { + EXPECT_GE(call_count, 1); // Test runner is used. + } return Status::OK(); } @@ -447,7 +462,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { { // Simple case: instantiating with no state_handle. for (int32 expected : {6, 4}) { - TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y})); + TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y}, true)); test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected})); } } @@ -460,7 +475,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { Instantiate(flr0_, "RandomUniformWrapper", {}, &handle_non_isolated)); EXPECT_EQ(handle, handle_non_isolated); for (int32 expected : {0, 1}) { - TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y})); + TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y}, true)); test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected})); } } @@ -475,7 +490,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { &handle_isolated)); EXPECT_NE(handle, handle_isolated); for (int32 expected : {6, 4, 0, 1}) { - TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y})); + TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true)); test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected})); } } @@ -490,7 +505,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { &handle_isolated)); EXPECT_NE(handle, handle_isolated); for (int32 expected : {6, 4, 0, 1}) { - TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y})); + TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true)); test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected})); } } @@ -507,7 +522,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { &handle_isolated)); EXPECT_NE(handle, handle_isolated); for (int32 expected : {6, 4, 0, 1}) { - TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y})); + TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true)); test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected})); } TF_CHECK_OK(flr0_->ReleaseHandle(handle_isolated)); @@ -515,6 +530,59 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { } } +TEST_F(FunctionLibraryRuntimeTest, DefaultThreadpool) { + using test::function::blocking_op_state; + using test::function::BlockingOpState; + + thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "FLRTest", 1); + Init({test::function::BlockingOpFn(), test::function::XTimesTwo()}, tp); + + auto x = test::AsScalar<float>(1.3); + Tensor y; + blocking_op_state = new BlockingOpState(); + + thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5); + bool finished_running = false; + tp1->Schedule([&x, &y, &finished_running, this]() { + TF_CHECK_OK(InstantiateAndRun(flr0_, "BlockingOpFn", {}, {x}, {&y}, + false /* add_runner */)); + finished_running = true; + }); + + // InstantiateAndRun shouldn't finish because BlockingOpFn should be blocked. + EXPECT_FALSE(finished_running); + + FunctionLibraryRuntime::Handle h; + TF_CHECK_OK(Instantiate(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, &h)); + + auto x1 = test::AsTensor<float>({1, 2, 3, 4}); + Tensor y1; + std::atomic<int32> num_done(0); + FunctionLibraryRuntime::Options opts; + for (int i = 0; i < 4; ++i) { + tp1->Schedule([&h, &x1, &y1, &opts, &num_done, this]() { + TF_CHECK_OK(Run(flr0_, h, opts, {x1}, {&y1}, false /* add_runner */)); + num_done.fetch_add(1); + }); + } + // All the 4 Run() calls should be blocked because the runner is occupied. + EXPECT_EQ(0, num_done.load()); + + blocking_op_state->AwaitState(1); + blocking_op_state->MoveToState(1, 2); + // Now the runner should be unblocked and all the other Run() calls should + // proceed. + blocking_op_state->AwaitState(3); + blocking_op_state->MoveToState(3, 0); + delete tp1; + EXPECT_TRUE(finished_running); + EXPECT_EQ(4, num_done.load()); + + delete blocking_op_state; + blocking_op_state = nullptr; + delete tp; +} + TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); @@ -787,7 +855,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto x4_x2_scale = ops::Const<float>( - s.WithOpName("x4/x2/scale/_12__cf__6") + s.WithOpName("x4/x2/scale/_12__cf__10") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale); @@ -993,13 +1061,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1); auto scale = ops::Const( - s.WithOpName("scale/_6__cf__11") + s.WithOpName("scale/_6__cf__15") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale); auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x); auto const0 = ops::Const( - s.WithOpName("Func/_1/sy/_5__cf__10") + s.WithOpName("Func/_1/sy/_5__cf__14") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 0, {0}); auto func1_rx = ops::internal::BroadcastGradientArgs( @@ -1247,14 +1315,14 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get()); opts.source_device = "/device:CPU:1"; // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1. - TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y})); + TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true)); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/device:CPU:1"}, TensorShape({}))); opts.remote_execution = true; opts.source_device = "/job:localhost/replica:0/task:0/cpu:2"; - TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y})); + TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}, true)); test::ExpectTensorEqual<string>( y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/device:CPU:1"}, diff --git a/tensorflow/core/common_runtime/function_testlib.cc b/tensorflow/core/common_runtime/function_testlib.cc index 87733ed2db..1720ee64c0 100644 --- a/tensorflow/core/common_runtime/function_testlib.cc +++ b/tensorflow/core/common_runtime/function_testlib.cc @@ -58,6 +58,59 @@ FunctionDef FindDevice() { {{{"device_name"}, "FindDeviceOp", {}, {}}}); } +void BlockingOpState::AwaitState(int awaiting_state) { + mutex_lock ml(mu_); + while (state_ != awaiting_state) { + cv_.wait(ml); + } +} + +void BlockingOpState::MoveToState(int expected_current, int next) { + mutex_lock ml(mu_); + CHECK_EQ(expected_current, state_); + state_ = next; + cv_.notify_all(); +} + +BlockingOpState* blocking_op_state = nullptr; + +// BlockingOp blocks on the global <blocking_op_state's> state, +// and also updates it when it is unblocked and finishing computation. +class BlockingOp : public OpKernel { + public: + explicit BlockingOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + blocking_op_state->MoveToState(0, 1); + blocking_op_state->AwaitState(2); + blocking_op_state->MoveToState(2, 3); + + Tensor* out = nullptr; + const Tensor& in = ctx->input(0); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out)); + out->flat<float>() = in.flat<float>(); + } +}; +REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp); +REGISTER_OP("BlockingOp") + .Input("x: float") + .Output("y: float") + .Doc("") + .SetShapeFn(shape_inference::UnknownShape); + +FunctionDef BlockingOpFn() { + return FDH::Define( + // Name + "BlockingOpFn", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attr def + {}, + // Nodes + {{{"y"}, "BlockingOp", {"x"}, {}}}); +} + // TODO(phawkins): replace with C++ API for calling functions, when that exists. Output Call(Scope* scope, const string& op_name, const string& fn_name, gtl::ArraySlice<Input> inputs) { diff --git a/tensorflow/core/common_runtime/function_testlib.h b/tensorflow/core/common_runtime/function_testlib.h index 3ddb26de92..fb967a6123 100644 --- a/tensorflow/core/common_runtime/function_testlib.h +++ b/tensorflow/core/common_runtime/function_testlib.h @@ -25,6 +25,22 @@ namespace function { // {} -> y:DT_STRING (device where this op runs). FunctionDef FindDevice(); +class BlockingOpState { + public: + void AwaitState(int awaiting_state); + + void MoveToState(int expected_current, int next); + + private: + mutex mu_; + condition_variable cv_; + int state_ = 0; +}; + +extern BlockingOpState* blocking_op_state; + +FunctionDef BlockingOpFn(); + // Adds a function call to the given scope and returns the output for the node. // TODO(phawkins): replace with C++ API for calling functions, when that exists. Output Call(Scope* scope, const string& op_name, const string& fn_name, diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 44dc6f9459..07c657a741 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -42,21 +42,23 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, + thread::ThreadPool* default_thread_pool, DistributedFunctionLibraryRuntime* parent) : device_mgr_(device_mgr), lib_def_(lib_def), + default_thread_pool_(default_thread_pool), next_handle_(0), parent_(parent) { if (device_mgr == nullptr) { - flr_map_[nullptr] = - NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version, - lib_def, optimizer_options, this); + flr_map_[nullptr] = NewFunctionLibraryRuntime( + nullptr, env, nullptr, graph_def_version, lib_def, default_thread_pool, + optimizer_options, this); return; } for (Device* d : device_mgr->ListDevices()) { - flr_map_[d] = - NewFunctionLibraryRuntime(device_mgr, env, d, graph_def_version, - lib_def, optimizer_options, this); + flr_map_[d] = NewFunctionLibraryRuntime( + device_mgr, env, d, graph_def_version, lib_def, default_thread_pool, + optimizer_options, this); } } @@ -65,21 +67,23 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, CustomKernelCreator custom_kernel_creator, + thread::ThreadPool* default_thread_pool, DistributedFunctionLibraryRuntime* parent) : device_mgr_(device_mgr), lib_def_(lib_def), + default_thread_pool_(default_thread_pool), next_handle_(0), parent_(parent) { if (device_mgr == nullptr) { flr_map_[nullptr] = NewFunctionLibraryRuntime( - nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options, - std::move(custom_kernel_creator), this); + nullptr, env, nullptr, graph_def_version, lib_def, default_thread_pool, + optimizer_options, std::move(custom_kernel_creator), this); return; } for (Device* d : device_mgr->ListDevices()) { flr_map_[d] = NewFunctionLibraryRuntime( - device_mgr, env, d, graph_def_version, lib_def, optimizer_options, - custom_kernel_creator, this); + device_mgr, env, d, graph_def_version, lib_def, default_thread_pool, + optimizer_options, custom_kernel_creator, this); } } @@ -370,7 +374,8 @@ Status ProcessFunctionLibraryRuntime::Clone( out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_)); out_pflr->reset(new ProcessFunctionLibraryRuntime( device_mgr_, env, graph_def_version, out_lib_def->get(), - optimizer_options, std::move(custom_kernel_creator), parent_)); + optimizer_options, std::move(custom_kernel_creator), default_thread_pool_, + parent_)); return Status::OK(); } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 10619ba6ea..d69e8bc2a0 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -33,6 +33,7 @@ class ProcessFunctionLibraryRuntime { const DeviceMgr* device_mgr, Env* env, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, + thread::ThreadPool* thread_pool = nullptr, DistributedFunctionLibraryRuntime* parent = nullptr); // With `custom_kernel_creator`. @@ -41,6 +42,7 @@ class ProcessFunctionLibraryRuntime { const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, CustomKernelCreator custom_kernel_creator, + thread::ThreadPool* thread_pool, DistributedFunctionLibraryRuntime* parent); // Sends `tensors_to_send` from `source_device` to `target_device` using @@ -174,6 +176,7 @@ class ProcessFunctionLibraryRuntime { const DeviceMgr* const device_mgr_; const FunctionLibraryDefinition* lib_def_; + thread::ThreadPool* default_thread_pool_; // Holds all the function invocations here. std::unordered_map<string, FunctionLibraryRuntime::Handle> table_ GUARDED_BY(mu_); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index ab1f919852..2da67b084a 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -71,7 +71,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { cluster_flr_.reset(new TestClusterFLR()); proc_flr_.reset(new ProcessFunctionLibraryRuntime( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts, cluster_flr_.get())); + opts, nullptr, cluster_flr_.get())); rendezvous_ = new IntraProcessRendezvous(device_mgr_.get()); } @@ -153,7 +153,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) { std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr( new ProcessFunctionLibraryRuntime( nullptr /* device_mgr */, Env::Default(), TF_GRAPH_DEF_VERSION, - lib_def.get(), opts, nullptr /* cluster_flr */)); + lib_def.get(), opts, nullptr, nullptr /* cluster_flr */)); FunctionLibraryRuntime* flr = proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); EXPECT_NE(flr, nullptr); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 7878ebb5f0..9768a244f2 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -134,7 +134,8 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, item->proc_flr.reset(new ProcessFunctionLibraryRuntime( device_mgr_, worker_env_->env, gdef.versions().producer(), - item->lib_def.get(), graph_options.optimizer_options(), cluster_flr)); + item->lib_def.get(), graph_options.optimizer_options(), + worker_env_->compute_pool, cluster_flr)); // Constructs the graph out of "gdef". Graph graph(OpRegistry::Global()); |