diff options
author | Rohan Jain <rohanj@google.com> | 2017-08-17 17:20:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-17 17:24:41 -0700 |
commit | 19a55725af8102d72d4e081c5139f0e4bd5a4bb7 (patch) | |
tree | 971673c250a44e0c4cfa4ab634a7c4c96f8ebd33 /tensorflow/core/common_runtime/function_test.cc | |
parent | 8c0853db731cf80cfeec9dfb4edab95961aaa585 (diff) |
Allowing functions to run across devices. This change expands the ProcessFunctionLibraryRuntime library to Instantiate and Run functions on different devices. When a FunctionLibraryRuntime encounters a function with a target that is another device, it delegates Instantiate() and Run() calls to the ProcessFunctionLibraryRuntime.
This change also moves the table_ containing all function instantiations to the PFLR instead of the FunctionLibraryRuntime.
PiperOrigin-RevId: 165651194
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 203 |
1 files changed, 119 insertions, 84 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 3ca4457b00..a9f06c4df0 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include <atomic> +#include <utility> #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op.h" @@ -34,7 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -49,40 +50,18 @@ Status GetOpSig(const string& op, const OpDef** sig) { return OpRegistry::Global()->LookUpOpDef(op, sig); } -void FunctionTestSchedClosure(std::function<void()> fn) { - static thread::ThreadPool* w = - new thread::ThreadPool(Env::Default(), "Test", 8); - w->Schedule(std::move(fn)); -} - void HasError(const Status& s, const string& substr) { EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) << s << ", expected substring " << substr; } -// A helper class to make AttrSlice from initializer lists -class Attrs { - public: - Attrs(const std::initializer_list< // NOLINT(runtime/explicit) - std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) { - for (const auto& aval : attrs) { - map_.insert({aval.first, aval.second.proto}); - } - } - - operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) - - private: - AttrValueMap map_; -}; - class FunctionTest : public ::testing::Test { protected: FunctionTest() : device_(DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0")) {} - void Create(const FunctionDef& fdef, Attrs attrs) { + void Create(const FunctionDef& fdef, test::function::Attrs attrs) { exec_ = nullptr; InstantiationResult result; TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result)); @@ -117,7 +96,7 @@ class FunctionTest : public ::testing::Test { TF_CHECK_OK(frame.SetArgs(args)); Executor::Args exec_args; exec_args.call_frame = &frame; - exec_args.runner = FunctionTestSchedClosure; + exec_args.runner = test::function::FunctionTestSchedClosure; TF_CHECK_OK(exec_->Run(exec_args)); std::vector<Tensor> computed; TF_CHECK_OK(frame.GetRetvals(&computed)); @@ -154,41 +133,42 @@ TEST_F(FunctionTest, WXPlusB) { class FunctionLibraryRuntimeTest : public ::testing::Test { protected: - FunctionLibraryRuntimeTest() - : device_(DeviceFactory::NewDevice("CPU", {}, - "/job:localhost/replica:0/task:0")) {} - void Init(const std::vector<FunctionDef>& flib) { + SessionOptions options; + auto* device_count = options.config.mutable_device_count(); + device_count->insert({"CPU", 3}); + TF_CHECK_OK(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices_)); + FunctionDefLibrary proto; for (const auto& fdef : flib) *(proto.add_function()) = fdef; lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); OptimizerOptions opts; - lib_ = - NewFunctionLibraryRuntime(nullptr, Env::Default(), device_.get(), - TF_GRAPH_DEF_VERSION, lib_def_.get(), opts); + device_mgr_.reset(new DeviceMgr(devices_)); + pflr_.reset(new ProcessFunctionLibraryRuntime( + device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), + opts)); + flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1"); + flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2"); fdef_lib_ = lib_def_->ToProto(); } - Status Run(const string& name, Attrs attrs, const std::vector<Tensor>& args, - std::vector<Tensor*> rets) { - FunctionLibraryRuntime::Handle handle; - Status status = lib_->Instantiate(name, attrs, &handle); - if (!status.ok()) { - return status; - } - + Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle, + const std::vector<Tensor>& args, std::vector<Tensor*> rets) { std::atomic<int32> call_count(0); std::function<void(std::function<void()>)> runner = [&call_count](std::function<void()> fn) { ++call_count; - FunctionTestSchedClosure(fn); + test::function::FunctionTestSchedClosure(fn); }; Notification done; FunctionLibraryRuntime::Options opts; opts.runner = &runner; std::vector<Tensor> out; - lib_->Run(opts, handle, args, &out, [&status, &done](const Status& s) { + Status status; + flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) { status = s; done.Notify(); }); @@ -206,28 +186,54 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return Status::OK(); } - std::unique_ptr<Graph> GetFuncBody(const string& name, Attrs attrs) { + Status Instantiate(FunctionLibraryRuntime* flr, const string& name, + test::function::Attrs attrs, + FunctionLibraryRuntime::Handle* handle) { + Status status = flr->Instantiate(name, attrs, handle); + if (!status.ok()) { + return status; + } + return Status::OK(); + } + + Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name, + test::function::Attrs attrs, + const std::vector<Tensor>& args, + std::vector<Tensor*> rets) { + FunctionLibraryRuntime::Handle handle; + Status status = flr->Instantiate(name, attrs, &handle); + if (!status.ok()) { + return status; + } + return Run(flr, handle, args, std::move(rets)); + } + + std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr, + const string& name, + test::function::Attrs attrs) { FunctionLibraryRuntime::Handle handle; - Status status = lib_->Instantiate(name, attrs, &handle); + Status status = flr->Instantiate(name, attrs, &handle); if (!status.ok()) { LOG(ERROR) << status; return nullptr; } - const FunctionBody* fbody = lib_->GetFunctionBody(handle); + const FunctionBody* fbody = flr->GetFunctionBody(handle); CHECK_NOTNULL(fbody); std::unique_ptr<Graph> ret(new Graph(lib_def_.get())); CopyGraph(*fbody->graph, ret.get()); return ret; } - std::unique_ptr<Graph> GetGradBody(const string& func, Attrs attrs) { + std::unique_ptr<Graph> GetGradBody(FunctionLibraryRuntime* flr, + const string& func, + test::function::Attrs attrs) { FunctionLibraryRuntime::Handle handle; - Status status = lib_->Instantiate(func, attrs, &handle); + Status status = flr->Instantiate(func, attrs, &handle); if (!status.ok()) { LOG(ERROR) << status; return nullptr; } - const FunctionBody* fbody = lib_->GetFunctionBody(handle); + const FunctionBody* fbody = flr->GetFunctionBody(handle); CHECK_NOTNULL(fbody); std::unique_ptr<FunctionBody> gbody(SymbolicGradient(*fbody)); CHECK_NOTNULL(gbody); @@ -236,24 +242,29 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return ret; } - std::unique_ptr<Device> device_; + FunctionLibraryRuntime* flr0_; + FunctionLibraryRuntime* flr1_; + FunctionLibraryRuntime* flr2_; + std::vector<Device*> devices_; + std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<FunctionLibraryDefinition> lib_def_; - std::unique_ptr<FunctionLibraryRuntime> lib_; + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; FunctionDefLibrary fdef_lib_; }; TEST_F(FunctionLibraryRuntimeTest, IsStateful) { Init({}); - EXPECT_TRUE(lib_->IsStateful("Variable")); - EXPECT_TRUE(lib_->IsStateful("VariableV2")); - EXPECT_FALSE(lib_->IsStateful("Matmul")); + EXPECT_TRUE(flr0_->IsStateful("Variable")); + EXPECT_TRUE(flr0_->IsStateful("VariableV2")); + EXPECT_FALSE(flr0_->IsStateful("Matmul")); } TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) { Init({test::function::XTimesTwo()}); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - TF_CHECK_OK(Run("XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK( + InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); } @@ -262,11 +273,14 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesN) { test::function::XTimes16()}); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - TF_CHECK_OK(Run("XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK( + InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8})); - TF_CHECK_OK(Run("XTimesFour", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK( + InstantiateAndRun(flr0_, "XTimesFour", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16})); - TF_CHECK_OK(Run("XTimes16", {{"T", DT_FLOAT}}, {x}, {&y})); + TF_CHECK_OK( + InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y})); test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64})); } @@ -294,7 +308,7 @@ Output Call(Scope* scope, const string& op_name, const string& fn_name, TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr<Graph> g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}}); ASSERT_TRUE(g != nullptr); { @@ -312,7 +326,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(lib_.get(), g.get()); + ExpandInlineFunctions(flr0_, g.get()); { Scope s = Scope::NewRootScope(); TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); @@ -334,7 +348,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(lib_.get(), g.get()); + ExpandInlineFunctions(flr0_, g.get()); GraphDef e2; { Scope s = Scope::NewRootScope(); @@ -373,7 +387,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { } // No further inlining. - ExpandInlineFunctions(lib_.get(), g.get()); + ExpandInlineFunctions(flr0_, g.get()); { GraphDef actual; g->ToGraphDef(&actual); @@ -425,7 +439,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { TF_ASSERT_OK(s.ToGraph(g.get())); } - ExpandInlineFunctions(lib_.get(), g.get()); + ExpandInlineFunctions(flr0_, g.get()); { Scope s = Scope::NewRootScope(); TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); @@ -449,7 +463,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(lib_.get(), g.get()); + ExpandInlineFunctions(flr0_, g.get()); { Scope s = Scope::NewRootScope(); TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); @@ -495,10 +509,10 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) { TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr<Graph> g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}}); ASSERT_TRUE(g != nullptr); - ExpandInlineFunctions(lib_.get(), g.get()); - OptimizeGraph(lib_.get(), &g); + ExpandInlineFunctions(flr0_, g.get()); + OptimizeGraph(flr0_, &g); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -541,9 +555,9 @@ TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) { // Return {{"o", "g:output"}}); Init({test::function::Swap(), func}); - std::unique_ptr<Graph> g = GetFuncBody("ManySwapsNodeDef", {}); + std::unique_ptr<Graph> g = GetFuncBody(flr0_, "ManySwapsNodeDef", {}); ASSERT_TRUE(g != nullptr); - OptimizeGraph(lib_.get(), &g); + OptimizeGraph(flr0_, &g); const char* e0 = R"P( (n3:float, n2:float) -> (n3:float) { } @@ -574,9 +588,9 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { {{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}}, {{"o", "o:z:0"}}); Init({test::function::Swap(), func}); - std::unique_ptr<Graph> g = GetFuncBody("ManySwapsFirst", {}); + std::unique_ptr<Graph> g = GetFuncBody(flr0_, "ManySwapsFirst", {}); ASSERT_TRUE(g != nullptr); - OptimizeGraph(lib_.get(), &g); + OptimizeGraph(flr0_, &g); // NOTE: We can remove func0, func1, func2, func9 with a control edge n8->n5. // But we don't have a pass doing that. @@ -609,7 +623,7 @@ TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) { Init({test::function::XTimesTwo(), test::function::XTimesFour()}); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - HasError(Run("Foo", {{"T", DT_FLOAT}}, {x}, {&y}), + HasError(InstantiateAndRun(flr0_, "Foo", {{"T", DT_FLOAT}}, {x}, {&y}), "Not found: Function Foo is not defined."); } @@ -632,25 +646,27 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) { // Instantiating "XTimesTwo" should fail. FunctionLibraryRuntime::Handle handle; - HasError(lib_->Instantiate("XTimesTwo", Attrs({{"T", DT_FLOAT}}), &handle), + HasError(flr0_->Instantiate( + "XTimesTwo", test::function::Attrs({{"T", DT_FLOAT}}), &handle), "Not found: type attr not found"); // But XTimesFour and XTimes16 instantiation should succeed. Only // when they run, they fail because XTimesTwo is bad. - TF_CHECK_OK( - lib_->Instantiate("XTimesFour", Attrs({{"T", DT_FLOAT}}), &handle)); - TF_CHECK_OK(lib_->Instantiate("XTimes16", Attrs({{"T", DT_FLOAT}}), &handle)); + TF_CHECK_OK(flr0_->Instantiate( + "XTimesFour", test::function::Attrs({{"T", DT_FLOAT}}), &handle)); + TF_CHECK_OK(flr0_->Instantiate( + "XTimes16", test::function::Attrs({{"T", DT_FLOAT}}), &handle)); auto x = test::AsTensor<float>({1, 2, 3, 4}); Tensor y; - HasError(Run("XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}), + HasError(InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}), "type attr not found"); } TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr<Graph> f = GetFuncBody("XTimesTwo", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> f = GetFuncBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -666,7 +682,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { TF_EXPECT_GRAPH_EQ(expected, actual); } - std::unique_ptr<Graph> g = GetGradBody("XTimesTwo", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> g = GetGradBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}); { Scope s = Scope::NewRootScope(); @@ -690,7 +706,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { TF_EXPECT_GRAPH_EQ(expected, actual); } - OptimizeGraph(lib_.get(), &g); + OptimizeGraph(flr0_, &g); { Scope s = Scope::NewRootScope(); @@ -726,7 +742,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) { Init({}); auto T = DT_FLOAT; std::unique_ptr<Graph> g = GetFuncBody( - "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}}); + flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}}); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -756,7 +772,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) { Init({}); auto T = DT_FLOAT; std::unique_ptr<Graph> g = GetFuncBody( - "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}}); + flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}}); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -812,7 +828,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { Init({test, grad}); - std::unique_ptr<Graph> g = GetFuncBody("TestGrad", {}); + std::unique_ptr<Graph> g = GetFuncBody(flr0_, "TestGrad", {}); ASSERT_TRUE(g != nullptr); { Scope s = Scope::NewRootScope(); @@ -836,7 +852,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { TF_EXPECT_GRAPH_EQ(expected, actual); } - ExpandInlineFunctions(lib_.get(), g.get()); + ExpandInlineFunctions(flr0_, g.get()); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -888,7 +904,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { TF_EXPECT_GRAPH_EQ(expected, actual); } - OptimizeGraph(lib_.get(), &g); + OptimizeGraph(flr0_, &g); { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); @@ -939,6 +955,25 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { } } +TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { + Init({test::function::FindDevice()}); + FunctionLibraryRuntime::Handle handle; + TF_CHECK_OK(Instantiate( + flr0_, "FindDevice", + {{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle)); + + Tensor y; + // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1. + TF_CHECK_OK(Run(flr1_, handle, {}, {&y})); + test::ExpectTensorEqual<string>( + y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"}, + TensorShape({}))); + TF_CHECK_OK(Run(flr2_, handle, {}, {&y})); + test::ExpectTensorEqual<string>( + y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"}, + TensorShape({}))); +} + namespace { bool DoNothing(Graph* g) { return false; } |