diff options
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.h | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 15 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding_test.cc | 34 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/grappler_test.cc | 38 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/grappler_test.h | 5 |
6 files changed, 102 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index fdf4540540..e225e99a9e 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -256,6 +256,10 @@ bool IsRestore(const NodeDef& node) { node.op() == "RestoreSlice"); } +bool IsReverse(const NodeDef& node) { + return node.op() == "Reverse" || node.op() == "ReverseV2"; +} + bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; } bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; } @@ -272,6 +276,10 @@ bool IsShape(const NodeDef& node) { return node.op() == "Shape"; } bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; } +bool IsShuffle(const NodeDef& node) { + return node.op() == "Shuffle" || node.op() == "RandomShuffle"; +} + bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; } bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 9cda40c0a6..1fa43a9b66 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -100,6 +100,7 @@ bool IsRecv(const NodeDef& node); bool IsReduction(const NodeDef& node); bool IsReshape(const NodeDef& node); bool IsRestore(const NodeDef& node); +bool IsReverse(const NodeDef& node); bool IsReverseV2(const NodeDef& node); bool IsRsqrtGrad(const NodeDef& node); bool IsSelect(const NodeDef& node); @@ -108,6 +109,7 @@ bool IsSend(const NodeDef& node); bool IsSlice(const NodeDef& node); bool IsShape(const NodeDef& node); bool IsShapeN(const NodeDef& node); +bool IsShuffle(const NodeDef& node); bool IsSigmoidGrad(const NodeDef& node); bool IsSoftplusGrad(const NodeDef& node); bool IsSoftsignGrad(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 7a621bd95d..95eaa31a46 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1446,6 +1446,20 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE; for (int i = 0; i < output->node_size(); ++i) { NodeDef* node = output->mutable_node(i); + // Remove Shuffle or Reverse op over scalar values. + if (use_shape_info && + (IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) { + const auto& shape = + properties.GetInputProperties(node->name())[0].shape(); + // The node is replaceable iff + // unknown_rank == false && (dim_size == 0 || all dims have size 1) + bool replaceable = !shape.unknown_rank(); + for (int j = 0; j < shape.dim_size(); ++j) { + replaceable &= shape.dim(j).size() == 1; + } + if (replaceable) ReplaceOperationWithIdentity(0, node, output); + } + if (IsSimplifiableReduction(*node)) { // Replace the reduction node with an identity node, that can be further // optimized by the model pruner. @@ -1713,6 +1727,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, TF_RETURN_IF_ERROR(FoldGraph(output)); node_map_.reset(new NodeMap(output)); TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info)); + return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index d8df19fe6a..3afc176402 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -1177,6 +1177,40 @@ TEST_F(ConstantFoldingTest, MergeNodes) { EXPECT_EQ(2, out_idx.flat<int32>()(0)); } +TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output in1 = + ops::Variable(scope.WithOpName("in1"), TensorShape({}), DT_FLOAT); + Output in2 = + ops::Variable(scope.WithOpName("in2"), TensorShape({}), DT_FLOAT); + ops::RandomShuffle s1(scope.WithOpName("s1"), in1); + ops::RandomShuffle s2(scope.WithOpName("s2").WithControlDependencies({in1}), + in2); + + ops::Add out1(scope.WithOpName("out1"), s1, s2); + ops::Identity out2(scope.WithOpName("out2"), s2); + + GrapplerItem item; + item.fetch = {"out1", "out2"}; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding fold(nullptr /* cpu_device */); + GraphDef got; + Status status = fold.Optimize(nullptr, item, &got); + TF_EXPECT_OK(status); + + GraphDef want; + AddNode("in1", "VariableV2", {}, &want); + AddNode("in2", "VariableV2", {}, &want); + AddNode("s1", "Identity", {"in1"}, &want); + AddNode("s2", "Identity", {"in2", AsControlDependency("in1")}, &want); + AddNode("out1", "Add", {"s1", "s2"}, &want); + AddNode("out2", "Identity", {"s2"}, &want); + + CompareGraphs(want, got); +} + TEST_F(ConstantFoldingTest, NoOpReduction) { // Build a simple graph with a reduction that can be reduced to the identity. tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc index 813f65f825..fed46c05fb 100644 --- a/tensorflow/core/grappler/utils/grappler_test.cc +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -35,5 +35,43 @@ std::vector<Tensor> GrapplerTest::EvaluateNodes( return output_tensors; } +void GrapplerTest::AddNode(const string& name, const string& op, + const std::vector<string>& inputs, GraphDef* graph) { + auto* node = graph->add_node(); + node->set_name(name); + node->set_op(op); + for (const auto& input : inputs) { + node->add_input(input); + } +} + +void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) { + auto comparator = [](const NodeDef& n1, const NodeDef& n2) -> bool { + return n1.name() < n2.name(); + }; + std::sort(want.mutable_node()->begin(), want.mutable_node()->end(), + comparator); + std::sort(got.mutable_node()->begin(), got.mutable_node()->end(), comparator); + + for (int i = 0; i < want.node_size(); ++i) { + std::sort(want.mutable_node(i)->mutable_input()->begin(), + want.mutable_node(i)->mutable_input()->end()); + } + for (int i = 0; i < got.node_size(); ++i) { + std::sort(got.mutable_node(i)->mutable_input()->begin(), + got.mutable_node(i)->mutable_input()->end()); + } + + ASSERT_EQ(want.node_size(), got.node_size()); + for (int i = 0; i < want.node_size(); ++i) { + EXPECT_EQ(want.node(i).op(), got.node(i).op()); + EXPECT_EQ(want.node(i).name(), got.node(i).name()); + ASSERT_EQ(want.node(i).input_size(), got.node(i).input_size()); + for (int j = 0; j < want.node(i).input_size(); ++j) { + EXPECT_TRUE(IsSameInput(want.node(i).input(j), got.node(i).input(j))); + } + } +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h index 46ce47c8c3..042b616aa4 100644 --- a/tensorflow/core/grappler/utils/grappler_test.h +++ b/tensorflow/core/grappler/utils/grappler_test.h @@ -29,6 +29,11 @@ class GrapplerTest : public ::testing::Test { protected: std::vector<Tensor> EvaluateNodes(const GraphDef& graph, const std::vector<string>& node_names); + + void AddNode(const string& name, const string& op, + const std::vector<string>& inputs, GraphDef* graph); + + void CompareGraphs(GraphDef want, GraphDef got); }; } // end namespace grappler |