aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-11 17:29:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-11 17:31:36 -0700
commitd62a5a11e99b391f2e61e80c4f0a80def6ff6508 (patch)
treee30b2b12d64e6c814888b6bd38226d3dce73e625
parent81a9ceaf7290b2260f636609a83b01b9ab2224d7 (diff)
Automated g4 rollback of changelist 192516190
PiperOrigin-RevId: 192536085
-rw-r--r--tensorflow/core/grappler/op_types.cc8
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc95
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc80
4 files changed, 16 insertions, 168 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index cfe1329dbf..9c45aed62f 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -249,10 +249,6 @@ bool IsPrint(const NodeDef& node) { return node.op() == "Print"; }
bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
-bool IsRandomShuffle(const NodeDef& node) {
- return node.op() == "RandomShuffle";
-}
-
bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
@@ -302,7 +298,9 @@ bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
-bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; }
+bool IsShuffle(const NodeDef& node) {
+ return node.op() == "Shuffle" || node.op() == "RandomShuffle";
+}
bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 0573b02604..79fd05e187 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -98,7 +98,6 @@ bool IsPolygamma(const NodeDef& node);
bool IsPrint(const NodeDef& node);
bool IsProd(const NodeDef& node);
bool IsPow(const NodeDef& node);
-bool IsRandomShuffle(const NodeDef& node);
bool IsReal(const NodeDef& node);
bool IsRealDiv(const NodeDef& node);
bool IsRelu6Grad(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 17d8b7421c..b2a1ce6ab6 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1574,99 +1574,24 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
continue;
}
- // Remove Shuffle or Transpose op over dimensions of size 1.
- if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
- !properties->GetInputProperties(node->name()).empty()) {
- const auto& shape =
- properties->GetInputProperties(node->name())[0].shape();
- if (shape.unknown_rank()) {
- // Not optimizable.
- continue;
- }
- const auto& p = properties->GetInputProperties(node->name())[1];
- if (TensorShape::IsValid(p.shape()) && p.has_value()) {
- Tensor perm(p.dtype(), p.shape());
- if (!perm.FromProto(p.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- p.value().DebugString());
- }
- std::vector<int> permutation;
- for (int j = 0; j < perm.NumElements(); ++j) {
- if (perm.dtype() == DT_INT64) {
- permutation.push_back(perm.vec<int64>()(j));
- } else {
- permutation.push_back(perm.vec<int>()(j));
- }
- }
- if (permutation.size() != shape.dim_size()) {
- // Number of elements in perm should be same as dim_size. Skip if not.
- continue;
- }
- // The node is replaceable iff
- // dim_size == 0 || all dims have size 1 ||
- // all dims with > 1 size are not permuted.
- bool replaceable = true;
- for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
- replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
- }
- if (replaceable) {
- ReplaceOperationWithIdentity(0, node, optimized_graph);
- continue;
- }
- }
- }
-
- // Remove RandomShuffle op if it is scalar or first dimension is of size 1.
- if (use_shape_info && IsRandomShuffle(*node) &&
- !properties->GetInputProperties(node->name()).empty()) {
+ // Remove Shuffle or Reverse op over scalar values.
+ if (use_shape_info &&
+ !properties->GetInputProperties(node->name()).empty() &&
+ (IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) {
const auto& shape =
properties->GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
- // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
- if (!shape.unknown_rank() &&
- (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
+ // unknown_rank == false && (dim_size == 0 || all dims have size 1)
+ bool replaceable = !shape.unknown_rank();
+ for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+ replaceable &= shape.dim(j).size() == 1;
+ }
+ if (replaceable) {
ReplaceOperationWithIdentity(0, node, optimized_graph);
continue;
}
}
- // Remove Reverse op over dimensions with size 1.
- if (use_shape_info && IsReverse(*node) &&
- !properties->GetInputProperties(node->name()).empty()) {
- const auto& shape =
- properties->GetInputProperties(node->name())[0].shape();
- const auto& a = properties->GetInputProperties(node->name())[1];
- if (TensorShape::IsValid(a.shape()) && a.has_value()) {
- Tensor axis(a.dtype(), a.shape());
- if (!axis.FromProto(a.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- a.value().DebugString());
- }
- std::set<int> target_axes;
- for (int j = 0; j < axis.NumElements(); ++j) {
- if (axis.dtype() == DT_INT64) {
- target_axes.insert(axis.vec<int64>()(j));
- } else {
- target_axes.insert(axis.vec<int>()(j));
- }
- }
-
- // The node is replaceable iff
- // unknown_rank == false &&
- // (dim_size == 0 || all dims have size 1 ||
- // all dims with > 1 size are not in target_axes)
- bool replaceable = !shape.unknown_rank();
- for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
- replaceable &= shape.dim(j).size() == 1 ||
- target_axes.find(j) == target_axes.end();
- }
- if (replaceable) {
- ReplaceOperationWithIdentity(0, node, optimized_graph);
- continue;
- }
- }
- }
-
if (use_shape_info && IsSlice(*node) &&
properties->GetInputProperties(node->name()).size() == 3) {
const auto& input = properties->GetInputProperties(node->name())[0];
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 7453fb6731..31abe43846 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -1389,6 +1389,8 @@ TEST_F(ConstantFoldingTest, SplitVRemoval) {
ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1);
ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2);
+ LOG(INFO) << s1.output.size();
+ LOG(INFO) << s2.output.size();
ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
GrapplerItem item;
@@ -1416,45 +1418,7 @@ TEST_F(ConstantFoldingTest, SplitVRemoval) {
CompareGraphs(want, got);
}
-TEST_F(ConstantFoldingTest, TransposeOnSize1DimsRemoval) {
- tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
-
- Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
- DT_FLOAT);
- Output p1 = ops::Const(scope.WithOpName("p1"), {3, 2, 1, 0}, {4});
- Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 4, 2, 1}),
- DT_FLOAT);
- Output p2 = ops::Const(scope.WithOpName("p2"), {3, 1, 2, 0}, {4});
- ops::Transpose t1(scope.WithOpName("t1"), in1, p1);
- ops::Transpose t2(scope.WithOpName("t2").WithControlDependencies({in1}), in2,
- p2);
-
- ops::Add out1(scope.WithOpName("out1"), t1, t2);
-
- GrapplerItem item;
- item.fetch = {"out1"};
- TF_CHECK_OK(scope.ToGraphDef(&item.graph));
-
- ConstantFolding optimizer(nullptr /* cpu_device */);
- GraphDef got;
- Status status = optimizer.Optimize(nullptr, item, &got);
- TF_EXPECT_OK(status);
-
- GraphDef want;
- AddNode("in1", "VariableV2", {}, {}, &want);
- AddNode("in2", "VariableV2", {}, {}, &want);
- AddNode("p1", "Const", {}, {}, &want);
- AddNode("p2", "Const", {}, {}, &want);
- AddNode("t1", "Transpose", {"in1", "p1"}, {}, &want);
- AddNode("t2", "Identity",
- {"in2", AsControlDependency("in1"), AsControlDependency("p2")}, {},
- &want);
- AddNode("out1", "Add", {"t1", "t2"}, {}, &want);
-
- CompareGraphs(want, got);
-}
-
-TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) {
+TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output in1 =
@@ -1488,44 +1452,6 @@ TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) {
CompareGraphs(want, got);
}
-TEST_F(ConstantFoldingTest, ReverseOnSize1DimsRemoval) {
- tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
-
- Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
- DT_FLOAT);
- Output a1 = ops::Const(scope.WithOpName("a1"), {3, 2, 1, 0}, {4});
- Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 2, 4, 1}),
- DT_FLOAT);
- Output a2 = ops::Const(scope.WithOpName("a2"), {0, 3}, {2});
- ops::Reverse r1(scope.WithOpName("r1"), in1, a1);
- ops::Reverse r2(scope.WithOpName("r2").WithControlDependencies({in1}), in2,
- a2);
-
- ops::Add out1(scope.WithOpName("out1"), r1, r2);
-
- GrapplerItem item;
- item.fetch = {"out1"};
- TF_CHECK_OK(scope.ToGraphDef(&item.graph));
-
- ConstantFolding optimizer(nullptr /* cpu_device */);
- GraphDef got;
- Status status = optimizer.Optimize(nullptr, item, &got);
- TF_EXPECT_OK(status);
-
- GraphDef want;
- AddNode("in1", "VariableV2", {}, {}, &want);
- AddNode("in2", "VariableV2", {}, {}, &want);
- AddNode("a1", "Const", {}, {}, &want);
- AddNode("a2", "Const", {}, {}, &want);
- AddNode("r1", "ReverseV2", {"in1", "a1"}, {}, &want);
- AddNode("r2", "Identity",
- {"in2", AsControlDependency("in1"), AsControlDependency("a2")}, {},
- &want);
- AddNode("out1", "Add", {"r1", "r2"}, {}, &want);
-
- CompareGraphs(want, got);
-}
-
TEST_F(ConstantFoldingTest, SliceWithSameDimensionRemoval) {
{ // size = {3, 5}
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();