aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-03 11:15:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-03 11:18:00 -0700
commit86235e48fe39f2b9318f01e963499a555ea88084 (patch)
treeac263de02903081bcb69b73040c8ae092d964133
parentf654b0d15af364d6f43d22a179fa05d20650fe9a (diff)
Turn no-op split/splitv operators into identity.
PiperOrigin-RevId: 191469655
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc76
4 files changed, 89 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 584008b0c1..a24d2dbd9f 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -365,6 +365,8 @@ bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
+bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
+
bool IsVariable(const NodeDef& node) {
const auto& op = node.op();
return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index aa6750d5c3..8667f72c7e 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -140,6 +140,7 @@ bool IsTile(const NodeDef& node);
bool IsTranspose(const NodeDef& node);
bool IsTruncateDiv(const NodeDef& node);
bool IsTruncateMod(const NodeDef& node);
+bool IsUnpack(const NodeDef& node);
bool IsVariable(const NodeDef& node);
bool IsZeta(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 7de544de52..87052c7ba0 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1542,6 +1542,16 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
for (int i = 0; i < optimized_graph->node_size(); ++i) {
NodeDef* node = optimized_graph->mutable_node(i);
+ if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
+ ReplaceOperationWithIdentity(1, node, optimized_graph);
+ continue;
+ }
+
+ if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
+ ReplaceOperationWithIdentity(0, node, optimized_graph);
+ continue;
+ }
+
// Remove Shuffle or Reverse op over scalar values.
if (use_shape_info &&
!properties->GetInputProperties(node->name()).empty() &&
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 1db4fb9de7..7faa68a657 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -1335,6 +1335,82 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
EXPECT_EQ(2, out_idx.flat<int32>()(0));
}
+TEST_F(ConstantFoldingTest, SplitRemoval) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 =
+ ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
+ Output in2 =
+ ops::Variable(scope.WithOpName("in2"), TensorShape({4}), DT_FLOAT);
+ auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
+ ops::Split s1(scope.WithOpName("s1"), split_dim, in1, 1);
+ ops::Split s2(scope.WithOpName("s2"), split_dim, in2, 2);
+
+ ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
+
+ GrapplerItem item;
+ item.fetch = {"out"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("split_dim", "Const", {}, {}, &want);
+ AddNode("s1", "Identity", {"in1", AsControlDependency("split_dim")}, {},
+ &want);
+ AddNode("s2", "Split", {"in2", "split_dim"}, {}, &want);
+ AddNode("out", "Add", {"s1", "s2"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, SplitVRemoval) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 =
+ ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT);
+ Output in2 =
+ ops::Variable(scope.WithOpName("in2"), TensorShape({5}), DT_FLOAT);
+ auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {});
+ auto size_splits1 = ops::Const(scope.WithOpName("size_splits1"), {2}, {1});
+ auto size_splits2 = ops::Const(scope.WithOpName("size_splits2"), {2, 3}, {2});
+ ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1);
+ ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2);
+
+ LOG(INFO) << s1.output.size();
+ LOG(INFO) << s2.output.size();
+ ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
+
+ GrapplerItem item;
+ item.fetch = {"out"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("split_dim", "Const", {}, {}, &want);
+ AddNode("size_splits1", "Const", {}, {}, &want);
+ AddNode("size_splits2", "Const", {}, {}, &want);
+ AddNode("s1", "Identity",
+ {"in1", AsControlDependency("size_splits1"),
+ AsControlDependency("split_dim")},
+ {}, &want);
+ AddNode("s2", "SplitV", {"in2", "size_splits2", "split_dim"}, {}, &want);
+ AddNode("out", "Add", {"s1", "s2"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();