diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-28 16:10:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 16:17:55 -0700 |
commit | 478d370eb116ad2294134d75a886637a7d6da225 (patch) | |
tree | 279ef8e8a2c9abeeda583393a986f055b9be314c /tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc | |
parent | a98bac521406bedef3ff2b9af9564b21ddda4d82 (diff) |
[tf.data] Use Graph instead of GraphDef/FunctionDef for vectorization transforms
PiperOrigin-RevId: 215011835
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc | 205 |
1 files changed, 145 insertions, 60 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc index e129fa9237..1ff62217dd 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" @@ -60,6 +61,11 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs, return node; } +string GetRetval(const FunctionDef& function_def, int index) { + return function_def.ret().at( + function_def.signature().output_arg(index).name()); +} + // TODO(rachelim): Use FunctionDefHelper::Create instead FunctionDef CreateFunction( StringPiece name, const std::vector<std::pair<string, DataType>>& inputs, @@ -85,7 +91,6 @@ FunctionDef CreateFunction( return func; } -TEST(FunctionDefInputDescTest, ConstructedCorrectly) {} // Before: // @@ -133,10 +138,15 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { {{}, {}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - EXPECT_EQ(outer.ret().at("mapdefun"), "ret0"); - EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1"); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + EXPECT_EQ(GetRetval(*vectorized, 0), "ret0"); + EXPECT_EQ(GetRetval(*vectorized, 1), "ret1"); } // Before: @@ -149,12 +159,12 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { // | +-----------+ Arg0 +---+ Arg1 +----+ | // | | +---+--+ +---+--+ | | // | | | | | | -// | | +------+ | +---v--+ | | -// | | |Const | | | Op0 | | | -// | | +---v--+ | +---+--+ | | +// | | +------+ | | | | +// | | |Const | | | | | +// | | +---v--+ | | | | // | | | | | | | // | | | +---v--+ +---v--+ | | -// | | +---| XOp1 | | XOp2 | | | +// | | +---| XOp1 | | Cast | | | // | | +---+--+ +---+--+ | | // | | | | | | // | | MapDefun +---v--+ +---v--+ | | @@ -165,23 +175,50 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { // +---------------+ Ret0 +---+ Ret1 +--------+ // +------+ +------+ // -// where XOp1 and XOp2 are not convertible. +// where XOp1 is not convertible. // // After: // -// No change because the ops are not convertible. +// +// +------+ +------+ +// +---------------+ Arg0 +---+ Arg1 +--------+ +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ | | +// | +-----------+ Arg0 +-+ | | +// | | +---+--+ | | | +// | | | | | | +// | | +------+ | | | | +// | | |Const | | | | | +// | | +---v--+ | | | | +// | | | | | | | +// | | | +---v--+ | +---v--+ | +// | | +---| XOp1 | | | Cast | | +// | | +---+--+ | +---+--+ | +// | | | | | | +// | | MapDefun +---v--+ | | | +// | +-----------+ Ret0 +-+ | | +// | +---+--+ | | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ // TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) { FunctionDef inner = CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}}, {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, - {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}}); + {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}}); + // TODO(rachelim): If we ever write a converter for MatMul, we have to + // change this test. NodeDef* x_op1 = - function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner); + function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner); CHECK_NOTNULL(x_op1); + graph_transforms::SetNodeAttr("T", DT_INT32, x_op1); - NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner); - CHECK_NOTNULL(x_op2); + NodeDef* cast_node = + AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner); + CHECK_NOTNULL(cast_node); FunctionDef outer = CreateFunction( "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}}, @@ -193,12 +230,22 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) { {{}, {}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - FunctionDef outer_copy(outer); - FunctionDef inner_copy(inner); - VectorizeMapDefun(&outer, &inner, map_defun); - // They should be unchanged - EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer)); - EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + + auto map_defun_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized)); + // The Cast node should be converted just fine. + EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0"); + + // The inner function should only have one retval. + FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib); + const FunctionDef* map_defun_fn = + lib_def.Find(map_defun_node.attr().at("f").func().name()); + EXPECT_EQ(map_defun_fn->signature().output_arg_size(), 1); } // Before: @@ -257,14 +304,19 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) { inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& cast_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); EXPECT_EQ(cast_node.input(0), "x"); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(cast_node.name(), ":y:0")); - EXPECT_EQ(outer.node_def_size(), 1); + EXPECT_EQ(vectorized->node_def_size(), 1); } // Before: @@ -330,16 +382,21 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) { {{}, {}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& cast_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); EXPECT_EQ(cast_node.input(0), "x"); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(cast_node.name(), ":y:0")); - EXPECT_EQ(outer.ret().at("mapdefun_0"), + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(cast_node.name(), ":y:0")); - EXPECT_EQ(outer.node_def_size(), 1); + EXPECT_EQ(vectorized->node_def_size(), 1); } // Before: @@ -411,21 +468,26 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) { {{1}, {1}, {1}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& unpack_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& unpack_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Unpack", *vectorized)); EXPECT_EQ(unpack_node.input(0), "x"); EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); EXPECT_EQ(unpack_node.attr().at("num").i(), 3); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(unpack_node.name(), ":output:0")); - EXPECT_EQ(outer.ret().at("mapdefun_0"), + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(unpack_node.name(), ":output:1")); - EXPECT_EQ(outer.ret().at("mapdefun_1"), + EXPECT_EQ(GetRetval(*vectorized, 2), strings::StrCat(unpack_node.name(), ":output:2")); - EXPECT_EQ(outer.node_def_size(), 1); + EXPECT_EQ(vectorized->node_def_size(), 1); } // Before: @@ -486,7 +548,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) { {"ret1", "MyUnstack:output:1"}, {"ret2", "MyUnstack:output:2"}}); NodeDef* cast_op = - AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner); + AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner); CHECK_NOTNULL(cast_op); NodeDef* unstack_op = AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner); @@ -505,25 +567,30 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) { {{1}, {1}, {1}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& cast_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); EXPECT_EQ(cast_node.input(0), "x"); - const NodeDef& unpack_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer)); + const NodeDef& unpack_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Unpack", *vectorized)); EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0")); EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); EXPECT_EQ(unpack_node.attr().at("num").i(), 3); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(unpack_node.name(), ":output:0")); - EXPECT_EQ(outer.ret().at("mapdefun_0"), + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(unpack_node.name(), ":output:1")); - EXPECT_EQ(outer.ret().at("mapdefun_1"), + EXPECT_EQ(GetRetval(*vectorized, 2), strings::StrCat(unpack_node.name(), ":output:2")); - EXPECT_EQ(outer.node_def_size(), 2); + EXPECT_EQ(vectorized->node_def_size(), 2); } // Before: @@ -561,9 +628,11 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { FunctionDef inner = CreateFunction("inner_function", {{"arg0", DT_INT32}}, {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}}); - // The attrs aren't relevant - NodeDef* print_op = - function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner); + NodeDef* print_op = function_utils::AddNode( + "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner); + graph_transforms::SetNodeAttr("T", DT_INT32, print_op); + graph_transforms::SetNodeAttr("U", gtl::ArraySlice<DataType>({DT_INT32}), + print_op); CHECK_NOTNULL(print_op); NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64, false, &inner); @@ -578,11 +647,27 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - FunctionDef outer_copy(outer); - FunctionDef inner_copy(inner); - VectorizeMapDefun(&outer, &inner, map_defun); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); // They should be unchanged - EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer)); + // We check this somewhat manually as the names of nodes may have changed + EXPECT_EQ(vectorized->node_def_size(), 1); + const NodeDef& map_defun_node = vectorized->node_def(0); + EXPECT_EQ(map_defun_node.op(), "MapDefun"); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib); + const FunctionDef* map_defun_fn = + lib_def.Find(map_defun_node.attr().at("f").func().name()); + + const NodeDef& print_node = map_defun_fn->node_def( + function_utils::FindFunctionNodeWithOp("Print", *map_defun_fn)); + const NodeDef& cast_node = map_defun_fn->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *map_defun_fn)); + string control_input = strings::StrCat("^", print_node.name()); + EXPECT_TRUE(cast_node.input(0) == control_input || + cast_node.input(1) == control_input); } // TODO(rachelim): More test cases when we get around to implementing them: |