aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-22 18:01:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-22 19:17:57 -0700
commite058f2d1382aae39c88b6d3a71f25ec6114ffb31 (patch)
treeb854b40232abffe4511662ac2aa8f51ffc6a04d5 /tensorflow/core/common_runtime/function_test.cc
parent946ef5dbcf6becffc29a8028956d740ab6e9cc51 (diff)
Begin transition to use NodeDef in FunctionDef instead of
FunctionDef.Node. Change: 131009401
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r--tensorflow/core/common_runtime/function_test.cc51
1 files changed, 42 insertions, 9 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index c4d1bb59e8..2f5507a0c5 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -257,7 +257,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
Graph* g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}});
- CHECK(g);
+ ASSERT_TRUE(g != nullptr);
const char* e0 = R"P(
(n2:float) -> (n4:float) {
n3 = XTimesFour[T=float](n2)
@@ -342,7 +342,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
Graph* g = GetFuncBody("XTimes16", {{"T", DT_FLOAT}});
- CHECK(g);
+ ASSERT_TRUE(g != nullptr);
ExpandInlineFunctions(lib_, g);
OptimizeGraph(lib_, &g);
const char* e0 = R"P(
@@ -358,8 +358,8 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
delete g;
}
-TEST_F(FunctionLibraryRuntimeTest, ManySwaps) {
- auto func = FDH::Define(
+TEST_F(FunctionLibraryRuntimeTest, ManySwapsOld) {
+ auto func = FDH::Define( // Creates a FunctionDef using FunctionDef::Nodes
// Name
"ManySwapsFirst",
// Args
@@ -377,8 +377,41 @@ TEST_F(FunctionLibraryRuntimeTest, ManySwaps) {
{{"a5", "b5"}, "Swap", {"a4", "b4"}, {{"T", DT_FLOAT}}},
{{"o"}, "Identity", {"a5"}, {{"T", DT_FLOAT}}}});
Init({test::function::Swap(), func});
- Graph* g = GetFuncBody("ManySwapsFirst", {{"T", DT_FLOAT}});
- CHECK(g);
+ Graph* g = GetFuncBody("ManySwapsFirst", {});
+ ASSERT_TRUE(g != nullptr);
+ OptimizeGraph(lib_, &g);
+ const char* e0 = R"P(
+(n3:float, n2:float) -> (n3:float) {
+}
+)P";
+ EXPECT_EQ(e0, DebugString(g));
+ delete g;
+}
+
+// Like the above test, but using NodeDefs in the FunctionDef.
+TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
+ auto func = FDH::Create( // Creates a FunctionDef using NodeDefs
+ // Name
+ "ManySwapsNodeDef",
+ // Input
+ {"x: float", "y: float"},
+ // Output
+ {"o: float"},
+ // Attr
+ {},
+ // Nodes
+ {{{"a"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}},
+ {{"b"}, "Swap", {"a:o0", "a:o1"}, {{"T", DT_FLOAT}}},
+ {{"c"}, "Swap", {"b:o0", "b:o1"}, {{"T", DT_FLOAT}}},
+ {{"d"}, "Swap", {"c:o0", "c:o1"}, {{"T", DT_FLOAT}}},
+ {{"e"}, "Swap", {"d:o0", "d:o1"}, {{"T", DT_FLOAT}}},
+ {{"f"}, "Swap", {"e:o0", "e:o1"}, {{"T", DT_FLOAT}}},
+ {{"g"}, "Identity", {"f:o0"}, {{"T", DT_FLOAT}}}},
+ // Return
+ {{"o", "g:output"}});
+ Init({test::function::Swap(), func});
+ Graph* g = GetFuncBody("ManySwapsNodeDef", {});
+ ASSERT_TRUE(g != nullptr);
OptimizeGraph(lib_, &g);
const char* e0 = R"P(
(n3:float, n2:float) -> (n3:float) {
@@ -410,8 +443,8 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
{{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}},
{{"o"}, "Add", {"x2", "y2"}, {{"T", DT_FLOAT}}}});
Init({test::function::Swap(), func});
- Graph* g = GetFuncBody("ManySwapsFirst", {{"T", DT_FLOAT}});
- CHECK(g);
+ Graph* g = GetFuncBody("ManySwapsFirst", {});
+ ASSERT_TRUE(g != nullptr);
OptimizeGraph(lib_, &g);
// NOTE: We can remove n8, n9, n10, n11 with a control edge n8->n5.
@@ -588,7 +621,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
Init({test, grad});
Graph* g = GetFuncBody("TestGrad", {});
- CHECK(g);
+ ASSERT_TRUE(g != nullptr);
const char* e0 = R"P(
(n4:float, n3:float) -> (n8:float, n6:float) {
n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()