diff options
author | 2018-02-21 12:57:26 -0800 | |
---|---|---|
committer | 2018-02-21 13:04:56 -0800 | |
commit | 7e8b4a09416e453555073a88b0fd47625e0c5036 (patch) | |
tree | 106e6f307bcf540e125595177f2632b08f54204b /tensorflow/core/grappler/utils | |
parent | 9dfb73b26c846038ef8101b2624de3b2cbf49c61 (diff) |
Change node to Identity operation for shuffle/reverse operations on scalar values, but not
directly removing those nodes from the graph.
PiperOrigin-RevId: 186505857
Diffstat (limited to 'tensorflow/core/grappler/utils')
-rw-r--r-- | tensorflow/core/grappler/utils/grappler_test.cc | 38 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/grappler_test.h | 5 |
2 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc index 813f65f825..fed46c05fb 100644 --- a/tensorflow/core/grappler/utils/grappler_test.cc +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -35,5 +35,43 @@ std::vector<Tensor> GrapplerTest::EvaluateNodes( return output_tensors; } +void GrapplerTest::AddNode(const string& name, const string& op, + const std::vector<string>& inputs, GraphDef* graph) { + auto* node = graph->add_node(); + node->set_name(name); + node->set_op(op); + for (const auto& input : inputs) { + node->add_input(input); + } +} + +void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) { + auto comparator = [](const NodeDef& n1, const NodeDef& n2) -> bool { + return n1.name() < n2.name(); + }; + std::sort(want.mutable_node()->begin(), want.mutable_node()->end(), + comparator); + std::sort(got.mutable_node()->begin(), got.mutable_node()->end(), comparator); + + for (int i = 0; i < want.node_size(); ++i) { + std::sort(want.mutable_node(i)->mutable_input()->begin(), + want.mutable_node(i)->mutable_input()->end()); + } + for (int i = 0; i < got.node_size(); ++i) { + std::sort(got.mutable_node(i)->mutable_input()->begin(), + got.mutable_node(i)->mutable_input()->end()); + } + + ASSERT_EQ(want.node_size(), got.node_size()); + for (int i = 0; i < want.node_size(); ++i) { + EXPECT_EQ(want.node(i).op(), got.node(i).op()); + EXPECT_EQ(want.node(i).name(), got.node(i).name()); + ASSERT_EQ(want.node(i).input_size(), got.node(i).input_size()); + for (int j = 0; j < want.node(i).input_size(); ++j) { + EXPECT_TRUE(IsSameInput(want.node(i).input(j), got.node(i).input(j))); + } + } +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h index 46ce47c8c3..042b616aa4 100644 --- a/tensorflow/core/grappler/utils/grappler_test.h +++ b/tensorflow/core/grappler/utils/grappler_test.h @@ -29,6 +29,11 @@ class GrapplerTest : public ::testing::Test { protected: std::vector<Tensor> EvaluateNodes(const GraphDef& graph, const std::vector<string>& node_names); + + void AddNode(const string& name, const string& op, + const std::vector<string>& inputs, GraphDef* graph); + + void CompareGraphs(GraphDef want, GraphDef got); }; } // end namespace grappler |