aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-21 12:57:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-21 13:04:56 -0800
commit7e8b4a09416e453555073a88b0fd47625e0c5036 (patch)
tree106e6f307bcf540e125595177f2632b08f54204b /tensorflow/core/grappler/utils
parent9dfb73b26c846038ef8101b2624de3b2cbf49c61 (diff)
Change node to Identity operation for shuffle/reverse operations on scalar values, but not
directly removing those nodes from the graph. PiperOrigin-RevId: 186505857
Diffstat (limited to 'tensorflow/core/grappler/utils')
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.cc38
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.h5
2 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc
index 813f65f825..fed46c05fb 100644
--- a/tensorflow/core/grappler/utils/grappler_test.cc
+++ b/tensorflow/core/grappler/utils/grappler_test.cc
@@ -35,5 +35,43 @@ std::vector<Tensor> GrapplerTest::EvaluateNodes(
return output_tensors;
}
+void GrapplerTest::AddNode(const string& name, const string& op,
+ const std::vector<string>& inputs, GraphDef* graph) {
+ auto* node = graph->add_node();
+ node->set_name(name);
+ node->set_op(op);
+ for (const auto& input : inputs) {
+ node->add_input(input);
+ }
+}
+
+void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) {
+ auto comparator = [](const NodeDef& n1, const NodeDef& n2) -> bool {
+ return n1.name() < n2.name();
+ };
+ std::sort(want.mutable_node()->begin(), want.mutable_node()->end(),
+ comparator);
+ std::sort(got.mutable_node()->begin(), got.mutable_node()->end(), comparator);
+
+ for (int i = 0; i < want.node_size(); ++i) {
+ std::sort(want.mutable_node(i)->mutable_input()->begin(),
+ want.mutable_node(i)->mutable_input()->end());
+ }
+ for (int i = 0; i < got.node_size(); ++i) {
+ std::sort(got.mutable_node(i)->mutable_input()->begin(),
+ got.mutable_node(i)->mutable_input()->end());
+ }
+
+ ASSERT_EQ(want.node_size(), got.node_size());
+ for (int i = 0; i < want.node_size(); ++i) {
+ EXPECT_EQ(want.node(i).op(), got.node(i).op());
+ EXPECT_EQ(want.node(i).name(), got.node(i).name());
+ ASSERT_EQ(want.node(i).input_size(), got.node(i).input_size());
+ for (int j = 0; j < want.node(i).input_size(); ++j) {
+ EXPECT_TRUE(IsSameInput(want.node(i).input(j), got.node(i).input(j)));
+ }
+ }
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h
index 46ce47c8c3..042b616aa4 100644
--- a/tensorflow/core/grappler/utils/grappler_test.h
+++ b/tensorflow/core/grappler/utils/grappler_test.h
@@ -29,6 +29,11 @@ class GrapplerTest : public ::testing::Test {
protected:
std::vector<Tensor> EvaluateNodes(const GraphDef& graph,
const std::vector<string>& node_names);
+
+ void AddNode(const string& name, const string& op,
+ const std::vector<string>& inputs, GraphDef* graph);
+
+ void CompareGraphs(GraphDef want, GraphDef got);
};
} // end namespace grappler