diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-08-22 18:01:49 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-22 19:17:57 -0700 |
commit | e058f2d1382aae39c88b6d3a71f25ec6114ffb31 (patch) | |
tree | b854b40232abffe4511662ac2aa8f51ffc6a04d5 /tensorflow/core/common_runtime/function_test.cc | |
parent | 946ef5dbcf6becffc29a8028956d740ab6e9cc51 (diff) |
Begin transition to use NodeDef in FunctionDef instead of
FunctionDef.Node.
Change: 131009401
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 51 |
1 files changed, 42 insertions, 9 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index c4d1bb59e8..2f5507a0c5 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -257,7 +257,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); Graph* g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); - CHECK(g); + ASSERT_TRUE(g != nullptr); const char* e0 = R"P( (n2:float) -> (n4:float) { n3 = XTimesFour[T=float](n2) @@ -342,7 +342,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Init({test::function::XTimesTwo(), test::function::XTimesFour(), test::function::XTimes16()}); Graph* g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}}); - CHECK(g); + ASSERT_TRUE(g != nullptr); ExpandInlineFunctions(lib_, g); OptimizeGraph(lib_, &g); const char* e0 = R"P( @@ -358,8 +358,8 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { delete g; } -TEST_F(FunctionLibraryRuntimeTest, ManySwaps) { - auto func = FDH::Define( +TEST_F(FunctionLibraryRuntimeTest, ManySwapsOld) { + auto func = FDH::Define( // Creates a FunctionDef using FunctionDef::Nodes // Name "ManySwapsFirst", // Args @@ -377,8 +377,41 @@ TEST_F(FunctionLibraryRuntimeTest, ManySwaps) { {{"a5", "b5"}, "Swap", {"a4", "b4"}, {{"T", DT_FLOAT}}}, {{"o"}, "Identity", {"a5"}, {{"T", DT_FLOAT}}}}); Init({test::function::Swap(), func}); - Graph* g = GetFuncBody("ManySwapsFirst", {{"T", DT_FLOAT}}); - CHECK(g); + Graph* g = GetFuncBody("ManySwapsFirst", {}); + ASSERT_TRUE(g != nullptr); + OptimizeGraph(lib_, &g); + const char* e0 = R"P( +(n3:float, n2:float) -> (n3:float) { +} +)P"; + EXPECT_EQ(e0, DebugString(g)); + delete g; +} + +// Like the above test, but using NodeDefs in the FunctionDef. +TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) { + auto func = FDH::Create( // Creates a FunctionDef using NodeDefs + // Name + "ManySwapsNodeDef", + // Input + {"x: float", "y: float"}, + // Output + {"o: float"}, + // Attr + {}, + // Nodes + {{{"a"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}}, + {{"b"}, "Swap", {"a:o0", "a:o1"}, {{"T", DT_FLOAT}}}, + {{"c"}, "Swap", {"b:o0", "b:o1"}, {{"T", DT_FLOAT}}}, + {{"d"}, "Swap", {"c:o0", "c:o1"}, {{"T", DT_FLOAT}}}, + {{"e"}, "Swap", {"d:o0", "d:o1"}, {{"T", DT_FLOAT}}}, + {{"f"}, "Swap", {"e:o0", "e:o1"}, {{"T", DT_FLOAT}}}, + {{"g"}, "Identity", {"f:o0"}, {{"T", DT_FLOAT}}}}, + // Return + {{"o", "g:output"}}); + Init({test::function::Swap(), func}); + Graph* g = GetFuncBody("ManySwapsNodeDef", {}); + ASSERT_TRUE(g != nullptr); OptimizeGraph(lib_, &g); const char* e0 = R"P( (n3:float, n2:float) -> (n3:float) { @@ -410,8 +443,8 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { {{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}}, {{"o"}, "Add", {"x2", "y2"}, {{"T", DT_FLOAT}}}}); Init({test::function::Swap(), func}); - Graph* g = GetFuncBody("ManySwapsFirst", {{"T", DT_FLOAT}}); - CHECK(g); + Graph* g = GetFuncBody("ManySwapsFirst", {}); + ASSERT_TRUE(g != nullptr); OptimizeGraph(lib_, &g); // NOTE: We can remove n8, n9, n10, n11 with a control edge n8->n5. @@ -588,7 +621,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { Init({test, grad}); Graph* g = GetFuncBody("TestGrad", {}); - CHECK(g); + ASSERT_TRUE(g != nullptr); const char* e0 = R"P( (n4:float, n3:float) -> (n8:float, n6:float) { n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]() |