aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc103
1 files changed, 93 insertions, 10 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index a6020e36bb..be498d150b 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -145,7 +145,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
FunctionDef* vectorized;
Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
LOG(ERROR) << s;
- EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
EXPECT_EQ(GetRetval(*vectorized, 0), "ret0");
@@ -237,7 +237,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
auto map_defun_node = vectorized->node_def(
function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized));
@@ -311,7 +311,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
const NodeDef& cast_node = vectorized->node_def(
@@ -389,7 +389,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
const NodeDef& cast_node = vectorized->node_def(
@@ -475,7 +475,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
const NodeDef& unpack_node = vectorized->node_def(
@@ -574,7 +574,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
const NodeDef& cast_node = vectorized->node_def(
@@ -654,7 +654,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
// They should be unchanged
// We check this somewhat manually as the names of nodes may have changed
EXPECT_EQ(vectorized->node_def_size(), 1);
@@ -738,7 +738,7 @@ TEST(VectorizeMapDefunTest, VectorizeConst) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized));
@@ -817,7 +817,7 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
auto const_node = vectorized->node_def(
@@ -902,7 +902,7 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
*lib.add_function() = inner;
FunctionDef* vectorized;
- EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
auto find_const = [vectorized](int val) -> const NodeDef* {
for (const auto& n : vectorized->node_def()) {
@@ -924,6 +924,89 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name()));
}
+// Before:
+//
+// +------+
+// +-----------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | +-----+ | |
+// | | | |Const| | |
+// | | | +-+---+ | |
+// | | | | | |
+// | | | +--------+ | |
+// | | | | | |
+// | | +-v---v-+ | |
+// | | | Add | | |
+// | | +-+-----+ | |
+// | | | | |
+// | | | | |
+// | | MapDefun +-v----+ | |
+// | +---------------| Ret |----------------+ |
+// | +--v---+ |
+// | | |
+// | | |
+// | +--v---- |
+// +-------------------| Ret |--------------------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | | +-----+ |
+// | | |Const| |
+// | +-v---------+ +--+--+ |
+// | |ExpandDims*| | |
+// | +-----+-----+ | |
+// | | | |
+// | +-----+ +-----+ |
+// | | | |
+// | +-v-v-+ |
+// | | Add | |
+// | +--+--+ |
+// | | |
+// | +---v--+ |
+// +-----------------------+ Ret +-----------+
+// +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunAdd) {
+ // Note that this checks that the "Add" vectorizer is successful, but does not
+ // check that the transformed function is correct (i.e. produces the same
+ // output as the unvectorized map defun). For the latter, the tests are in
+ // tensorflow/python/data/experimental/kernel_tests/optimization/
+ // map_vectorization_test.py
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2),
+ {{"Add"}, "Add", {"arg0", "Const:output:0"}, {{"T", DT_INT32}}}},
+ {{"ret0", "Add:z:0"}});
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+}
+
// TODO(rachelim): More test cases when we get around to implementing them:
// [] A badly defined converter, e.g. doesn't produce nodes that have the
// same number of outputs/inputs as the nodes to be converted