diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-05-03 14:03:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-03 15:38:21 -0700 |
commit | 965d620104d375c5fd2b18881f353eb41d9a63a2 (patch) | |
tree | a801f0e211bf6ad5eb81536eea9343edf0544dfa /tensorflow/core/common_runtime/function_test.cc | |
parent | 7828637e07b0081a37dfdc66ff912dd1d6ff3228 (diff) |
Internal change.
Change: 155009390
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 1036 |
1 files changed, 630 insertions, 406 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 8f70ab8783..af1ff6aec0 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -17,6 +17,10 @@ limitations under the License. #include <atomic> +#include "tensorflow/cc/ops/array_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" @@ -28,10 +32,12 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -58,13 +64,8 @@ class FunctionTest : public ::testing::Test { : device_(DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0")) {} - ~FunctionTest() override { - delete exec_; - delete device_; - } - void Create(const FunctionDef& fdef, InstantiateAttrValueSlice attrs) { - delete exec_; + exec_ = nullptr; InstantiationResult result; TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result)); @@ -79,15 +80,18 @@ class FunctionTest : public ::testing::Test { const int version = g->versions().producer(); LocalExecutorParams params; - params.device = device_; + params.device = device_.get(); params.create_kernel = [this, version](const NodeDef& ndef, OpKernel** kernel) { - return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel); + return CreateNonCachedKernel(device_.get(), nullptr, ndef, version, + kernel); }; params.delete_kernel = [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); }; - TF_CHECK_OK(NewLocalExecutor(params, g, &exec_)); + Executor* exec; + TF_CHECK_OK(NewLocalExecutor(params, g, &exec)); + exec_.reset(exec); } void Run(const std::vector<Tensor>& args, std::vector<Tensor*> rets) { @@ -105,8 +109,8 @@ class FunctionTest : public ::testing::Test { } } - Device* device_ = nullptr; - Executor* exec_ = nullptr; + std::unique_ptr<Device> device_; + std::unique_ptr<Executor> exec_; DataTypeVector arg_types_; DataTypeVector ret_types_; }; @@ -136,21 +140,15 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { : device_(DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0")) {} - ~FunctionLibraryRuntimeTest() override { - delete lib_; - delete lib_def_; - delete device_; - } - void Init(const std::vector<FunctionDef>& flib) { FunctionDefLibrary proto; for (const auto& fdef : flib) *(proto.add_function()) = fdef; - delete lib_def_; - lib_def_ = new FunctionLibraryDefinition(OpRegistry::Global(), proto); - delete lib_; + lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto)); OptimizerOptions opts; - lib_ = NewFunctionLibraryRuntime(nullptr, Env::Default(), device_, - TF_GRAPH_DEF_VERSION, lib_def_, opts); + lib_.reset(NewFunctionLibraryRuntime(nullptr, Env::Default(), device_.get(), + TF_GRAPH_DEF_VERSION, lib_def_.get(), + opts)); + fdef_lib_ = lib_def_->ToProto(); } Status Run(const string& name, InstantiateAttrValueSlice attrs, @@ -190,7 +188,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return Status::OK(); } - Graph* GetFuncBody(const string& name, InstantiateAttrValueSlice attrs) { + std::unique_ptr<Graph> GetFuncBody(const string& name, + InstantiateAttrValueSlice attrs) { FunctionLibraryRuntime::Handle handle; Status status = lib_->Instantiate(name, attrs, &handle); if (!status.ok()) { @@ -199,12 +198,13 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } const FunctionBody* fbody = lib_->GetFunctionBody(handle); CHECK_NOTNULL(fbody); - Graph* ret = new Graph(lib_def_); - CopyGraph(*fbody->graph, ret); + std::unique_ptr<Graph> ret(new Graph(lib_def_.get())); + CopyGraph(*fbody->graph, ret.get()); return ret; } - Graph* GetGradBody(const string& func, InstantiateAttrValueSlice attrs) { + std::unique_ptr<Graph> GetGradBody(const string& func, + InstantiateAttrValueSlice attrs) { FunctionLibraryRuntime::Handle handle; Status status = lib_->Instantiate(func, attrs, &handle); if (!status.ok()) { @@ -213,17 +213,17 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } const FunctionBody* fbody = lib_->GetFunctionBody(handle); CHECK_NOTNULL(fbody); - FunctionBody* gbody = SymbolicGradient(*fbody); + std::unique_ptr<FunctionBody> gbody(SymbolicGradient(*fbody)); CHECK_NOTNULL(gbody); - Graph* ret = new Graph(lib_def_); - CopyGraph(*gbody->graph, ret); - delete gbody; + std::unique_ptr<Graph> ret(new Graph(lib_def_.get())); + CopyGraph(*gbody->graph, ret.get()); return ret; } - Device* device_ = nullptr; - FunctionLibraryDefinition* lib_def_ = nullptr; - FunctionLibraryRuntime* lib_ = nullptr; + std::unique_ptr<Device> device_; + std::unique_ptr<FunctionLibraryDefinition> lib_def_; + std::unique_ptr<FunctionLibraryRuntime> lib_; + FunctionDefLibrary fdef_lib_; }; TEST_F(FunctionLibraryRuntimeTest, IsStateful) { @@ -254,113 +254,174 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesN) { test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64})); } +// Adds a function call to 'scope. +// TODO(phawkins): replace with C++ API for calling functions, when that exists. +Output Call(Scope* scope, const string& op_name, const string& fn_name, + gtl::ArraySlice<Input> inputs) { + NodeDef def; + NodeDefBuilder builder(op_name, fn_name, scope->graph()->op_registry()); + for (const Input& input : inputs) { + builder.Input(input.node()->name(), input.index(), + input.node()->output_type(input.index())); + } + TF_CHECK_OK(builder.Finalize(&def)); + Status status; + Node* n = scope->graph()->AddNode(def, &status); + TF_CHECK_OK(status); + for (int i = 0; i < inputs.size(); ++i) { + scope->graph()->AddEdge(inputs[i].node(), inputs[i].index(), n, i); + } + return Output(n); +} + TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - Graph* g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); + std::unique_ptr<Graph> g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); ASSERT_TRUE(g != nullptr); - const char* e0 = R"P( -(n2:float) -> (n4:float) { - n3 = XTimesFour[T=float](n2) - n4 = XTimesFour[T=float](n3) -} -)P"; - EXPECT_EQ(e0, DebugString(g)); - - ExpandInlineFunctions(lib_, g); - const char* e1 = R"P( -(n2:float) -> (n17:float) { - n10 = Identity[T=float](n2) - n7 = XTimesTwo[T=float](n10) - n8 = XTimesTwo[T=float](n7) - n11 = Identity[T=float](n8) - n16 = Identity[T=float](n11) - n13 = XTimesTwo[T=float](n16) - n14 = XTimesTwo[T=float](n13) - n17 = Identity[T=float](n14) -} -)P"; - EXPECT_EQ(e1, DebugString(g)); - - ExpandInlineFunctions(lib_, g); - const char* e2 = R"P( -(n2:float) -> (n17:float) { - n18 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() - n25 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() - n32 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() - n39 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() - n19 = Cast[DstT=float, SrcT=int64](n18) - n26 = Cast[DstT=float, SrcT=int64](n25) - n33 = Cast[DstT=float, SrcT=int64](n32) - n40 = Cast[DstT=float, SrcT=int64](n39) - n10 = Identity[T=float](n2) - n23 = Identity[T=float](n10) - n21 = Mul[T=float](n23, n19) - n24 = Identity[T=float](n21) - n30 = Identity[T=float](n24) - n28 = Mul[T=float](n30, n26) - n31 = Identity[T=float](n28) - n11 = Identity[T=float](n31) - n16 = Identity[T=float](n11) - n37 = Identity[T=float](n16) - n35 = Mul[T=float](n37, n33) - n38 = Identity[T=float](n35) - n44 = Identity[T=float](n38) - n42 = Mul[T=float](n44, n40) - n45 = Identity[T=float](n42) - n17 = Identity[T=float](n45) -} -)P"; - EXPECT_EQ(e2, DebugString(g)); + + { + Scope s = Scope::NewRootScope(); + TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); + auto arg = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto a = Call(&s, "x4", "XTimesFour", {arg}); + auto b = Call(&s, "y", "XTimesFour", {a}); + auto ret = ops::_Retval(s.WithOpName("y_RetVal"), b, 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } + + ExpandInlineFunctions(lib_.get(), g.get()); + { + Scope s = Scope::NewRootScope(); + TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_)); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto func0 = ops::Identity(s.WithOpName("Func/_0"), x); + auto x4_x2 = Call(&s, "x4/x2", "XTimesTwo", {func0}); + auto x4_y = Call(&s, "x4/y", "XTimesTwo", {x4_x2}); + auto func1 = ops::Identity(s.WithOpName("Func/_1"), x4_y); + auto func2 = ops::Identity(s.WithOpName("Func/_2"), func1); + auto y_x2 = Call(&s, "y/x2", "XTimesTwo", {func2}); + auto y_y = Call(&s, "y/y", "XTimesTwo", {y_x2}); + auto func3 = ops::Identity(s.WithOpName("Func/_3"), y_y); + auto ret = ops::_Retval(s.WithOpName("y_RetVal"), func3, 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } + + ExpandInlineFunctions(lib_.get(), g.get()); + GraphDef e2; + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto x4_x2_two = ops::Const<int64>(s.WithOpName("x4/x2/two"), 2LL); + auto x4_y_two = ops::Const<int64>(s.WithOpName("x4/y/two"), 2LL); + auto y_x2_two = ops::Const<int64>(s.WithOpName("y/x2/two"), 2LL); + auto y_y_two = ops::Const<int64>(s.WithOpName("y/y/two"), 2LL); + auto x4_x2_scale = + ops::Cast(s.WithOpName("x4/x2/scale"), x4_x2_two, DT_FLOAT); + auto x4_y_scale = ops::Cast(s.WithOpName("x4/y/scale"), x4_y_two, DT_FLOAT); + auto y_x2_scale = ops::Cast(s.WithOpName("y/x2/scale"), y_x2_two, DT_FLOAT); + auto y_y_scale = ops::Cast(s.WithOpName("y/y/scale"), y_y_two, DT_FLOAT); + auto func0 = ops::Identity(s.WithOpName("Func/_0"), x); + auto func4 = ops::Identity(s.WithOpName("Func/_4"), func0); + auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), func4, x4_x2_scale); + auto func5 = ops::Identity(s.WithOpName("Func/_5"), x4_x2_y); + auto func6 = ops::Identity(s.WithOpName("Func/_6"), func5); + auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), func6, x4_y_scale); + auto func7 = ops::Identity(s.WithOpName("Func/_7"), x4_y_y); + auto func1 = ops::Identity(s.WithOpName("Func/_1"), func7); + auto func2 = ops::Identity(s.WithOpName("Func/_2"), func1); + auto func8 = ops::Identity(s.WithOpName("Func/_8"), func2); + auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), func8, y_x2_scale); + auto func9 = ops::Identity(s.WithOpName("Func/_9"), y_x2_y); + auto func10 = ops::Identity(s.WithOpName("Func/_10"), func9); + auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), func10, y_y_scale); + auto func11 = ops::Identity(s.WithOpName("Func/_11"), y_y_y); + auto func3 = ops::Identity(s.WithOpName("Func/_3"), func11); + auto ret = ops::_Retval(s.WithOpName("y_RetVal"), func3, 0); + TF_ASSERT_OK(s.ToGraphDef(&e2)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(e2, actual); + } // No further inlining. - ExpandInlineFunctions(lib_, g); - EXPECT_EQ(e2, DebugString(g)); + ExpandInlineFunctions(lib_.get(), g.get()); + { + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(e2, actual); + } // Get rid of redundant Identity nodes. - RemoveIdentityNodes(g); - const char* e3 = R"P( -(n2:float) -> (n42:float) { - n18 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() - n25 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() - n32 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() - n39 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() - n19 = Cast[DstT=float, SrcT=int64](n18) - n26 = Cast[DstT=float, SrcT=int64](n25) - n33 = Cast[DstT=float, SrcT=int64](n32) - n40 = Cast[DstT=float, SrcT=int64](n39) - n21 = Mul[T=float](n2, n19) - n28 = Mul[T=float](n21, n26) - n35 = Mul[T=float](n28, n33) - n42 = Mul[T=float](n35, n40) -} -)P"; - EXPECT_EQ(e3, DebugString(g)); - delete g; + RemoveIdentityNodes(g.get()); + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto x4_x2_two = ops::Const<int64>(s.WithOpName("x4/x2/two"), 2LL); + auto x4_y_two = ops::Const<int64>(s.WithOpName("x4/y/two"), 2LL); + auto y_x2_two = ops::Const<int64>(s.WithOpName("y/x2/two"), 2LL); + auto y_y_two = ops::Const<int64>(s.WithOpName("y/y/two"), 2LL); + auto x4_x2_scale = + ops::Cast(s.WithOpName("x4/x2/scale"), x4_x2_two, DT_FLOAT); + auto x4_y_scale = ops::Cast(s.WithOpName("x4/y/scale"), x4_y_two, DT_FLOAT); + auto y_x2_scale = ops::Cast(s.WithOpName("y/x2/scale"), y_x2_two, DT_FLOAT); + auto y_y_scale = ops::Cast(s.WithOpName("y/y/scale"), y_y_two, DT_FLOAT); + auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale); + auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), x4_x2_y, x4_y_scale); + auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), x4_y_y, y_x2_scale); + auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), y_x2_y, y_y_scale); + auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y_y_y, 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } } TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - std::unique_ptr<Graph> g(GetFuncBody("XTimes16", {{"T", DT_FLOAT}})); + std::unique_ptr<Graph> g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); ASSERT_TRUE(g != nullptr); - ExpandInlineFunctions(lib_, g.get()); - OptimizeGraph(lib_, &g); - const char* e0 = R"P( -(n2:float) -> (n7:float) { - n8 = Const[dtype=float, value=Tensor<type: float shape: [] values: 2>]() - n4 = Mul[T=float](n2, n8) - n5 = Mul[T=float](n4, n8) - n6 = Mul[T=float](n5, n8) - n7 = Mul[T=float](n6, n8) -} -)P"; - EXPECT_EQ(e0, DebugString(g.get())); + ExpandInlineFunctions(lib_.get(), g.get()); + OptimizeGraph(lib_.get(), &g); + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto x4_x2_scale = ops::Const<float>( + s.WithOpName("x4/x2/scale/_12__cf__2") + .WithDevice("/job:localhost/replica:0/task:0/cpu:0"), + 2.0f); + auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale); + auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), x4_x2_y, x4_x2_scale); + auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), x4_y_y, x4_x2_scale); + auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), y_x2_y, x4_x2_scale); + auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y_y_y, 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } } TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) { auto func = FDH::Create( // Creates a FunctionDef using NodeDefs - // Name + // Name "ManySwapsNodeDef", // Input {"x: float", "y: float"}, @@ -379,9 +440,9 @@ TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) { // Return {{"o", "g:output"}}); Init({test::function::Swap(), func}); - std::unique_ptr<Graph> g(GetFuncBody("ManySwapsNodeDef", {})); + std::unique_ptr<Graph> g = GetFuncBody("ManySwapsNodeDef", {}); ASSERT_TRUE(g != nullptr); - OptimizeGraph(lib_, &g); + OptimizeGraph(lib_.get(), &g); const char* e0 = R"P( (n3:float, n2:float) -> (n3:float) { } @@ -412,24 +473,35 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { {{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}}, {{"o", "o:z:0"}}); Init({test::function::Swap(), func}); - std::unique_ptr<Graph> g(GetFuncBody("ManySwapsFirst", {})); + std::unique_ptr<Graph> g = GetFuncBody("ManySwapsFirst", {}); ASSERT_TRUE(g != nullptr); - OptimizeGraph(lib_, &g); + OptimizeGraph(lib_.get(), &g); - // NOTE: We can remove n8, n9, n10, n11 with a control edge n8->n5. + // NOTE: We can remove func0, func1, func2, func9 with a control edge n8->n5. // But we don't have a pass doing that. - const char* e0 = R"P( -(n3:float, n2:float) -> (n6:float) { - n4 = Mul[T=float](n3, n3) - n8 = NoOp() @ n4 - n9 = Identity[T=float](n3) @ n8 - n10 = Identity[T=float](n2) @ n8 - n11 = NoOp() @ n9, n10 - n5 = Mul[T=float](n2, n2) @ n11 - n6 = Add[T=float](n4, n5) -} -)P"; - EXPECT_EQ(e0, DebugString(g.get())); + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1); + auto x2 = ops::Mul(s.WithOpName("x2"), x, x); + auto func0 = ops::NoOp(s.WithOpName("Func/_0").WithControlDependencies(x2)); + auto func1 = ops::Identity( + s.WithOpName("Func/_1").WithControlDependencies({func0}), x); + auto func2 = ops::Identity( + s.WithOpName("Func/_2").WithControlDependencies({func0}), y); + auto func9 = ops::NoOp(s.WithOpName("Func/_9").WithControlDependencies( + {func1.output.op(), func2.output.op()})); + auto y2 = + ops::Mul(s.WithOpName("y2").WithControlDependencies({func9}), y, y); + auto o = ops::Add(s.WithOpName("o"), x2, y2); + auto ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } } TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) { @@ -476,84 +548,136 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) { TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); - auto f = GetFuncBody("XTimesTwo", {{"T", DT_FLOAT}}); - const char* e0 = R"P( -(n4:float) -> (n5:float) { - n2 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() - n3 = Cast[DstT=float, SrcT=int64](n2) - n5 = Mul[T=float](n4, n3) -} -)P"; - EXPECT_EQ(e0, DebugString(f)); - delete f; - std::unique_ptr<Graph> g(GetGradBody("XTimesTwo", {{"T", DT_FLOAT}})); - const char* e1 = R"P( -(n4:float, n6:float) -> (n7:float) { - n2 = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() - n3 = Cast[DstT=float, SrcT=int64](n2) - n5 = Mul[T=float](n4, n3) - n7 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Mul[T=float]](n4, n3, n6) -} -)P"; - EXPECT_EQ(e1, DebugString(g.get())); - - OptimizeGraph(lib_, &g); - const char* e2 = R"P( -(n2:float, n3:float) -> (n9:float) { - n10 = Const[dtype=float, value=Tensor<type: float shape: [] values: 2>]() - n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [0] values: >]() - n6 = Shape[T=float, out_type=int32](n2) - n5 = Mul[T=float](n3, n10) - n7 = BroadcastGradientArgs[T=int32](n6, n11) - n8 = Sum[T=float, Tidx=int32, keep_dims=false](n5, n7) - n9 = Reshape[T=float, Tshape=int32](n8, n6) -} -)P"; - EXPECT_EQ(e2, DebugString(g.get())); + std::unique_ptr<Graph> f = GetFuncBody("XTimesTwo", {{"T", DT_FLOAT}}); + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto two = ops::Const(s.WithOpName("two"), 2LL); + auto scale = ops::Cast(s.WithOpName("scale"), two, DT_FLOAT); + auto y = ops::Mul(s.WithOpName("y"), x, scale); + auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + f->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } + + std::unique_ptr<Graph> g = GetGradBody("XTimesTwo", {{"T", DT_FLOAT}}); + + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1); + auto two = ops::Const(s.WithOpName("two"), 2LL); + auto scale = ops::Cast(s.WithOpName("scale"), two, DT_FLOAT); + auto y = ops::Mul(s.WithOpName("y"), x, scale); + NameAttrList fn; + fn.set_name("Mul"); + (*fn.mutable_attr())["T"].set_type(DT_FLOAT); + auto func1 = ops::SymbolicGradient( + s.WithOpName("Func/_1"), std::initializer_list<Input>{x, scale, func0}, + {DT_FLOAT, DT_FLOAT}, fn); + auto func2 = ops::_Retval(s.WithOpName("Func/_2"), func1[0], 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } + + OptimizeGraph(lib_.get(), &g); + + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1); + auto scale = + ops::Const(s.WithOpName("scale/_5__cf__6") + .WithDevice("/job:localhost/replica:0/task:0/cpu:0"), + 2.0f); + auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale); + auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x); + auto const0 = + ops::Const(s.WithOpName("Func/_1/sy/_6__cf__7") + .WithDevice("/job:localhost/replica:0/task:0/cpu:0"), + 0, {0}); + auto func1_rx = ops::internal::BroadcastGradientArgs( + s.WithOpName("Func/_1/rx"), func1_sx, const0); + auto func1_sum_gx = + ops::Sum(s.WithOpName("Func/_1/sum_gx"), func1_gx, func1_rx.r0); + auto func1_dx = + ops::Reshape(s.WithOpName("Func/_1/dx"), func1_sum_gx, func1_sx); + auto func2 = ops::_Retval(s.WithOpName("Func/_2"), func1_dx, 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } } TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) { Init({}); auto T = DT_FLOAT; - auto g = GetFuncBody("SymbolicGradient", - {{"f", FDH::FunctionRef("Add", {{"T", T}})}}); - const char* e0 = R"P( -(n7:float, n5:float, n2:float) -> (n14:float, n11:float) { - n3 = Identity[T=float](n2) - n4 = Identity[T=float](n2) - n6 = Shape[T=float, out_type=int32](n5) - n8 = Shape[T=float, out_type=int32](n7) - n9 = BroadcastGradientArgs[T=int32](n8, n6) - n10 = Sum[T=float, Tidx=int32, keep_dims=false](n3, n9:1) - n13 = Sum[T=float, Tidx=int32, keep_dims=false](n4, n9) - n11 = Reshape[T=float, Tshape=int32](n10, n6) - n14 = Reshape[T=float, Tshape=int32](n13, n8) -} -)P"; - EXPECT_EQ(e0, DebugString(g)); - delete g; + std::unique_ptr<Graph> g = GetFuncBody( + "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}}); + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1); + auto dz = ops::_Arg(s.WithOpName("dz"), DT_FLOAT, 2); + auto gx = ops::Identity(s.WithOpName("gx"), dz); + auto gy = ops::Identity(s.WithOpName("gy"), dz); + auto sx = ops::Shape(s.WithOpName("sx"), x); + auto sy = ops::Shape(s.WithOpName("sy"), y); + auto rx = ops::internal::BroadcastGradientArgs(s.WithOpName("rx"), sx, sy); + auto sum_gx = ops::Sum(s.WithOpName("sum_gx"), gx, rx.r0); + auto sum_gy = ops::Sum(s.WithOpName("sum_gy"), gy, rx.r1); + auto dx = ops::Reshape(s.WithOpName("dx"), sum_gx, sx); + auto dy = ops::Reshape(s.WithOpName("dy"), sum_gy, sy); + auto dx_ret = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0); + auto dy_ret = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } } TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) { Init({}); auto T = DT_FLOAT; - auto g = GetFuncBody("SymbolicGradient", - {{"f", FDH::FunctionRef("Mul", {{"T", T}})}}); - const char* e0 = R"P( -(n6:float, n3:float, n2:float) -> (n14:float, n11:float) { - n4 = Mul[T=float](n2, n3) - n5 = Shape[T=float, out_type=int32](n3) - n7 = Mul[T=float](n6, n2) - n8 = Shape[T=float, out_type=int32](n6) - n9 = BroadcastGradientArgs[T=int32](n8, n5) - n10 = Sum[T=float, Tidx=int32, keep_dims=false](n7, n9:1) - n13 = Sum[T=float, Tidx=int32, keep_dims=false](n4, n9) - n11 = Reshape[T=float, Tshape=int32](n10, n5) - n14 = Reshape[T=float, Tshape=int32](n13, n8) -} -)P"; - EXPECT_EQ(e0, DebugString(g)); - delete g; + std::unique_ptr<Graph> g = GetFuncBody( + "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}}); + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1); + auto dz = ops::_Arg(s.WithOpName("dz"), DT_FLOAT, 2); + auto gx = ops::Mul(s.WithOpName("gx"), dz, y); + auto sx = ops::Shape(s.WithOpName("sx"), x); + auto gy = ops::Mul(s.WithOpName("gy"), x, dz); + auto sy = ops::Shape(s.WithOpName("sy"), y); + auto rx = ops::internal::BroadcastGradientArgs(s.WithOpName("rx"), sx, sy); + auto sum_gx = ops::Sum(s.WithOpName("sum_gx"), gx, rx.r0); + auto sum_gy = ops::Sum(s.WithOpName("sum_gy"), gy, rx.r1); + auto dx = ops::Reshape(s.WithOpName("dx"), sum_gx, sx); + auto dy = ops::Reshape(s.WithOpName("dy"), sum_gy, sy); + auto dx_ret = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0); + auto dy_ret = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } } TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { @@ -570,108 +694,170 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { }); // TestGrad = Test'(x, y) - auto grad = - FDH::Define("TestGrad", {"x:float", "y:float"}, {"dx:float", "dy:float"}, - {}, {FDH::Const<float>("dz", 1), - {{"grad0", "grad1"}, - "SymbolicGradient", - {"x", "y", "dz"}, - { - {"f", FDH::FunctionRef("Test")}, - {"Tin", DataTypeSlice{T, T, T}}, - {"Tout", DataTypeSlice{T, T}}, - }}, - {{"dx"}, "Identity", {"grad0"}, {{"T", DT_FLOAT}}}, - {{"dy"}, "Identity", {"grad1"}, {{"T", DT_FLOAT}}}}); + auto grad = FDH::Define("TestGrad", {"x:float", "y:float"}, + {"dx:float", "dy:float"}, {}, + {FDH::Const<float>("dz", 1), + {{"grad0", "grad1"}, + "SymbolicGradient", + {"x", "y", "dz"}, + { + {"f", FDH::FunctionRef("Test")}, + {"Tin", DataTypeSlice{T, T, T}}, + {"Tout", DataTypeSlice{T, T}}, + }}, + {{"dx"}, "Identity", {"grad0"}, {{"T", DT_FLOAT}}}, + {{"dy"}, "Identity", {"grad1"}, {{"T", DT_FLOAT}}}}); Init({test, grad}); - std::unique_ptr<Graph> g(GetFuncBody("TestGrad", {})); + std::unique_ptr<Graph> g = GetFuncBody("TestGrad", {}); ASSERT_TRUE(g != nullptr); - const char* e0 = R"P( -(n4:float, n3:float) -> (n8:float, n6:float) { - n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]() - n5 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Test](n4, n3, n2) - n6 = Identity[T=float](n5:1) - n8 = Identity[T=float](n5) -} -)P"; - EXPECT_EQ(e0, DebugString(g.get())); + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1); + auto dz = ops::Const(s.WithOpName("dz"), 1.0f); + NameAttrList fn; + fn.set_name("Test"); + auto grad0 = ops::SymbolicGradient(s.WithOpName("grad0"), + std::initializer_list<Input>{x, y, dz}, + {DT_FLOAT, DT_FLOAT}, fn); + auto dx = ops::Identity(s.WithOpName("dx"), grad0[0]); + auto dy = ops::Identity(s.WithOpName("dy"), grad0[1]); + auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0); + auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } - ExpandInlineFunctions(lib_, g.get()); - const char* e1 = R"P( -(n4:float, n3:float) -> (n8:float, n6:float) { - n10 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]() - n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() - n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]() - n26 = Identity[T=float](n2) - n25 = Identity[T=float](n3) - n24 = Identity[T=float](n4) - n14 = Add[T=float](n24, n25) - n15 = Rank[T=float](n14) - n16 = Range[Tidx=int32](n11, n15, n10) - n20 = ZerosLike[T=int32](n15) - n17 = Sum[T=float, Tidx=int32, keep_dims=false](n14, n16) - n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, Tidx=int32, keep_dims=false]](n14, n16, n26) - n21 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Add[T=float]](n24, n25, n19) - n27 = Identity[T=float](n21) - n28 = Identity[T=float](n21:1) - n8 = Identity[T=float](n27) - n6 = Identity[T=float](n28) -} -)P"; - EXPECT_EQ(e1, DebugString(g.get())); - - OptimizeGraph(lib_, &g); - const char* e2 = R"P( -(n4:float, n3:float) -> (n25:float, n23:float) { - n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]() - n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]() - n8 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() - n19 = Shape[T=float, out_type=int32](n3) - n9 = Add[T=float](n4, n3) - n20 = Shape[T=float, out_type=int32](n4) - n10 = Rank[T=float](n9) - n14 = Shape[T=float, out_type=int32](n9) - n21 = BroadcastGradientArgs[T=int32](n20, n19) - n11 = Range[Tidx=int32](n8, n10, n7) - n12 = Shape[T=int32, out_type=int32](n11) - n13 = Fill[T=int32](n12, n7) - n15 = DynamicStitch[N=2, T=int32](n11, n11, n14, n13) - n16 = Reshape[T=float, Tshape=int32](n2, n15) - n17 = Div[T=int32](n14, n15) - n18 = Tile[T=float, Tmultiples=int32](n16, n17) - n22 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21:1) - n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21) - n23 = Reshape[T=float, Tshape=int32](n22, n19) - n25 = Reshape[T=float, Tshape=int32](n24, n20) -} -)P"; - EXPECT_EQ(e2, DebugString(g.get())); + ExpandInlineFunctions(lib_.get(), g.get()); + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1); + auto dz = ops::Const(s.WithOpName("dz"), 1.0f); + auto grad0_zero = ops::Const(s.WithOpName("grad0/zero"), 0); + auto grad0_one = ops::Const(s.WithOpName("grad0/one"), 1); + auto func0 = ops::Identity(s.WithOpName("Func/_0"), x); + auto func1 = ops::Identity(s.WithOpName("Func/_1"), y); + auto func2 = ops::Identity(s.WithOpName("Func/_2"), dz); + auto grad0_z = ops::Add(s.WithOpName("grad0/z"), func0, func1); + auto grad0_r = ops::Rank(s.WithOpName("grad0/r"), grad0_z); + auto grad0_indices = ops::Range(s.WithOpName("grad0/indices"), grad0_zero, + grad0_r, grad0_one); + auto grad0_l = ops::Sum(s.WithOpName("grad0/l"), grad0_z, grad0_indices); + + NameAttrList sum; + sum.set_name("Sum"); + (*sum.mutable_attr())["T"].set_type(DT_FLOAT); + (*sum.mutable_attr())["Tidx"].set_type(DT_INT32); + (*sum.mutable_attr())["keep_dims"].set_b(false); + auto grad0_func1 = ops::SymbolicGradient( + s.WithOpName("grad0/Func/_1"), + std::initializer_list<Input>{grad0_z, grad0_indices, func2}, + {DT_FLOAT, DT_INT32}, sum); + + auto grad0_func2 = ops::ZerosLike(s.WithOpName("grad0/Func/_2"), grad0_r); + + NameAttrList add; + add.set_name("Add"); + (*add.mutable_attr())["T"].set_type(DT_FLOAT); + auto grad0_func3 = ops::SymbolicGradient( + s.WithOpName("grad0/Func/_3"), + std::initializer_list<Input>{func0, func1, grad0_func1[0]}, + {DT_FLOAT, DT_FLOAT}, add); + + auto func3 = ops::Identity(s.WithOpName("Func/_3"), grad0_func3[0]); + auto func4 = ops::Identity(s.WithOpName("Func/_4"), grad0_func3[1]); + auto dx = ops::Identity(s.WithOpName("dx"), func3); + auto dy = ops::Identity(s.WithOpName("dy"), func4); + auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0); + auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1); + + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } + + OptimizeGraph(lib_.get(), &g); + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); + auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1); + auto dz = ops::Const(s.WithOpName("dz"), 1.0f); + auto grad0_zero = ops::Const(s.WithOpName("grad0/zero"), 0); + auto grad0_one = ops::Const(s.WithOpName("grad0/one"), 1); + auto grad0_z = ops::Add(s.WithOpName("grad0/z"), x, y); + auto grad0_r = ops::Rank(s.WithOpName("grad0/r"), grad0_z); + auto grad0_indices = ops::Range(s.WithOpName("grad0/indices"), grad0_zero, + grad0_r, grad0_one); + auto i_shape = + ops::Shape(s.WithOpName("grad0/Func/_1/i_shape"), grad0_indices); + auto stitch_val = ops::Fill(s.WithOpName("grad0/Func/_1/stitch_val1"), + i_shape, grad0_one); + auto x_shape = ops::Shape(s.WithOpName("grad0/Func/_1/x_shape"), grad0_z); + auto y_shape = ops::DynamicStitch( + s.WithOpName("grad0/Func/_1/y_shape"), + std::initializer_list<Input>{grad0_indices, grad0_indices}, + std::initializer_list<Input>{x_shape, stitch_val}); + auto dy_reshaped = + ops::Reshape(s.WithOpName("grad0/Func/_1/dy_reshaped"), dz, y_shape); + auto tile_scaling = + ops::Div(s.WithOpName("grad0/Func/_1/tile_scaling"), x_shape, y_shape); + auto func1_dx = + ops::Tile(s.WithOpName("grad0/Func/_1/dx"), dy_reshaped, tile_scaling); + + auto sx = ops::Shape(s.WithOpName("grad0/Func/_3/sx"), x); + auto sy = ops::Shape(s.WithOpName("grad0/Func/_3/sy"), y); + auto rx = ops::internal::BroadcastGradientArgs( + s.WithOpName("grad0/Func/_3/rx"), sx, sy); + auto sum_gx = + ops::Sum(s.WithOpName("grad0/Func/_3/sum_gx"), func1_dx, rx.r0); + auto sum_gy = + ops::Sum(s.WithOpName("grad0/Func/_3/sum_gy"), func1_dx, rx.r1); + auto dx = ops::Reshape(s.WithOpName("grad0/Func/_3/dx"), sum_gx, sx); + auto dy = ops::Reshape(s.WithOpName("grad0/Func/_3/dy"), sum_gy, sy); + + auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0); + auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1); + + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + + GraphDef actual; + g->ToGraphDef(&actual); + TF_EXPECT_GRAPH_EQ(expected, actual); + } } namespace { bool DoNothing(Graph* g) { return false; } -string Optimize(const std::function<bool(Graph* g)>& pass, - const FunctionDef& fdef) { +GraphDef Optimize(const std::function<bool(Graph* g)>& pass, + const FunctionDef& fdef) { InstantiationResult result; InstantiateAttrValueMap empty; TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result)); - Graph* g = new Graph(OpRegistry::Global()); + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = false; - TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g)); - pass(g); - Graph* g1 = new Graph(OpRegistry::Global()); - CopyGraph(*g, g1); - delete g; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, result.gdef, g.get())); + pass(g.get()); + std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global())); + CopyGraph(*g, g1.get()); + g = nullptr; GraphDef gdef; g1->ToGraphDef(&gdef); - delete g1; - return DebugString(gdef); + return gdef; } } // end namespace @@ -700,21 +886,25 @@ TEST(OptimizationTest, RemoveDeadNodes) { {{"keep_me"}, "RandomUniform", {"o"}, {{"T", T}, {"dtype", DT_FLOAT}}}, // y = Add<T>(a, o) {{"y"}, "Add", {"a", "o"}, {{"T", T}}}}); - const char* e0 = R"S( -(x:int32) -> (y:int32) { - o = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]() - keep_me = RandomUniform[T=int32, dtype=float, seed2=0, seed=0](o) - x1 = Add[T=int32](o, o) - a = Square[T=int32](x) - y = Add[T=int32](a, o) - x2 = Mul[T=int32](a, x1) - x3 = Mul[T=int32](x1, x2) -} -)S"; - EXPECT_EQ(Optimize(DoNothing, func), e0); + + GraphDef expected; + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0); + auto o = ops::Const(s.WithOpName("o"), 1); + auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT); + auto x1 = ops::Add(s.WithOpName("x1"), o, o); + auto a = ops::Square(s.WithOpName("a"), x); + auto y = ops::Add(s.WithOpName("y"), a, o); + auto x2 = ops::Mul(s.WithOpName("x2"), a, x1); + auto x3 = ops::Mul(s.WithOpName("x3"), x1, x2); + auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0); + TF_ASSERT_OK(s.ToGraphDef(&expected)); + } + TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func)); // TODO(zhifengc): Comes up another test case. - EXPECT_EQ(Optimize(::tensorflow::RemoveDeadNodes, func), e0); + TF_EXPECT_GRAPH_EQ(expected, Optimize(::tensorflow::RemoveDeadNodes, func)); } TEST(OptimizationTest, RemoveIdentityNodes_Ref) { @@ -735,23 +925,19 @@ TEST(OptimizationTest, RemoveIdentityNodes_Ref) { {{"v_read"}, "Identity", {"v"}, {{"T", T}}}, // returns v + v {{"ret"}, "Add", {"v_read", "v_read"}, {{"T", T}}}}); - const char* e0 = R"S( -() -> (ret:float) { - v = VariableV2[container="", dtype=float, shape=[], shared_name=""]() - v_read = Identity[T=float](v) - ret = Add[T=float](v_read, v_read) -} -)S"; - EXPECT_EQ(Optimize(DoNothing, func), e0); - - const char* e1 = R"S( -() -> (ret:float) { - v = VariableV2[container="", dtype=float, shape=[], shared_name=""]() - v_read = Identity[T=float](v) - ret = Add[T=float](v_read, v_read) -} -)S"; - EXPECT_EQ(Optimize(::tensorflow::RemoveIdentityNodes, func), e1); + + GraphDef expected; + { + Scope s = Scope::NewRootScope(); + auto v = ops::Variable(s.WithOpName("v"), PartialTensorShape({}), DT_FLOAT); + auto v_read = ops::Identity(s.WithOpName("v_read"), v); + auto ret = ops::Add(s.WithOpName("ret"), v_read, v_read); + auto ret_retval = ops::_Retval(s.WithOpName("ret_RetVal"), ret, 0); + TF_ASSERT_OK(s.ToGraphDef(&expected)); + } + TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func)); + TF_EXPECT_GRAPH_EQ(expected, + Optimize(::tensorflow::RemoveIdentityNodes, func)); } TEST(OptimizationTest, RemoveIdentityNodes) { @@ -782,28 +968,38 @@ TEST(OptimizationTest, RemoveIdentityNodes) { {"x3"}}, // y = Add<T>(a, o) {{"y"}, "Add", {"a", "o"}, {{"T", T}}}}); - const char* e0 = R"S( -(x:int32) -> (y:int32) { - o = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]() - a = Square[T=int32](x) - y = Add[T=int32](a, o) - x1 = Identity[T=int32](a) - x2 = Identity[T=int32](x1) - x3 = Identity[T=int32](x2) - keep_me = RandomUniform[T=int32, dtype=float, seed2=0, seed=0](o) @ x3 -} -)S"; - EXPECT_EQ(Optimize(DoNothing, func), e0); - - const char* e1 = R"S( -(x:int32) -> (y:int32) { - o = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]() - a = Square[T=int32](x) - y = Add[T=int32](a, o) - keep_me = RandomUniform[T=int32, dtype=float, seed2=0, seed=0](o) @ a -} -)S"; - EXPECT_EQ(Optimize(::tensorflow::RemoveIdentityNodes, func), e1); + + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0); + auto o = ops::Const(s.WithOpName("o"), 1); + auto a = ops::Square(s.WithOpName("a"), x); + auto y = ops::Add(s.WithOpName("y"), a, o); + auto x1 = ops::Identity(s.WithOpName("x1"), a); + auto x2 = ops::Identity(s.WithOpName("x2"), x1); + auto x3 = ops::Identity(s.WithOpName("x3"), x2); + auto keep_me = ops::RandomUniform( + s.WithOpName("keep_me").WithControlDependencies(x3), {o}, DT_FLOAT); + auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func)); + } + + { + Scope s = Scope::NewRootScope(); + auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0); + auto o = ops::Const(s.WithOpName("o"), 1); + auto a = ops::Square(s.WithOpName("a"), x); + auto y = ops::Add(s.WithOpName("y"), a, o); + auto keep_me = ops::RandomUniform( + s.WithOpName("keep_me").WithControlDependencies(a), {o}, DT_FLOAT); + auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, + Optimize(::tensorflow::RemoveIdentityNodes, func)); + } } TEST(OptimizationTest, RemoveListArrayConverter) { @@ -840,49 +1036,63 @@ TEST(OptimizationTest, RemoveListArrayConverter) { // Return values {{"o", "o:sum"}}); - const char* e0 = R"P( -(i:float) -> (o:float) { - zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() - s = Split[T=float, num_split=4](zero, i) - a = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](s, s:1, s:2, s:3) - r = Mul[T=float](a:2, a:3) - l = Mul[T=float](a, a:1) - x = _ListToArray[N=2, T=float, Tin={float, float}](l, r) - o = AddN[N=2, T=float](x, x:1) -} -)P"; - EXPECT_EQ(Optimize(DoNothing, func), e0); - - const char* e1 = R"P( -(i:float) -> (o:float) { - zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() - s = Split[T=float, num_split=4](zero, i) - r = Mul[T=float](Func/_2, Func/_3) - l = Mul[T=float](Func/_0, Func/_1) - o = AddN[N=2, T=float](Func/_4, Func/_5) - Func/_0 = Identity[T=float](s) - Func/_1 = Identity[T=float](s:1) - Func/_2 = Identity[T=float](s:2) - Func/_3 = Identity[T=float](s:3) - Func/_4 = Identity[T=float](l) - Func/_5 = Identity[T=float](r) -} -)P"; - EXPECT_EQ(Optimize(RemoveListArrayConverter, func), e1); - - const char* e2 = R"P( -(i:float) -> (o:float) { - zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() - s = Split[T=float, num_split=4](zero, i) - r = Mul[T=float](s:2, s:3) - l = Mul[T=float](s, s:1) - o = AddN[N=2, T=float](l, r) -} -)P"; - auto remove_listarray_and_identity = [](Graph* g) { - return RemoveListArrayConverter(g) && RemoveIdentityNodes(g); - }; - EXPECT_EQ(Optimize(remove_listarray_and_identity, func), e2); + { + Scope scope = Scope::NewRootScope(); + auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0); + auto zero = ops::Const(scope.WithOpName("zero"), 0); + auto s = ops::Split(scope.WithOpName("s"), zero, i, 4); + auto a = ops::_ArrayToList(scope.WithOpName("a"), s.output, + {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}); + auto r = ops::Mul(scope.WithOpName("r"), a[2], a[3]); + auto l = ops::Mul(scope.WithOpName("l"), a[0], a[1]); + auto x = ops::_ListToArray(scope.WithOpName("x"), + std::initializer_list<Input>{l, r}, DT_FLOAT, 2); + auto o = ops::AddN(scope.WithOpName("o"), x.output); + auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func)); + } + + { + Scope scope = Scope::NewRootScope(); + auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0); + auto zero = ops::Const(scope.WithOpName("zero"), 0); + auto s = ops::Split(scope.WithOpName("s"), zero, i, 4); + auto func_0 = ops::Identity(scope.WithOpName("Func/_0"), s[0]); + auto func_1 = ops::Identity(scope.WithOpName("Func/_1"), s[1]); + auto func_2 = ops::Identity(scope.WithOpName("Func/_2"), s[2]); + auto func_3 = ops::Identity(scope.WithOpName("Func/_3"), s[3]); + auto r = ops::Mul(scope.WithOpName("r"), func_2, func_3); + auto l = ops::Mul(scope.WithOpName("l"), func_0, func_1); + auto func_4 = ops::Identity(scope.WithOpName("Func/_4"), l); + auto func_5 = ops::Identity(scope.WithOpName("Func/_5"), r); + auto o = ops::AddN(scope.WithOpName("o"), + std::initializer_list<Input>{func_4, func_5}); + auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, Optimize(RemoveListArrayConverter, func)); + } + + { + Scope scope = Scope::NewRootScope(); + auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0); + auto zero = ops::Const(scope.WithOpName("zero"), 0); + auto s = ops::Split(scope.WithOpName("s"), zero, i, 4); + auto r = ops::Mul(scope.WithOpName("r"), s[2], s[3]); + auto l = ops::Mul(scope.WithOpName("l"), s[0], s[1]); + auto o = + ops::AddN(scope.WithOpName("o"), std::initializer_list<Input>{l, r}); + auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + + auto remove_listarray_and_identity = [](Graph* g) { + return RemoveListArrayConverter(g) && RemoveIdentityNodes(g); + }; + TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func)); + } } TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) { @@ -911,33 +1121,47 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) { {"x"}}}, {{"o", "o:sum"}}); - const char* e0 = R"P( -(i:float) -> (o:float) { - dummy = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() - x = _ListToArray[N=2, T=float, Tin={float, float}](i, i) @ dummy - o = AddN[N=2, T=float](x, x:1) @ x -} -)P"; - EXPECT_EQ(Optimize(DoNothing, func), e0); - - const char* e1 = R"P( -(i:float) -> (o:float) { - dummy = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() - o = AddN[N=2, T=float](Func/_0, Func/_1) @ Func/_3 - Func/_0 = Identity[T=float](i) @ Func/_2 - Func/_1 = Identity[T=float](i) @ Func/_2 - Func/_2 = NoOp() @ dummy - Func/_3 = NoOp() @ Func/_0, Func/_1 -} -)P"; - EXPECT_EQ(Optimize(RemoveListArrayConverter, func), e1); + { + Scope s = Scope::NewRootScope(); + auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0); + auto dummy = ops::Const(s.WithOpName("dummy"), 0); + auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy), + std::initializer_list<Input>{i, i}, DT_FLOAT, 2); + auto o = + ops::AddN(s.WithOpName("o").WithControlDependencies({x.output[0].op()}), + x.output); + auto o_ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0); + GraphDef expected; + TF_ASSERT_OK(s.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func)); + } + + GraphDef expected; + { + Scope s = Scope::NewRootScope(); + auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0); + auto dummy = ops::Const(s.WithOpName("dummy"), 0); + auto func_2 = + ops::NoOp(s.WithOpName("Func/_2").WithControlDependencies(dummy)); + auto func_0 = ops::Identity( + s.WithOpName("Func/_0").WithControlDependencies({func_2}), i); + auto func_1 = ops::Identity( + s.WithOpName("Func/_1").WithControlDependencies({func_2}), i); + auto func_3 = ops::NoOp(s.WithOpName("Func/_3").WithControlDependencies( + {func_0.output.op(), func_1.output.op()})); + auto o = ops::AddN(s.WithOpName("o").WithControlDependencies({func_3}), + std::initializer_list<Input>{func_0, func_1}); + auto o_ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0); + TF_ASSERT_OK(s.ToGraphDef(&expected)); + } + TF_EXPECT_GRAPH_EQ(expected, Optimize(RemoveListArrayConverter, func)); auto remove_listarray_and_identity = [](Graph* g) { return RemoveListArrayConverter(g) && RemoveIdentityNodes(g); }; // NOTE: We are not removing Identity nodes with any control // dependencies yet. - EXPECT_EQ(Optimize(remove_listarray_and_identity, func), e1); + TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func)); } } // end namespace tensorflow |