aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-10-04 09:26:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 09:35:51 -0700
commit1fb84c2e41c454939a02a69093cb214673eab343 (patch)
treee3ee1c19e3a73e1d1cddbc76d5573b7800b1048b /tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
parentac22e1583aed390d78d2e87a4bf8a6ec39400ec4 (diff)
Add ability to vectorize nodes that do not derive from function arguments. (This indirectly handles "Const" outputs automagically, since they are always unstacked.)
PiperOrigin-RevId: 215749824
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc251
1 files changed, 251 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index 1ff62217dd..a958d706c1 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -670,6 +670,257 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
cast_node.input(1) == control_input);
}
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | | |
+// | | +------+ | |
+// | | |Const | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | +------+ |
+// | |Const | |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// | |Stack*| |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeConst) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2)},
+ {{"ret0", "Const:output:0"}});
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized));
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | | |
+// | | +------+ | |
+// | | |Const | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | +------+ |
+// | |Const | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | |Stack*| |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2)},
+ {{"ret0", "Cast:y:0"}});
+ AddCastNode("Cast", {"Const:output:0"}, DT_INT32, DT_INT64, false, &inner);
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ auto const_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Const", *vectorized));
+ auto cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')),
+ const_node.name());
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | +------+ +------+ | |
+// | | |Const | |Const | | |
+// | | +---+--+ +---+--+ | |
+// | | : +---v--+ | |
+// | | ::::::> Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | |
+// | +------+ |
+// | +------+ |Const | |
+// | |Const | +---+--+ |
+// | +---+--+ | |
+// | : +---v--+ |
+// | ::::::> Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +Stack*+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2),
+ FunctionDefHelper::Const("ConstDep", 3)},
+ {{"ret0", "Cast:y:0"}});
+ AddCastNode("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64,
+ false, &inner);
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto find_const = [vectorized](int val) -> const NodeDef* {
+ for (const auto& n : vectorized->node_def()) {
+ if (n.attr().at("value").tensor().int_val(0) == val) {
+ return &n;
+ }
+ }
+ return nullptr;
+ };
+
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ auto const_node = find_const(2);
+ auto const_dep_node = find_const(3);
+ auto cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')),
+ const_node->name());
+ EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name()));
+}
+
// TODO(rachelim): More test cases when we get around to implementing them:
// [] A badly defined converter, e.g. doesn't produce nodes that have the
// same number of outputs/inputs as the nodes to be converted