aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-21 12:57:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-21 13:04:56 -0800
commit7e8b4a09416e453555073a88b0fd47625e0c5036 (patch)
tree106e6f307bcf540e125595177f2632b08f54204b /tensorflow
parent9dfb73b26c846038ef8101b2624de3b2cbf49c61 (diff)
Change node to Identity operation for shuffle/reverse operations on scalar values, but not
directly removing those nodes from the graph. PiperOrigin-RevId: 186505857
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/grappler/op_types.cc8
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc15
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc34
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.cc38
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.h5
6 files changed, 102 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index fdf4540540..e225e99a9e 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -256,6 +256,10 @@ bool IsRestore(const NodeDef& node) {
node.op() == "RestoreSlice");
}
+bool IsReverse(const NodeDef& node) {
+ return node.op() == "Reverse" || node.op() == "ReverseV2";
+}
+
bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
@@ -272,6 +276,10 @@ bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
+bool IsShuffle(const NodeDef& node) {
+ return node.op() == "Shuffle" || node.op() == "RandomShuffle";
+}
+
bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 9cda40c0a6..1fa43a9b66 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -100,6 +100,7 @@ bool IsRecv(const NodeDef& node);
bool IsReduction(const NodeDef& node);
bool IsReshape(const NodeDef& node);
bool IsRestore(const NodeDef& node);
+bool IsReverse(const NodeDef& node);
bool IsReverseV2(const NodeDef& node);
bool IsRsqrtGrad(const NodeDef& node);
bool IsSelect(const NodeDef& node);
@@ -108,6 +109,7 @@ bool IsSend(const NodeDef& node);
bool IsSlice(const NodeDef& node);
bool IsShape(const NodeDef& node);
bool IsShapeN(const NodeDef& node);
+bool IsShuffle(const NodeDef& node);
bool IsSigmoidGrad(const NodeDef& node);
bool IsSoftplusGrad(const NodeDef& node);
bool IsSoftsignGrad(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 7a621bd95d..95eaa31a46 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1446,6 +1446,20 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
for (int i = 0; i < output->node_size(); ++i) {
NodeDef* node = output->mutable_node(i);
+ // Remove Shuffle or Reverse op over scalar values.
+ if (use_shape_info &&
+ (IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) {
+ const auto& shape =
+ properties.GetInputProperties(node->name())[0].shape();
+ // The node is replaceable iff
+ // unknown_rank == false && (dim_size == 0 || all dims have size 1)
+ bool replaceable = !shape.unknown_rank();
+ for (int j = 0; j < shape.dim_size(); ++j) {
+ replaceable &= shape.dim(j).size() == 1;
+ }
+ if (replaceable) ReplaceOperationWithIdentity(0, node, output);
+ }
+
if (IsSimplifiableReduction(*node)) {
// Replace the reduction node with an identity node, that can be further
// optimized by the model pruner.
@@ -1713,6 +1727,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
TF_RETURN_IF_ERROR(FoldGraph(output));
node_map_.reset(new NodeMap(output));
TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
+
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index d8df19fe6a..3afc176402 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -1177,6 +1177,40 @@ TEST_F(ConstantFoldingTest, MergeNodes) {
EXPECT_EQ(2, out_idx.flat<int32>()(0));
}
+TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 =
+ ops::Variable(scope.WithOpName("in1"), TensorShape({}), DT_FLOAT);
+ Output in2 =
+ ops::Variable(scope.WithOpName("in2"), TensorShape({}), DT_FLOAT);
+ ops::RandomShuffle s1(scope.WithOpName("s1"), in1);
+ ops::RandomShuffle s2(scope.WithOpName("s2").WithControlDependencies({in1}),
+ in2);
+
+ ops::Add out1(scope.WithOpName("out1"), s1, s2);
+ ops::Identity out2(scope.WithOpName("out2"), s2);
+
+ GrapplerItem item;
+ item.fetch = {"out1", "out2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding fold(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = fold.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, &want);
+ AddNode("in2", "VariableV2", {}, &want);
+ AddNode("s1", "Identity", {"in1"}, &want);
+ AddNode("s2", "Identity", {"in2", AsControlDependency("in1")}, &want);
+ AddNode("out1", "Add", {"s1", "s2"}, &want);
+ AddNode("out2", "Identity", {"s2"}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ConstantFoldingTest, NoOpReduction) {
// Build a simple graph with a reduction that can be reduced to the identity.
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc
index 813f65f825..fed46c05fb 100644
--- a/tensorflow/core/grappler/utils/grappler_test.cc
+++ b/tensorflow/core/grappler/utils/grappler_test.cc
@@ -35,5 +35,43 @@ std::vector<Tensor> GrapplerTest::EvaluateNodes(
return output_tensors;
}
+void GrapplerTest::AddNode(const string& name, const string& op,
+ const std::vector<string>& inputs, GraphDef* graph) {
+ auto* node = graph->add_node();
+ node->set_name(name);
+ node->set_op(op);
+ for (const auto& input : inputs) {
+ node->add_input(input);
+ }
+}
+
+void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) {
+ auto comparator = [](const NodeDef& n1, const NodeDef& n2) -> bool {
+ return n1.name() < n2.name();
+ };
+ std::sort(want.mutable_node()->begin(), want.mutable_node()->end(),
+ comparator);
+ std::sort(got.mutable_node()->begin(), got.mutable_node()->end(), comparator);
+
+ for (int i = 0; i < want.node_size(); ++i) {
+ std::sort(want.mutable_node(i)->mutable_input()->begin(),
+ want.mutable_node(i)->mutable_input()->end());
+ }
+ for (int i = 0; i < got.node_size(); ++i) {
+ std::sort(got.mutable_node(i)->mutable_input()->begin(),
+ got.mutable_node(i)->mutable_input()->end());
+ }
+
+ ASSERT_EQ(want.node_size(), got.node_size());
+ for (int i = 0; i < want.node_size(); ++i) {
+ EXPECT_EQ(want.node(i).op(), got.node(i).op());
+ EXPECT_EQ(want.node(i).name(), got.node(i).name());
+ ASSERT_EQ(want.node(i).input_size(), got.node(i).input_size());
+ for (int j = 0; j < want.node(i).input_size(); ++j) {
+ EXPECT_TRUE(IsSameInput(want.node(i).input(j), got.node(i).input(j)));
+ }
+ }
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h
index 46ce47c8c3..042b616aa4 100644
--- a/tensorflow/core/grappler/utils/grappler_test.h
+++ b/tensorflow/core/grappler/utils/grappler_test.h
@@ -29,6 +29,11 @@ class GrapplerTest : public ::testing::Test {
protected:
std::vector<Tensor> EvaluateNodes(const GraphDef& graph,
const std::vector<string>& node_names);
+
+ void AddNode(const string& name, const string& op,
+ const std::vector<string>& inputs, GraphDef* graph);
+
+ void CompareGraphs(GraphDef want, GraphDef got);
};
} // end namespace grappler