diff options
author | 2018-04-11 17:29:32 -0700 | |
---|---|---|
committer | 2018-04-11 17:31:36 -0700 | |
commit | d62a5a11e99b391f2e61e80c4f0a80def6ff6508 (patch) | |
tree | e30b2b12d64e6c814888b6bd38226d3dce73e625 | |
parent | 81a9ceaf7290b2260f636609a83b01b9ab2224d7 (diff) |
Automated g4 rollback of changelist 192516190
PiperOrigin-RevId: 192536085
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.h | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 95 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding_test.cc | 80 |
4 files changed, 16 insertions, 168 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index cfe1329dbf..9c45aed62f 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -249,10 +249,6 @@ bool IsPrint(const NodeDef& node) { return node.op() == "Print"; } bool IsProd(const NodeDef& node) { return node.op() == "Prod"; } -bool IsRandomShuffle(const NodeDef& node) { - return node.op() == "RandomShuffle"; -} - bool IsReal(const NodeDef& node) { return node.op() == "Real"; } bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; } @@ -302,7 +298,9 @@ bool IsShape(const NodeDef& node) { return node.op() == "Shape"; } bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; } -bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; } +bool IsShuffle(const NodeDef& node) { + return node.op() == "Shuffle" || node.op() == "RandomShuffle"; +} bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 0573b02604..79fd05e187 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -98,7 +98,6 @@ bool IsPolygamma(const NodeDef& node); bool IsPrint(const NodeDef& node); bool IsProd(const NodeDef& node); bool IsPow(const NodeDef& node); -bool IsRandomShuffle(const NodeDef& node); bool IsReal(const NodeDef& node); bool IsRealDiv(const NodeDef& node); bool IsRelu6Grad(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 17d8b7421c..b2a1ce6ab6 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1574,99 +1574,24 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, continue; } - // Remove Shuffle or Transpose op over dimensions of size 1. - if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) && - !properties->GetInputProperties(node->name()).empty()) { - const auto& shape = - properties->GetInputProperties(node->name())[0].shape(); - if (shape.unknown_rank()) { - // Not optimizable. - continue; - } - const auto& p = properties->GetInputProperties(node->name())[1]; - if (TensorShape::IsValid(p.shape()) && p.has_value()) { - Tensor perm(p.dtype(), p.shape()); - if (!perm.FromProto(p.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - p.value().DebugString()); - } - std::vector<int> permutation; - for (int j = 0; j < perm.NumElements(); ++j) { - if (perm.dtype() == DT_INT64) { - permutation.push_back(perm.vec<int64>()(j)); - } else { - permutation.push_back(perm.vec<int>()(j)); - } - } - if (permutation.size() != shape.dim_size()) { - // Number of elements in perm should be same as dim_size. Skip if not. - continue; - } - // The node is replaceable iff - // dim_size == 0 || all dims have size 1 || - // all dims with > 1 size are not permuted. - bool replaceable = true; - for (int j = 0; replaceable && j < shape.dim_size(); ++j) { - replaceable &= shape.dim(j).size() == 1 || j == permutation[j]; - } - if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); - continue; - } - } - } - - // Remove RandomShuffle op if it is scalar or first dimension is of size 1. - if (use_shape_info && IsRandomShuffle(*node) && - !properties->GetInputProperties(node->name()).empty()) { + // Remove Shuffle or Reverse op over scalar values. + if (use_shape_info && + !properties->GetInputProperties(node->name()).empty() && + (IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) { const auto& shape = properties->GetInputProperties(node->name())[0].shape(); // The node is replaceable iff - // unknown_rank == false && (dim_size == 0 || first dim is of size 1) - if (!shape.unknown_rank() && - (shape.dim_size() == 0 || shape.dim(0).size() == 1)) { + // unknown_rank == false && (dim_size == 0 || all dims have size 1) + bool replaceable = !shape.unknown_rank(); + for (int j = 0; replaceable && j < shape.dim_size(); ++j) { + replaceable &= shape.dim(j).size() == 1; + } + if (replaceable) { ReplaceOperationWithIdentity(0, node, optimized_graph); continue; } } - // Remove Reverse op over dimensions with size 1. - if (use_shape_info && IsReverse(*node) && - !properties->GetInputProperties(node->name()).empty()) { - const auto& shape = - properties->GetInputProperties(node->name())[0].shape(); - const auto& a = properties->GetInputProperties(node->name())[1]; - if (TensorShape::IsValid(a.shape()) && a.has_value()) { - Tensor axis(a.dtype(), a.shape()); - if (!axis.FromProto(a.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - a.value().DebugString()); - } - std::set<int> target_axes; - for (int j = 0; j < axis.NumElements(); ++j) { - if (axis.dtype() == DT_INT64) { - target_axes.insert(axis.vec<int64>()(j)); - } else { - target_axes.insert(axis.vec<int>()(j)); - } - } - - // The node is replaceable iff - // unknown_rank == false && - // (dim_size == 0 || all dims have size 1 || - // all dims with > 1 size are not in target_axes) - bool replaceable = !shape.unknown_rank(); - for (int j = 0; replaceable && j < shape.dim_size(); ++j) { - replaceable &= shape.dim(j).size() == 1 || - target_axes.find(j) == target_axes.end(); - } - if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); - continue; - } - } - } - if (use_shape_info && IsSlice(*node) && properties->GetInputProperties(node->name()).size() == 3) { const auto& input = properties->GetInputProperties(node->name())[0]; diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 7453fb6731..31abe43846 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -1389,6 +1389,8 @@ TEST_F(ConstantFoldingTest, SplitVRemoval) { ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1); ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2); + LOG(INFO) << s1.output.size(); + LOG(INFO) << s2.output.size(); ops::Add out(scope.WithOpName("out"), s1[0], s2[0]); GrapplerItem item; @@ -1416,45 +1418,7 @@ TEST_F(ConstantFoldingTest, SplitVRemoval) { CompareGraphs(want, got); } -TEST_F(ConstantFoldingTest, TransposeOnSize1DimsRemoval) { - tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); - - Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}), - DT_FLOAT); - Output p1 = ops::Const(scope.WithOpName("p1"), {3, 2, 1, 0}, {4}); - Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 4, 2, 1}), - DT_FLOAT); - Output p2 = ops::Const(scope.WithOpName("p2"), {3, 1, 2, 0}, {4}); - ops::Transpose t1(scope.WithOpName("t1"), in1, p1); - ops::Transpose t2(scope.WithOpName("t2").WithControlDependencies({in1}), in2, - p2); - - ops::Add out1(scope.WithOpName("out1"), t1, t2); - - GrapplerItem item; - item.fetch = {"out1"}; - TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - - ConstantFolding optimizer(nullptr /* cpu_device */); - GraphDef got; - Status status = optimizer.Optimize(nullptr, item, &got); - TF_EXPECT_OK(status); - - GraphDef want; - AddNode("in1", "VariableV2", {}, {}, &want); - AddNode("in2", "VariableV2", {}, {}, &want); - AddNode("p1", "Const", {}, {}, &want); - AddNode("p2", "Const", {}, {}, &want); - AddNode("t1", "Transpose", {"in1", "p1"}, {}, &want); - AddNode("t2", "Identity", - {"in2", AsControlDependency("in1"), AsControlDependency("p2")}, {}, - &want); - AddNode("out1", "Add", {"t1", "t2"}, {}, &want); - - CompareGraphs(want, got); -} - -TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) { +TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); Output in1 = @@ -1488,44 +1452,6 @@ TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) { CompareGraphs(want, got); } -TEST_F(ConstantFoldingTest, ReverseOnSize1DimsRemoval) { - tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); - - Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}), - DT_FLOAT); - Output a1 = ops::Const(scope.WithOpName("a1"), {3, 2, 1, 0}, {4}); - Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 2, 4, 1}), - DT_FLOAT); - Output a2 = ops::Const(scope.WithOpName("a2"), {0, 3}, {2}); - ops::Reverse r1(scope.WithOpName("r1"), in1, a1); - ops::Reverse r2(scope.WithOpName("r2").WithControlDependencies({in1}), in2, - a2); - - ops::Add out1(scope.WithOpName("out1"), r1, r2); - - GrapplerItem item; - item.fetch = {"out1"}; - TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - - ConstantFolding optimizer(nullptr /* cpu_device */); - GraphDef got; - Status status = optimizer.Optimize(nullptr, item, &got); - TF_EXPECT_OK(status); - - GraphDef want; - AddNode("in1", "VariableV2", {}, {}, &want); - AddNode("in2", "VariableV2", {}, {}, &want); - AddNode("a1", "Const", {}, {}, &want); - AddNode("a2", "Const", {}, {}, &want); - AddNode("r1", "ReverseV2", {"in1", "a1"}, {}, &want); - AddNode("r2", "Identity", - {"in2", AsControlDependency("in1"), AsControlDependency("a2")}, {}, - &want); - AddNode("out1", "Add", {"r1", "r2"}, {}, &want); - - CompareGraphs(want, got); -} - TEST_F(ConstantFoldingTest, SliceWithSameDimensionRemoval) { { // size = {3, 5} tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); |