diff options
author | 2018-04-03 11:15:29 -0700 | |
---|---|---|
committer | 2018-04-03 11:18:00 -0700 | |
commit | 86235e48fe39f2b9318f01e963499a555ea88084 (patch) | |
tree | ac263de02903081bcb69b73040c8ae092d964133 | |
parent | f654b0d15af364d6f43d22a179fa05d20650fe9a (diff) |
Turn no-op split/splitv operators into identity.
PiperOrigin-RevId: 191469655
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.h | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding_test.cc | 76 |
4 files changed, 89 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 584008b0c1..a24d2dbd9f 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -365,6 +365,8 @@ bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; } bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; } +bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; } + bool IsVariable(const NodeDef& node) { const auto& op = node.op(); return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" || diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index aa6750d5c3..8667f72c7e 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -140,6 +140,7 @@ bool IsTile(const NodeDef& node); bool IsTranspose(const NodeDef& node); bool IsTruncateDiv(const NodeDef& node); bool IsTruncateMod(const NodeDef& node); +bool IsUnpack(const NodeDef& node); bool IsVariable(const NodeDef& node); bool IsZeta(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 7de544de52..87052c7ba0 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1542,6 +1542,16 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, for (int i = 0; i < optimized_graph->node_size(); ++i) { NodeDef* node = optimized_graph->mutable_node(i); + if (IsSplit(*node) && node->attr().at("num_split").i() == 1) { + ReplaceOperationWithIdentity(1, node, optimized_graph); + continue; + } + + if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) { + ReplaceOperationWithIdentity(0, node, optimized_graph); + continue; + } + // Remove Shuffle or Reverse op over scalar values. if (use_shape_info && !properties->GetInputProperties(node->name()).empty() && diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 1db4fb9de7..7faa68a657 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -1335,6 +1335,82 @@ TEST_F(ConstantFoldingTest, MergeNodes) { EXPECT_EQ(2, out_idx.flat<int32>()(0)); } +TEST_F(ConstantFoldingTest, SplitRemoval) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output in1 = + ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT); + Output in2 = + ops::Variable(scope.WithOpName("in2"), TensorShape({4}), DT_FLOAT); + auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {}); + ops::Split s1(scope.WithOpName("s1"), split_dim, in1, 1); + ops::Split s2(scope.WithOpName("s2"), split_dim, in2, 2); + + ops::Add out(scope.WithOpName("out"), s1[0], s2[0]); + + GrapplerItem item; + item.fetch = {"out"}; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(nullptr /* cpu_device */); + GraphDef got; + Status status = optimizer.Optimize(nullptr, item, &got); + TF_EXPECT_OK(status); + + GraphDef want; + AddNode("in1", "VariableV2", {}, {}, &want); + AddNode("in2", "VariableV2", {}, {}, &want); + AddNode("split_dim", "Const", {}, {}, &want); + AddNode("s1", "Identity", {"in1", AsControlDependency("split_dim")}, {}, + &want); + AddNode("s2", "Split", {"in2", "split_dim"}, {}, &want); + AddNode("out", "Add", {"s1", "s2"}, {}, &want); + + CompareGraphs(want, got); +} + +TEST_F(ConstantFoldingTest, SplitVRemoval) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output in1 = + ops::Variable(scope.WithOpName("in1"), TensorShape({2}), DT_FLOAT); + Output in2 = + ops::Variable(scope.WithOpName("in2"), TensorShape({5}), DT_FLOAT); + auto split_dim = ops::Const(scope.WithOpName("split_dim"), {0}, {}); + auto size_splits1 = ops::Const(scope.WithOpName("size_splits1"), {2}, {1}); + auto size_splits2 = ops::Const(scope.WithOpName("size_splits2"), {2, 3}, {2}); + ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1); + ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2); + + LOG(INFO) << s1.output.size(); + LOG(INFO) << s2.output.size(); + ops::Add out(scope.WithOpName("out"), s1[0], s2[0]); + + GrapplerItem item; + item.fetch = {"out"}; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(nullptr /* cpu_device */); + GraphDef got; + Status status = optimizer.Optimize(nullptr, item, &got); + TF_EXPECT_OK(status); + + GraphDef want; + AddNode("in1", "VariableV2", {}, {}, &want); + AddNode("in2", "VariableV2", {}, {}, &want); + AddNode("split_dim", "Const", {}, {}, &want); + AddNode("size_splits1", "Const", {}, {}, &want); + AddNode("size_splits2", "Const", {}, {}, &want); + AddNode("s1", "Identity", + {"in1", AsControlDependency("size_splits1"), + AsControlDependency("split_dim")}, + {}, &want); + AddNode("s2", "SplitV", {"in2", "size_splits2", "split_dim"}, {}, &want); + AddNode("out", "Add", {"s1", "s2"}, {}, &want); + + CompareGraphs(want, got); +} + TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); |