diff options
author | 2016-06-01 08:34:11 -0800 | |
---|---|---|
committer | 2016-06-01 09:48:06 -0700 | |
commit | caed96e99e87aa3aef8ba88ee8a799d50b0e597b (patch) | |
tree | dc3f0f6e7f837f28e2d6894a836a63ec8314d765 | |
parent | df9e479579a199db2de6721511073cf93b6409f5 (diff) |
Change function library to use a runner passed in from the Run call, instead of
one from library-construction time. This is to properly support inter-op thread
pools with the function library.
Also change testlib graph construction to pass through the Graph's op_registry.
Change: 123761344
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session_test.cc | 66 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/executor.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 16 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/function.h | 6 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 18 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/graph_mgr.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/framework/function.h | 2 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 5 | ||||
-rw-r--r-- | tensorflow/core/graph/testlib.cc | 13 |
10 files changed, 102 insertions, 44 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 1f864b018b..062a6236b6 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -740,9 +740,6 @@ Status DirectSession::GetOrCreateExecutors( } } ek->items.reserve(graphs.size()); - auto runner = [this, pool](Executor::Args::Closure c) { - SchedClosure(pool, c); - }; const auto& optimizer_opts = options_.config.graph_options().optimizer_options(); GraphOptimizer optimizer(optimizer_opts); @@ -757,9 +754,9 @@ Status DirectSession::GetOrCreateExecutors( ek->items.resize(ek->items.size() + 1); auto* item = &(ek->items.back()); - item->flib = NewFunctionLibraryRuntime(device_mgr_.get(), device, runner, - graph_def_version, flib_def_.get(), - optimizer_opts); + item->flib = + NewFunctionLibraryRuntime(device_mgr_.get(), device, graph_def_version, + flib_def_.get(), optimizer_opts); LocalExecutorParams params; params.device = device; diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 7a4cbc2e7e..5068d24df1 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -805,6 +805,7 @@ class BlockingOp : public OpKernel { void Compute(OpKernelContext* ctx) override { blocking_op_state->MoveToState(0, 1); blocking_op_state->AwaitState(2); + blocking_op_state->MoveToState(2, 3); Tensor* out = nullptr; const Tensor& in = ctx->input(0); @@ -815,14 +816,32 @@ class BlockingOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp); REGISTER_OP("BlockingOp").Input("x: float").Output("y: float").Doc(""); -TEST(DirectSessionTest, TestSessionInterOpThreads) { - Graph g(OpRegistry::Global()); +static void TestSessionInterOpThreadsImpl(bool use_function_lib) { + FunctionDefLibrary library_graph_def; + if (use_function_lib) { + const string lib = R"proto( + signature: { + name: "BlockingOpFn" input_arg: { name: "x" type: DT_FLOAT } + output_arg: { name: "y" type: DT_FLOAT }} + node: { ret: "y" op: "BlockingOp" arg: "x" })proto"; + CHECK(protobuf::TextFormat::ParseFromString( + lib, library_graph_def.add_function())); + } + + FunctionLibraryDefinition flib(library_graph_def); + Graph g(&flib); Tensor t(DT_FLOAT, TensorShape({})); t.scalar<float>()() = {1.2}; Node* x = test::graph::Constant(&g, t); - Node* y = test::graph::Unary(&g, "BlockingOp", x); + Node* y; + if (use_function_lib) { + y = test::graph::Unary(&g, "BlockingOpFn", x); + } else { + y = test::graph::Unary(&g, "BlockingOp", x); + } GraphDef def; test::graph::ToGraphDef(&g, &def); + *def.mutable_library() = library_graph_def; // Create session with two inter-op thread pools. SessionOptions options; @@ -832,9 +851,13 @@ TEST(DirectSessionTest, TestSessionInterOpThreads) { ->mutable_optimizer_options() ->set_opt_level(OptimizerOptions_Level_L0); (*options.config.mutable_device_count())["CPU"] = 2; + + options.config.add_session_inter_op_thread_pool(); auto* p = options.config.add_session_inter_op_thread_pool(); - p->set_num_threads(1); // This one has only one thread. - p = options.config.add_session_inter_op_thread_pool(); + p->set_num_threads(1); + const int kLargePool = 0; + const int kSmallPool = 1; + std::unique_ptr<Session> session(NewSession(options)); ASSERT_TRUE(session != nullptr); TF_ASSERT_OK(session->Create(def)); @@ -862,19 +885,30 @@ TEST(DirectSessionTest, TestSessionInterOpThreads) { // For blocking states: // - Starts at 0, BlockingOp::Compute will move to 1. // - This main thread will wait for 1, then move to 2 when other ops are done. - // Moving to 2 unblocks the blocking op. + // Moving to 2 unblocks the blocking op, which then moves to state 3. - // Launch 2 session run calls. Neither will finish until the blocking op is - // unblocked, because it is using all threads in inter_op pool #0. - thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5); + // Run the graph once on the non-limited pool. + thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 1); blocking_op_state = new BlockingOpState(); - add_session_run_call(tp1, y, 0 /* inter_op_pool */); + add_session_run_call(tp1, y, kLargePool); + blocking_op_state->AwaitState(1); + blocking_op_state->MoveToState(1, 2); + blocking_op_state->AwaitState(3); + blocking_op_state->MoveToState(3, 0); + delete tp1; + num_done = 0; + + tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5); + + // Launch 2 session run calls. Neither will finish until the blocking op is + // unblocked, because it is using all threads in the small pool. + add_session_run_call(tp1, y, kSmallPool); blocking_op_state->AwaitState(1); // Wait for the blocking op to Compute. // These will block on <BlockingOpState>. const int kBlockedThreads = 3; for (int i = 0; i < kBlockedThreads; ++i) { - add_session_run_call(tp1, x, 0 /* inter_op_pool */); + add_session_run_call(tp1, x, kSmallPool); } // Launch session calls using the other inter-op pool. These will finish @@ -882,7 +916,7 @@ TEST(DirectSessionTest, TestSessionInterOpThreads) { thread::ThreadPool* tp2 = new thread::ThreadPool(Env::Default(), "tp2", 3); const int kUnblockedThreads = 4; for (int i = 0; i < kUnblockedThreads; ++i) { - add_session_run_call(tp2, x, 1 /* inter_op_pool */); + add_session_run_call(tp2, x, kLargePool); } delete tp2; EXPECT_EQ(kUnblockedThreads, num_done.load()); @@ -895,6 +929,14 @@ TEST(DirectSessionTest, TestSessionInterOpThreads) { blocking_op_state = nullptr; } +TEST(DirectSessionTest, TestSessionInterOpThreads) { + TestSessionInterOpThreadsImpl(false /* use_function_lib */); +} + +TEST(DirectSessionTest, TestSessionInterOpThreadsWithFunctions) { + TestSessionInterOpThreadsImpl(true /* use_function_lib */); +} + TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) { Graph g(OpRegistry::Global()); Tensor t(DT_FLOAT, TensorShape({})); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 21765283fb..bfa4e36ce9 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -967,6 +967,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { params.inputs = &inputs; params.input_device_contexts = &input_device_contexts; params.input_alloc_attrs = &input_alloc_attrs; + params.runner = &runner_; Status s; NodeExecStats* stats = nullptr; diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index bc6a56136f..34fd24e731 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -229,7 +229,7 @@ static const FunctionLibraryRuntime::Handle kInvalidHandle = -1; class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { public: FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Device* device, - Runner runner, int graph_def_version, + int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options); @@ -255,7 +255,6 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { const DeviceMgr* const device_mgr_; Device* const device_; - Runner runner_ = nullptr; const int graph_def_version_; const FunctionLibraryDefinition* const lib_def_; GraphOptimizer optimizer_; @@ -293,12 +292,11 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { }; FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( - const DeviceMgr* dmgr, Device* device, Runner runner, int graph_def_version, + const DeviceMgr* dmgr, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options) : device_mgr_(dmgr), device_(device), - runner_(runner), graph_def_version_(graph_def_version), lib_def_(lib_def), optimizer_(optimizer_options) { @@ -334,6 +332,7 @@ class CallOp : public AsyncOpKernel { done); FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); + opts.runner = ctx->runner(); std::vector<Tensor> args; args.reserve(ctx->num_inputs()); for (int i = 0; i < ctx->num_inputs(); ++i) { @@ -379,6 +378,7 @@ class SymbolicGradientOp : public AsyncOpKernel { FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); + opts.runner = ctx->runner(); std::vector<Tensor> args; args.reserve(ctx->num_inputs()); for (int i = 0; i < ctx->num_inputs(); ++i) { @@ -660,12 +660,14 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, delete frame; return done(s); } + DCHECK(opts.runner != nullptr); + Executor::Args exec_args; // Inherit the step_id from the caller. exec_args.step_id = opts.step_id; exec_args.call_frame = frame; exec_args.cancellation_manager = opts.cancellation_manager; - exec_args.runner = runner_; + exec_args.runner = *opts.runner; // TODO(zhifengc): we can avoid creating rendez here if we know // there is no send/recv nodes in the graph. auto* rendez = new IntraProcessRendezvous(device_mgr_); @@ -693,10 +695,10 @@ bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) { } FunctionLibraryRuntime* NewFunctionLibraryRuntime( - const DeviceMgr* dmgr, Device* device, Runner runner, int graph_def_version, + const DeviceMgr* dmgr, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options) { - return new FunctionLibraryRuntimeImpl(dmgr, device, runner, graph_def_version, + return new FunctionLibraryRuntimeImpl(dmgr, device, graph_def_version, lib_def, optimizer_options); } diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h index f9e072f5b4..26e8b7aa3d 100644 --- a/tensorflow/core/common_runtime/function.h +++ b/tensorflow/core/common_runtime/function.h @@ -33,11 +33,9 @@ namespace tensorflow { // The returned object does not take ownerships of "device" or // "lib_def". The caller must ensure "device" and "lib_def" outlives // the returned object. -typedef std::function<void()> Closure; -typedef std::function<void(Closure)> Runner; FunctionLibraryRuntime* NewFunctionLibraryRuntime( - const DeviceMgr* device_mgr, Device* device, Runner runner, - int graph_def_version, const FunctionLibraryDefinition* lib_def, + const DeviceMgr* device_mgr, Device* device, int graph_def_version, + const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options); // FunctionLibraryRuntime::GetFunctionBody returns a description of an diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 70d00411a5..3460597a8c 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" +#include <atomic> + #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" @@ -149,8 +151,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { lib_def_ = new FunctionLibraryDefinition(proto); delete lib_; OptimizerOptions opts; - lib_ = NewFunctionLibraryRuntime(nullptr, device_, FunctionTestSchedClosure, - TF_GRAPH_DEF_VERSION, lib_def_, opts); + lib_ = NewFunctionLibraryRuntime(nullptr, device_, TF_GRAPH_DEF_VERSION, + lib_def_, opts); } Status Run(const string& name, InstantiateAttrValueSlice attrs, @@ -160,8 +162,17 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { if (!status.ok()) { return status; } + + std::atomic<int32> call_count(0); + std::function<void(std::function<void()>)> runner = + [&call_count](std::function<void()> fn) { + ++call_count; + FunctionTestSchedClosure(fn); + }; + Notification done; FunctionLibraryRuntime::Options opts; + opts.runner = &runner; std::vector<Tensor> out; lib_->Run(opts, handle, args, &out, [&status, &done](const Status& s) { status = s; @@ -175,6 +186,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { for (size_t i = 0; i < rets.size(); ++i) { *rets[i] = out[i]; } + + EXPECT_GE(call_count, 1); // Test runner is used. + return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 87172dc321..c806bd8ac7 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -134,9 +134,6 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions)); } - thread::ThreadPool* pool = worker_env_->compute_pool; - auto runner = [pool](std::function<void()> fn) { pool->Schedule(fn); }; - LocalExecutorParams params; Status s; @@ -169,10 +166,9 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, opseg->AddHold(session); // Function library runtime. - unit->lib = - NewFunctionLibraryRuntime(worker_env_->device_mgr, unit->device, runner, - def->versions().producer(), item->lib_def, - graph_options.optimizer_options()); + unit->lib = NewFunctionLibraryRuntime( + worker_env_->device_mgr, unit->device, def->versions().producer(), + item->lib_def, graph_options.optimizer_options()); // Construct the root executor for the subgraph. params.device = unit->device; diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 71675e72a2..213a89ea94 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -328,6 +328,8 @@ class FunctionLibraryRuntime { CancellationManager* cancellation_manager = nullptr; // The id of the step that is calling this function. int64 step_id = 0; + + std::function<void(std::function<void()>)>* runner = nullptr; }; typedef std::function<void(const Status&)> DoneCallback; virtual void Run(const Options& opts, Handle handle, diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 5470cfddbb..6b5ca31f84 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -533,6 +533,7 @@ class OpKernelContext { // Function call supports. FunctionCallFrame* call_frame = nullptr; FunctionLibraryRuntime* function_library = nullptr; + std::function<void(std::function<void()>)>* runner = nullptr; // TensorSliceReaderCache support. checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; @@ -868,6 +869,10 @@ class OpKernelContext { return params_->function_library; } + std::function<void(std::function<void()>)>* runner() const { + return params_->runner; + } + // Shared resources accessible to this kernel. ResourceMgr* resource_manager() const { return params_->resource_manager; } diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index 5bf898d7d5..b789c408c9 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -134,7 +134,7 @@ Node* Assign(Graph* g, Node* var, Node* val) { Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes, bool keep_dims) { Node* ret; - TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce) + TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce, g->op_registry()) .Input(data) .Input(axes) .Attr("keep_dims", keep_dims) @@ -168,7 +168,7 @@ Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a, Node* RandomNumberGenerator(const string& op, Graph* g, Node* input, DataType dtype) { Node* ret; - TF_CHECK_OK(NodeBuilder(g->NewName("n"), op) + TF_CHECK_OK(NodeBuilder(g->NewName("n"), op, g->op_registry()) .Input(input) .Attr("dtype", dtype) .Attr("seed", 0) @@ -190,14 +190,15 @@ Node* TruncatedNormal(Graph* g, Node* input, DataType dtype) { Node* Unary(Graph* g, const string& func, Node* input, int index) { Node* ret; - TF_CHECK_OK( - NodeBuilder(g->NewName("n"), func).Input(input, index).Finalize(g, &ret)); + TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry()) + .Input(input, index) + .Finalize(g, &ret)); return ret; } Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) { Node* ret; - TF_CHECK_OK(NodeBuilder(g->NewName("n"), func) + TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry()) .Input(in0) .Input(in1) .Finalize(g, &ret)); @@ -206,7 +207,7 @@ Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) { Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) { Node* ret; - auto b = NodeBuilder(g->NewName("n"), func); + auto b = NodeBuilder(g->NewName("n"), func, g->op_registry()); for (Node* n : ins) b = b.Input(n); TF_CHECK_OK(b.Finalize(g, &ret)); return ret; |