aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-28 16:10:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 16:17:55 -0700
commit478d370eb116ad2294134d75a886637a7d6da225 (patch)
tree279ef8e8a2c9abeeda583393a986f055b9be314c /tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
parenta98bac521406bedef3ff2b9af9564b21ddda4d82 (diff)
[tf.data] Use Graph instead of GraphDef/FunctionDef for vectorization transforms
PiperOrigin-RevId: 215011835
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc205
1 files changed, 145 insertions, 60 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index e129fa9237..1ff62217dd 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
@@ -60,6 +61,11 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
return node;
}
+string GetRetval(const FunctionDef& function_def, int index) {
+ return function_def.ret().at(
+ function_def.signature().output_arg(index).name());
+}
+
// TODO(rachelim): Use FunctionDefHelper::Create instead
FunctionDef CreateFunction(
StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
@@ -85,7 +91,6 @@ FunctionDef CreateFunction(
return func;
}
-TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
// Before:
//
@@ -133,10 +138,15 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
- EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_EQ(GetRetval(*vectorized, 0), "ret0");
+ EXPECT_EQ(GetRetval(*vectorized, 1), "ret1");
}
// Before:
@@ -149,12 +159,12 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// | +-----------+ Arg0 +---+ Arg1 +----+ |
// | | +---+--+ +---+--+ | |
// | | | | | |
-// | | +------+ | +---v--+ | |
-// | | |Const | | | Op0 | | |
-// | | +---v--+ | +---+--+ | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
// | | | | | | |
// | | | +---v--+ +---v--+ | |
-// | | +---| XOp1 | | XOp2 | | |
+// | | +---| XOp1 | | Cast | | |
// | | +---+--+ +---+--+ | |
// | | | | | |
// | | MapDefun +---v--+ +---v--+ | |
@@ -165,23 +175,50 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// +---------------+ Ret0 +---+ Ret1 +--------+
// +------+ +------+
//
-// where XOp1 and XOp2 are not convertible.
+// where XOp1 is not convertible.
//
// After:
//
-// No change because the ops are not convertible.
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ | |
+// | +-----------+ Arg0 +-+ | |
+// | | +---+--+ | | |
+// | | | | | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
+// | | | | | | |
+// | | | +---v--+ | +---v--+ |
+// | | +---| XOp1 | | | Cast | |
+// | | +---+--+ | +---+--+ |
+// | | | | | |
+// | | MapDefun +---v--+ | | |
+// | +-----------+ Ret0 +-+ | |
+// | +---+--+ | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
//
TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
{{"ret0", DT_INT32}, {"ret1", DT_INT32}},
- {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+ {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}});
+ // TODO(rachelim): If we ever write a converter for MatMul, we have to
+ // change this test.
NodeDef* x_op1 =
- function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+ function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner);
CHECK_NOTNULL(x_op1);
+ graph_transforms::SetNodeAttr("T", DT_INT32, x_op1);
- NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
- CHECK_NOTNULL(x_op2);
+ NodeDef* cast_node =
+ AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner);
+ CHECK_NOTNULL(cast_node);
FunctionDef outer = CreateFunction(
"outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
@@ -193,12 +230,22 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
- // They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
- EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto map_defun_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized));
+ // The Cast node should be converted just fine.
+ EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0");
+
+ // The inner function should only have one retval.
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+ EXPECT_EQ(map_defun_fn->signature().output_arg_size(), 1);
}
// Before:
@@ -257,14 +304,19 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -330,16 +382,21 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -411,21 +468,26 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), "x");
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -486,7 +548,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{"ret1", "MyUnstack:output:1"},
{"ret2", "MyUnstack:output:2"}});
NodeDef* cast_op =
- AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner);
CHECK_NOTNULL(cast_op);
NodeDef* unstack_op =
AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
@@ -505,25 +567,30 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 2);
+ EXPECT_EQ(vectorized->node_def_size(), 2);
}
// Before:
@@ -561,9 +628,11 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}},
{{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
- // The attrs aren't relevant
- NodeDef* print_op =
- function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+ NodeDef* print_op = function_utils::AddNode(
+ "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner);
+ graph_transforms::SetNodeAttr("T", DT_INT32, print_op);
+ graph_transforms::SetNodeAttr("U", gtl::ArraySlice<DataType>({DT_INT32}),
+ print_op);
CHECK_NOTNULL(print_op);
NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
false, &inner);
@@ -578,11 +647,27 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
// They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+ // We check this somewhat manually as the names of nodes may have changed
+ EXPECT_EQ(vectorized->node_def_size(), 1);
+ const NodeDef& map_defun_node = vectorized->node_def(0);
+ EXPECT_EQ(map_defun_node.op(), "MapDefun");
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+
+ const NodeDef& print_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Print", *map_defun_fn));
+ const NodeDef& cast_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *map_defun_fn));
+ string control_input = strings::StrCat("^", print_node.name());
+ EXPECT_TRUE(cast_node.input(0) == control_input ||
+ cast_node.input(1) == control_input);
}
// TODO(rachelim): More test cases when we get around to implementing them: