aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar Piotr Padlewski <prazek@google.com>2018-09-23 18:30:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-23 18:37:06 -0700
commitfcd7840fbf49802be4bb7f67671465338b7b78a4 (patch)
tree017562cbd2b66b462a3562d667f3dae2a99c0ee5 /tensorflow/core/grappler
parent167272ead245ac9e0183da807d996ba9d6e401b0 (diff)
Fix noop elimination optimization.
Fix for b/116169724 Only remove noops if they refer to const nodes. PiperOrigin-RevId: 214199200
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h3
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc43
4 files changed, 65 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index b3f60e34f9..2dd9ee822e 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -88,6 +88,16 @@ NodeDef* AddScalarConstNodeHelper(
} // namespace
+NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
+ NodeDef node;
+ node.set_op("Placeholder");
+ SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node);
+ (*node.mutable_attr())["dtype"].set_type(dtype);
+ TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
+ shape->set_unknown_rank(false);
+ return graph->AddNode(std::move(node));
+}
+
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
const std::vector<std::pair<string, AttrValue>>& attributes,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 1652afcd9e..b117482db2 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -37,6 +37,9 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<std::pair<string, AttrValue>>& attributes,
MutableGraphView* graph);
+// Adds Placeholder node for given type.
+NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph);
+
// Adds a Const node with the given value to the graph.
template <typename T>
NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
index a26f1000a3..cf5a19bab1 100644
--- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
@@ -33,25 +33,27 @@ namespace {
bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) {
if (take_node.op() != "TakeDataset") return false;
- const NodeDef& count_node = *graph.GetNode(take_node.input(1));
+ const auto& count_node = *graph.GetNode(take_node.input(1));
+ if (count_node.op() != "Const") return false;
// We are looking only for 'take' with negative count.
return count_node.attr().at("value").tensor().int64_val(0) < 0;
}
+bool IsConstNodeWithValue(const NodeDef& node, int value) {
+ if (node.op() != "Const") return false;
+ return node.attr().at("value").tensor().int64_val(0) == value;
+}
+
bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) {
if (skip_node.op() != "SkipDataset") return false;
-
- const NodeDef& count_node = *graph.GetNode(skip_node.input(1));
// We are looking only for skip(0) nodes.
- return count_node.attr().at("value").tensor().int64_val(0) == 0;
+ return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0);
}
bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) {
if (repeat_node.op() != "RepeatDataset") return false;
-
- const NodeDef& count_node = *graph.GetNode(repeat_node.input(1));
// We are looking only for repeat(1) nodes.
- return count_node.attr().at("value").tensor().int64_val(0) == 1;
+ return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1);
}
bool IsNoOp(const NodeDef& node, const GraphView& graph) {
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
index f445e75aa7..be1a66df75 100644
--- a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
@@ -43,6 +43,14 @@ NodeDef *MakeUnaryNode(StringPiece node_type, int count, string input_node,
GetCommonAttributes(), graph);
}
+NodeDef *MakeUnaryNonConstNode(StringPiece node_type, string input_node,
+ MutableGraphView *graph) {
+ NodeDef *node_count = graph_utils::AddScalarPlaceholder(DT_INT32, graph);
+ return graph_utils::AddNode("", node_type,
+ {std::move(input_node), node_count->name()},
+ GetCommonAttributes(), graph);
+}
+
NodeDef *MakeCacheNode(string input_node, MutableGraphView *graph) {
NodeDef *node_filename =
graph_utils::AddScalarConstNode<StringPiece>("", graph);
@@ -205,6 +213,41 @@ INSTANTIATE_TEST_CASE_P(
::testing::Values(*kTakeNode, *kSkipNode,
*kRepeatNode)));
+struct NoOpPlaceholdersTest
+ : ::testing::TestWithParam<std::tuple<string, string>> {};
+
+TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) {
+ GrapplerItem item;
+ MutableGraphView graph(&item.graph);
+
+ static_assert(std::tuple_size<NodesTypes>::value == 2,
+ "Make sure to include everything in the test");
+ const std::vector<string> noop_nodes = {std::get<0>(GetParam()),
+ std::get<1>(GetParam())};
+ NodeDef *range_node = MakeRangeNode(&graph);
+ std::vector<string> nodes_to_keep;
+ nodes_to_keep.reserve(noop_nodes.size());
+ NodeDef *previous = range_node;
+
+ for (const auto &noop_node : noop_nodes) {
+ NodeDef *node = MakeUnaryNonConstNode(noop_node, previous->name(), &graph);
+ nodes_to_keep.push_back(node->name());
+ previous = node;
+ }
+
+ NoOpElimination optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ for (const auto &noop_node_name : nodes_to_keep)
+ EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(noop_node_name, output));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DoNotRemovePlaceholders, NoOpPlaceholdersTest,
+ ::testing::Combine(
+ ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset"),
+ ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset")));
+
} // namespace
} // namespace grappler
} // namespace tensorflow