aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function_test.cc
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-03-09 12:20:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-09 12:24:51 -0800
commit20dfc25c378c600fac683e62dc8a1ed2a522711c (patch)
treedf618d2f7874bcc08ea3de65297f7ee259f65a51 /tensorflow/core/common_runtime/function_test.cc
parent61a744fffbcc68e453aafc6eaa2c7ff2318a3584 (diff)
Allowing for FunctionLibraryRuntime::Run calls to not be provided with a runner to execute kernels with. In that case, it defaults to using the threadpool provided by the device.
Also makes sure each device has a default threadpool to fall back on. PiperOrigin-RevId: 188520648
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r--tensorflow/core/common_runtime/function_test.cc116
1 files changed, 92 insertions, 24 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 63ad0d231c..d7e5f0018e 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
@@ -135,7 +136,8 @@ TEST_F(FunctionTest, WXPlusB) {
class FunctionLibraryRuntimeTest : public ::testing::Test {
protected:
- void Init(const std::vector<FunctionDef>& flib) {
+ void Init(const std::vector<FunctionDef>& flib,
+ thread::ThreadPool* default_thread_pool = nullptr) {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 3});
@@ -149,7 +151,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
device_mgr_.reset(new DeviceMgr(devices_));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
- opts, nullptr /* cluster_flr */));
+ opts, default_thread_pool, nullptr /* cluster_flr */));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");
@@ -158,16 +160,20 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime::Options opts,
- const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
+ const std::vector<Tensor>& args, std::vector<Tensor*> rets,
+ bool add_runner = true) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
[&call_count](std::function<void()> fn) {
++call_count;
test::function::FunctionTestSchedClosure(fn);
};
-
+ if (add_runner) {
+ opts.runner = &runner;
+ } else {
+ opts.runner = nullptr;
+ }
Notification done;
- opts.runner = &runner;
std::vector<Tensor> out;
Status status;
flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
@@ -183,7 +189,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
*rets[i] = out[i];
}
- EXPECT_GE(call_count, 1); // Test runner is used.
+ if (add_runner) {
+ EXPECT_GE(call_count, 1); // Test runner is used.
+ }
return Status::OK();
}
@@ -204,24 +212,25 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
const std::vector<Tensor>& args,
- std::vector<Tensor*> rets) {
+ std::vector<Tensor*> rets, bool add_runner = true) {
return InstantiateAndRun(flr, name, attrs,
FunctionLibraryRuntime::InstantiateOptions(), args,
- std::move(rets));
+ std::move(rets), add_runner);
}
Status InstantiateAndRun(
FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
- const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
+ const std::vector<Tensor>& args, std::vector<Tensor*> rets,
+ bool add_runner = true) {
FunctionLibraryRuntime::Handle handle;
Status status = flr->Instantiate(name, attrs, options, &handle);
if (!status.ok()) {
return status;
}
FunctionLibraryRuntime::Options opts;
- status = Run(flr, handle, opts, args, rets);
+ status = Run(flr, handle, opts, args, rets, add_runner);
if (!status.ok()) return status;
// Release the handle and try running again. It should not succeed.
@@ -237,16 +246,20 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
- FunctionLibraryRuntime::Options opts, CallFrameInterface* frame) {
+ FunctionLibraryRuntime::Options opts, CallFrameInterface* frame,
+ bool add_runner = true) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
[&call_count](std::function<void()> fn) {
++call_count;
test::function::FunctionTestSchedClosure(fn);
};
-
+ if (add_runner) {
+ opts.runner = &runner;
+ } else {
+ opts.runner = nullptr;
+ }
Notification done;
- opts.runner = &runner;
std::vector<Tensor> out;
Status status;
flr->Run(opts, handle, frame, [&status, &done](const Status& s) {
@@ -258,7 +271,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return status;
}
- EXPECT_GE(call_count, 1); // Test runner is used.
+ if (add_runner) {
+ EXPECT_GE(call_count, 1); // Test runner is used.
+ }
return Status::OK();
}
@@ -447,7 +462,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
{
// Simple case: instantiating with no state_handle.
for (int32 expected : {6, 4}) {
- TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
@@ -460,7 +475,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
Instantiate(flr0_, "RandomUniformWrapper", {}, &handle_non_isolated));
EXPECT_EQ(handle, handle_non_isolated);
for (int32 expected : {0, 1}) {
- TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
@@ -475,7 +490,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
&handle_isolated));
EXPECT_NE(handle, handle_isolated);
for (int32 expected : {6, 4, 0, 1}) {
- TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
@@ -490,7 +505,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
&handle_isolated));
EXPECT_NE(handle, handle_isolated);
for (int32 expected : {6, 4, 0, 1}) {
- TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
@@ -507,7 +522,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
&handle_isolated));
EXPECT_NE(handle, handle_isolated);
for (int32 expected : {6, 4, 0, 1}) {
- TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
TF_CHECK_OK(flr0_->ReleaseHandle(handle_isolated));
@@ -515,6 +530,59 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
}
}
+TEST_F(FunctionLibraryRuntimeTest, DefaultThreadpool) {
+ using test::function::blocking_op_state;
+ using test::function::BlockingOpState;
+
+ thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "FLRTest", 1);
+ Init({test::function::BlockingOpFn(), test::function::XTimesTwo()}, tp);
+
+ auto x = test::AsScalar<float>(1.3);
+ Tensor y;
+ blocking_op_state = new BlockingOpState();
+
+ thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5);
+ bool finished_running = false;
+ tp1->Schedule([&x, &y, &finished_running, this]() {
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "BlockingOpFn", {}, {x}, {&y},
+ false /* add_runner */));
+ finished_running = true;
+ });
+
+ // InstantiateAndRun shouldn't finish because BlockingOpFn should be blocked.
+ EXPECT_FALSE(finished_running);
+
+ FunctionLibraryRuntime::Handle h;
+ TF_CHECK_OK(Instantiate(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, &h));
+
+ auto x1 = test::AsTensor<float>({1, 2, 3, 4});
+ Tensor y1;
+ std::atomic<int32> num_done(0);
+ FunctionLibraryRuntime::Options opts;
+ for (int i = 0; i < 4; ++i) {
+ tp1->Schedule([&h, &x1, &y1, &opts, &num_done, this]() {
+ TF_CHECK_OK(Run(flr0_, h, opts, {x1}, {&y1}, false /* add_runner */));
+ num_done.fetch_add(1);
+ });
+ }
+ // All the 4 Run() calls should be blocked because the runner is occupied.
+ EXPECT_EQ(0, num_done.load());
+
+ blocking_op_state->AwaitState(1);
+ blocking_op_state->MoveToState(1, 2);
+ // Now the runner should be unblocked and all the other Run() calls should
+ // proceed.
+ blocking_op_state->AwaitState(3);
+ blocking_op_state->MoveToState(3, 0);
+ delete tp1;
+ EXPECT_TRUE(finished_running);
+ EXPECT_EQ(4, num_done.load());
+
+ delete blocking_op_state;
+ blocking_op_state = nullptr;
+ delete tp;
+}
+
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
@@ -787,7 +855,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
- s.WithOpName("x4/x2/scale/_12__cf__6")
+ s.WithOpName("x4/x2/scale/_12__cf__10")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@@ -993,13 +1061,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
- s.WithOpName("scale/_6__cf__11")
+ s.WithOpName("scale/_6__cf__15")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
- s.WithOpName("Func/_1/sy/_5__cf__10")
+ s.WithOpName("Func/_1/sy/_5__cf__14")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(
@@ -1247,14 +1315,14 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
opts.source_device = "/device:CPU:1";
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
- TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true));
test::ExpectTensorEqual<string>(
y,
test::AsTensor<string>({"/job:localhost/replica:0/task:0/device:CPU:1"},
TensorShape({})));
opts.remote_execution = true;
opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
- TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
+ TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}, true));
test::ExpectTensorEqual<string>(
y,
test::AsTensor<string>({"/job:localhost/replica:0/task:0/device:CPU:1"},