From 9b8fe3cec9b6ec2e872b068ccd230733106e19fe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 24 Jan 2016 21:28:56 -0800 Subject: Adds function_test.cc for testing functions. Change: 112917269 --- tensorflow/core/common_runtime/function_test.cc | 884 ++++++++++++++++++++++++ 1 file changed, 884 insertions(+) create mode 100644 tensorflow/core/common_runtime/function_test.cc diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc new file mode 100644 index 0000000000..b150f5cd2a --- /dev/null +++ b/tensorflow/core/common_runtime/function_test.cc @@ -0,0 +1,884 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/function.h" + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +typedef FunctionDefHelper FDH; + +Status GetOpSig(const string& op, const OpDef** sig) { + Status s; + *sig = OpRegistry::Global()->LookUp(op, &s); + return s; +} + +void FunctionTestSchedClosure(std::function fn) { + static thread::ThreadPool* w = + new thread::ThreadPool(Env::Default(), "Test", 8); + w->Schedule(fn); +} + +void HasError(const Status& s, const string& substr) { + EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) + << s << ", expected substring " << substr; +} + +class FunctionTest : public ::testing::Test { + protected: + FunctionTest() + : device_(DeviceFactory::NewDevice("CPU", {}, + "/job:localhost/replica:0/task:0")) {} + + ~FunctionTest() override { + delete exec_; + delete device_; + } + + void Create(const FunctionDef& fdef, InstantiateAttrValueSlice attrs) { + delete exec_; + InstantiationResult result; + TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result)); + + arg_types_ = result.arg_types; + ret_types_ = result.ret_types; + + Graph* g = new Graph(OpRegistry::Global()); + GraphConstructorOptions opts; + opts.allow_internal_ops = true; + opts.expect_device_spec = false; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g)); + + const int version = g->version(); + LocalExecutorParams params; + params.device = device_; + params.create_kernel = [this, version](const NodeDef& ndef, + OpKernel** kernel) { + return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel); + }; + params.delete_kernel = [](OpKernel* kernel) { + DeleteNonCachedKernel(kernel); + }; + TF_CHECK_OK(NewLocalExecutor(params, g, &exec_)); + } + + void Run(const std::vector& args, std::vector rets) { + FunctionCallFrame frame(arg_types_, ret_types_); + TF_CHECK_OK(frame.SetArgs(args)); + Executor::Args exec_args; + exec_args.call_frame = &frame; + exec_args.runner = FunctionTestSchedClosure; + TF_CHECK_OK(exec_->Run(exec_args)); + std::vector computed; + TF_CHECK_OK(frame.GetRetvals(&computed)); + CHECK_EQ(computed.size(), rets.size()); + for (int i = 0; i < rets.size(); ++i) { + *(rets[i]) = computed[i]; + } + } + + Device* device_ = nullptr; + Executor* exec_ = nullptr; + DataTypeVector arg_types_; + DataTypeVector ret_types_; +}; + +TEST_F(FunctionTest, XTimesTwo) { + Create(test::function::XTimesTwo(), {{"T", DT_FLOAT}}); + auto x = test::AsTensor({1, 2, 3, 4}); + Tensor y; + Run({x}, {&y}); + test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); +} + +TEST_F(FunctionTest, WXPlusB) { + Create(test::function::WXPlusB(), {{"T", DT_FLOAT}}); + auto w = test::AsTensor({1., 2., 3., 4.}, {2, 2}); + auto x = test::AsTensor({1., 3., 2., 4.}, {2, 2}); + auto b = test::AsTensor({0.5, 2.5}, {2}); + Tensor y; + Run({w, x, b}, {&y}); + test::ExpectTensorEqual( + y, test::AsTensor({5.5, 13.5, 11.5, 27.5}, {2, 2})); +} + +class FunctionLibraryRuntimeTest : public ::testing::Test { + protected: + FunctionLibraryRuntimeTest() + : device_(DeviceFactory::NewDevice("CPU", {}, + "/job:localhost/replica:0/task:0")) {} + + ~FunctionLibraryRuntimeTest() override { + delete lib_; + delete lib_def_; + delete device_; + } + + void Init(const std::vector& flib) { + FunctionDefLibrary proto; + for (auto fdef : flib) *(proto.add_function()) = fdef; + delete lib_def_; + lib_def_ = new FunctionLibraryDefinition(proto); + delete lib_; + lib_ = NewFunctionLibraryRuntime(device_, FunctionTestSchedClosure, + TF_GRAPH_DEF_VERSION, lib_def_); + } + + Status Run(const string& name, InstantiateAttrValueSlice attrs, + const std::vector& args, std::vector rets) { + FunctionLibraryRuntime::Handle handle; + Status status = lib_->Instantiate(name, attrs, &handle); + if (!status.ok()) { + return status; + } + Notification done; + FunctionLibraryRuntime::Options opts; + std::vector out; + lib_->Run(opts, handle, args, &out, [&status, &done](const Status& s) { + status = s; + done.Notify(); + }); + done.WaitForNotification(); + if (!status.ok()) { + return status; + } + CHECK_EQ(rets.size(), out.size()); + for (int i = 0; i < rets.size(); ++i) { + *rets[i] = out[i]; + } + return Status::OK(); + } + + Graph* GetFuncBody(const string& name, InstantiateAttrValueSlice attrs) { + FunctionLibraryRuntime::Handle handle; + Status status = lib_->Instantiate(name, attrs, &handle); + if (!status.ok()) { + LOG(ERROR) << status; + return nullptr; + } + const FunctionBody* fbody = lib_->GetFunctionBody(handle); + CHECK_NOTNULL(fbody); + Graph* ret = new Graph(lib_def_); + CopyGraph(*fbody->graph, ret); + return ret; + } + + Graph* GetGradBody(const string& func, InstantiateAttrValueSlice attrs) { + FunctionLibraryRuntime::Handle handle; + Status status = lib_->Instantiate(func, attrs, &handle); + if (!status.ok()) { + LOG(ERROR) << status; + return nullptr; + } + const FunctionBody* fbody = lib_->GetFunctionBody(handle); + CHECK_NOTNULL(fbody); + FunctionBody* gbody = SymbolicGradient(*fbody); + CHECK_NOTNULL(gbody); + Graph* ret = new Graph(lib_def_); + CopyGraph(*gbody->graph, ret); + delete gbody; + return ret; + } + + Device* device_ = nullptr; + FunctionLibraryDefinition* lib_def_ = nullptr; + FunctionLibraryRuntime* lib_ = nullptr; +}; + +TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) { + Init({test::function::XTimesTwo()}); + auto x = test::AsTensor({1, 2, 3, 4}); + Tensor y; + TF_CHECK_OK(Run("XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); + test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); +} + +TEST_F(FunctionLibraryRuntimeTest, XTimesN) { + Init({test::function::XTimesTwo(), test::function::XTimesFour(), + test::function::XTimes16()}); + auto x = test::AsTensor({1, 2, 3, 4}); + Tensor y; + TF_CHECK_OK(Run("XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y})); + test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); + TF_CHECK_OK(Run("XTimesFour", {{"T", DT_FLOAT}}, {x}, {&y})); + test::ExpectTensorEqual(y, test::AsTensor({4, 8, 12, 16})); + TF_CHECK_OK(Run("XTimes16", {{"T", DT_FLOAT}}, {x}, {&y})); + test::ExpectTensorEqual(y, test::AsTensor({16, 32, 48, 64})); +} + +TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { + Init({test::function::XTimesTwo(), test::function::XTimesFour(), + test::function::XTimes16()}); + Graph* g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); + CHECK(g); + const char* e0 = R"P( +(n2:float) -> (n4:float) { + n3 = XTimesFour[T=float](n2) + n4 = XTimesFour[T=float](n3) +} +)P"; + EXPECT_EQ(e0, DebugString(g)); + + ExpandInlineFunctions(lib_, g); + const char* e1 = R"P( +(n2:float) -> (n17:float) { + n10 = Identity[T=float](n2) + n7 = XTimesTwo[T=float](n10) + n8 = XTimesTwo[T=float](n7) + n11 = Identity[T=float](n8) + n16 = Identity[T=float](n11) + n13 = XTimesTwo[T=float](n16) + n14 = XTimesTwo[T=float](n13) + n17 = Identity[T=float](n14) +} +)P"; + EXPECT_EQ(e1, DebugString(g)); + + ExpandInlineFunctions(lib_, g); + const char* e2 = R"P( +(n2:float) -> (n17:float) { + n18 = Const[dtype=int64, value=Tensor]() + n25 = Const[dtype=int64, value=Tensor]() + n32 = Const[dtype=int64, value=Tensor]() + n39 = Const[dtype=int64, value=Tensor]() + n19 = Cast[DstT=float, SrcT=int64](n18) + n26 = Cast[DstT=float, SrcT=int64](n25) + n33 = Cast[DstT=float, SrcT=int64](n32) + n40 = Cast[DstT=float, SrcT=int64](n39) + n10 = Identity[T=float](n2) + n23 = Identity[T=float](n10) + n21 = Mul[T=float](n23, n19) + n24 = Identity[T=float](n21) + n30 = Identity[T=float](n24) + n28 = Mul[T=float](n30, n26) + n31 = Identity[T=float](n28) + n11 = Identity[T=float](n31) + n16 = Identity[T=float](n11) + n37 = Identity[T=float](n16) + n35 = Mul[T=float](n37, n33) + n38 = Identity[T=float](n35) + n44 = Identity[T=float](n38) + n42 = Mul[T=float](n44, n40) + n45 = Identity[T=float](n42) + n17 = Identity[T=float](n45) +} +)P"; + EXPECT_EQ(e2, DebugString(g)); + + // No further inlining. + ExpandInlineFunctions(lib_, g); + EXPECT_EQ(e2, DebugString(g)); + + // Get rid of redunant Identity nodes. + RemoveIdentityNodes(g); + const char* e3 = R"P( +(n2:float) -> (n42:float) { + n18 = Const[dtype=int64, value=Tensor]() + n25 = Const[dtype=int64, value=Tensor]() + n32 = Const[dtype=int64, value=Tensor]() + n39 = Const[dtype=int64, value=Tensor]() + n19 = Cast[DstT=float, SrcT=int64](n18) + n26 = Cast[DstT=float, SrcT=int64](n25) + n33 = Cast[DstT=float, SrcT=int64](n32) + n40 = Cast[DstT=float, SrcT=int64](n39) + n21 = Mul[T=float](n2, n19) + n28 = Mul[T=float](n21, n26) + n35 = Mul[T=float](n28, n33) + n42 = Mul[T=float](n35, n40) +} +)P"; + EXPECT_EQ(e3, DebugString(g)); + delete g; +} + +TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { + Init({test::function::XTimesTwo(), test::function::XTimesFour(), + test::function::XTimes16()}); + Graph* g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); + CHECK(g); + ExpandInlineFunctions(lib_, g); + OptimizeGraph(lib_, &g); + const char* e0 = R"P( +(n2:float) -> (n9:float) { + n7 = Const[dtype=int64, value=Tensor]() + n8 = Cast[DstT=float, SrcT=int64](n7) + n4 = Mul[T=float](n2, n8) + n5 = Mul[T=float](n4, n8) + n6 = Mul[T=float](n5, n8) + n9 = Mul[T=float](n6, n8) +} +)P"; + EXPECT_EQ(e0, DebugString(g)); + delete g; +} + +TEST_F(FunctionLibraryRuntimeTest, ManySwaps) { + auto func = FDH::Define( + // Name + "ManySwapsFirst", + // Args + {"x: float", "y: float"}, + // Return values + {"o: float"}, + // attr def + {}, + // Nodes + {{{"a0", "b0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}}, + {{"a1", "b1"}, "Swap", {"a0", "b0"}, {{"T", DT_FLOAT}}}, + {{"a2", "b2"}, "Swap", {"a1", "b1"}, {{"T", DT_FLOAT}}}, + {{"a3", "b3"}, "Swap", {"a2", "b2"}, {{"T", DT_FLOAT}}}, + {{"a4", "b4"}, "Swap", {"a3", "b3"}, {{"T", DT_FLOAT}}}, + {{"a5", "b5"}, "Swap", {"a4", "b4"}, {{"T", DT_FLOAT}}}, + {{"o"}, "Identity", {"a5"}, {{"T", DT_FLOAT}}}}); + Init({test::function::Swap(), func}); + Graph* g = GetFuncBody("ManySwapsFirst", {{"T", DT_FLOAT}}); + CHECK(g); + OptimizeGraph(lib_, &g); + const char* e0 = R"P( +(n3:float, n2:float) -> (n3:float) { +} +)P"; + EXPECT_EQ(e0, DebugString(g)); + delete g; +} + +TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { + auto func = FDH::Define( + // Name + "ManySwapsFirst", + // Args + {"x: float", "y: float"}, + // Return values + {"o: float"}, + // attr def + {}, + // Nodes + // + // o = x*x + y*y. Furthermore, The 1st swap depends on x2, and + // y2 depends on the 2nd swap. The 2nd swap has data dependency + // on the 1st swap. The optimization should maintain the control + // dependencies. + {{{"a0", "b0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}}, + {{"a1", "b1"}, "Swap", {"a0", "b0"}, {{"T", DT_FLOAT}}}, + {{"x2"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}}, + {{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}}, + {{"o"}, "Add", {"x2", "y2"}, {{"T", DT_FLOAT}}}}); + Init({test::function::Swap(), func}); + Graph* g = GetFuncBody("ManySwapsFirst", {{"T", DT_FLOAT}}); + CHECK(g); + OptimizeGraph(lib_, &g); + + // NOTE: We can remove n8, n9, n10, n11 with a control edge n8->n5. + // But we don't have a pass doing that. + const char* e0 = R"P( +(n3:float, n2:float) -> (n6:float) { + n4 = Mul[T=float](n3, n3) + n8 = NoOp() @ n4 + n9 = Identity[T=float](n3) @ n8 + n10 = Identity[T=float](n2) @ n8 + n11 = NoOp() @ n10, n9 + n5 = Mul[T=float](n2, n2) @ n11 + n6 = Add[T=float](n4, n5) +} +)P"; + EXPECT_EQ(e0, DebugString(g)); + delete g; +} + +TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) { + Init({test::function::XTimesTwo(), test::function::XTimesFour()}); + auto x = test::AsTensor({1, 2, 3, 4}); + Tensor y; + HasError(Run("Foo", {{"T", DT_FLOAT}}, {x}, {&y}), + "Not found: Function Foo is not defined."); +} + +TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) { + auto bad_x_times_two = FDH::Define( + // Name + "XTimesTwo", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"y"}, "Add", {"x", "x"}, {{"no_T", "$T"}}}, + }); + Init({bad_x_times_two, test::function::XTimesFour(), + test::function::XTimes16()}); + + // Instantiating "XTimesTwo" should fail. + FunctionLibraryRuntime::Handle handle; + HasError(lib_->Instantiate("XTimesTwo", {{"T", DT_FLOAT}}, &handle), + "Not found: type attr not found"); + + // But XTimesFour and XTimes16 instantiation should succeed. Only + // when they run, they fail because XTimesTwo is bad. + TF_CHECK_OK(lib_->Instantiate("XTimesFour", {{"T", DT_FLOAT}}, &handle)); + TF_CHECK_OK(lib_->Instantiate("XTimes16", {{"T", DT_FLOAT}}, &handle)); + + auto x = test::AsTensor({1, 2, 3, 4}); + Tensor y; + HasError(Run("XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}), + "type attr not found"); +} + +TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { + Init({test::function::XTimesTwo(), test::function::XTimesFour(), + test::function::XTimes16()}); + auto f = GetFuncBody("XTimesTwo", {{"T", DT_FLOAT}}); + const char* e0 = R"P( +(n4:float) -> (n5:float) { + n2 = Const[dtype=int64, value=Tensor]() + n3 = Cast[DstT=float, SrcT=int64](n2) + n5 = Mul[T=float](n4, n3) +} +)P"; + EXPECT_EQ(e0, DebugString(f)); + delete f; + auto g = GetGradBody("XTimesTwo", {{"T", DT_FLOAT}}); + const char* e1 = R"P( +(n4:float, n6:float) -> (n7:float) { + n2 = Const[dtype=int64, value=Tensor]() + n3 = Cast[DstT=float, SrcT=int64](n2) + n5 = Mul[T=float](n4, n3) + n7 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Mul[T=float]](n4, n3, n6) +} +)P"; + EXPECT_EQ(e1, DebugString(g)); + + OptimizeGraph(lib_, &g); + const char* e2 = R"P( +(n4:float, n5:float) -> (n12:float) { + n2 = Const[dtype=int64, value=Tensor]() + n3 = Cast[DstT=float, SrcT=int64](n2) + n9 = Shape[T=float](n4) + n8 = Shape[T=float](n3) + n7 = Mul[T=float](n5, n3) + n10 = BroadcastGradientArgs(n9, n8) + n11 = Sum[T=float, keep_dims=false](n7, n10) + n12 = Reshape[T=float](n11, n9) +} +)P"; + EXPECT_EQ(e2, DebugString(g)); + + delete g; +} + +TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) { + Init({}); + auto T = DT_FLOAT; + auto g = GetFuncBody("SymbolicGradient", + {{"f", FDH::FunctionRef("Add", {{"T", T}})}}); + const char* e0 = R"P( +(n7:float, n5:float, n2:float) -> (n14:float, n11:float) { + n3 = Identity[T=float](n2) + n4 = Identity[T=float](n2) + n6 = Shape[T=float](n5) + n8 = Shape[T=float](n7) + n9 = BroadcastGradientArgs(n8, n6) + n10 = Sum[T=float, keep_dims=false](n3, n9:1) + n13 = Sum[T=float, keep_dims=false](n4, n9) + n11 = Reshape[T=float](n10, n6) + n14 = Reshape[T=float](n13, n8) +} +)P"; + EXPECT_EQ(e0, DebugString(g)); + delete g; +} + +TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) { + Init({}); + auto T = DT_FLOAT; + auto g = GetFuncBody("SymbolicGradient", + {{"f", FDH::FunctionRef("Mul", {{"T", T}})}}); + const char* e0 = R"P( +(n6:float, n3:float, n2:float) -> (n14:float, n11:float) { + n4 = Mul[T=float](n2, n3) + n5 = Shape[T=float](n3) + n7 = Mul[T=float](n6, n2) + n8 = Shape[T=float](n6) + n9 = BroadcastGradientArgs(n8, n5) + n10 = Sum[T=float, keep_dims=false](n7, n9:1) + n13 = Sum[T=float, keep_dims=false](n4, n9) + n11 = Reshape[T=float](n10, n5) + n14 = Reshape[T=float](n13, n8) +} +)P"; + EXPECT_EQ(e0, DebugString(g)); + delete g; +} + +TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { + // Sum(Add(x, y)) + auto T = DT_FLOAT; + auto test = FDH::Define("Test", {"x:float", "y:float"}, {"l:float"}, {}, + { + {{"z"}, "Add", {"x", "y"}, {{"T", T}}}, + FDH::Const("zero", 0), + FDH::Const("one", 1), + {{"r"}, "Rank", {"z"}, {{"T", T}}}, + {{"indices"}, "Range", {"zero", "r", "one"}}, + {{"l"}, "Sum", {"z", "indices"}, {{"T", T}}}, + }); + + // TestGrad = Test'(x, y) + auto grad = + FDH::Define("TestGrad", {"x:float", "y:float"}, {"dx:float", "dy:float"}, + {}, {FDH::Const("dz", 1), + {{"grad"}, + "SymbolicGradient", + {"x", "y", "dz"}, + { + {"f", FDH::FunctionRef("Test")}, + {"Tin", DataTypeSlice{T, T, T}}, + {"Tout", DataTypeSlice{T, T}}, + }}, + {{"dx"}, "Identity", {"grad:0"}, {{"T", DT_FLOAT}}}, + {{"dy"}, "Identity", {"grad:1"}, {{"T", DT_FLOAT}}}}); + + Init({test, grad}); + + Graph* g = GetFuncBody("TestGrad", {}); + CHECK(g); + const char* e0 = R"P( +(n4:float, n3:float) -> (n8:float, n6:float) { + n2 = Const[dtype=float, value=Tensor]() + n5 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Test](n4, n3, n2) + n6 = Identity[T=float](n5:1) + n8 = Identity[T=float](n5) +} +)P"; + EXPECT_EQ(e0, DebugString(g)); + + ExpandInlineFunctions(lib_, g); + const char* e1 = R"P( +(n4:float, n3:float) -> (n8:float, n6:float) { + n10 = Const[dtype=int32, value=Tensor]() + n11 = Const[dtype=int32, value=Tensor]() + n2 = Const[dtype=float, value=Tensor]() + n26 = Identity[T=float](n2) + n25 = Identity[T=float](n3) + n24 = Identity[T=float](n4) + n14 = Add[T=float](n24, n25) + n15 = Rank[T=float](n14) + n16 = Range(n11, n15, n10) + n20 = ZerosLike[T=int32](n15) + n17 = Sum[T=float, keep_dims=false](n14, n16) + n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, keep_dims=false]](n14, n16, n26) + n21 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Add[T=float]](n24, n25, n19) + n28 = Identity[T=float](n21:1) + n27 = Identity[T=float](n21) + n6 = Identity[T=float](n28) + n8 = Identity[T=float](n27) +} +)P"; + EXPECT_EQ(e1, DebugString(g)); + + OptimizeGraph(lib_, &g); + const char* e2 = R"P( +(n4:float, n3:float) -> (n25:float, n23:float) { + n11 = Const[dtype=int32, value=Tensor]() + n2 = Const[dtype=float, value=Tensor]() + n7 = Const[dtype=int32, value=Tensor]() + n19 = Shape[T=float](n3) + n8 = Add[T=float](n4, n3) + n20 = Shape[T=float](n4) + n9 = Rank[T=float](n8) + n14 = Shape[T=float](n8) + n21 = BroadcastGradientArgs(n20, n19) + n10 = Range(n7, n9, n11) + n12 = Shape[T=int32](n10) + n13 = Fill[T=int32](n12, n11) + n15 = DynamicStitch[N=2, T=int32](n10, n10, n14, n13) + n16 = Reshape[T=float](n2, n15) + n17 = Div[T=int32](n14, n15) + n18 = Tile[T=float](n16, n17) + n24 = Sum[T=float, keep_dims=false](n18, n21) + n22 = Sum[T=float, keep_dims=false](n18, n21:1) + n25 = Reshape[T=float](n24, n20) + n23 = Reshape[T=float](n22, n19) +} +)P"; + EXPECT_EQ(e2, DebugString(g)); + delete g; +} + +namespace { + +bool DoNothing(Graph* g) { return false; } + +string Optimize(std::function pass, const FunctionDef& fdef) { + InstantiationResult result; + InstantiateAttrValueMap empty; + TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result)); + Graph* g = new Graph(OpRegistry::Global()); + GraphConstructorOptions opts; + opts.allow_internal_ops = true; + opts.expect_device_spec = false; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g)); + pass(g); + Graph* g1 = new Graph(OpRegistry::Global()); + CopyGraph(*g, g1); + delete g; + GraphDef gdef; + g1->ToGraphDef(&gdef); + delete g1; + return DebugString(gdef); +} + +} // end namespace + +TEST(OptimizationTest, RemoveDeadNodes) { + auto T = DT_INT32; + auto func = FDH::Define( + // Name + "F", + // Args + {"x: int32"}, + // Return values + {"y: int32"}, + // Attrs + {}, + // Nodes + {// a = Square(x) + {{"a"}, "Square", {"x"}, {{"T", T}}}, + // 1 + FDH::Const("o", 1), + // A bunch of extra arithmatic that y doesn't depend on + {{"x1"}, "Add", {"o", "o"}, {{"T", T}}}, + {{"x2"}, "Mul", {"a", "x1"}, {{"T", T}}}, + {{"x3"}, "Mul", {"x1", "x2"}, {{"T", T}}}, + // A stateful node. + {{"keep_me"}, "RandomUniform", {"o"}, {{"T", T}, {"dtype", DT_FLOAT}}}, + // y = Add(a, o) + {{"y"}, "Add", {"a", "o"}, {{"T", T}}}}); + const char* e0 = R"S( +(n0:int32) -> (n7:int32) { + n2 = Const[dtype=int32, value=Tensor]() + n6 = RandomUniform[T=int32, dtype=float, seed2=0, seed=0](n2) + n3 = Add[T=int32](n2, n2) + n1 = Square[T=int32](n0) + n7 = Add[T=int32](n1, n2) + n4 = Mul[T=int32](n1, n3) + n5 = Mul[T=int32](n3, n4) +} +)S"; + EXPECT_EQ(Optimize(DoNothing, func), e0); + + // TODO(zhifengc): Comes up another test case. + EXPECT_EQ(Optimize(::tensorflow::RemoveDeadNodes, func), e0); +} + +TEST(OptimizationTest, RemoveIdentityNodes) { + auto T = DT_INT32; + auto func = FDH::Define( + // Name + "F", + // Args + {"x: int32"}, + // Return values + {"y: int32"}, + // Attrs + {}, + // Nodes + {// a = Square(x) + {{"a"}, "Square", {"x"}, {{"T", T}}}, + // 1 + FDH::Const("o", 1), + // A bunch of extra arithmatic that y doesn't depend on + {{"x1"}, "Identity", {"a"}, {{"T", T}}}, + {{"x2"}, "Identity", {"x1"}, {{"T", T}}}, + {{"x3"}, "Identity", {"x2"}, {{"T", T}}}, + // A stateful node. + {{"keep_me"}, + "RandomUniform", + {"o"}, + {{"T", T}, {"dtype", DT_FLOAT}}, + {"x3"}}, + // y = Add(a, o) + {{"y"}, "Add", {"a", "o"}, {{"T", T}}}}); + const char* e0 = R"S( +(n0:int32) -> (n7:int32) { + n2 = Const[dtype=int32, value=Tensor]() + n1 = Square[T=int32](n0) + n7 = Add[T=int32](n1, n2) + n3 = Identity[T=int32](n1) + n4 = Identity[T=int32](n3) + n5 = Identity[T=int32](n4) + n6 = RandomUniform[T=int32, dtype=float, seed2=0, seed=0](n2) @ n5 +} +)S"; + EXPECT_EQ(Optimize(DoNothing, func), e0); + + const char* e1 = R"S( +(n0:int32) -> (n7:int32) { + n2 = Const[dtype=int32, value=Tensor]() + n1 = Square[T=int32](n0) + n7 = Add[T=int32](n1, n2) + n6 = RandomUniform[T=int32, dtype=float, seed2=0, seed=0](n2) @ n1 +} +)S"; + EXPECT_EQ(Optimize(::tensorflow::RemoveIdentityNodes, func), e1); +} + +TEST(OptimizationTest, RemoveListArrayConverter) { + auto func = FDH::Define( + // Name + "Test", + // Args + {"i: float"}, + // Return values + {"o: float"}, + // Attrs + {}, + // Nodes + {FDH::Const("zero", 0), + {{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}}, + {{"a"}, + "_ArrayToList", + {"s"}, + {{"N", 4}, + {"T", DT_FLOAT}, + {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}}, + {{"l"}, "Mul", {"a:0", "a:1"}, {{"T", DT_FLOAT}}}, + {{"r"}, "Mul", {"a:2", "a:3"}, {{"T", DT_FLOAT}}}, + {{"x"}, + "_ListToArray", + {"l", "r"}, + {{"N", 2}, + {"T", DT_FLOAT}, + {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}}, + {{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}}); + + const char* e0 = R"P( +(n0:float) -> (n7:float) { + n1 = Const[dtype=int32, value=Tensor]() + n2 = Split[T=float, num_split=4](n1, n0) + n3 = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](n2, n2:1, n2:2, n2:3) + n5 = Mul[T=float](n3:2, n3:3) + n4 = Mul[T=float](n3, n3:1) + n6 = _ListToArray[N=2, T=float, Tin={float, float}](n4, n5) + n7 = AddN[N=2, T=float](n6, n6:1) +} +)P"; + EXPECT_EQ(Optimize(DoNothing, func), e0); + + const char* e1 = R"P( +(n0:float) -> (n7:float) { + n1 = Const[dtype=int32, value=Tensor]() + n2 = Split[T=float, num_split=4](n1, n0) + n5 = Mul[T=float](Func/_2, Func/_3) + n4 = Mul[T=float](Func/_0, Func/_1) + n7 = AddN[N=2, T=float](Func/_4, Func/_5) + Func/_0 = Identity[T=float](n2) + Func/_1 = Identity[T=float](n2:1) + Func/_2 = Identity[T=float](n2:2) + Func/_3 = Identity[T=float](n2:3) + Func/_4 = Identity[T=float](n4) + Func/_5 = Identity[T=float](n5) +} +)P"; + EXPECT_EQ(Optimize(RemoveListArrayConverter, func), e1); + + const char* e2 = R"P( +(n0:float) -> (n7:float) { + n1 = Const[dtype=int32, value=Tensor]() + n2 = Split[T=float, num_split=4](n1, n0) + n5 = Mul[T=float](n2:2, n2:3) + n4 = Mul[T=float](n2, n2:1) + n7 = AddN[N=2, T=float](n4, n5) +} +)P"; + auto remove_listarray_and_identity = [](Graph* g) { + return RemoveListArrayConverter(g) && RemoveIdentityNodes(g); + }; + EXPECT_EQ(Optimize(remove_listarray_and_identity, func), e2); +} + +TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) { + auto func = FDH::Define( + // Name + "Test", + // Args + {"i: float"}, + // Return values + {"o: float"}, + // Attrs + {}, + // Nodes + {FDH::Const("dummy", 0), + {{"x"}, + "_ListToArray", + {"i", "i"}, + {{"N", 2}, {"T", DT_FLOAT}, {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}, + // Control dep + {"dummy"}}, + {{"o"}, + "AddN", + {"x"}, + {{"N", 2}, {"T", DT_FLOAT}}, + // Control dep + {"x"}}}); + + const char* e0 = R"P( +(n0:float) -> (n3:float) { + n1 = Const[dtype=int32, value=Tensor]() + n2 = _ListToArray[N=2, T=float, Tin={float, float}](n0, n0) @ n1 + n3 = AddN[N=2, T=float](n2, n2:1) @ n2 +} +)P"; + EXPECT_EQ(Optimize(DoNothing, func), e0); + + const char* e1 = R"P( +(n0:float) -> (n3:float) { + n1 = Const[dtype=int32, value=Tensor]() + n3 = AddN[N=2, T=float](Func/_0, Func/_1) @ Func/_3 + Func/_0 = Identity[T=float](n0) @ Func/_2 + Func/_1 = Identity[T=float](n0) @ Func/_2 + Func/_2 = NoOp() @ n1 + Func/_3 = NoOp() @ Func/_0, Func/_1 +} +)P"; + EXPECT_EQ(Optimize(RemoveListArrayConverter, func), e1); + + auto remove_listarray_and_identity = [](Graph* g) { + return RemoveListArrayConverter(g) && RemoveIdentityNodes(g); + }; + // NOTE: We are not removing Identity nodes with any control + // dependencies yet. + EXPECT_EQ(Optimize(remove_listarray_and_identity, func), e1); +} + +} // end namespace tensorflow -- cgit v1.2.3