diff options
author | Rohan Jain <rohanj@google.com> | 2017-08-17 11:33:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-17 11:36:50 -0700 |
commit | 935ff49201edd7a6297b313fb9545d1299b9a28d (patch) | |
tree | 36486015014d33efa99d7fd0875eb1545bd518cb /tensorflow/core/common_runtime/function_test.cc | |
parent | d94dca2174f0c05dfa03796c3ae31d345813d025 (diff) |
Automated g4 rollback of changelist 165521057
PiperOrigin-RevId: 165604864
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 203 |
1 files changed, 84 insertions, 119 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index a9f06c4df0..3ca4457b00 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include <atomic> -#include <utility> #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" @@ -25,7 +24,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" -#include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op.h" @@ -36,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -50,18 +49,40 @@ Status GetOpSig(const string& op, const OpDef** sig) { return OpRegistry::Global()->LookUpOpDef(op, sig); } +void FunctionTestSchedClosure(std::function<void()> fn) { + static thread::ThreadPool* w = + new thread::ThreadPool(Env::Default(), "Test", 8); + w->Schedule(std::move(fn)); +} + void HasError(const Status& s, const string& substr) { EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) << s << ", expected substring " << substr; } +// A helper class to make AttrSlice from initializer lists +class Attrs { + public: + Attrs(const std::initializer_list< // NOLINT(runtime/explicit) + std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) { + for (const auto& aval : attrs) { + map_.insert({aval.first, aval.second.proto}); + } + } + + operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) + + private: + AttrValueMap map_; +}; + class FunctionTest : public ::testing::Test { protected: FunctionTest() : device_(DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0")) {} - void Create(const FunctionDef& fdef, test::function::Attrs attrs) { + void Create(const FunctionDef& fdef, Attrs attrs) { exec_ = nullptr; InstantiationResult result; TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result)); @@ -96,7 +117,7 @@ class FunctionTest : public ::testing::Test { TF_CHECK_OK(frame.SetArgs(args)); Executor::Args exec_args; exec_args.call_frame = &frame; - exec_args.runner = test::function::FunctionTestSchedClosure; + exec_args.runner = FunctionTestSchedClosure; TF_CHECK_OK(exec_->Run(exec_args)); std::vector<Tensor> computed; TF_CHECK_OK(frame.GetRetvals(&computed)); @@ -133,42 +154,41 @@ TEST_F(FunctionTest, WXPlusB) { class FunctionLibraryRuntimeTest : public ::testing::Test { protected: - void Init(const std::vector<FunctionDef>& flib) { - SessionOptions options; - auto* device_count = options.config.mutable_device_count(); - device_count->insert({"CPU", 3}); - TF_CHECK_OK(DeviceFactory::AddDevices( - options, "/job:localhost/replica:0/task:0", &devices_)); + FunctionLibraryRuntimeTest() + : device_(DeviceFactory::NewDevice("CPU", {}, + "/job:localhost/replica:0/task:0")) {} + void Init(const std::vector<FunctionDef>& flib) { FunctionDefLibrary proto; for (const auto& fdef : flib) *(proto.add_function()) = fdef; lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); OptimizerOptions opts; - device_mgr_.reset(new DeviceMgr(devices_)); - pflr_.reset(new ProcessFunctionLibraryRuntime( - device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), - opts)); - flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); - flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1"); - flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2"); + lib_ = + NewFunctionLibraryRuntime(nullptr, Env::Default(), device_.get(), + TF_GRAPH_DEF_VERSION, lib_def_.get(), opts); fdef_lib_ = lib_def_->ToProto(); } - Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle, - const std::vector<Tensor>& args, std::vector<Tensor*> rets) { + Status Run(const string& name, Attrs attrs, const std::vector<Tensor>& args, + std::vector<Tensor*> rets) { + FunctionLibraryRuntime::Handle handle; + Status status = lib_->Instantiate(name, attrs, &handle); + if (!status.ok()) { + return status; + } + std::atomic<int32> call_count(0); std::function<void(std::function<void()>)> runner = [&call_count](std::function<void()> fn) { ++call_count; - test::function::FunctionTestSchedClosure(fn); + FunctionTestSchedClosure(fn); }; Notification done; FunctionLibraryRuntime::Options opts; opts.runner = &runner; std::vector<Tensor> out; - Status status; - flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) { + lib_->Run(opts, handle, args, &out, [&status, &done](const Status& s) { status = s; done.Notify(); }); @@ -186,54 +206,28 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return Status::OK(); } - Status Instantiate(FunctionLibraryRuntime* flr, const string& name, - test::function::Attrs attrs, - FunctionLibraryRuntime::Handle* handle) { - Status status = flr->Instantiate(name, attrs, handle); - if (!status.ok()) { - return status; - } - return Status::OK(); - } - - Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name, - test::function::Attrs attrs, - const std::vector<Tensor>& args, - std::vector<Tensor*> rets) { - FunctionLibraryRuntime::Handle handle; - Status status = flr->Instantiate(name, attrs, &handle); - if (!status.ok()) { - return status; - } - return Run(flr, handle, args, std::move(rets)); - } - - std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr, - const string& name, - test::function::Attrs attrs) { + std::unique_ptr<Graph> GetFuncBody(const string& name, Attrs attrs) { FunctionLibraryRuntime::Handle handle; - Status status = flr->Instantiate(name, attrs, &handle); + Status status = lib_->Instantiate(name, attrs, &handle); if (!status.ok()) { LOG(ERROR) << status; return nullptr; } - const FunctionBody* fbody = flr->GetFunctionBody(handle); + const FunctionBody* fbody = lib_->GetFunctionBody(handle); CHECK_NOTNULL(fbody); std::unique_ptr<Graph> ret(new Graph(lib_def_.get())); CopyGraph(*fbody->graph, ret.get()); return ret; } - std::unique_ptr<Graph> GetGradBody(FunctionLibraryRuntime* flr, - const string& func, - test::function::Attrs attrs) { + std::unique_ptr<Graph> GetGradBody(const string& func, Attrs attrs) { FunctionLibraryRuntime::Handle handle; - Status status = flr->Instantiate(func, attrs, &handle); + Status status = lib_->Instantiate(func, attrs, &handle); if (!status.ok()) { LOG(ERROR) << status; return nullptr; } - const FunctionBody* fbody = flr->GetFunctionBody(handle); + const FunctionBody* fbody = lib_->GetFunctionBody(handle); CHECK_NOTNULL(fbody); std::unique_ptr<FunctionBody> gbody(SymbolicGradient(*fbody)); CHECK_NOTNULL(gbody); @@ -242,29 +236,24 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return ret; } - FunctionLibraryRuntime* flr0_; - FunctionLibraryRuntime* flr1_; - FunctionLibraryRuntime* flr2_; - std::vector<Device*> devices_; - std::unique_ptr<DeviceMgr> device_mgr_; + std::unique_ptr<Device> device_; std::unique_ptr<FunctionLibraryDefinition> lib_def_; - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; + std::unique_ptr<FunctionLibraryRuntime> lib_; FunctionDefLibrary fdef_lib_; }; TEST_F(FunctionLibraryRuntimeTest, IsStateful) { Init({}); - EXPECT_TRUE(flr0_->IsStateful("Variable")); - EXPECT_TRUE(flr0_->IsStateful("VariableV2")); - EXPECT_FALSE(flr0_->IsStateful("Matmul")); + EXPECT_TRUE(lib_->IsStateful("Variable")); + EXPECT_TRUE(lib_->IsStateful("VariableV2")); + EXPECT_FALSE(lib_->IsStateful("Matmul")); } TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) { Init({test::function::XTimesTwo()}); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - TF_CHECK_OK( - InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK(Run("XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); } @@ -273,14 +262,11 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesN) { test::function::XTimes16()}); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - TF_CHECK_OK( - InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK(Run("XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); - TF_CHECK_OK( - InstantiateAndRun(flr0_, "XTimesFour", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK(Run("XTimesFour", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16})); - TF_CHECK_OK( - InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK(Run("XTimes16", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64})); } @@ -308,7 +294,7 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name, TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); ASSERT_TRUE(g != nullptr); { @@ -326,7 +312,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(flr0_, g.get()); + ExpandInlineFunctions(lib_.get(), g.get()); { Scope s = Scope::NewRootScope(); TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); @@ -348,7 +334,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(flr0_, g.get()); + ExpandInlineFunctions(lib_.get(), g.get()); GraphDef e2; { Scope s = Scope::NewRootScope(); @@ -387,7 +373,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { } // No further inlining. - ExpandInlineFunctions(flr0_, g.get()); + ExpandInlineFunctions(lib_.get(), g.get()); { GraphDef actual; g->ToGraphDef(&actual); @@ -439,7 +425,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { TF_ASSERT_OK(s.ToGraph(g.get())); } - ExpandInlineFunctions(flr0_, g.get()); + ExpandInlineFunctions(lib_.get(), g.get()); { Scope s = Scope::NewRootScope(); TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); @@ -463,7 +449,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(flr0_, g.get()); + ExpandInlineFunctions(lib_.get(), g.get()); { Scope s = Scope::NewRootScope(); TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); @@ -509,10 +495,10 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); ASSERT_TRUE(g != nullptr); - ExpandInlineFunctions(flr0_, g.get()); - OptimizeGraph(flr0_, &g); + ExpandInlineFunctions(lib_.get(), g.get()); + OptimizeGraph(lib_.get(), &g); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -555,9 +541,9 @@ TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) { // Return {{"o", "g:output"}}); Init({test::function::Swap(), func}); - std::unique_ptr<Graph> g = GetFuncBody(flr0_, "ManySwapsNodeDef", {}); + std::unique_ptr<Graph> g = GetFuncBody("ManySwapsNodeDef", {}); ASSERT_TRUE(g != nullptr); - OptimizeGraph(flr0_, &g); + OptimizeGraph(lib_.get(), &g); const char* e0 = R"P( (n3:float, n2:float) -> (n3:float) { } @@ -588,9 +574,9 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { {{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}}, {{"o", "o:z:0"}}); Init({test::function::Swap(), func}); - std::unique_ptr<Graph> g = GetFuncBody(flr0_, "ManySwapsFirst", {}); + std::unique_ptr<Graph> g = GetFuncBody("ManySwapsFirst", {}); ASSERT_TRUE(g != nullptr); - OptimizeGraph(flr0_, &g); + OptimizeGraph(lib_.get(), &g); // NOTE: We can remove func0, func1, func2, func9 with a control edge n8->n5. // But we don't have a pass doing that. @@ -623,7 +609,7 @@ TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) { Init({test::function::XTimesTwo(), test::function::XTimesFour()}); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - HasError(InstantiateAndRun(flr0_, "Foo", {{"T", DT_FLOAT}}, {x}, {&y}), + HasError(Run("Foo", {{"T", DT_FLOAT}}, {x}, {&y}), "Not found: Function Foo is not defined."); } @@ -646,27 +632,25 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) { // Instantiating "XTimesTwo" should fail. FunctionLibraryRuntime::Handle handle; - HasError(flr0_->Instantiate( - "XTimesTwo", test::function::Attrs({{"T", DT_FLOAT}}), &handle), + HasError(lib_->Instantiate("XTimesTwo", Attrs({{"T", DT_FLOAT}}), &handle), "Not found: type attr not found"); // But XTimesFour and XTimes16 instantiation should succeed. Only // when they run, they fail because XTimesTwo is bad. - TF_CHECK_OK(flr0_->Instantiate( - "XTimesFour", test::function::Attrs({{"T", DT_FLOAT}}), &handle)); - TF_CHECK_OK(flr0_->Instantiate( - "XTimes16", test::function::Attrs({{"T", DT_FLOAT}}), &handle)); + TF_CHECK_OK( + lib_->Instantiate("XTimesFour", Attrs({{"T", DT_FLOAT}}), &handle)); + TF_CHECK_OK(lib_->Instantiate("XTimes16", Attrs({{"T", DT_FLOAT}}), &handle)); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - HasError(InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}), + HasError(Run("XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}), "type attr not found"); } TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr<Graph> f = GetFuncBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> f = GetFuncBody("XTimesTwo", {{"T", DT_FLOAT}}); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -682,7 +666,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { TF_EXPECT_GRAPH_EQ(expected, actual); } - std::unique_ptr<Graph> g = GetGradBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> g = GetGradBody("XTimesTwo", {{"T", DT_FLOAT}}); { Scope s = Scope::NewRootScope(); @@ -706,7 +690,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { TF_EXPECT_GRAPH_EQ(expected, actual); } - OptimizeGraph(flr0_, &g); + OptimizeGraph(lib_.get(), &g); { Scope s = Scope::NewRootScope(); @@ -742,7 +726,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) { Init({}); auto T = DT_FLOAT; std::unique_ptr<Graph> g = GetFuncBody( - flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}}); + "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}}); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -772,7 +756,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) { Init({}); auto T = DT_FLOAT; std::unique_ptr<Graph> g = GetFuncBody( - flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}}); + "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}}); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -828,7 +812,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { Init({test, grad}); - std::unique_ptr<Graph> g = GetFuncBody(flr0_, "TestGrad", {}); + std::unique_ptr<Graph> g = GetFuncBody("TestGrad", {}); ASSERT_TRUE(g != nullptr); { Scope s = Scope::NewRootScope(); @@ -852,7 +836,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(flr0_, g.get()); + ExpandInlineFunctions(lib_.get(), g.get()); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -904,7 +888,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { TF_EXPECT_GRAPH_EQ(expected, actual); } - OptimizeGraph(flr0_, &g); + OptimizeGraph(lib_.get(), &g); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -955,25 +939,6 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { } } -TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { - Init({test::function::FindDevice()}); - FunctionLibraryRuntime::Handle handle; - TF_CHECK_OK(Instantiate( - flr0_, "FindDevice", - {{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle)); - - Tensor y; - // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1. - TF_CHECK_OK(Run(flr1_, handle, {}, {&y})); - test::ExpectTensorEqual<string>( - y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"}, - TensorShape({}))); - TF_CHECK_OK(Run(flr2_, handle, {}, {&y})); - test::ExpectTensorEqual<string>( - y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"}, - TensorShape({}))); -} - namespace { bool DoNothing(Graph* g) { return false; } |