aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-06-01 08:34:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-01 09:48:06 -0700
commitcaed96e99e87aa3aef8ba88ee8a799d50b0e597b (patch)
treedc3f0f6e7f837f28e2d6894a836a63ec8314d765
parentdf9e479579a199db2de6721511073cf93b6409f5 (diff)
Change function library to use a runner passed in from the Run call, instead of
one from library-construction time. This is to properly support inter-op thread pools with the function library. Also change testlib graph construction to pass through the Graph's op_registry. Change: 123761344
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc9
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc66
-rw-r--r--tensorflow/core/common_runtime/executor.cc1
-rw-r--r--tensorflow/core/common_runtime/function.cc16
-rw-r--r--tensorflow/core/common_runtime/function.h6
-rw-r--r--tensorflow/core/common_runtime/function_test.cc18
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc10
-rw-r--r--tensorflow/core/framework/function.h2
-rw-r--r--tensorflow/core/framework/op_kernel.h5
-rw-r--r--tensorflow/core/graph/testlib.cc13
10 files changed, 102 insertions, 44 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 1f864b018b..062a6236b6 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -740,9 +740,6 @@ Status DirectSession::GetOrCreateExecutors(
}
}
ek->items.reserve(graphs.size());
- auto runner = [this, pool](Executor::Args::Closure c) {
- SchedClosure(pool, c);
- };
const auto& optimizer_opts =
options_.config.graph_options().optimizer_options();
GraphOptimizer optimizer(optimizer_opts);
@@ -757,9 +754,9 @@ Status DirectSession::GetOrCreateExecutors(
ek->items.resize(ek->items.size() + 1);
auto* item = &(ek->items.back());
- item->flib = NewFunctionLibraryRuntime(device_mgr_.get(), device, runner,
- graph_def_version, flib_def_.get(),
- optimizer_opts);
+ item->flib =
+ NewFunctionLibraryRuntime(device_mgr_.get(), device, graph_def_version,
+ flib_def_.get(), optimizer_opts);
LocalExecutorParams params;
params.device = device;
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 7a4cbc2e7e..5068d24df1 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -805,6 +805,7 @@ class BlockingOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
blocking_op_state->MoveToState(0, 1);
blocking_op_state->AwaitState(2);
+ blocking_op_state->MoveToState(2, 3);
Tensor* out = nullptr;
const Tensor& in = ctx->input(0);
@@ -815,14 +816,32 @@ class BlockingOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("BlockingOp").Device(DEVICE_CPU), BlockingOp);
REGISTER_OP("BlockingOp").Input("x: float").Output("y: float").Doc("");
-TEST(DirectSessionTest, TestSessionInterOpThreads) {
- Graph g(OpRegistry::Global());
+static void TestSessionInterOpThreadsImpl(bool use_function_lib) {
+ FunctionDefLibrary library_graph_def;
+ if (use_function_lib) {
+ const string lib = R"proto(
+ signature: {
+ name: "BlockingOpFn" input_arg: { name: "x" type: DT_FLOAT }
+ output_arg: { name: "y" type: DT_FLOAT }}
+ node: { ret: "y" op: "BlockingOp" arg: "x" })proto";
+ CHECK(protobuf::TextFormat::ParseFromString(
+ lib, library_graph_def.add_function()));
+ }
+
+ FunctionLibraryDefinition flib(library_graph_def);
+ Graph g(&flib);
Tensor t(DT_FLOAT, TensorShape({}));
t.scalar<float>()() = {1.2};
Node* x = test::graph::Constant(&g, t);
- Node* y = test::graph::Unary(&g, "BlockingOp", x);
+ Node* y;
+ if (use_function_lib) {
+ y = test::graph::Unary(&g, "BlockingOpFn", x);
+ } else {
+ y = test::graph::Unary(&g, "BlockingOp", x);
+ }
GraphDef def;
test::graph::ToGraphDef(&g, &def);
+ *def.mutable_library() = library_graph_def;
// Create session with two inter-op thread pools.
SessionOptions options;
@@ -832,9 +851,13 @@ TEST(DirectSessionTest, TestSessionInterOpThreads) {
->mutable_optimizer_options()
->set_opt_level(OptimizerOptions_Level_L0);
(*options.config.mutable_device_count())["CPU"] = 2;
+
+ options.config.add_session_inter_op_thread_pool();
auto* p = options.config.add_session_inter_op_thread_pool();
- p->set_num_threads(1); // This one has only one thread.
- p = options.config.add_session_inter_op_thread_pool();
+ p->set_num_threads(1);
+ const int kLargePool = 0;
+ const int kSmallPool = 1;
+
std::unique_ptr<Session> session(NewSession(options));
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
@@ -862,19 +885,30 @@ TEST(DirectSessionTest, TestSessionInterOpThreads) {
// For blocking states:
// - Starts at 0, BlockingOp::Compute will move to 1.
// - This main thread will wait for 1, then move to 2 when other ops are done.
- // Moving to 2 unblocks the blocking op.
+ // Moving to 2 unblocks the blocking op, which then moves to state 3.
- // Launch 2 session run calls. Neither will finish until the blocking op is
- // unblocked, because it is using all threads in inter_op pool #0.
- thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5);
+ // Run the graph once on the non-limited pool.
+ thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 1);
blocking_op_state = new BlockingOpState();
- add_session_run_call(tp1, y, 0 /* inter_op_pool */);
+ add_session_run_call(tp1, y, kLargePool);
+ blocking_op_state->AwaitState(1);
+ blocking_op_state->MoveToState(1, 2);
+ blocking_op_state->AwaitState(3);
+ blocking_op_state->MoveToState(3, 0);
+ delete tp1;
+ num_done = 0;
+
+ tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5);
+
+ // Launch 2 session run calls. Neither will finish until the blocking op is
+ // unblocked, because it is using all threads in the small pool.
+ add_session_run_call(tp1, y, kSmallPool);
blocking_op_state->AwaitState(1); // Wait for the blocking op to Compute.
// These will block on <BlockingOpState>.
const int kBlockedThreads = 3;
for (int i = 0; i < kBlockedThreads; ++i) {
- add_session_run_call(tp1, x, 0 /* inter_op_pool */);
+ add_session_run_call(tp1, x, kSmallPool);
}
// Launch session calls using the other inter-op pool. These will finish
@@ -882,7 +916,7 @@ TEST(DirectSessionTest, TestSessionInterOpThreads) {
thread::ThreadPool* tp2 = new thread::ThreadPool(Env::Default(), "tp2", 3);
const int kUnblockedThreads = 4;
for (int i = 0; i < kUnblockedThreads; ++i) {
- add_session_run_call(tp2, x, 1 /* inter_op_pool */);
+ add_session_run_call(tp2, x, kLargePool);
}
delete tp2;
EXPECT_EQ(kUnblockedThreads, num_done.load());
@@ -895,6 +929,14 @@ TEST(DirectSessionTest, TestSessionInterOpThreads) {
blocking_op_state = nullptr;
}
+TEST(DirectSessionTest, TestSessionInterOpThreads) {
+ TestSessionInterOpThreadsImpl(false /* use_function_lib */);
+}
+
+TEST(DirectSessionTest, TestSessionInterOpThreadsWithFunctions) {
+ TestSessionInterOpThreadsImpl(true /* use_function_lib */);
+}
+
TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) {
Graph g(OpRegistry::Global());
Tensor t(DT_FLOAT, TensorShape({}));
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 21765283fb..bfa4e36ce9 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -967,6 +967,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
params.inputs = &inputs;
params.input_device_contexts = &input_device_contexts;
params.input_alloc_attrs = &input_alloc_attrs;
+ params.runner = &runner_;
Status s;
NodeExecStats* stats = nullptr;
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index bc6a56136f..34fd24e731 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -229,7 +229,7 @@ static const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
public:
FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Device* device,
- Runner runner, int graph_def_version,
+ int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options);
@@ -255,7 +255,6 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
const DeviceMgr* const device_mgr_;
Device* const device_;
- Runner runner_ = nullptr;
const int graph_def_version_;
const FunctionLibraryDefinition* const lib_def_;
GraphOptimizer optimizer_;
@@ -293,12 +292,11 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
};
FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
- const DeviceMgr* dmgr, Device* device, Runner runner, int graph_def_version,
+ const DeviceMgr* dmgr, Device* device, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options)
: device_mgr_(dmgr),
device_(device),
- runner_(runner),
graph_def_version_(graph_def_version),
lib_def_(lib_def),
optimizer_(optimizer_options) {
@@ -334,6 +332,7 @@ class CallOp : public AsyncOpKernel {
done);
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
+ opts.runner = ctx->runner();
std::vector<Tensor> args;
args.reserve(ctx->num_inputs());
for (int i = 0; i < ctx->num_inputs(); ++i) {
@@ -379,6 +378,7 @@ class SymbolicGradientOp : public AsyncOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
+ opts.runner = ctx->runner();
std::vector<Tensor> args;
args.reserve(ctx->num_inputs());
for (int i = 0; i < ctx->num_inputs(); ++i) {
@@ -660,12 +660,14 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
delete frame;
return done(s);
}
+ DCHECK(opts.runner != nullptr);
+
Executor::Args exec_args;
// Inherit the step_id from the caller.
exec_args.step_id = opts.step_id;
exec_args.call_frame = frame;
exec_args.cancellation_manager = opts.cancellation_manager;
- exec_args.runner = runner_;
+ exec_args.runner = *opts.runner;
// TODO(zhifengc): we can avoid creating rendez here if we know
// there is no send/recv nodes in the graph.
auto* rendez = new IntraProcessRendezvous(device_mgr_);
@@ -693,10 +695,10 @@ bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
}
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
- const DeviceMgr* dmgr, Device* device, Runner runner, int graph_def_version,
+ const DeviceMgr* dmgr, Device* device, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options) {
- return new FunctionLibraryRuntimeImpl(dmgr, device, runner, graph_def_version,
+ return new FunctionLibraryRuntimeImpl(dmgr, device, graph_def_version,
lib_def, optimizer_options);
}
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
index f9e072f5b4..26e8b7aa3d 100644
--- a/tensorflow/core/common_runtime/function.h
+++ b/tensorflow/core/common_runtime/function.h
@@ -33,11 +33,9 @@ namespace tensorflow {
// The returned object does not take ownerships of "device" or
// "lib_def". The caller must ensure "device" and "lib_def" outlives
// the returned object.
-typedef std::function<void()> Closure;
-typedef std::function<void(Closure)> Runner;
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
- const DeviceMgr* device_mgr, Device* device, Runner runner,
- int graph_def_version, const FunctionLibraryDefinition* lib_def,
+ const DeviceMgr* device_mgr, Device* device, int graph_def_version,
+ const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options);
// FunctionLibraryRuntime::GetFunctionBody returns a description of an
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 70d00411a5..3460597a8c 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
+#include <atomic>
+
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
@@ -149,8 +151,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
lib_def_ = new FunctionLibraryDefinition(proto);
delete lib_;
OptimizerOptions opts;
- lib_ = NewFunctionLibraryRuntime(nullptr, device_, FunctionTestSchedClosure,
- TF_GRAPH_DEF_VERSION, lib_def_, opts);
+ lib_ = NewFunctionLibraryRuntime(nullptr, device_, TF_GRAPH_DEF_VERSION,
+ lib_def_, opts);
}
Status Run(const string& name, InstantiateAttrValueSlice attrs,
@@ -160,8 +162,17 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
if (!status.ok()) {
return status;
}
+
+ std::atomic<int32> call_count(0);
+ std::function<void(std::function<void()>)> runner =
+ [&call_count](std::function<void()> fn) {
+ ++call_count;
+ FunctionTestSchedClosure(fn);
+ };
+
Notification done;
FunctionLibraryRuntime::Options opts;
+ opts.runner = &runner;
std::vector<Tensor> out;
lib_->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
status = s;
@@ -175,6 +186,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
for (size_t i = 0; i < rets.size(); ++i) {
*rets[i] = out[i];
}
+
+ EXPECT_GE(call_count, 1); // Test runner is used.
+
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 87172dc321..c806bd8ac7 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -134,9 +134,6 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
}
- thread::ThreadPool* pool = worker_env_->compute_pool;
- auto runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
-
LocalExecutorParams params;
Status s;
@@ -169,10 +166,9 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
opseg->AddHold(session);
// Function library runtime.
- unit->lib =
- NewFunctionLibraryRuntime(worker_env_->device_mgr, unit->device, runner,
- def->versions().producer(), item->lib_def,
- graph_options.optimizer_options());
+ unit->lib = NewFunctionLibraryRuntime(
+ worker_env_->device_mgr, unit->device, def->versions().producer(),
+ item->lib_def, graph_options.optimizer_options());
// Construct the root executor for the subgraph.
params.device = unit->device;
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 71675e72a2..213a89ea94 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -328,6 +328,8 @@ class FunctionLibraryRuntime {
CancellationManager* cancellation_manager = nullptr;
// The id of the step that is calling this function.
int64 step_id = 0;
+
+ std::function<void(std::function<void()>)>* runner = nullptr;
};
typedef std::function<void(const Status&)> DoneCallback;
virtual void Run(const Options& opts, Handle handle,
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 5470cfddbb..6b5ca31f84 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -533,6 +533,7 @@ class OpKernelContext {
// Function call supports.
FunctionCallFrame* call_frame = nullptr;
FunctionLibraryRuntime* function_library = nullptr;
+ std::function<void(std::function<void()>)>* runner = nullptr;
// TensorSliceReaderCache support.
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
@@ -868,6 +869,10 @@ class OpKernelContext {
return params_->function_library;
}
+ std::function<void(std::function<void()>)>* runner() const {
+ return params_->runner;
+ }
+
// Shared resources accessible to this kernel.
ResourceMgr* resource_manager() const { return params_->resource_manager; }
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index 5bf898d7d5..b789c408c9 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -134,7 +134,7 @@ Node* Assign(Graph* g, Node* var, Node* val) {
Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
bool keep_dims) {
Node* ret;
- TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce)
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce, g->op_registry())
.Input(data)
.Input(axes)
.Attr("keep_dims", keep_dims)
@@ -168,7 +168,7 @@ Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
Node* RandomNumberGenerator(const string& op, Graph* g, Node* input,
DataType dtype) {
Node* ret;
- TF_CHECK_OK(NodeBuilder(g->NewName("n"), op)
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), op, g->op_registry())
.Input(input)
.Attr("dtype", dtype)
.Attr("seed", 0)
@@ -190,14 +190,15 @@ Node* TruncatedNormal(Graph* g, Node* input, DataType dtype) {
Node* Unary(Graph* g, const string& func, Node* input, int index) {
Node* ret;
- TF_CHECK_OK(
- NodeBuilder(g->NewName("n"), func).Input(input, index).Finalize(g, &ret));
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
+ .Input(input, index)
+ .Finalize(g, &ret));
return ret;
}
Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
Node* ret;
- TF_CHECK_OK(NodeBuilder(g->NewName("n"), func)
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), func, g->op_registry())
.Input(in0)
.Input(in1)
.Finalize(g, &ret));
@@ -206,7 +207,7 @@ Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) {
Node* ret;
- auto b = NodeBuilder(g->NewName("n"), func);
+ auto b = NodeBuilder(g->NewName("n"), func, g->op_registry());
for (Node* n : ins) b = b.Input(n);
TF_CHECK_OK(b.Finalize(g, &ret));
return ret;