diff options
author | Rachel Lim <rachelim@google.com> | 2018-10-04 09:26:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 09:35:51 -0700 |
commit | 1fb84c2e41c454939a02a69093cb214673eab343 (patch) | |
tree | e3ee1c19e3a73e1d1cddbc76d5573b7800b1048b /tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc | |
parent | ac22e1583aed390d78d2e87a4bf8a6ec39400ec4 (diff) |
Add ability to vectorize nodes that do not derive from function arguments. (This indirectly handles "Const" outputs automagically, since they are always unstacked.)
PiperOrigin-RevId: 215749824
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc | 251 |
1 files changed, 251 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc index 1ff62217dd..a958d706c1 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -670,6 +670,257 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { cast_node.input(1) == control_input); } +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | | | +// | | +------+ | | +// | | |Const | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | +------+ | +// | |Const | | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// | |Stack*| | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeConst) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2)}, + {{"ret0", "Const:output:0"}}); + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized)); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | | | +// | | +------+ | | +// | | |Const | | | +// | | +---+--+ | | +// | | | | | +// | | +---v--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | +------+ | +// | |Const | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | |Stack*| | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2)}, + {{"ret0", "Cast:y:0"}}); + AddCastNode("Cast", {"Const:output:0"}, DT_INT32, DT_INT64, false, &inner); + + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + auto const_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Const", *vectorized)); + auto cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); + EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')), + const_node.name()); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | +------+ +------+ | | +// | | |Const | |Const | | | +// | | +---+--+ +---+--+ | | +// | | : +---v--+ | | +// | | ::::::> Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | | +// | +------+ | +// | +------+ |Const | | +// | |Const | +---+--+ | +// | +---+--+ | | +// | : +---v--+ | +// | ::::::> Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | +Stack*+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2), + FunctionDefHelper::Const("ConstDep", 3)}, + {{"ret0", "Cast:y:0"}}); + AddCastNode("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64, + false, &inner); + + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + + auto find_const = [vectorized](int val) -> const NodeDef* { + for (const auto& n : vectorized->node_def()) { + if (n.attr().at("value").tensor().int_val(0) == val) { + return &n; + } + } + return nullptr; + }; + + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + auto const_node = find_const(2); + auto const_dep_node = find_const(3); + auto cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); + EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')), + const_node->name()); + EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name())); +} + // TODO(rachelim): More test cases when we get around to implementing them: // [] A badly defined converter, e.g. doesn't produce nodes that have the // same number of outputs/inputs as the nodes to be converted |