aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/graph_analyzer/BUILD139
-rw-r--r--tensorflow/core/grappler/graph_analyzer/gen_node.cc148
-rw-r--r--tensorflow/core/grappler/graph_analyzer/gen_node.h167
-rw-r--r--tensorflow/core/grappler/graph_analyzer/gen_node_test.cc491
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc341
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer.h154
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc569
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc98
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h31
-rw-r--r--tensorflow/core/grappler/graph_analyzer/hash_tools.h47
-rw-r--r--tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc46
-rw-r--r--tensorflow/core/grappler/graph_analyzer/map_tools.h46
-rw-r--r--tensorflow/core/grappler/graph_analyzer/sig_node.cc453
-rw-r--r--tensorflow/core/grappler/graph_analyzer/sig_node.h304
-rw-r--r--tensorflow/core/grappler/graph_analyzer/sig_node_test.cc1235
-rw-r--r--tensorflow/core/grappler/graph_analyzer/subgraph.cc235
-rw-r--r--tensorflow/core/grappler/graph_analyzer/subgraph.h189
-rw-r--r--tensorflow/core/grappler/graph_analyzer/subgraph_test.cc348
-rw-r--r--tensorflow/core/grappler/graph_analyzer/test_tools.cc296
-rw-r--r--tensorflow/core/grappler/graph_analyzer/test_tools.h120
-rw-r--r--tensorflow/python/BUILD14
-rw-r--r--tensorflow/python/grappler/graph_analyzer.i26
-rw-r--r--tensorflow/python/grappler/graph_analyzer.py46
-rw-r--r--tensorflow/python/tensorflow.i1
24 files changed, 5544 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/graph_analyzer/BUILD b/tensorflow/core/grappler/graph_analyzer/BUILD
new file mode 100644
index 0000000000..d56a08d3c8
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/BUILD
@@ -0,0 +1,139 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+cc_library(
+ name = "graph_analyzer_lib",
+ srcs = [
+ "gen_node.cc",
+ "graph_analyzer.cc",
+ "sig_node.cc",
+ "subgraph.cc",
+ ],
+ hdrs = [
+ "gen_node.h",
+ "graph_analyzer.h",
+ "hash_tools.h",
+ "map_tools.h",
+ "sig_node.h",
+ "subgraph.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+cc_library(
+ name = "graph_analyzer_tool",
+ srcs = ["graph_analyzer_tool.cc"],
+ hdrs = ["graph_analyzer_tool.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_analyzer_lib",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/grappler:grappler_item",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "test_tools_lib",
+ testonly = 1,
+ srcs = [
+ "test_tools.cc",
+ ],
+ hdrs = [
+ "test_tools.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_analyzer_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/grappler:op_types",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+tf_cc_test(
+ name = "hash_tools_test",
+ testonly = 1,
+ srcs = [
+ "hash_tools_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "gen_node_test",
+ testonly = 1,
+ srcs = [
+ "gen_node_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ ":test_tools_lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "sig_node_test",
+ testonly = 1,
+ srcs = [
+ "sig_node_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ ":test_tools_lib",
+ "//tensorflow/core/grappler:utils",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "graph_analyzer_test",
+ testonly = 1,
+ srcs = [
+ "graph_analyzer_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ ":test_tools_lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "subgraph_test",
+ testonly = 1,
+ srcs = [
+ "subgraph_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ ":test_tools_lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node.cc b/tensorflow/core/grappler/graph_analyzer/gen_node.cc
new file mode 100644
index 0000000000..f8c15fd50e
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/gen_node.cc
@@ -0,0 +1,148 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+GenNode::GenNode(const NodeDef* node) : node_(node), op_(nullptr) {}
+
+Status GenNode::BuildGraphInMap(const GraphDef& source, GenNodeMap* map) {
+ for (const auto& n : source.node()) {
+ const string& name = n.name();
+ if (map->find(name) != map->end()) {
+ // This error code looks more meaningful than ALREADY_EXISTS.
+ return Status(error::INVALID_ARGUMENT,
+ "Duplicate node name '" + name + "'.");
+ }
+ (*map)[name] = absl::make_unique<GenNode>(&n);
+ }
+ // Now parse the links.
+ for (const auto& mapit : *map) {
+ Status st = mapit.second->ParseInputs(map);
+ if (!st.ok()) {
+ return st;
+ }
+ }
+ return Status::OK();
+}
+
+Status GenNode::ParseInputs(const GenNodeMap* map) {
+ all_inputs_or_none_ = false;
+ Status st = OpRegistry::Global()->LookUpOpDef(opcode(), &op_);
+ if (!st.ok()) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ absl::StrFormat("Node '%s' contains an undefined operation '%s': %s",
+ name(), opcode(), st.error_message()));
+ }
+
+ int n_inputs = node_->input_size();
+
+ int n_named_inputs = op_->input_arg_size();
+
+ int n_multi_inputs = 0;
+ for (const auto& inarg : op_->input_arg()) {
+ if (!inarg.number_attr().empty() || !inarg.type_list_attr().empty()) {
+ ++n_multi_inputs;
+ }
+ }
+ bool is_commutative = grappler::IsCommutative(*node_);
+
+ if (n_multi_inputs > 1 || (n_multi_inputs > 0 && n_named_inputs > 1)) {
+ // Can't handle more than one multi-input at a time.
+ // And can't handle the commutativeness of only some arguments
+ // rather than all of them.
+ is_commutative = false;
+ }
+
+ if (is_commutative) {
+ // If truly commutative, can treat all the inputs as one multi-input.
+ // It's possible to just treat the commutative nodes as AllInputsOrNone
+ // but (1) this way is a bit more efficient and (2) I want to preserve this
+ // more efficient code path that does all-or-none by a single input and
+ // perhaps extend its use in the future.
+ n_named_inputs = 1;
+ all_inputs_or_none_ = false;
+ } else if (n_multi_inputs > 0) {
+ all_inputs_or_none_ = true;
+ }
+
+ for (int i = 0; i < n_inputs; ++i) {
+ int other_position;
+ string other_name = ParseNodeName(node_->input(i), &other_position);
+ auto other_it = map->find(other_name);
+ if (other_it == map->end()) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ absl::StrFormat(
+ "Node '%s' input %d refers to a non-existing node '%s'.", name(),
+ i, other_name));
+ }
+ GenNode* other_node = other_it->second.get();
+
+ int this_position = other_position < 0 ? -1 : (is_commutative ? 0 : i);
+
+ if (this_position >= 0 && n_multi_inputs == 0 &&
+ this_position >= n_named_inputs) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ absl::StrFormat(
+ "Node '%s' has a non-control input from '%s' at index %d but its "
+ "operation '%s' defines only %d inputs.",
+ name(), other_name, this_position, op_->name(), n_named_inputs));
+ }
+
+ Port this_port(/*inbound=*/true, this_position);
+ Port other_port(/*inbound=*/false, other_position);
+
+ links_[this_port].emplace_back(LinkTarget(other_node, other_port));
+ other_node->links_[other_port].emplace_back(LinkTarget(this, this_port));
+ }
+ return Status::OK();
+}
+
+bool GenNode::IsMultiInput(Port port) const {
+ if (!port.IsInbound()) {
+ return false;
+ }
+ auto it = links_.find(port);
+ if (it == links_.end()) {
+ return false; // Shouldn't happen.
+ }
+ return (it->second.size() > 1);
+}
+
+GenNode::Port::operator string() const {
+ string result = this->IsInbound() ? "i" : "o";
+ if (this->IsControl()) {
+ result.append("C");
+ } else {
+ result.append(absl::StrFormat("%d", this->Id()));
+ }
+ return result;
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node.h b/tensorflow/core/grappler/graph_analyzer/gen_node.h
new file mode 100644
index 0000000000..faec9ecad8
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/gen_node.h
@@ -0,0 +1,167 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+class GenNode;
+
+// To find nodes by name.
+using GenNodeMap = std::unordered_map<string, std::unique_ptr<GenNode>>;
+
+// One node in the graph, in the form convenient for traversal and generation of
+// subgraphs. It refers to the original NodeDef protobuf for most information
+// and adds the extra enrichment.
+//
+// The graph building is 2-stage: first match a GenNode with each NodeDef and
+// collect them into a map that finds them by name, then process the map,
+// deep-parse the underlying NodeDefs and connect the GenNodes together.
+class GenNode {
+ public:
+ // Will keep the pointer, so the underlying object must not be deleted while
+ // GenNode is alive.
+ explicit GenNode(const NodeDef* node);
+
+ // Access wrappers.
+ const string& name() const { return node_->name(); }
+ const string& opcode() const { return node_->op(); }
+ const NodeDef* node_def() const { return node_; }
+
+ // Parse the inputs of this node and update the map accordingly, creating the
+ // links (i.e. edges, connections between nodes) in itself and in the nodes
+ // it's linked to (the map itself is unchanged, only the nodes in it are
+ // updated).
+ Status ParseInputs(const GenNodeMap* map);
+
+ // Does the full 2-stage build of the graph. The map should be initially
+ // empty. The map keeps pointers to the nodes in source, so the source must
+ // not be destroyed before the map.
+ static Status BuildGraphInMap(const GraphDef& source, GenNodeMap* map);
+
+ // The enrichment that constitutes the point of this class.
+
+ // Representation of a connection on a node.
+ class Port {
+ public:
+ // A port may be inbound or outbound.
+ // Negative ids (canonically -1) mean a control port.
+ Port(bool inbound, int32_t id) : value_(id << 1) {
+ if (inbound) {
+ value_ |= 1;
+ }
+ }
+ Port(const Port&) = default;
+ Port& operator=(const Port&) = default;
+
+ bool IsInbound() const { return (value_ & 0x1); }
+
+ bool IsControl() const { return (value_ < 0); }
+
+ int32_t Id() const {
+ // Arithmetic shift preserves the sign.
+ return (value_ >> 1);
+ }
+
+ // Integer type used to represent the encoded port value.
+ using IntPort = int32_t;
+
+ // Returns the encoded form of this port, so that it can be used
+ // as various map indexes.
+ IntPort Encoded() const { return value_; }
+
+ static Port Decode(IntPort encoded) { return Port(encoded); }
+
+ bool operator==(const Port& other) const { return value_ == other.value_; }
+ bool operator<(const Port& other) const { return value_ < other.value_; }
+
+ struct Hasher {
+ size_t operator()(const Port& port) const noexcept {
+ return hasher(port.Encoded());
+ }
+ std::hash<int32_t> hasher;
+ };
+
+ // Convenient for printing. I've really wanted it to be implicit but
+ // ClangTidy insists on making it explicit.
+ explicit operator string() const;
+
+ private:
+ explicit Port(IntPort value) : value_(value) {}
+
+ IntPort value_;
+ };
+
+ struct LinkTarget {
+ GenNode* node; // Node where this link points.
+ Port port; // Port on the remote side of this link.
+
+ LinkTarget(GenNode* a_node, Port a_port) : node(a_node), port(a_port) {}
+ };
+ // All the links that are connected to the same port of this node
+ // are collected in one vector. A link is an edge of the graph that connects
+ // 2 nodes. Each of the connected nodes has its own perspective on the link,
+ // seeing its local port, remote port and the remote node. The direction of
+ // the link is encoded in the ports, one port is always incoming and another
+ // one outgoing.
+ using LinkTargetVector = std::vector<LinkTarget>;
+ // Both inputs and outputs are stored in the same map.
+ using LinkMap = std::unordered_map<Port, LinkTargetVector, Port::Hasher>;
+
+ // Access to the link map.
+ const LinkMap& links() const { return links_; }
+
+ // Check whether the port is an input (including the controls) with multiple
+ // connections. Such inputs get handled in a special way when building the
+ // subgraphs, in an "all or nothing" fashion.
+ bool IsMultiInput(Port port) const;
+
+ // When building the subgraphs, must include either all non-control inputs of
+ // this node into the subgraph or none of them. This happens when at least one
+ // of the inputs is a multi-input (or if the opcode is commutative, thus
+ // treating all the inputs as one multi-input).
+ bool AllInputsOrNone() const { return all_inputs_or_none_; }
+
+ private:
+ const NodeDef* node_;
+ // Becomes valid only after ParseInputs().
+ const OpDef* op_;
+
+ // The opcode has a complicated structure of input args, with multi-input args
+ // that are not commutative. This means that to make sense, the subgraphs that
+ // include this node must also include either all its inputs or none of them.
+ bool all_inputs_or_none_ = false;
+
+ LinkMap links_;
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc b/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc
new file mode 100644
index 0000000000..d77daf7849
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc
@@ -0,0 +1,491 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Ne;
+
+TEST(GenNodeTest, Port) {
+ {
+ GenNode::Port p(true, 100);
+ EXPECT_THAT(p.IsInbound(), Eq(true));
+ EXPECT_THAT(p.IsControl(), Eq(false));
+ EXPECT_THAT(p.Id(), Eq(100));
+ GenNode::Port p2 = GenNode::Port::Decode(p.Encoded());
+ EXPECT_THAT(p2.IsInbound(), Eq(true));
+ EXPECT_THAT(p2.IsControl(), Eq(false));
+ EXPECT_THAT(p2.Id(), Eq(100));
+ }
+ {
+ GenNode::Port p(false, 0);
+ EXPECT_THAT(p.IsInbound(), Eq(false));
+ EXPECT_THAT(p.IsControl(), Eq(false));
+ EXPECT_THAT(p.Id(), Eq(0));
+ GenNode::Port p2 = GenNode::Port::Decode(p.Encoded());
+ EXPECT_THAT(p2.IsInbound(), Eq(false));
+ EXPECT_THAT(p2.IsControl(), Eq(false));
+ EXPECT_THAT(p2.Id(), Eq(0));
+ }
+ {
+ GenNode::Port p(true, -100);
+ EXPECT_THAT(p.IsInbound(), Eq(true));
+ EXPECT_THAT(p.IsControl(), Eq(true));
+ EXPECT_THAT(p.Id(), Eq(-100));
+ GenNode::Port p2 = GenNode::Port::Decode(p.Encoded());
+ EXPECT_THAT(p2.IsInbound(), Eq(true));
+ EXPECT_THAT(p2.IsControl(), Eq(true));
+ EXPECT_THAT(p2.Id(), Eq(-100));
+ }
+ {
+ GenNode::Port p(false, -1);
+ EXPECT_THAT(p.IsInbound(), Eq(false));
+ EXPECT_THAT(p.IsControl(), Eq(true));
+ EXPECT_THAT(p.Id(), Eq(-1));
+ GenNode::Port p2 = GenNode::Port::Decode(p.Encoded());
+ EXPECT_THAT(p2.IsInbound(), Eq(false));
+ EXPECT_THAT(p2.IsControl(), Eq(true));
+ EXPECT_THAT(p2.Id(), Eq(-1));
+ }
+}
+
+TEST(GenNodeTest, ParseNodeNoInputs) {
+ GenNodeMap map;
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ auto gn1 = map["node1"].get();
+ ASSERT_THAT(gn1->ParseInputs(&map), Eq(Status::OK()));
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre());
+}
+
+// A general operation, and a control link.
+TEST(GenNodeTest, ParseNodeWithControl) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeSub("node3", "node1", "node2");
+ node3.add_input("^node1"); // The control link.
+ node3.add_input("^node2"); // The control link.
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]",
+ "oC: node3[iC]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i1]",
+ "oC: node3[iC]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]",
+ "iC: node1[oC], node2[oC]"
+ ));
+ // clang-format on
+
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(false));
+
+ // This is a multi-control-input.
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, -1)), Eq(true));
+
+ EXPECT_FALSE(gn1->AllInputsOrNone());
+ EXPECT_FALSE(gn2->AllInputsOrNone());
+ EXPECT_FALSE(gn3->AllInputsOrNone());
+}
+
+// Commutative nodes are treated as having a single input,
+// because their inputs are equivalent.
+TEST(GenNodeTest, ParseNodeCommutative) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ // TODO(babkin): grappler::IsCommutative() should return true for Add but
+ // apparently doesn't. So use Mul in the meantime.
+ NodeDef node3 = MakeNodeMul("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0], node2[o0]"
+ ));
+ // clang-format on
+
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(true));
+
+ EXPECT_FALSE(gn3->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiInputCommutative) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeAddN("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0], node2[o0]"
+ ));
+ // clang-format on
+
+ // This is a multi-output.
+ EXPECT_THAT(gn2->IsMultiInput(GenNode::Port(false, 0)), Eq(false));
+ // This is a multi-input.
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(true));
+
+ EXPECT_FALSE(gn3->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiInputNotCommutative) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeShapeN("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i1]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]"
+ ));
+ // clang-format on
+
+ // Non-commutative multi-input doesn't count.
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(false));
+ EXPECT_TRUE(gn3->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiInputList) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeIdentityN("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i1]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]"
+ ));
+ // clang-format on
+
+ // Non-commutative multi-input doesn't count.
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(false));
+ EXPECT_TRUE(gn3->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiMultiInput) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeConst("node3");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ NodeDef node4 = MakeNodeConst("node4");
+ map["node4"] = absl::make_unique<GenNode>(&node4);
+
+ NodeDef node5 =
+ MakeNodeQuantizedConcat("node5", "node1", "node2", "node3", "node4");
+ map["node5"] = absl::make_unique<GenNode>(&node5);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ auto gn4 = map["node4"].get();
+ auto gn5 = map["node5"].get();
+ ASSERT_THAT(gn5->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node5[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node5[i1]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "o0: node5[i2]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn4->links()), ElementsAre(
+ "o0: node5[i3]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn5->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]",
+ "i2: node3[o0]",
+ "i3: node4[o0]"
+ ));
+ // clang-format on
+
+ // Non-commutative multi-input doesn't count.
+ EXPECT_THAT(gn5->IsMultiInput(GenNode::Port(true, 1)), Eq(false));
+ EXPECT_THAT(gn5->IsMultiInput(GenNode::Port(true, 2)), Eq(false));
+ EXPECT_TRUE(gn5->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiOutput) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ NodeDef node4 = MakeNodeSub("node4", "node3:1", "node3:0");
+ map["node4"] = absl::make_unique<GenNode>(&node4);
+
+ auto gn4 = map["node4"].get();
+ ASSERT_THAT(gn4->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn4->links()), ElementsAre(
+ "i0: node3[o1]",
+ "i1: node3[o0]"
+ ));
+ // clang-format on
+}
+
+TEST(GenNodeTest, ParseNodeUndefinedOp) {
+ GenNodeMap map;
+ NodeDef node1;
+ node1.set_name("node1");
+ node1.set_op("Zzzx");
+
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ const OpDef* opdef;
+ Status nested_error = OpRegistry::Global()->LookUpOpDef("Zzzx", &opdef);
+
+ auto gn = map["node1"].get();
+ ASSERT_THAT(
+ gn->ParseInputs(&map),
+ Eq(Status(error::INVALID_ARGUMENT,
+ "Node 'node1' contains an undefined operation 'Zzzx': " +
+ nested_error.error_message())));
+}
+
+TEST(GenNodeTest, ParseNodeUnexpectedInputs) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+ node1.add_input("node1");
+
+ auto gn1 = map["node1"].get();
+ EXPECT_THAT(gn1->ParseInputs(&map),
+ Eq(Status(error::INVALID_ARGUMENT,
+ "Node 'node1' has a non-control "
+ "input from 'node1' at index 0 but its operation "
+ "'Const' defines only 0 inputs.")));
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeSub("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+ node3.add_input("node1");
+
+ auto gn3 = map["node3"].get();
+ EXPECT_THAT(gn3->ParseInputs(&map),
+ Eq(Status(error::INVALID_ARGUMENT,
+ "Node 'node3' has a non-control "
+ "input from 'node1' at index 2 but its operation "
+ "'Sub' defines only 2 inputs.")));
+}
+
+// Even if an opcode defines no inputs, the node may still accept the control
+// inputs.
+TEST(GenNodeTest, ParseNodeControlInputsAlwaysOk) {
+ GenNodeMap map;
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+ node1.add_input("^node1");
+ auto gn1 = map["node1"].get();
+ ASSERT_THAT(gn1->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "iC: node1[oC]",
+ "oC: node1[iC]"
+ ));
+ // clang-format on
+}
+
+TEST(GenNodeTest, ParseNodeInvalidInput) {
+ GenNodeMap map;
+ NodeDef node1 = MakeNodeAddN("node1", "node2", "node3");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+ node1.add_input("node1");
+ auto gn1 = map["node1"].get();
+ ASSERT_THAT(
+ gn1->ParseInputs(&map),
+ Eq(Status(
+ error::INVALID_ARGUMENT,
+ "Node 'node1' input 0 refers to a non-existing node 'node2'.")));
+}
+
+TEST(GenNodeTest, BuildGraphInMap) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ (*graph.add_node()) =
+ MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node1"), Ne(map.end()));
+ ASSERT_THAT(map.find("node2"), Ne(map.end()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ EXPECT_THAT(map["node1"]->name(), Eq("node1"));
+ EXPECT_THAT(map["node2"]->name(), Eq("node2"));
+ EXPECT_THAT(map["node3"]->name(), Eq("node3"));
+
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(map["node1"]->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(map["node2"]->links()), ElementsAre(
+ "i0: node3[o1]",
+ "i1: node3[o0]",
+ "o0: node3[i1]"
+ ));
+ EXPECT_THAT(DumpLinkMap(map["node3"]->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]",
+ "o0: node2[i1]",
+ "o1: node2[i0]"
+ ));
+ // clang-format on
+}
+
+TEST(GenNodeTest, BuildGraphInMapDuplicateNode) {
+ GraphDef graph;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node1");
+ GenNodeMap map;
+ ASSERT_THAT(
+ GenNode::BuildGraphInMap(graph, &map),
+ Eq(Status(error::INVALID_ARGUMENT, "Duplicate node name 'node1'.")));
+}
+
+TEST(GenNodeTest, BuildGraphInMapParseError) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+
+ GenNodeMap map;
+ ASSERT_THAT(
+ GenNode::BuildGraphInMap(graph, &map),
+ Eq(Status(
+ error::INVALID_ARGUMENT,
+ "Node 'node2' input 0 refers to a non-existing node 'node3'.")));
+}
+
+} // end namespace
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc
new file mode 100644
index 0000000000..f3796fcf86
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc
@@ -0,0 +1,341 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <deque>
+#include <iostream>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+GraphAnalyzer::GraphAnalyzer(const GraphDef& graph, int subgraph_size)
+ : graph_(graph), subgraph_size_(subgraph_size) {}
+
+GraphAnalyzer::~GraphAnalyzer() {}
+
+Status GraphAnalyzer::Run() {
+ // The signature computation code would detect this too, but better
+ // to report it up front than spend time computing all the graphs first.
+ if (subgraph_size_ > Signature::kMaxGraphSize) {
+ return Status(error::INVALID_ARGUMENT,
+ absl::StrFormat("Subgraphs of %d nodes are not supported, "
+ "the maximal supported node count is %d.",
+ subgraph_size_, Signature::kMaxGraphSize));
+ }
+
+ Status st = BuildMap();
+ if (!st.ok()) {
+ return st;
+ }
+
+ FindSubgraphs();
+ DropInvalidSubgraphs();
+ st = CollateResult();
+ if (!st.ok()) {
+ return st;
+ }
+
+ return Status::OK();
+}
+
+Status GraphAnalyzer::BuildMap() {
+ nodes_.clear();
+ return GenNode::BuildGraphInMap(graph_, &nodes_);
+}
+
+void GraphAnalyzer::FindSubgraphs() {
+ result_.clear();
+
+ if (subgraph_size_ < 1) {
+ return;
+ }
+
+ partial_.clear();
+ todo_.clear(); // Just in case.
+
+ // Start with all subgraphs of size 1.
+ const Subgraph::Identity empty_parent;
+ for (const auto& node : nodes_) {
+ if (subgraph_size_ == 1) {
+ result_.ExtendParent(empty_parent, node.second.get());
+ } else {
+ // At this point ExtendParent() is guaranteed to not return nullptr.
+ todo_.push_back(partial_.ExtendParent(empty_parent, node.second.get()));
+ }
+ }
+
+ // Then extend the subgraphs until no more extensions are possible.
+ while (!todo_.empty()) {
+ ExtendSubgraph(todo_.front());
+ todo_.pop_front();
+ }
+
+ partial_.clear();
+}
+
+void GraphAnalyzer::ExtendSubgraph(Subgraph* parent) {
+ bool will_complete = (parent->id().size() + 1 == subgraph_size_);
+ SubgraphPtrSet& sg_set = will_complete ? result_ : partial_;
+
+ const GenNode* last_all_or_none_node = nullptr;
+ for (SubgraphIterator sit(parent); !sit.AtEnd(); sit.Next()) {
+ const GenNode* node = sit.GetNode();
+ GenNode::Port port = sit.GetPort();
+ const GenNode::LinkTarget& neighbor = sit.GetNeighbor();
+
+ if (node->AllInputsOrNone() && port.IsInbound() && !port.IsControl()) {
+ if (node != last_all_or_none_node) {
+ ExtendSubgraphAllOrNone(parent, node);
+ last_all_or_none_node = node;
+ }
+ sit.SkipPort();
+ } else if (neighbor.node->AllInputsOrNone() && !port.IsInbound() &&
+ !port.IsControl()) {
+ if (parent->id().find(neighbor.node) == parent->id().end()) {
+ // Not added yet.
+ ExtendSubgraphAllOrNone(parent, neighbor.node);
+ }
+ } else if (node->IsMultiInput(port)) {
+ ExtendSubgraphPortAllOrNone(parent, node, port);
+ sit.SkipPort();
+ } else if (neighbor.node->IsMultiInput(neighbor.port)) {
+ // Would need to add all inputs of the neighbor node at this port at
+ // once.
+ if (parent->id().find(neighbor.node) != parent->id().end()) {
+ continue; // Already added.
+ }
+ ExtendSubgraphPortAllOrNone(parent, neighbor.node, neighbor.port);
+ } else {
+ Subgraph* sg = sg_set.ExtendParent(parent->id(), neighbor.node);
+ if (!will_complete && sg != nullptr) {
+ todo_.push_back(sg);
+ }
+ }
+ }
+}
+
+void GraphAnalyzer::ExtendSubgraphAllOrNone(Subgraph* parent,
+ const GenNode* node) {
+ Subgraph::Identity id = parent->id();
+ id.insert(node);
+
+ auto range_end = node->links().end();
+
+ for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) {
+ auto port = nbit->first;
+ if (!port.IsInbound() || port.IsControl()) {
+ continue;
+ }
+
+ // Since there might be multiple links to the same nodes,
+ // have to add all links one-by-one to check whether the subgraph
+ // would grow too large. But if it does grow too large, there is no
+ // point in growing it more, can just skip over the rest of the links.
+ for (const auto& link : nbit->second) {
+ id.insert(link.node);
+ if (id.size() > subgraph_size_) {
+ return; // Too big.
+ }
+ }
+ }
+
+ AddExtendedSubgraph(parent, id);
+}
+
+void GraphAnalyzer::ExtendSubgraphPortAllOrNone(Subgraph* parent,
+ const GenNode* node,
+ GenNode::Port port) {
+ auto nbit = node->links().find(port);
+ if (nbit == node->links().end()) {
+ return; // Should never happen.
+ }
+
+ Subgraph::Identity id = parent->id();
+ id.insert(node);
+
+ // Since there might be multiple links to the same nodes,
+ // have to add all links one-by-one to check whether the subgraph
+ // would grow too large. But if it does grow too large, there is no
+ // point in growing it more, can just skip over the rest of the links.
+ for (const auto& link : nbit->second) {
+ id.insert(link.node);
+ if (id.size() > subgraph_size_) {
+ return; // Too big.
+ }
+ }
+
+ AddExtendedSubgraph(parent, id);
+}
+
+void GraphAnalyzer::AddExtendedSubgraph(Subgraph* parent,
+ const Subgraph::Identity& id) {
+ if (id.size() == parent->id().size()) {
+ return; // Nothing new was added.
+ }
+
+ auto sg = absl::make_unique<Subgraph>(id);
+ SubgraphPtrSet& spec_sg_set =
+ (id.size() == subgraph_size_) ? result_ : partial_;
+ if (spec_sg_set.find(sg) != spec_sg_set.end()) {
+ // This subgraph was already found by extending from a different path.
+ return;
+ }
+
+ if (id.size() != subgraph_size_) {
+ todo_.push_back(sg.get());
+ }
+ spec_sg_set.insert(std::move(sg));
+}
+
+void GraphAnalyzer::DropInvalidSubgraphs() {
+ auto resit = result_.begin();
+ while (resit != result_.end()) {
+ if (HasInvalidMultiInputs(resit->get())) {
+ auto delit = resit;
+ ++resit;
+ result_.erase(delit);
+ } else {
+ ++resit;
+ }
+ }
+}
+
+bool GraphAnalyzer::HasInvalidMultiInputs(Subgraph* sg) {
+ // Do the all-or-none-input nodes.
+ for (auto const& node : sg->id()) {
+ if (!node->AllInputsOrNone()) {
+ continue;
+ }
+
+ bool anyIn = false;
+ bool anyOut = false;
+
+ auto range_end = node->links().end();
+ for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) {
+ auto port = nbit->first;
+ if (!port.IsInbound() || port.IsControl()) {
+ continue;
+ }
+
+ // Since there might be multiple links to the same nodes,
+ // have to add all links one-by-one to check whether the subgraph
+ // would grow too large. But if it does grow too large, there is no
+ // point in growing it more, can just skip over the rest of the links.
+ for (const auto& link : nbit->second) {
+ if (sg->id().find(link.node) == sg->id().end()) {
+ anyOut = true;
+ } else {
+ anyIn = true;
+ }
+ }
+ }
+
+ if (anyIn && anyOut) {
+ return true;
+ }
+ }
+
+ // Do the multi-input ports.
+ for (SubgraphIterator sit(sg); !sit.AtEnd(); sit.Next()) {
+ if (sit.GetNode()->IsMultiInput(sit.GetPort())) {
+ bool anyIn = false;
+ bool anyOut = false;
+ do {
+ GenNode* peer = sit.GetNeighbor().node;
+ if (sg->id().find(peer) == sg->id().end()) {
+ anyOut = true;
+ } else {
+ anyIn = true;
+ }
+ } while (sit.NextIfSamePort());
+
+ if (anyIn && anyOut) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+Status GraphAnalyzer::CollateResult() {
+ ordered_collation_.clear();
+ collation_map_.clear();
+
+ // Collate by the signatures of the graphs.
+ for (const auto& it : result_) {
+ auto sig = absl::make_unique<Signature>();
+ it->ExtractForSignature(&sig->map);
+ Status status = sig->Compute();
+ if (!status.ok()) {
+ return status;
+ }
+
+ auto& coll_entry = collation_map_[sig.get()];
+ if (coll_entry.sig == nullptr) {
+ coll_entry.sig = std::move(sig);
+ }
+ ++coll_entry.count;
+ }
+
+ // Then order them by the count.
+ for (auto& entry : collation_map_) {
+ ordered_collation_.insert(&entry.second);
+ }
+
+ result_.clear(); // Not needed after collation.
+
+ return Status::OK();
+}
+
+std::vector<string> GraphAnalyzer::DumpRawSubgraphs() {
+ std::vector<string> result;
+ for (const auto& it : result_) {
+ result.emplace_back(it->Dump());
+ }
+ return result;
+}
+
+std::vector<string> GraphAnalyzer::DumpSubgraphs() {
+ std::vector<string> result;
+ for (auto ptr : ordered_collation_) {
+ result.emplace_back(
+ absl::StrFormat("%d %s", ptr->count, ptr->sig->ToString()));
+ }
+ return result;
+}
+
+Status GraphAnalyzer::OutputSubgraphs() {
+ size_t total = 0;
+ for (auto ptr : ordered_collation_) {
+ std::cout << ptr->count << ' ' << ptr->sig->ToString() << '\n';
+ total += ptr->count;
+ }
+ std::cout << "Total: " << total << '\n';
+ if (std::cout.fail()) {
+ return Status(error::DATA_LOSS, "Failed to write to stdout");
+ } else {
+ return Status::OK();
+ }
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h
new file mode 100644
index 0000000000..26d38a4931
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h
@@ -0,0 +1,154 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/grappler/graph_analyzer/map_tools.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+namespace test {
+class GraphAnalyzerTest;
+} // end namespace test
+
+// Finds all the subgraphs of a given size and groups them by equivalence.
+class GraphAnalyzer {
+ public:
+ // Makes a copy of the graph.
+ GraphAnalyzer(const GraphDef& graph, int subgraph_size);
+
+ virtual ~GraphAnalyzer();
+
+ // Performs the analysis and collects the subgraphs.
+ Status Run();
+
+ // Returns the subgraphs found in Run() printed to text.
+ std::vector<string> DumpSubgraphs();
+
+ // Prints the subgraphs found in Run() to stdout.
+ Status OutputSubgraphs();
+
+ // TODO(babkin): add a way to extract the subgraphs as direct data
+ // structures and as protobufs, and to write protobufs to a RecordIO.
+
+ private:
+ GraphAnalyzer() = delete;
+ GraphAnalyzer(const GraphAnalyzer&) = delete;
+ void operator=(const GraphAnalyzer&) = delete;
+
+ friend class tensorflow::grappler::graph_analyzer::test::GraphAnalyzerTest;
+
+ // Builds the map of nodes from the original graph definition.
+ Status BuildMap();
+
+ // Using nodes_, finds all the subgraphs of size subgraph_size_ and places
+ // them into result_.
+ void FindSubgraphs();
+
+ // Deletes from result_ the unacceptable subgraphs. Those include the
+ // subgraphs where not all the inputs at a multi-input port are included (this
+ // could happen if some of these inputs were reached and included through
+ // different paths).
+ void DropInvalidSubgraphs();
+
+ // Deletes from result_ duplicate entries of equivalent topology.
+ Status CollateResult();
+
+ // Returns the raw subgraphs found in FindSubgraphs() printed to text.
+ std::vector<string> DumpRawSubgraphs();
+
+ // Finds and adds appropriately to either partial_ or result_ all the
+ // subgraphs that can be created by extending the parent subgraph by one node.
+ // Ignores the duplicates.
+ void ExtendSubgraph(Subgraph* parent);
+
+ // Extends the parent subgraph by adding another node (if it wasn't already
+ // added) and all its non-control inputs in the link map range at once.
+ // If the subgraph would grow over subgraph_size_, it gets ignored.
+ void ExtendSubgraphAllOrNone(Subgraph* parent, const GenNode* node);
+ // Same but adds one specific inbound port (even control) all-or-none.
+ void ExtendSubgraphPortAllOrNone(Subgraph* parent, const GenNode* node,
+ GenNode::Port port);
+ // The common final step called by ExtendSubgraph*AllOrNone() methods.
+ void AddExtendedSubgraph(Subgraph* parent, const Subgraph::Identity& id);
+
+ // Returns true if this subgraph has any multi-inputs that aren't all-in or
+ // all-out.
+ bool HasInvalidMultiInputs(Subgraph* sg);
+
+ // Graph to run the analysis on.
+ GraphDef graph_;
+ int subgraph_size_;
+
+ // The enriched graph of parsed nodes and connections.
+ GenNodeMap nodes_;
+ // The resulting set of subgraphs.
+ SubgraphPtrSet result_;
+ // The subgraphs of partial size, stored while finding the result.
+ SubgraphPtrSet partial_;
+ // The subgraphs of partial size (stored in partial_) that are still waiting
+ // to be extended.
+ //
+ // TODO(babkin): This is rather simple-minded, each subgraph is examined from
+ // scratch, which means that all its internal links get iterated too. But it's
+ // OK for the small subgraphs. This can be improved by keeping not just
+ // subgraphs but iterators on the list, each of them having the list not-yet
+ // examined nodes (and the link position of the next link to be examined for
+ // the first node). This would add extra constant overhead, so the break-even
+ // subgraph size is not clear yet.
+ std::deque<Subgraph*> todo_;
+
+ // The collation map by signature is designed to allow the removal of entries
+ // and moving of the signature references from the keys of this map to the
+ // outside world. Must be careful at inserting and removal: make sure that
+ // when a new entry is inserted, its signature reference gets populated with
+ // the same data as the key of the map, and that if a reference is moved out,
+ // the map entry gets removed before that reference gets destroyed.
+ struct CollationEntry {
+ std::shared_ptr<Signature> sig;
+ size_t count = 0;
+ };
+ using CollationMap =
+ std::unordered_map<Signature*, CollationEntry, HashAtPtr<Signature*>,
+ EqAtPtr<Signature*> >;
+ CollationMap collation_map_;
+
+ // The entries are owned by collation_map_, so must be removed from
+ // ordered_collation_ before removing them from collation_map_.
+ struct ReverseLessByCount {
+ bool operator()(CollationEntry* left, CollationEntry* right) {
+ return left->count > right->count; // Reverse order.
+ }
+ };
+ using CollationOrderByCount =
+ std::multiset<CollationEntry*, ReverseLessByCount>;
+ CollationOrderByCount ordered_collation_;
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc
new file mode 100644
index 0000000000..e94c472056
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc
@@ -0,0 +1,569 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
+
+#include <algorithm>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Ne;
+using ::testing::SizeIs;
+using ::testing::UnorderedElementsAre;
+
+class GraphAnalyzerTest : public ::testing::Test, protected TestGraphs {
+ protected:
+ Status BuildMap() { return gran_->BuildMap(); }
+
+ void FindSubgraphs() { gran_->FindSubgraphs(); }
+
+ void DropInvalidSubgraphs() { gran_->DropInvalidSubgraphs(); }
+
+ Status CollateResult() { return gran_->CollateResult(); }
+
+ void ExtendSubgraph(Subgraph* parent) { gran_->ExtendSubgraph(parent); }
+
+ void ExtendSubgraphPortAllOrNone(Subgraph* parent, GenNode* node,
+ GenNode::Port port) {
+ gran_->ExtendSubgraphPortAllOrNone(parent, node, port);
+ }
+
+ void ExtendSubgraphAllOrNone(Subgraph* parent, GenNode* node) {
+ gran_->ExtendSubgraphAllOrNone(parent, node);
+ }
+
+ std::vector<string> DumpRawSubgraphs() { return gran_->DumpRawSubgraphs(); }
+
+ std::vector<string> DumpPartials() {
+ std::vector<string> result;
+ for (const auto& it : gran_->partial_) {
+ result.emplace_back(it->Dump());
+ }
+ return result;
+ }
+
+ const GenNodeMap& GetNodes() { return gran_->nodes_; }
+
+ GenNode* GetNode(const string& name) { return gran_->nodes_.at(name).get(); }
+
+ SubgraphPtrSet& GetResult() { return gran_->result_; }
+ SubgraphPtrSet& GetPartial() { return gran_->partial_; }
+ std::deque<Subgraph*>& GetTodo() { return gran_->todo_; }
+
+ // Gets initialized by a particular test from a suitable GraphDef.
+ std::unique_ptr<GraphAnalyzer> gran_;
+};
+
+TEST_F(GraphAnalyzerTest, BuildMap) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 1);
+ Status st = BuildMap();
+ EXPECT_THAT(st, Eq(Status::OK()));
+
+ auto& map = GetNodes();
+ EXPECT_THAT(map.find("node1"), Ne(map.end()));
+ EXPECT_THAT(map.find("node2"), Ne(map.end()));
+ EXPECT_THAT(map.find("node3"), Ne(map.end()));
+}
+
+TEST_F(GraphAnalyzerTest, BuildMapError) {
+ // A duplicate node.
+ (*graph_3n_self_control_.add_node()) = MakeNodeConst("node1");
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 1);
+ Status st = BuildMap();
+ ASSERT_THAT(
+ st, Eq(Status(error::INVALID_ARGUMENT, "Duplicate node name 'node1'.")));
+}
+
+TEST_F(GraphAnalyzerTest, FindSubgraphs0) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 0);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ FindSubgraphs();
+ auto& subgraphs = GetResult();
+ EXPECT_THAT(subgraphs, SizeIs(0));
+ EXPECT_THAT(DumpRawSubgraphs(), ElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+TEST_F(GraphAnalyzerTest, FindSubgraphs1) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 1);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ FindSubgraphs();
+ auto& subgraphs = GetResult();
+ EXPECT_THAT(subgraphs, SizeIs(3));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: BroadcastGradientArgs(node3)",
+ "1: Const(node1)",
+ "1: Sub(node2)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// The required subgraphs are larger than the graph.
+TEST_F(GraphAnalyzerTest, FindSubgraphsTooLarge) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ FindSubgraphs();
+ EXPECT_THAT(DumpRawSubgraphs(), ElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+//===
+
+// Successfully propagate backwards through a multi-input link,
+// with the base (currently-extending) node already in the graph.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsBaseIn) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards through a multi-input link,
+// with the base (currently-extending) node not in the graph yet.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsBaseOut) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto parent = absl::make_unique<Subgraph>(Subgraph::Identity());
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraphPortAllOrNone(parent.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards through a multi-input link,
+// where the target subgraph size is larger.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsIncomplete) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 5);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ // clang-format off
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// Propagate backwards through a multi-input link, finding that the
+// resulting subgraph would be too large.
+TEST_F(GraphAnalyzerTest, MultiInputTooLargeBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 3);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Propagate backwards through a multi-input link, finding that nothing
+// would be added to the parent subgraph.
+TEST_F(GraphAnalyzerTest, MultiInputNothingAddedBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root = absl::make_unique<Subgraph>(
+ Subgraph::Identity({GetNode("add2"), GetNode("const2_1"),
+ GetNode("const2_2"), GetNode("const2_3")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate forwards through a multi-input link,
+// with the base (currently-extending) node not in the subgraph yet.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessForwardsBaseOut) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards through a multi-input link.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsFull) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraph(root.get());
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: AddN(add2), Sub(sub)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// Successfully propagate forwards through a multi-input link.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessForwardsFull) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")}));
+
+ ExtendSubgraph(root.get());
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+TEST_F(GraphAnalyzerTest, DropInvalidSubgraphsMulti) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 3);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ // A good one, multi-input is all-in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("const1_1"),
+ GetNode("const1_2"),
+ GetNode("add1"),
+ })));
+ // A good one, multi-input is all-out
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("add1"),
+ GetNode("add2"),
+ GetNode("sub"),
+ })));
+ // A bad one, multi-input is partially in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("const1_1"),
+ GetNode("add1"),
+ GetNode("sub"),
+ })));
+ // A bad one, multi-input is partially in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("add2"),
+ GetNode("const2_1"),
+ GetNode("const2_2"),
+ })));
+
+ DropInvalidSubgraphs();
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add1), AddN(add2), Sub(sub)",
+ "1: AddN(add1), Const(const1_1), Const(const1_2)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+//===
+
+// Successfully propagate backwards through a multi-input link,
+// with the base (currently-extending) node already in the graph.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass2")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass2"));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards through a multi-input link,
+// but no control links propagate. It also tests the situation
+// where the target subgraph size is larger.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwardsNoControl) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 5);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass1")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass1"));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: Const(const1_1), Const(const1_2), IdentityN(pass1)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// The control links propagate separately as all-or-none, even on the nodes
+// that are all-or-none for the normal inputs.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSeparateControl) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 5);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass1")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("pass1"),
+ GenNode::Port(true, -1));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass1)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// Propagate backwards from all-or-none-input node, finding that the
+// resulting subgraph would be too large.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputTooLargeBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 3);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass2")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass2"));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Propagate backwards from all-or-none-input node, finding that nothing
+// would be added to the parent subgraph.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputNothingAddedBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root = absl::make_unique<Subgraph>(
+ Subgraph::Identity({GetNode("pass2"), GetNode("const2_1"),
+ GetNode("const2_2"), GetNode("const2_3")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass2"));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate forwards to all-or-none-input node,
+// with the base (currently-extending) node not in the subgraph yet.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessForwardsBaseOut) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass2"));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards from all-or-none-input node.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwardsFull) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass2")}));
+
+ ExtendSubgraph(root.get());
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: IdentityN(pass2), Sub(sub)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// Successfully propagate forwards to all-or-none-input node. This includes
+// both all-or-none-input for the normal inputs, and multi-input by the
+// control path.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessForwardsFull) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")}));
+
+ ExtendSubgraph(root.get());
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)",
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass1)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+TEST_F(GraphAnalyzerTest, DropInvalidSubgraphsAllOrNone) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 3);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ // A good one, all-or-none is all-in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("const1_1"),
+ GetNode("const1_2"),
+ GetNode("pass1"),
+ })));
+ // A good one, all-or-none is all-out
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("pass1"),
+ GetNode("pass2"),
+ GetNode("sub"),
+ })));
+ // A bad one, all-or-none is partially in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("const1_1"),
+ GetNode("pass1"),
+ GetNode("sub"),
+ })));
+ // A bad one, all-or-none is partially in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("pass2"),
+ GetNode("const2_1"),
+ GetNode("const2_2"),
+ })));
+
+ DropInvalidSubgraphs();
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: IdentityN(pass1), IdentityN(pass2), Sub(sub)",
+ "1: Const(const1_1), Const(const1_2), IdentityN(pass1)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc
new file mode 100644
index 0000000000..924ca11e61
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+// Dies on failure.
+static void LoadModel(const string& filename,
+ tensorflow::MetaGraphDef* metagraph) {
+ LOG(INFO) << "Loading model from " << filename;
+ Status st;
+ st = ReadBinaryProto(Env::Default(), filename, metagraph);
+ if (!st.ok()) {
+ LOG(WARNING) << "Failed to read a binary metagraph: " << st;
+ st = ReadTextProto(Env::Default(), filename, metagraph);
+ if (!st.ok()) {
+ LOG(FATAL) << "Failed to read a text metagraph: " << st;
+ }
+ }
+}
+
+// Prune the graph to only keep the transitive fanin part with respect to a set
+// of train ops (if provided).
+void MaybePruneGraph(const tensorflow::MetaGraphDef& metagraph,
+ tensorflow::GraphDef* graph) {
+ std::vector<string> fetch_nodes;
+ for (const auto& fetch :
+ metagraph.collection_def().at("train_op").node_list().value()) {
+ LOG(INFO) << "Fetch node: " << fetch;
+ fetch_nodes.push_back(fetch);
+ }
+ if (fetch_nodes.empty()) {
+ *graph = metagraph.graph_def();
+ } else {
+ std::vector<const tensorflow::NodeDef*> fanin_nodes =
+ tensorflow::grappler::ComputeTransitiveFanin(metagraph.graph_def(),
+ fetch_nodes);
+ for (const tensorflow::NodeDef* node : fanin_nodes) {
+ *(graph->add_node()) = *node;
+ }
+ LOG(INFO) << "Pruned "
+ << metagraph.graph_def().node_size() - graph->node_size()
+ << " nodes. Original graph size: "
+ << metagraph.graph_def().node_size()
+ << ". New graph size: " << graph->node_size() << ".";
+ }
+}
+
+void GraphAnalyzerTool(const string& file_name, int n) {
+ if (n < 1) {
+ LOG(FATAL) << "Invalid subgraph size " << n << ", must be at least 1";
+ }
+
+ tensorflow::MetaGraphDef metagraph;
+ LoadModel(file_name, &metagraph);
+ tensorflow::GraphDef graph;
+ MaybePruneGraph(metagraph, &graph);
+ tensorflow::grappler::graph_analyzer::GraphAnalyzer analyzer(graph, n);
+ LOG(INFO) << "Running the analysis";
+ tensorflow::Status st = analyzer.Run();
+ if (!st.ok()) {
+ LOG(FATAL) << "Analysis failed: " << st;
+ }
+
+ LOG(INFO) << "Printing the result";
+ st = analyzer.OutputSubgraphs();
+ if (!st.ok()) {
+ LOG(FATAL) << "Failed to print the result: " << st;
+ }
+
+ LOG(INFO) << "Completed";
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h
new file mode 100644
index 0000000000..5a91fe7dc8
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h
@@ -0,0 +1,31 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_
+
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+void GraphAnalyzerTool(const string& file_name, int n);
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/hash_tools.h b/tensorflow/core/grappler/graph_analyzer/hash_tools.h
new file mode 100644
index 0000000000..b0e79f9a68
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/hash_tools.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_
+
+#include <cstddef>
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+// Unfortunately, std::hash provides no way to combine hashes, so everyone
+// is copying boost::hash_combine. This is a version that follows Google's
+// guidelines on the arguments, and contains only the combination, without
+// hashing.
+inline void CombineHash(size_t from, size_t* to) {
+ *to ^= from + 0x9e3779b9 + (*to << 6) + (*to >> 2);
+}
+
+// Combine two hashes in such a way that the order of combination doesn't matter
+// (so it's really both commutative and associative). The result is not a very
+// high-quality hash but can be used in case if the order of sub-elements must
+// not matter in the following comparison. An alternative would be to sort the
+// hashes of the sub-elements and then combine them normally in the sorted
+// order.
+inline void CombineHashCommutative(size_t from, size_t* to) {
+ *to = *to + from + 0x9e3779b9;
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc b/tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc
new file mode 100644
index 0000000000..b5e9ce6b8e
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+namespace {
+
+using ::testing::Eq;
+
+TEST(HashToolsTest, CombineHashCommutative) {
+ size_t a = 0;
+ size_t b = 999;
+
+ size_t c = a;
+ CombineHashCommutative(b, &c);
+
+ size_t d = b;
+ CombineHashCommutative(a, &d);
+
+ EXPECT_THAT(c, Eq(d));
+}
+
+} // namespace
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/map_tools.h b/tensorflow/core/grappler/graph_analyzer/map_tools.h
new file mode 100644
index 0000000000..584062c5f2
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/map_tools.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_
+
+#include <functional>
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+// Helpers for building maps of pointers.
+
+template <typename Ptr>
+struct LessAtPtr : std::binary_function<Ptr, Ptr, bool> {
+ bool operator()(const Ptr& x, const Ptr& y) const { return *x < *y; }
+};
+
+template <typename Ptr>
+struct EqAtPtr : std::binary_function<Ptr, Ptr, bool> {
+ bool operator()(const Ptr& x, const Ptr& y) const { return *x == *y; }
+};
+
+template <typename Ptr>
+struct HashAtPtr : std::unary_function<Ptr, size_t> {
+ size_t operator()(const Ptr& x) const { return x->Hash(); }
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node.cc b/tensorflow/core/grappler/graph_analyzer/sig_node.cc
new file mode 100644
index 0000000000..b5cca6a512
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/sig_node.cc
@@ -0,0 +1,453 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+
+#include <algorithm>
+
+#include "absl/strings/str_format.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+static constexpr bool debug = false;
+
+//=== SigNode
+
+SigNode::SigNode(const NodeDef* node) : node_(node) {}
+
+void SigNode::CopyLinks(const GenNode& from, const TranslationMap& map) {
+ hash_to_link_.clear();
+ hashed_peers_.clear();
+
+ std::map<LinkTag, Link> link_map;
+ CopyLinksPass1(from, map, &link_map);
+ CopyLinksPass2(&link_map);
+}
+
+void SigNode::CopyLinksPass1(const GenNode& from, const TranslationMap& map,
+ std::map<LinkTag, Link>* link_map) {
+ LinkTag::Hasher link_hasher;
+
+ for (const auto& entry : from.links()) {
+ for (const auto& target : entry.second) {
+ auto nodeit = map.find(target.node);
+ if (nodeit == map.end()) {
+ // Node is not in the subgraph, ignore.
+ continue;
+ }
+
+ LinkTag tag(entry.first, target.port);
+ size_t hval = link_hasher(tag);
+
+ // This instantiates the entry if it was not present.
+ Link& map_entry = (*link_map)[tag];
+ if (map_entry.peers.empty()) {
+ map_entry.tag = tag;
+ map_entry.unique_hash = hval;
+ }
+ map_entry.peers.push_back(nodeit->second);
+ }
+ }
+}
+
+void SigNode::CopyLinksPass2(std::map<LinkTag, Link>* link_map) {
+ for (auto& entry : *link_map) {
+ Link* hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
+ // In case of a conflict, rehash. This should almost never happen.
+ // Because the order of iteration is predictable, the rehashed values
+ // will also be predictable.
+ while (!hl_entry_ptr->peers.empty()) {
+ CombineHash(1, &entry.second.unique_hash);
+ hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
+ }
+
+ for (const auto& peer : entry.second.peers) {
+ hashed_peers_.emplace_back(HashedPeer(entry.second.unique_hash, peer));
+ }
+
+ hl_entry_ptr->tag = entry.second.tag;
+ hl_entry_ptr->unique_hash = entry.second.unique_hash;
+ hl_entry_ptr->peers.swap(entry.second.peers);
+ }
+}
+
+void SigNode::ComputeTopoHash0() {
+ topo_hash_.clear();
+ last_hashed_nodes_ = next_hashed_nodes_ = node_mask_;
+
+ // TODO(babkin): include the attrbutes too, as an option.
+ size_t hval = std::hash<string>()(opcode());
+
+ // Getting the topology of the links in to the hash early should get more
+ // conflicts resolved early.
+ for (const auto& entry : hashed_peers_) {
+ CombineHash(entry.link_hash, &hval);
+ }
+
+ topo_hash_.push_back(hval);
+}
+
+void SigNode::ComputeTopoHash(int distance) {
+ // The new starting point.
+ next_hashed_nodes_ = last_hashed_nodes_;
+ if (debug) {
+ LOG(INFO) << "DEBUG node " << name() << " mask=" << std::hex
+ << next_hashed_nodes_;
+ }
+
+ if (hash_is_final_) {
+ return;
+ }
+
+ CHECK(topo_hash_.size() == distance);
+
+ int prev = distance - 1;
+
+ // Start with own's local topology hash. This value is stable, so
+ // if the hashes of the surrounding nodes don't change on the following
+ // distances, the hash of this node won't change either.
+ size_t hval = topo_hash_[0];
+
+ if (!hashed_peers_.empty()) {
+ size_t last_link_hash = hashed_peers_[0].link_hash;
+ size_t comm_hash = 0;
+
+ for (const auto& entry : hashed_peers_) {
+ if (entry.link_hash != last_link_hash) {
+ CombineHash(last_link_hash, &hval);
+ CombineHash(comm_hash, &hval);
+ comm_hash = 0;
+ last_link_hash = entry.link_hash;
+ }
+
+ // The links in the same vector are commutative, so combine their
+ // hashes in a commutative way.
+ CombineHashCommutative(entry.peer->GetTopoHash(prev), &comm_hash);
+ next_hashed_nodes_ |= entry.peer->last_hashed_nodes_;
+ if (debug) {
+ LOG(INFO) << "DEBUG node " << name() << " += " << entry.peer->name()
+ << " mask=" << std::hex << next_hashed_nodes_;
+ }
+ }
+
+ // The last commutative group.
+ CombineHash(last_link_hash, &hval);
+ CombineHash(comm_hash, &hval);
+ }
+
+ topo_hash_.push_back(hval);
+}
+
+size_t SigNode::GetTopoHash(int distance) const {
+ CHECK(!topo_hash_.empty());
+ if (distance >= topo_hash_.size()) {
+ CHECK(hash_is_final_);
+ return topo_hash_.back();
+ } else {
+ return topo_hash_[distance];
+ }
+}
+
+bool SigNode::operator==(const SigNode& other) const {
+ // TODO(babkin): add attributes too.
+ if (opcode() != other.opcode()) {
+ return false;
+ }
+
+ // Normally the caller is expected to compare the nodes
+ // at the same rank in different graphs, but just in case...
+ if (unique_rank_ != other.unique_rank_) {
+ return false;
+ }
+
+ if (hashed_peers_.size() != other.hashed_peers_.size()) {
+ return false;
+ }
+
+ for (auto it1 = hashed_peers_.begin(), it2 = other.hashed_peers_.begin();
+ it1 != hashed_peers_.end(); ++it1, ++it2) {
+ // TODO(babkin): might compare the actual values too
+ // but the hash is probably just as good.
+ if (it1->link_hash != it2->link_hash) {
+ return false;
+ }
+ if (it1->peer->unique_rank_ != it2->peer->unique_rank_) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+//=== Signature
+
+constexpr int Signature::kMaxGraphSize;
+
+string Signature::ToString() const {
+ string result;
+ for (size_t n = 0; n < nodes.size(); ++n) {
+ // TODO(babkin): add attributes too.
+ result += absl::StrFormat("%d:%s", n, nodes[n]->opcode());
+ for (const auto& entry : nodes[n]->hashed_peers_) {
+ const auto& link = nodes[n]->hash_to_link_[entry.link_hash];
+
+ // The link entries are already sorted, by tags and then by the
+ // node ranks.
+ if (link.tag.local.IsInbound()) {
+ result +=
+ absl::StrFormat("[%s:%s:%d]", string(link.tag.local),
+ string(link.tag.remote), entry.peer->unique_rank_);
+ }
+ }
+ result.push_back(',');
+ }
+ return result;
+}
+
+Status Signature::Compute() {
+ if (map.size() > kMaxGraphSize) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ absl::StrFormat(
+ "A graph of %d nodes is too big for signature computation, "
+ "the maximal supported node count is %d.",
+ map.size(), kMaxGraphSize));
+ }
+
+ // The value that will be assigned next as the unique node id.
+ // This also means that all the entries in nodes at indexes less than this
+ // have been finalized and don't need to be touched any more.
+ size_t next_node_id = 0;
+
+ sig_short = 0;
+ sig_full.resize(0); // Keep the storage.
+
+ // The main signature generation.
+ PrepareNodes();
+ FindUniqueHashes(&next_node_id);
+ while (next_node_id < map.size()) {
+ ComputeOneRound(next_node_id);
+ FindUniqueHashes(&next_node_id);
+ }
+
+ OrderLinks();
+
+ return Status::OK();
+}
+
+void Signature::PrepareNodes() {
+ nodes.resize(0); // Keep the storage.
+
+ // Initialize the nodes.
+ int64_t mask = 1;
+ for (const auto& entry : map) {
+ SigNode* node = entry.second.get();
+ node->last_hashed_nodes_ = node->node_mask_ = mask;
+ mask <<= 1;
+ node->unique_rank_ = ~0;
+ node->hash_is_final_ = false;
+ node->ComputeTopoHash0();
+ if (node->GetHighTopoHash() <= map.size()) {
+ // Would conflict with one of the reserved values.
+ node->ReHighTopoHash();
+ }
+
+ // The initial order is random.
+ nodes.emplace_back(node);
+ }
+}
+
+void Signature::FindUniqueHashes(size_t* next_node_id_p) {
+ // Start by sorting by the hash value.
+ std::sort(nodes.begin() + *next_node_id_p, nodes.end(),
+ SigNode::NodeOrderLess());
+
+ // At each call, if no nodes have unique hashes, one node that has a
+ // non-unique (shared) hash can be made unique by assigning a unique id.
+ // This node gets picked predictably by taking the last node.
+ // TODO(babkin): Technically, more than one node can be unshared,
+ // as long as their last_hashed_nodes_ overlap only by the nodes that
+ // already had the assigned ids before the current round. But it's not clear
+ // yet, how often would this beneficial, because it looks like for many
+ // subgraphs unsharing one node should be enough to untangle them. This
+ // would need more measurement before implementing.
+ bool found_unique = false;
+ for (size_t n = *next_node_id_p; n < nodes.size(); ++n) {
+ size_t cur_hash = nodes[n]->GetHighTopoHash();
+ if (n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash) {
+ // A sequence of nodes sharing the same hash. Skip over it.
+ // TODO(babkin): check here for the arbitrary hash conflicts and resolve
+ // them.
+ for (++n;
+ n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash;
+ ++n) {
+ }
+ if (found_unique || n != nodes.size() - 1) {
+ // Either some unique nodes have already been found, or this is
+ // not the last chance, keep trying to find the unique nodes.
+ continue;
+ }
+ // Here we're at the last node and haven't found any unique ones.
+ // So fall through and make this last node unique.
+ }
+
+ found_unique = true;
+ size_t id = (*next_node_id_p)++;
+ nodes[n]->unique_rank_ = id;
+
+ size_t last_hash = nodes[n]->GetHighTopoHash();
+ CombineHash(last_hash, &sig_short);
+ sig_full.push_back(last_hash);
+
+ // Take the hash at 0 and mix the unique rank into it. After that it will
+ // stay fixed.
+ nodes[n]->topo_hash_.resize(1);
+ nodes[n]->topo_hash_[0] = id + 1; // Avoid the value of 0.
+
+ nodes[n]->hash_is_final_ = true;
+ nodes[n]->last_hashed_nodes_ = nodes[n]->node_mask_;
+ if (n != id) {
+ std::swap(nodes[id], nodes[n]);
+ }
+ }
+}
+
+void Signature::ComputeOneRound(size_t next_node_id) {
+ // Reset the state of the nodes.
+ int debug_i = 0;
+ for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
+ auto node = *it;
+ // The hash at distance 0 never changes, so preserve it.
+ node->topo_hash_.resize(1);
+ node->last_hashed_nodes_ = node->node_mask_;
+ node->hash_is_final_ = false;
+ if (debug) {
+ LOG(INFO) << "DEBUG distance=" << 0 << " node " << debug_i++ << " "
+ << node->name() << " mask=" << std::hex
+ << node->last_hashed_nodes_;
+ }
+ }
+
+ bool stop = false;
+ // The distance can reach up to nodes.size()+1, to include not only all the
+ // nodes but also all the redundant paths.
+ for (int distance = 1; !stop; ++distance) {
+ for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
+ auto node = *it;
+ if (node->hash_is_final_) {
+ continue;
+ }
+ node->ComputeTopoHash(distance);
+ if (node->GetHighTopoHash() <= nodes.size()) {
+ // Would conflict with one of the reserved values.
+ node->ReHighTopoHash();
+ }
+ }
+
+ // Will be looking for the indications to not stop.
+ stop = true;
+
+ debug_i = 0;
+ // The bitmasks get moved after all the hash computations are done.
+ for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
+ auto node = *it;
+ if (debug) {
+ LOG(INFO) << "DEBUG distance=" << distance << " node " << debug_i++
+ << " " << node->name() << " oldmask=" << std::hex
+ << node->last_hashed_nodes_ << " mask=" << std::hex
+ << node->next_hashed_nodes_;
+ }
+ if (node->last_hashed_nodes_ == node->next_hashed_nodes_) {
+ // Stopped growing, this part of the graph must be fully
+ // surrounded by nodes that already have the unique ids.
+ node->hash_is_final_ = true;
+ } else {
+ node->last_hashed_nodes_ = node->next_hashed_nodes_;
+ stop = false;
+ }
+ }
+ }
+}
+
+void Signature::OrderLinks() {
+ for (const auto& node : nodes) {
+ if (node->hashed_peers_.empty()) {
+ continue;
+ }
+
+ size_t cur_link_hash = node->hashed_peers_[0].link_hash + 1;
+ int first_idx = -1;
+
+ int idx;
+ for (idx = 0; idx < node->hashed_peers_.size(); ++idx) {
+ auto& entry = node->hashed_peers_[idx];
+ if (entry.link_hash == cur_link_hash) {
+ continue;
+ }
+ if (idx - first_idx > 1) {
+ // Need to sort.
+ std::sort(node->hashed_peers_.begin() + first_idx,
+ node->hashed_peers_.begin() + idx,
+ SigNode::HashedPeer::LessByRank());
+ }
+
+ cur_link_hash = entry.link_hash;
+ first_idx = idx;
+ }
+ if (idx - first_idx > 1) {
+ // Sort the last bunch.
+ std::sort(node->hashed_peers_.begin() + first_idx,
+ node->hashed_peers_.begin() + idx,
+ SigNode::HashedPeer::LessByRank());
+ }
+ }
+}
+
+bool Signature::operator==(const Signature& other) const {
+ // Tries to find the differences as early as possible by
+ // comparing the hashes first.
+
+ if (sig_short != other.sig_short) {
+ return false;
+ }
+ if (sig_full.size() != other.sig_full.size()) {
+ return false;
+ }
+
+ for (auto it1 = sig_full.begin(), it2 = other.sig_full.begin();
+ it1 != sig_full.end(); ++it1, ++it2) {
+ if (*it1 != *it2) {
+ return false;
+ }
+ }
+
+ if (nodes.size() != other.nodes.size()) {
+ return false;
+ }
+ for (auto it1 = nodes.begin(), it2 = other.nodes.begin(); it1 != nodes.end();
+ ++it1, ++it2) {
+ if (**it1 != **it2) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node.h b/tensorflow/core/grappler/graph_analyzer/sig_node.h
new file mode 100644
index 0000000000..45c0ed3162
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/sig_node.h
@@ -0,0 +1,304 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+namespace test {
+class SigBaseTest;
+} // end namespace test
+
+class SigNode;
+
+// To find nodes by name. Having the map ordered makes the tests easier,
+// and it isn't used in production code often enough to get any win from
+// using an unordered map.
+using SigNodeMap = std::map<string, std::unique_ptr<SigNode>>;
+
+// One node in the graph, in the form convenient for generation of the signature
+// of the graph, and comparison of two (sub)graphs for equivalence. It refers to
+// the original NodeDef protobuf for most information and adds the extra
+// enrichment.
+//
+// The graph building is 2-stage: first match a SigNode with each NodeDef and
+// collect them into a map that finds them by name, then process the map,
+// deep-parse the underlying NodeDefs and connect the SigNodes together.
+class SigNode {
+ public:
+ friend struct Signature;
+
+ // Will keep the pointer to the underlying NodeDef, so that
+ // underlying object must not be deleted while SigNode is alive.
+ explicit SigNode(const NodeDef* node);
+
+ // Access wrappers.
+ const string& name() const { return node_->name(); }
+ const string& opcode() const { return node_->op(); }
+ const NodeDef* node_def() const { return node_; }
+
+ // For extraction of subgraphs into a separate SigNodeMap, copies the links
+ // that point inside the subgraph from a full-graph SigNode to a subgraph
+ // SigNode. The translation map defines the subgraph and gives the mapping
+ // from the nodes in the full graph to the matching nodes in subgraph.
+ using TranslationMap =
+ std::unordered_map<const GenNode* /*full_graph*/, SigNode* /*subgraph*/>;
+ void CopyLinks(const GenNode& from, const TranslationMap& map);
+
+ // A link is an edge of the graph that connects 2 nodes. Each of the connected
+ // nodes has its own perspective on the link, seeing its local port, remote
+ // port and the remote node. The direction of the link is encoded in the
+ // ports, one port is always incoming and another one outgoing.
+ //
+ // The link tag here contains both ports of the link viewed from the
+ // perspective of this node; consisting of both the local port (i.e. at this
+ // node) and remote port (i.e. on the other node), the local one going first.
+ struct LinkTag {
+ struct Hasher {
+ size_t operator()(const LinkTag& tag) const noexcept {
+ size_t hval = port_hasher(tag.local);
+ CombineHash(port_hasher(tag.remote), &hval);
+ return hval;
+ }
+ GenNode::Port::Hasher port_hasher;
+ };
+
+ LinkTag(GenNode::Port a_local, GenNode::Port a_remote)
+ : local(a_local), remote(a_remote) {}
+
+ // The default constructor is used for the default values in maps.
+ // (false, 99) is an arbitrary value that makes the uninitialized
+ // links easy to tell when debugging (they should never happen).
+ LinkTag() : local(false, 99), remote(false, 99) {}
+
+ // Port of the link on the local node.
+ GenNode::Port local;
+ // Port of the link on the remote node.
+ GenNode::Port remote;
+
+ bool operator==(const LinkTag& other) const {
+ return local == other.local && remote == other.remote;
+ }
+ bool operator<(const LinkTag& other) const {
+ return local < other.local ||
+ (local == other.local && remote < other.remote);
+ }
+ };
+
+ // Since the signature logic doesn't differentiate between the links
+ // with the same tag (other than by the "peer" nodes on their other ends),
+ // all the links with the same tag are grouped into a single structure.
+ struct Link {
+ LinkTag tag;
+ size_t unique_hash; // Hash of the tag after conflict resolution.
+ // The remote node(s) on the other side on the link(s).
+ using PeerVector = std::vector<SigNode*>;
+ PeerVector peers;
+ };
+
+ // A way to look up the link description by its hash.
+ using LinkHashMap = std::map<size_t, Link>;
+ const LinkHashMap& hash_to_link() const { return hash_to_link_; }
+
+ // The enumeration of all the peer nodes in a predictable order.
+ // Before the signature generation, only the link values determine the
+ // order, after the signature generation the entries at the same
+ // links get further sorted by their peer node ranks.
+ struct HashedPeer {
+ HashedPeer(size_t l, SigNode* p) : link_hash(l), peer(p) {}
+
+ struct LessByRank {
+ bool operator()(const SigNode::HashedPeer& left,
+ const SigNode::HashedPeer& right) {
+ return left.peer->unique_rank_ < right.peer->unique_rank_;
+ }
+ };
+
+ size_t link_hash;
+ SigNode* peer;
+ };
+ using HashedPeerVector = std::vector<HashedPeer>;
+ const HashedPeerVector& hashed_peers() const { return hashed_peers_; }
+
+ // Compares two nodes in two different graphs for equivalence (two nodes in
+ // the same graph would never be equivalent). Expects that the signatures of
+ // the graphs have already been computed, so unique_rank_ is filled in and
+ // the hashed_peers_ properly ordered.
+ bool operator==(const SigNode& other) const;
+
+ bool operator!=(const SigNode& other) const { return !(*this == other); }
+
+ private:
+ friend class test::SigBaseTest;
+
+ // The CopyLinks code is split into 2 parts for testability.
+ // The first pass builds a map ordered by LinkTag for predictability.
+ void CopyLinksPass1(const GenNode& from, const TranslationMap& map,
+ std::map<LinkTag, Link>* link_map);
+ // The second pass converts to the map by hash value,
+ // resolves any hash conflicts, and builds the hashed peer vector.
+ void CopyLinksPass2(std::map<LinkTag, Link>* link_map);
+
+ // Computes the topological hash at distance 0. Resets the topo_hash_ vector
+ // and hashed_nodes_;
+ void ComputeTopoHash0();
+
+ // Compute the topological has at the given distance. The hashes for all the
+ // lower distances must be already computed for all the nodes in the graph.
+ // Also computes next_hashed_nodes_ from last_hashed_nodes_.
+ void ComputeTopoHash(int distance);
+
+ // Get the hash value for a particular distance. It must be previously
+ // computed.
+ size_t GetTopoHash(int distance) const;
+
+ // The the hash value for the highest computed distance. It must be previously
+ // computed.
+ size_t GetHighTopoHash() const {
+ CHECK(!topo_hash_.empty());
+ return topo_hash_.back();
+ }
+
+ // Rehash the topmost hash, to avoid conflicts.
+ void ReHighTopoHash() {
+ CHECK(!topo_hash_.empty());
+ CombineHash(1, &topo_hash_.back());
+ }
+
+ // Ordering by node order and highest available hash (it must be
+ // previously computed).
+ struct NodeOrderLess {
+ bool operator()(const SigNode* left, const SigNode* right) {
+ return left->topo_hash_.back() < right->topo_hash_.back();
+ }
+ };
+
+ private:
+ const NodeDef* node_;
+
+ // The bitmap mask with 1 bit set that represents this node in the set
+ // during the computation of the signature.
+ uint64_t node_mask_ = 0;
+
+ // The code that populates this map makes sure that there are no hash
+ // conflicts, rehashing if necessary.
+ LinkHashMap hash_to_link_;
+
+ // The enumeration of all the direct peers in the predictable order (which
+ // happens to be the order ot their link tags, but the order of the hashes
+ // would do too). It is used for the quick enumeration during the signature
+ // computation. After the signature building is completed, the entries that
+ // have the same link tag get further sorted in the order of the ranks of
+ // their nodes.
+ HashedPeerVector hashed_peers_;
+
+ // The unique rank represents the order in which the node will be included
+ // into the signature. It gets assigned in order either when the topo_hash_ of
+ // this node becomes unique in the graph, or when the nodes are completely
+ // equivalent, one of them is picked at random to assign the next rank, and
+ // then the rest of the nodes attempt to disambiguate based on that
+ // information.
+ size_t unique_rank_ = ~0;
+ // When hash_is_final_ is set, the topo_has_ vector stops growing, and the
+ // last value from it is used for all the further hashes.
+ bool hash_is_final_ = false;
+ // The hashes that include the topology of the nodes up to the distance N. The
+ // hash for distance 0 is produced from the attributes of this node itself and
+ // its general connectivity properties but no information about the
+ // neighboring nodes. The hash for distance D+1 is build from hashes at level
+ // D of this node and of all its immediate neighbors. The neighbors that are
+ // connected by equivalent links are included in a commutative way.
+ std::vector<size_t> topo_hash_;
+ // The set of nodes that got included into the computation of the
+ // last topo_hash_ entry.
+ uint64_t last_hashed_nodes_ = 0;
+ // The next set of nodes that gets used for the current topo_hash entry.
+ uint64_t next_hashed_nodes_ = 0;
+};
+
+// Signature of a graph. The computation is intertwined with the private methods
+// of SigNode, so keeping both in the same file looks more convenient.
+struct Signature {
+ friend class test::SigBaseTest;
+
+ // Maximal size of the graphs for which the signature can be computed.
+ // Changing this constant won't magically add the support for a larger size,
+ // the rest of implementation would have to be extended. The value of 64 is
+ // driven by the size of a bitset in an uint64_t, and should be enough for our
+ // purposes, while having a high efficiency of implementation.
+ static constexpr int kMaxGraphSize = 64;
+
+ // Using the map, computes the rest of the fields of a signature.
+ // Returns an error is the graph is too big.
+ Status Compute();
+
+ // Convert the computed signature to a string representation.
+ string ToString() const;
+
+ SigNodeMap map; // The nodes in the graph, accessible by name.
+ size_t sig_short = 0; // Hash of the signature, for the quick equality check.
+ // The full signature: hashes of the nodes in a predictable order.
+ std::vector<size_t> sig_full;
+ // The nodes in the same order as they go in the signature.
+ std::vector<SigNode*> nodes;
+
+ // For building the unordered maps.
+ size_t Hash() const { return sig_short; }
+
+ // Returns true if the graphs are equivalent. The signature must be already
+ // computed.
+ bool operator==(const Signature& other) const;
+
+ private:
+ // Populates the nodes vector from the map and initializes the state of the
+ // nodes for the signature computation.
+ void PrepareNodes();
+
+ // Finds the nodes with the hashes that are unique and assigns the unique ids
+ // to them. If there are nodes with non-unique hashes, exactly one node from
+ // the first such sequence (in the order of hash values) will be picked and
+ // assigned a unique id. Assumes that the nodes[0...(next_node_id-1)] have
+ // been already assigned the unique ids. Advances next_node_id by at least 1.
+ void FindUniqueHashes(size_t* next_node_id_p);
+
+ // One round of the signature computation. Assumes that the
+ // nodes[0...(next_node_id-1)] have been already assigned the fixed
+ // positions, and thus computes the hashes only for the remaining nodes.
+ void ComputeOneRound(size_t next_node_id);
+
+ // Additional ordering of the hashed_peers_ links in the nodes, so that they
+ // can be compared and printed in a predictable order.
+ void OrderLinks();
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc b/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc
new file mode 100644
index 0000000000..4c6a9ba9e0
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc
@@ -0,0 +1,1235 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+#include "tensorflow/core/grappler/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Gt;
+using ::testing::Ne;
+using ::testing::SizeIs;
+
+//===
+
+TEST(SigNodeLinkTag, Compare) {
+ SigNode::LinkTag a(GenNode::Port(false, 1), GenNode::Port(false, 2));
+ SigNode::LinkTag b(GenNode::Port(false, 1), GenNode::Port(false, 2));
+ SigNode::LinkTag c(GenNode::Port(false, 2), GenNode::Port(false, 1));
+ SigNode::LinkTag d(GenNode::Port(false, 1), GenNode::Port(false, 3));
+ SigNode::LinkTag e(GenNode::Port(false, 2), GenNode::Port(false, 2));
+
+ EXPECT_TRUE(a == b);
+ EXPECT_FALSE(a == c);
+ EXPECT_FALSE(a == e);
+
+ EXPECT_FALSE(a < b);
+ EXPECT_FALSE(b < a);
+
+ EXPECT_TRUE(a < c);
+ EXPECT_FALSE(c < a);
+
+ EXPECT_TRUE(a < d);
+ EXPECT_FALSE(d < a);
+}
+
+//===
+
+class SigBaseTest : public ::testing::Test, protected TestGraphs {
+ protected:
+ void BuildSigMap(const GraphDef& graph) {
+ gen_map_.clear();
+ sig_.map.clear();
+ CHECK(GenNode::BuildGraphInMap(graph, &gen_map_).ok());
+ Subgraph::Identity id;
+ for (const auto& entry : gen_map_) {
+ id.insert(entry.second.get());
+ }
+ Subgraph sg(id);
+ sg.ExtractForSignature(&sig_.map);
+ }
+
+ static void CopyLinksPass2(
+ std::map<SigNode::LinkTag, SigNode::Link>* link_map, SigNode* node) {
+ node->CopyLinksPass2(link_map);
+ }
+
+ static void ComputeTopoHash0(SigNode* node) { node->ComputeTopoHash0(); }
+
+ static void ComputeTopoHash(int distance, SigNode* node) {
+ node->ComputeTopoHash(distance);
+ }
+
+ static size_t GetTopoHash(int distance, SigNode* node) {
+ return node->GetTopoHash(distance);
+ }
+
+ static size_t GetHighTopoHash(SigNode* node) {
+ return node->GetHighTopoHash();
+ }
+
+ static void ReHighTopoHash(SigNode* node) { node->ReHighTopoHash(); }
+
+ static SigNode::HashedPeerVector& RefHashedPeers(SigNode* node) {
+ return node->hashed_peers_;
+ }
+ static size_t& RefUniqueRank(SigNode* node) { return node->unique_rank_; }
+ static bool& RefHashIsFinal(SigNode* node) { return node->hash_is_final_; }
+ static std::vector<size_t>& RefTopoHash(SigNode* node) {
+ return node->topo_hash_;
+ }
+ static uint64_t& RefNodeMask(SigNode* node) { return node->node_mask_; }
+ static uint64_t& RefLastHashedNodes(SigNode* node) {
+ return node->last_hashed_nodes_;
+ }
+ static uint64_t& RefNextHashedNodes(SigNode* node) {
+ return node->next_hashed_nodes_;
+ }
+
+ static void PrepareNodes(Signature* signature) { signature->PrepareNodes(); }
+
+ static void FindUniqueHashes(size_t* next_node_id_p, Signature* signature) {
+ signature->FindUniqueHashes(next_node_id_p);
+ }
+
+ static void ComputeOneRound(size_t next_node_id, Signature* signature) {
+ signature->ComputeOneRound(next_node_id);
+ }
+
+ static void OrderLinks(Signature* signature) { signature->OrderLinks(); }
+
+ // These get initialized in BuildSigMap().
+ GenNodeMap gen_map_;
+ Signature sig_;
+};
+
+//===
+
+class SigNodeTest : public SigBaseTest {};
+
+// Tests that the duplicate hashes get resolved by rehashing.
+TEST_F(SigNodeTest, DuplicateHash) {
+ NodeDef node1 = MakeNodeConst("node1");
+ NodeDef node2 = MakeNodeConst("node2");
+ NodeDef node3 = MakeNodeShapeN("node3", "node1", "node2");
+
+ SigNode sn1(&node1);
+ SigNode sn2(&node2);
+ SigNode sn3(&node3);
+
+ constexpr size_t kSameHash = 999;
+
+ SigNode::Link link1;
+ link1.tag = SigNode::LinkTag(GenNode::Port(true, 0), GenNode::Port(false, 0));
+ link1.unique_hash = kSameHash;
+ link1.peers.emplace_back(&sn1);
+
+ SigNode::Link link2;
+ link2.tag = SigNode::LinkTag(GenNode::Port(true, 1), GenNode::Port(false, 0));
+ link2.unique_hash = kSameHash;
+ link2.peers.emplace_back(&sn2);
+
+ SigNode::Link link3;
+ link3.tag = SigNode::LinkTag(GenNode::Port(true, 2), GenNode::Port(false, 0));
+ link3.unique_hash = kSameHash;
+ link3.peers.emplace_back(&sn3);
+
+ std::map<SigNode::LinkTag, SigNode::Link> link_map;
+ link_map[link1.tag] = link1;
+ link_map[link2.tag] = link2;
+ link_map[link3.tag] = link3;
+
+ CopyLinksPass2(&link_map, &sn3);
+ auto& hl = sn3.hash_to_link();
+ EXPECT_THAT(hl, SizeIs(3));
+
+ // Check that the hashes are self_consistent, and put the entries into
+ // another map with a known order.
+ std::map<SigNode::LinkTag, SigNode::Link> rehashed;
+ auto hlit = hl.begin();
+ ASSERT_THAT(hlit, Ne(hl.end()));
+ EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
+ rehashed[hlit->second.tag] = hlit->second;
+ ++hlit;
+ ASSERT_THAT(hlit, Ne(hl.end()));
+ EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
+ rehashed[hlit->second.tag] = hlit->second;
+ ++hlit;
+ ASSERT_THAT(hlit, Ne(hl.end()));
+ EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
+ rehashed[hlit->second.tag] = hlit->second;
+
+ // Just in case.
+ ASSERT_THAT(rehashed, SizeIs(3));
+
+ auto rhit = rehashed.begin();
+ ASSERT_THAT(rhit, Ne(rehashed.end()));
+ EXPECT_TRUE(rhit->second.tag == link1.tag);
+ EXPECT_THAT(rhit->second.unique_hash, Eq(kSameHash));
+ EXPECT_THAT(rhit->second.peers, ElementsAre(&sn1));
+
+ ++rhit;
+ ASSERT_THAT(rhit, Ne(rehashed.end()));
+ EXPECT_TRUE(rhit->second.tag == link2.tag);
+ // This hash must be rehashed.
+ EXPECT_THAT(rhit->second.unique_hash, Ne(kSameHash));
+ size_t hash2 = rhit->second.unique_hash;
+ EXPECT_THAT(rhit->second.peers, ElementsAre(&sn2));
+
+ ++rhit;
+ ASSERT_THAT(rhit, Ne(rehashed.end()));
+ EXPECT_TRUE(rhit->second.tag == link3.tag);
+ // This hash must be rehashed.
+ EXPECT_THAT(rhit->second.unique_hash, Ne(kSameHash));
+ EXPECT_THAT(rhit->second.unique_hash, Ne(hash2));
+ size_t hash3 = rhit->second.unique_hash;
+ EXPECT_THAT(rhit->second.peers, ElementsAre(&sn3));
+
+ auto& peers = sn3.hashed_peers();
+ EXPECT_THAT(peers, SizeIs(3));
+
+ auto peerit = peers.begin();
+ ASSERT_THAT(peerit, Ne(peers.end()));
+ EXPECT_THAT(peerit->link_hash, Eq(kSameHash));
+ EXPECT_THAT(peerit->peer, Eq(&sn1));
+
+ ++peerit;
+ ASSERT_THAT(peerit, Ne(peers.end()));
+ EXPECT_THAT(peerit->link_hash, Eq(hash2));
+ EXPECT_THAT(peerit->peer, Eq(&sn2));
+
+ ++peerit;
+ ASSERT_THAT(peerit, Ne(peers.end()));
+ EXPECT_THAT(peerit->link_hash, Eq(hash3));
+ EXPECT_THAT(peerit->peer, Eq(&sn3));
+}
+
+// The full CopyLinks() is tested in (SubgraphTest, ExtractForSignature).
+
+TEST_F(SigNodeTest, GetTopoHash) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ // Fake some hash values.
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(456);
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
+ EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
+
+ RefHashIsFinal(&sn1) = true;
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
+ EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
+ EXPECT_THAT(GetTopoHash(2, &sn1), Eq(456));
+
+ EXPECT_THAT(GetHighTopoHash(&sn1), Eq(456));
+}
+
+TEST_F(SigNodeTest, ReTopoHash) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ // Fake some hash values.
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(456);
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
+ EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
+
+ ReHighTopoHash(&sn1);
+
+ size_t expected_hash = 456;
+ CombineHash(1, &expected_hash);
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
+ EXPECT_THAT(GetTopoHash(1, &sn1), Eq(expected_hash));
+}
+
+TEST_F(SigNodeTest, ComputeTopoHash0) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ // Fake a topology.
+ RefUniqueRank(&sn1) = 10;
+ RefNodeMask(&sn1) = 0x02;
+
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(456);
+
+ // Fake a state.
+ RefLastHashedNodes(&sn1) = 0xFF;
+ RefNextHashedNodes(&sn1) = 0xFF;
+
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(1, nullptr));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(1, nullptr));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(2, nullptr));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(3, nullptr));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(3, nullptr));
+
+ // Run the test.
+ ComputeTopoHash0(&sn1);
+
+ EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x02));
+ EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x02));
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(1));
+
+ size_t exp_hval = std::hash<string>()(sn1.opcode());
+ CombineHash(1, &exp_hval);
+ CombineHash(1, &exp_hval);
+ CombineHash(2, &exp_hval);
+ CombineHash(3, &exp_hval);
+ CombineHash(3, &exp_hval);
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(exp_hval));
+}
+
+TEST_F(SigNodeTest, ComputeTopoHashNotFinal) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+
+ // Fake a topology.
+ RefUniqueRank(&sn1) = 0;
+ RefNodeMask(&sn1) = 0x01;
+ RefUniqueRank(&sn2) = 0;
+ RefNodeMask(&sn2) = 0x02;
+ RefUniqueRank(&sn3) = 0;
+ RefNodeMask(&sn3) = 0x04;
+
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn2));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn3));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(20, &sn2));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn3));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn2));
+
+ // Fake a state.
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(321);
+
+ RefTopoHash(&sn2).emplace_back(456);
+ RefTopoHash(&sn2).emplace_back(654);
+
+ RefTopoHash(&sn3).emplace_back(789);
+ RefTopoHash(&sn3).emplace_back(987);
+
+ // These values are not realistic in the way that they don't include the bits
+ // from the mask of nodes themselves, but that's the point of this test: only
+ // the previous nodes' node sets are used in the computation, not their own
+ // masks directly.
+ RefLastHashedNodes(&sn1) = 0x8;
+ RefLastHashedNodes(&sn2) = 0x10;
+ RefLastHashedNodes(&sn3) = 0x20;
+
+ // A scratch value to get overwritten.
+ RefNextHashedNodes(&sn1) = 0x100;
+
+ ComputeTopoHash(2, &sn1);
+
+ EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x8)); // Unchanged.
+ EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x38));
+
+ // This computes the hash form the explicit numbers above.
+ size_t exp_hash = 123; // The 0th hash is the starting point.
+ size_t comm_hash;
+
+ comm_hash = 0;
+ CombineHashCommutative(654, &comm_hash);
+ CombineHashCommutative(987, &comm_hash);
+
+ CombineHash(10, &exp_hash);
+ CombineHash(comm_hash, &exp_hash);
+
+ comm_hash = 0;
+ CombineHashCommutative(654, &comm_hash);
+
+ CombineHash(20, &exp_hash);
+ CombineHash(comm_hash, &exp_hash);
+
+ comm_hash = 0;
+ CombineHashCommutative(654, &comm_hash);
+ CombineHashCommutative(987, &comm_hash);
+
+ CombineHash(30, &exp_hash);
+ CombineHash(comm_hash, &exp_hash);
+
+ EXPECT_THAT(GetTopoHash(2, &sn1), Eq(exp_hash));
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(3));
+}
+
+TEST_F(SigNodeTest, ComputeTopoHashFinal) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+
+ // Fake a topology - same as for ComputeTopoHashNotFinal.
+ RefUniqueRank(&sn1) = 0;
+ RefNodeMask(&sn1) = 0x01;
+ RefUniqueRank(&sn2) = 0;
+ RefNodeMask(&sn2) = 0x02;
+ RefUniqueRank(&sn3) = 0;
+ RefNodeMask(&sn3) = 0x04;
+
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn2));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn3));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(20, &sn2));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn3));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn2));
+
+ // Fake a state - mostly same as for ComputeTopoHashNotFinal.
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(321);
+
+ RefTopoHash(&sn2).emplace_back(456);
+ RefTopoHash(&sn2).emplace_back(654);
+
+ RefTopoHash(&sn3).emplace_back(789);
+ RefTopoHash(&sn3).emplace_back(987);
+
+ // These values are not realistic in the way that they don't include the bits
+ // from the mask of nodes themselves, but that's the point of this test: only
+ // the previous nodes' node sets are used in the computation, not their own
+ // masks directly.
+ RefLastHashedNodes(&sn1) = 0x8;
+ RefLastHashedNodes(&sn2) = 0x10;
+ RefLastHashedNodes(&sn3) = 0x20;
+
+ // A scratch value to get overwritten.
+ RefNextHashedNodes(&sn1) = 0x100;
+
+ // This is the difference in configuration.
+ RefHashIsFinal(&sn1) = true;
+
+ ComputeTopoHash(2, &sn1);
+
+ EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x8)); // Unchanged.
+ EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x8));
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
+ EXPECT_THAT(GetTopoHash(2, &sn1), Eq(321));
+}
+
+TEST_F(SigNodeTest, EqualsOpcode) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+
+ EXPECT_TRUE(sn1 == sn2);
+ EXPECT_FALSE(sn1 != sn2);
+
+ node2.set_op("Mul");
+
+ EXPECT_TRUE(sn1 != sn2);
+ EXPECT_FALSE(sn1 == sn2);
+}
+
+TEST_F(SigNodeTest, EqualsRank) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+
+ EXPECT_TRUE(sn1 == sn2);
+ EXPECT_FALSE(sn1 != sn2);
+
+ RefUniqueRank(&sn1) = 1;
+ RefUniqueRank(&sn2) = 2;
+
+ EXPECT_TRUE(sn1 != sn2);
+ EXPECT_FALSE(sn1 == sn2);
+}
+
+// Checks that if the nodes have a different number of links,
+// they will be considered unequal.
+TEST_F(SigNodeTest, EqualsLinkSize) {
+ GraphDef graph1;
+ (*graph1.add_node()) = MakeNodeConst("node1");
+ (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1");
+
+ GenNodeMap gen_map1;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(Status::OK()));
+
+ Subgraph::Identity id1;
+ id1.insert(gen_map1["node1"].get());
+ id1.insert(gen_map1["node2"].get());
+ Subgraph sg1(id1);
+
+ SigNodeMap sig_map1;
+ sg1.ExtractForSignature(&sig_map1);
+
+ GraphDef graph2;
+ (*graph2.add_node()) = MakeNodeConst("node1");
+ // The difference between graph1 and graph2: one more input.
+ auto node22 = graph2.add_node();
+ *node22 = MakeNodeMul("node2", "node1", "node1");
+ node22->add_input("node2");
+
+ GenNodeMap gen_map2;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph2, &gen_map2), Eq(Status::OK()));
+
+ Subgraph::Identity id2;
+ id2.insert(gen_map2["node1"].get());
+ id2.insert(gen_map2["node2"].get());
+ Subgraph sg2(id2);
+
+ SigNodeMap sig_map2;
+ sg2.ExtractForSignature(&sig_map2);
+
+ EXPECT_TRUE(*sig_map1["node1"] == *sig_map2["node1"]);
+ EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
+ EXPECT_FALSE(*sig_map2["node2"] == *sig_map1["node2"]);
+}
+
+TEST_F(SigNodeTest, EqualsLinks) {
+ // Start with 2 copies of the same graph.
+ GraphDef graph1;
+ (*graph1.add_node()) = MakeNodeConst("node1");
+ (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1");
+
+ GenNodeMap gen_map1;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(Status::OK()));
+
+ Subgraph::Identity id1;
+ id1.insert(gen_map1["node1"].get());
+ id1.insert(gen_map1["node2"].get());
+ Subgraph sg1(id1);
+
+ SigNodeMap sig_map1;
+ sg1.ExtractForSignature(&sig_map1);
+
+ GenNodeMap gen_map2;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map2), Eq(Status::OK()));
+
+ Subgraph::Identity id2;
+ id2.insert(gen_map2["node1"].get());
+ id2.insert(gen_map2["node2"].get());
+ Subgraph sg2(id2);
+
+ SigNodeMap sig_map2;
+ sg2.ExtractForSignature(&sig_map2);
+
+ EXPECT_TRUE(*sig_map1["node1"] == *sig_map2["node1"]);
+ EXPECT_TRUE(*sig_map1["node2"] == *sig_map2["node2"]);
+
+ // Alter the link hash of one of the nodes.
+ SigNode* sn2 = sig_map2["node2"].get();
+ ++RefHashedPeers(sn2)[0].link_hash;
+
+ EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
+
+ // Restore back.
+ --RefHashedPeers(sn2)[0].link_hash;
+ EXPECT_TRUE(*sig_map1["node2"] == *sig_map2["node2"]);
+
+ // Alter the unique rank of a referenced node.
+ ++RefUniqueRank(sig_map2["node1"].get());
+
+ EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
+}
+
+//===
+
+class SignatureTest : public SigBaseTest {
+ protected:
+ // Initializeds the state used to generate the permutations of a given size.
+ static void InitPermutation(size_t size,
+ std::vector<size_t>* plain_permutation,
+ std::vector<size_t>* countdown) {
+ plain_permutation->clear();
+ countdown->clear();
+ for (size_t i = 0; i < size; ++i) {
+ plain_permutation->emplace_back(i);
+ countdown->emplace_back(size - 1 - i);
+ }
+ }
+
+ // Builds a permutation guided by the count-down value.
+ static void BuildPermutation(const std::vector<size_t>& plain_permutation,
+ const std::vector<size_t>& countdown,
+ std::vector<size_t>* result) {
+ *result = plain_permutation;
+ for (int i = 0; i < result->size(); ++i) {
+ std::swap((*result)[i], (*result)[i + countdown[i]]);
+ }
+ }
+
+ // Returns false when the count-down is finished.
+ static bool CountDown(std::vector<size_t>* countdown) {
+ // The last position always contains 0, so skip it.
+ int pos;
+ for (pos = countdown->size() - 2; pos >= 0; --pos) {
+ if ((*countdown)[pos] > 0) {
+ --(*countdown)[pos];
+ break;
+ }
+ (*countdown)[pos] = (countdown->size() - 1 - pos);
+ }
+
+ return pos >= 0;
+ }
+
+ // Permutes the nodes every which way and checks that all the signatures
+ // produced are the same. This is reasonable for the graphs up to the
+ // size 5, maybe 6 at the stretch. After that the number of permutation grows
+ // huge and the test becomes very slow.
+ void TestGraphEveryWay(const GraphDef& graph) {
+ size_t graph_size = graph.node_size();
+
+ gen_map_.clear();
+ sig_.map.clear();
+ Status result = GenNode::BuildGraphInMap(graph, &gen_map_);
+ ASSERT_THAT(result, Eq(Status::OK()));
+ Subgraph::Identity id;
+ for (const auto& entry : gen_map_) {
+ id.insert(entry.second.get());
+ }
+ Subgraph sg(id);
+ sg.ExtractForSignature(&sig_.map);
+
+ std::vector<size_t> plain_permutation;
+ std::vector<size_t> countdown;
+ InitPermutation(graph_size, &plain_permutation, &countdown);
+
+ std::set<string> signatures;
+ std::vector<size_t> permutation;
+ do {
+ BuildPermutation(plain_permutation, countdown, &permutation);
+
+ constexpr bool kDebugPermutation = false;
+ if (kDebugPermutation) {
+ string p;
+ for (int i = 0; i < permutation.size(); ++i) {
+ p.push_back('0' + permutation[i]);
+ }
+ LOG(INFO) << "Permutation: " << p;
+ }
+
+ std::vector<std::unique_ptr<SigNode>> hold(graph_size);
+ int idx;
+
+ // Permute the nodes.
+ sig_.nodes.clear();
+ idx = 0;
+ if (kDebugPermutation) {
+ LOG(INFO) << " nodes before permutation:";
+ }
+ for (auto& entry : sig_.map) {
+ if (kDebugPermutation) {
+ LOG(INFO) << " " << entry.second.get();
+ }
+ hold[idx++] = std::move(entry.second);
+ }
+ idx = 0;
+ if (kDebugPermutation) {
+ LOG(INFO) << " nodes after permutation:";
+ }
+ for (auto& entry : sig_.map) {
+ entry.second = std::move(hold[permutation[idx++]]);
+ if (kDebugPermutation) {
+ LOG(INFO) << " " << entry.second.get();
+ }
+ // This is used to order the links per permutation.
+ sig_.nodes.emplace_back(entry.second.get());
+ RefUniqueRank(entry.second.get()) = idx;
+ }
+ // Order the links with the same tags per permutation.
+ OrderLinks(&sig_);
+
+ // The test as such.
+ ASSERT_THAT(sig_.Compute(), Eq(Status::OK()));
+
+ signatures.insert(sig_.ToString());
+
+ EXPECT_THAT(sig_.sig_full, SizeIs(graph_size));
+ size_t hval = 0;
+ for (size_t ih : sig_.sig_full) {
+ // The space 1..graph_size is reserved.
+ EXPECT_THAT(ih, Gt(graph_size));
+ CombineHash(ih, &hval);
+ }
+ EXPECT_THAT(sig_.sig_short, Eq(hval));
+
+ // Un-permute the nodes for the next iteration.
+ idx = 0;
+ for (auto& entry : sig_.map) {
+ hold[permutation[idx++]] = std::move(entry.second);
+ }
+ idx = 0;
+ if (kDebugPermutation) {
+ LOG(INFO) << " nodes after un-permutation:";
+ }
+ for (auto& entry : sig_.map) {
+ entry.second = std::move(hold[idx++]);
+ if (kDebugPermutation) {
+ LOG(INFO) << " " << entry.second.get();
+ }
+ }
+ } while (CountDown(&countdown));
+
+ for (const auto& s : signatures) {
+ LOG(INFO) << "Signature: " << s;
+ }
+
+ // All the permutations should produce the same signature.
+ EXPECT_THAT(signatures, SizeIs(1));
+ }
+};
+
+TEST_F(SignatureTest, PrepareNodes) {
+ NodeDef node1 = MakeNodeConst("node1");
+ sig_.map["node1"] = absl::make_unique<SigNode>(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ sig_.map["node2"] = absl::make_unique<SigNode>(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ sig_.map["node3"] = absl::make_unique<SigNode>(&node3);
+
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(3));
+
+ int idx = 0;
+ for (const auto& entry : sig_.map) {
+ EXPECT_THAT(RefNodeMask(entry.second.get()), Eq(1 << idx))
+ << " at index " << idx;
+ EXPECT_THAT(RefUniqueRank(entry.second.get()), Eq(static_cast<size_t>(~0)))
+ << " at index " << idx;
+ EXPECT_THAT(RefHashIsFinal(entry.second.get()), false)
+ << " at index " << idx;
+ EXPECT_THAT(RefTopoHash(entry.second.get()), SizeIs(1))
+ << " at index " << idx;
+ ++idx;
+ }
+}
+
+TEST_F(SignatureTest, FindUniqueHashesAllDifferent) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+ NodeDef node4 = MakeNodeConst("node4");
+ SigNode sn4(&node4);
+
+ // The last values in the arrays values go in the backwards order.
+ RefTopoHash(&sn1).emplace_back(100);
+ RefTopoHash(&sn1).emplace_back(900);
+
+ RefTopoHash(&sn2).emplace_back(200);
+ RefTopoHash(&sn2).emplace_back(800);
+
+ RefTopoHash(&sn3).emplace_back(300);
+ RefTopoHash(&sn3).emplace_back(700);
+
+ RefTopoHash(&sn4).emplace_back(400);
+ RefTopoHash(&sn4).emplace_back(600);
+
+ sig_.nodes.emplace_back(&sn1);
+ sig_.nodes.emplace_back(&sn2);
+ sig_.nodes.emplace_back(&sn3);
+ sig_.nodes.emplace_back(&sn4);
+
+ size_t next = 1; // Skips over sn1.
+
+ FindUniqueHashes(&next, &sig_);
+ EXPECT_THAT(next, Eq(4));
+
+ EXPECT_THAT(sig_.nodes[0], Eq(&sn1));
+ // The nodes after first one get sorted by the high hash.
+ EXPECT_THAT(sig_.nodes[1], Eq(&sn4));
+ EXPECT_THAT(sig_.nodes[2], Eq(&sn3));
+ EXPECT_THAT(sig_.nodes[3], Eq(&sn2));
+
+ EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
+ // Nodes that get finalized are marked as such.
+ EXPECT_THAT(RefHashIsFinal(&sn2), Eq(true));
+ EXPECT_THAT(RefHashIsFinal(&sn3), Eq(true));
+ EXPECT_THAT(RefHashIsFinal(&sn4), Eq(true));
+
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
+ ASSERT_THAT(RefTopoHash(&sn2), SizeIs(1));
+ ASSERT_THAT(RefTopoHash(&sn3), SizeIs(1));
+ ASSERT_THAT(RefTopoHash(&sn4), SizeIs(1));
+
+ EXPECT_THAT(RefTopoHash(&sn2)[0], Eq(4));
+ EXPECT_THAT(RefTopoHash(&sn3)[0], Eq(3));
+ EXPECT_THAT(RefTopoHash(&sn4)[0], Eq(2));
+
+ EXPECT_THAT(sig_.sig_full, ElementsAre(600, 700, 800));
+
+ size_t exp_short_hash = 0;
+ CombineHash(600, &exp_short_hash);
+ CombineHash(700, &exp_short_hash);
+ CombineHash(800, &exp_short_hash);
+ EXPECT_THAT(sig_.sig_short, Eq(exp_short_hash));
+}
+
+TEST_F(SignatureTest, FindUniqueHashesDuplicatesExceptOne) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+ NodeDef node4 = MakeNodeConst("node4");
+ SigNode sn4(&node4);
+ NodeDef node5 = MakeNodeConst("node5");
+ SigNode sn5(&node5);
+
+ RefTopoHash(&sn1).emplace_back(100);
+ RefTopoHash(&sn1).emplace_back(600);
+
+ RefTopoHash(&sn2).emplace_back(200);
+ RefTopoHash(&sn2).emplace_back(600);
+
+ RefTopoHash(&sn3).emplace_back(300);
+ RefTopoHash(&sn3).emplace_back(700);
+
+ RefTopoHash(&sn4).emplace_back(400);
+ RefTopoHash(&sn4).emplace_back(800);
+
+ RefTopoHash(&sn5).emplace_back(500);
+ RefTopoHash(&sn5).emplace_back(800);
+
+ sig_.nodes.emplace_back(&sn1);
+ sig_.nodes.emplace_back(&sn2);
+ sig_.nodes.emplace_back(&sn3);
+ sig_.nodes.emplace_back(&sn4);
+ sig_.nodes.emplace_back(&sn5);
+
+ size_t next = 0;
+
+ FindUniqueHashes(&next, &sig_);
+ EXPECT_THAT(next, Eq(1));
+
+ // The unique node goes first.
+ EXPECT_THAT(sig_.nodes[0], Eq(&sn3));
+
+ // The rest of the nodes are assumed to be sorted in a stable order.
+ EXPECT_THAT(sig_.nodes[1], Eq(&sn2));
+ // Node 1 gets swapped with node 3.
+ EXPECT_THAT(sig_.nodes[2], Eq(&sn1));
+ EXPECT_THAT(sig_.nodes[3], Eq(&sn4));
+ EXPECT_THAT(sig_.nodes[4], Eq(&sn5));
+
+ EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn2), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn3), Eq(true));
+ EXPECT_THAT(RefHashIsFinal(&sn4), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn5), Eq(false));
+
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn2), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn3), SizeIs(1));
+ EXPECT_THAT(RefTopoHash(&sn4), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn5), SizeIs(2));
+
+ EXPECT_THAT(RefTopoHash(&sn3)[0], Eq(1));
+}
+
+TEST_F(SignatureTest, FindUniqueHashesDuplicates) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+ NodeDef node4 = MakeNodeConst("node4");
+ SigNode sn4(&node4);
+ NodeDef node5 = MakeNodeConst("node5");
+ SigNode sn5(&node5);
+
+ RefTopoHash(&sn1).emplace_back(100);
+ RefTopoHash(&sn1).emplace_back(600);
+
+ RefTopoHash(&sn2).emplace_back(200);
+ RefTopoHash(&sn2).emplace_back(600);
+
+ RefTopoHash(&sn3).emplace_back(300);
+ RefTopoHash(&sn3).emplace_back(700);
+
+ RefTopoHash(&sn4).emplace_back(400);
+ RefTopoHash(&sn4).emplace_back(700);
+
+ RefTopoHash(&sn5).emplace_back(500);
+ RefTopoHash(&sn5).emplace_back(700);
+
+ sig_.nodes.emplace_back(&sn1);
+ sig_.nodes.emplace_back(&sn2);
+ sig_.nodes.emplace_back(&sn3);
+ sig_.nodes.emplace_back(&sn4);
+ sig_.nodes.emplace_back(&sn5);
+
+ size_t next = 0;
+
+ FindUniqueHashes(&next, &sig_);
+ EXPECT_THAT(next, Eq(1));
+
+ // The last copy of the last duplicate wins.
+ EXPECT_THAT(sig_.nodes[0], Eq(&sn5));
+
+ // The rest of the nodes are assumed to be sorted in a stable order.
+ // Node 1 gets swapped.
+ EXPECT_THAT(sig_.nodes[1], Eq(&sn2));
+ EXPECT_THAT(sig_.nodes[2], Eq(&sn3));
+ EXPECT_THAT(sig_.nodes[3], Eq(&sn4));
+ EXPECT_THAT(sig_.nodes[4], Eq(&sn1));
+
+ EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn2), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn3), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn4), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn5), Eq(true));
+
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn2), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn3), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn4), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn5), SizeIs(1));
+
+ EXPECT_THAT(RefTopoHash(&sn5)[0], Eq(1));
+}
+
+// On a circular topology.
+TEST_F(SignatureTest, ComputeOneRoundCircular) {
+ BuildSigMap(graph_circular_onedir_);
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(5));
+
+ // This skips FindUniqueHashes() which would pick one node, so that
+ // all the nodes are equivalent for ComputeOneRound().
+
+ ComputeOneRound(0, &sig_);
+
+ // All the nodes are the same, so the computed hashes will also be the same.
+ size_t hval = GetHighTopoHash(sig_.nodes[0]);
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(GetHighTopoHash(sig_.nodes[i]), Eq(hval)) << " at index " << i;
+ EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[i]), Eq(0x1F))
+ << " at index " << i;
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[i]), Eq(0x1F))
+ << " at index " << i;
+ // The sets of hashed nodes go like this:
+ // Step 0: self.
+ // Step 1: self, previous (-1) and next (+1) node.
+ // Step 2: self, (-1), (-2), (+1), (+2): all 5 nodes in the graph
+ // Step 3: still all 5 nodes in the graph
+ EXPECT_THAT(RefTopoHash(sig_.nodes[i]), SizeIs(4)) << " at index " << i;
+ }
+}
+
+// On a linear topology.
+TEST_F(SignatureTest, ComputeOneRoundLinear) {
+ BuildSigMap(graph_linear_);
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(5));
+
+ // This skips FindUniqueHashes() which would pick one node, so that
+ // all the nodes are equivalent for ComputeOneRound().
+
+ ComputeOneRound(0, &sig_);
+
+ std::vector<size_t> hash_size;
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[i]), Eq(0x1F))
+ << " at index " << i;
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[i]), Eq(0x1F))
+ << " at index " << i;
+ hash_size.emplace_back(RefTopoHash(sig_.nodes[i]).size());
+ }
+
+ // The sets of hashed nodes for the central node go like this:
+ // Step 0: self.
+ // Step 1: self, previous (-1) and next (+1) node.
+ // Step 2: self, (-1), (-2), (+1), (+2): all 5 nodes in the graph
+ // Step 3: still all 5 nodes in the graph
+ //
+ // The nodes one step closer to the ends require one more step. The end nodes
+ // require one more step yet.
+ std::sort(hash_size.begin(), hash_size.end());
+ EXPECT_THAT(hash_size, ElementsAre(4, 5, 5, 6, 6));
+}
+
+// On a linear topology where the cental node has been already marked as unique
+// (yeah, not a very realistic case but tests the situations when the
+// disconnected subgraphs get created).
+TEST_F(SignatureTest, ComputeOneRoundSplitLinear) {
+ BuildSigMap(graph_linear_);
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(5));
+
+ // This test relies on the order of SigNodeMap imposed on sig_.nodes.
+
+ // The middle node gets separated by moving it to the front.
+ std::swap(sig_.nodes[0], sig_.nodes[2]);
+ ASSERT_THAT(RefNodeMask(sig_.nodes[0]), Eq(0x04));
+ ASSERT_THAT(RefLastHashedNodes(sig_.nodes[0]), Eq(0x04));
+ ASSERT_THAT(RefNextHashedNodes(sig_.nodes[0]), Eq(0x04));
+ RefHashIsFinal(sig_.nodes[0]) = true;
+
+ ComputeOneRound(1, &sig_);
+
+ // These should stay unchanged.
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[0]), Eq(0x04));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[0]), Eq(0x04));
+
+ std::vector<size_t> hash_size;
+ for (int i = 1; i < 5; ++i) {
+ EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
+ hash_size.emplace_back(RefTopoHash(sig_.nodes[i]).size());
+ }
+
+ std::sort(hash_size.begin(), hash_size.end());
+ // The end nodes take 4 steps, closer to the center 3 steps.
+ EXPECT_THAT(hash_size, ElementsAre(3, 3, 4, 4));
+
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[1]), Eq(0x07));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[1]), Eq(0x07));
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[2]), Eq(0x07));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[2]), Eq(0x07));
+
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[3]), Eq(0x1C));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[3]), Eq(0x1C));
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[4]), Eq(0x1C));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[4]), Eq(0x1C));
+}
+
+TEST_F(SignatureTest, OrderLinks) {
+ gen_map_.clear();
+ sig_.map.clear();
+ Status result = GenNode::BuildGraphInMap(graph_for_link_order_, &gen_map_);
+ ASSERT_THAT(result, Eq(Status::OK()));
+ Subgraph::Identity id;
+ for (const auto& entry : gen_map_) {
+ id.insert(entry.second.get());
+ }
+ Subgraph sg(id);
+ sg.ExtractForSignature(&sig_.map);
+
+ // Populate the fake signature and assign the ranks in the backwards order.
+ for (auto it = sig_.map.rbegin(); it != sig_.map.rend(); ++it) {
+ auto& entry = *it;
+ RefUniqueRank(entry.second.get()) = sig_.nodes.size();
+ sig_.nodes.emplace_back(entry.second.get());
+ }
+
+ // How it was ordered in the original graph.
+ string before = sig_.ToString();
+ // clang-format off
+ EXPECT_THAT(before, Eq(
+ "0:Mul[i0:o0:5][i0:o0:4][i0:o1:4][i0:o2:3][i0:o2:2][i0:o3:2],"
+ "1:Mul[i0:o0:5][i0:o0:4][i0:o0:3][i0:o0:2],"
+ "2:Const,"
+ "3:Const,"
+ "4:Const,"
+ "5:Const,"
+ ));
+ // clang-format on
+
+ OrderLinks(&sig_);
+
+ string after = sig_.ToString();
+ // clang-format off
+ EXPECT_THAT(after, Eq(
+ "0:Mul[i0:o0:4][i0:o0:5][i0:o1:4][i0:o2:2][i0:o2:3][i0:o3:2],"
+ "1:Mul[i0:o0:2][i0:o0:3][i0:o0:4][i0:o0:5],"
+ "2:Const,"
+ "3:Const,"
+ "4:Const,"
+ "5:Const,"
+ ));
+ // clang-format on
+}
+
+TEST_F(SignatureTest, GraphTooBig) {
+ GraphDef graph;
+ for (int i = 0; i <= Signature::kMaxGraphSize; ++i) {
+ (*graph.add_node()) = MakeNodeConst(absl::StrFormat("node%d", i));
+ }
+
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &gen_map_), Eq(Status::OK()));
+
+ Subgraph::Identity id;
+ for (const auto& entry : gen_map_) {
+ id.insert(entry.second.get());
+ }
+ Subgraph sg(id);
+ sg.ExtractForSignature(&sig_.map);
+
+ ASSERT_THAT(sig_.Compute(),
+ Eq(Status(error::INVALID_ARGUMENT,
+ "A graph of 65 nodes is too big for signature "
+ "computation, the maximal supported node count is "
+ "64.")));
+}
+
+TEST_F(SignatureTest, ToString) {
+ BuildSigMap(graph_circular_onedir_);
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(5));
+
+ // Fake the works by assigning unique ranks as they go in the initial order.
+ for (int i = 0; i < 5; ++i) {
+ RefUniqueRank(sig_.nodes[i]) = i;
+ RefHashIsFinal(sig_.nodes[i]) = true;
+ }
+
+ string result = sig_.ToString();
+
+ // clang-format off
+ ASSERT_THAT(result, Eq(
+ "0:Mul[i0:o0:4][i0:o0:4],"
+ "1:Mul[i0:o0:0][i0:o0:0],"
+ "2:Mul[i0:o0:1][i0:o0:1],"
+ "3:Mul[i0:o0:2][i0:o0:2],"
+ "4:Mul[i0:o0:3][i0:o0:3],"
+ ));
+ // clang-format on
+}
+
+// This is a test of the permutation logic itself.
+TEST_F(SignatureTest, Permutation) {
+ std::vector<size_t> plain_permutation;
+ std::vector<size_t> countdown;
+ InitPermutation(5, &plain_permutation, &countdown);
+
+ std::set<string> results;
+
+ std::vector<size_t> permutation;
+ do {
+ BuildPermutation(plain_permutation, countdown, &permutation);
+ EXPECT_THAT(permutation, SizeIs(5));
+
+ string p;
+ for (int i = 0; i < permutation.size(); ++i) {
+ p.push_back('0' + permutation[i]);
+ }
+ LOG(INFO) << "Permutation: " << p;
+ results.insert(p);
+ } while (CountDown(&countdown));
+
+ EXPECT_THAT(results, SizeIs(5 * 4 * 3 * 2 * 1));
+}
+
+TEST_F(SignatureTest, ComputeCircularOneDir) {
+ TestGraphEveryWay(graph_circular_onedir_);
+}
+
+TEST_F(SignatureTest, ComputeCircularBiDir) {
+ TestGraphEveryWay(graph_circular_bidir_);
+}
+
+TEST_F(SignatureTest, ComputeLinear) { TestGraphEveryWay(graph_linear_); }
+
+TEST_F(SignatureTest, ComputeMultiInput) {
+ TestGraphEveryWay(graph_multi_input_);
+}
+
+TEST_F(SignatureTest, ComputeAllOrNone) {
+ TestGraphEveryWay(graph_all_or_none_);
+}
+
+TEST_F(SignatureTest, ComputeCross) { TestGraphEveryWay(graph_small_cross_); }
+
+TEST_F(SignatureTest, Equals) {
+ // Start with 2 copies of the same graph.
+ GenNodeMap gen_map1;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map1),
+ Eq(Status::OK()));
+
+ Subgraph::Identity id1;
+ id1.insert(gen_map1["node1"].get());
+ id1.insert(gen_map1["node2"].get());
+ Subgraph sg1(id1);
+
+ Signature sig1;
+ sg1.ExtractForSignature(&sig1.map);
+ ASSERT_THAT(sig1.Compute(), Eq(Status::OK()));
+
+ GenNodeMap gen_map2;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map2),
+ Eq(Status::OK()));
+
+ Subgraph::Identity id2;
+ id2.insert(gen_map2["node1"].get());
+ id2.insert(gen_map2["node2"].get());
+ Subgraph sg2(id2);
+
+ Signature sig2;
+ sg2.ExtractForSignature(&sig2.map);
+ ASSERT_THAT(sig2.Compute(), Eq(Status::OK()));
+
+ EXPECT_TRUE(sig1 == sig2);
+
+ // Change the short hash.
+ ++sig2.sig_short;
+ EXPECT_FALSE(sig1 == sig2);
+
+ // Restore back.
+ --sig2.sig_short;
+ EXPECT_TRUE(sig1 == sig2);
+
+ // Change the full hash.
+ ++sig2.sig_full[0];
+ EXPECT_FALSE(sig1 == sig2);
+
+ // Restore back.
+ --sig2.sig_full[0];
+ EXPECT_TRUE(sig1 == sig2);
+
+ // Make the nodes different.
+ std::swap(sig2.nodes[0], sig2.nodes[1]);
+ EXPECT_FALSE(sig1 == sig2);
+
+ // Restore back.
+ std::swap(sig2.nodes[0], sig2.nodes[1]);
+ EXPECT_TRUE(sig1 == sig2);
+
+ // Different number of nodes.
+ sig2.nodes.emplace_back(sig2.nodes[0]);
+ EXPECT_FALSE(sig1 == sig2);
+ EXPECT_FALSE(sig2 == sig1);
+}
+
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph.cc b/tensorflow/core/grappler/graph_analyzer/subgraph.cc
new file mode 100644
index 0000000000..28a91e0f84
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/subgraph.cc
@@ -0,0 +1,235 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
+
+#include <functional>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+//=== Subgraph::Identity
+
+Subgraph::Identity::Identity(InitializerList init) {
+ for (auto element : init) {
+ insert(element);
+ }
+}
+
+bool Subgraph::Identity::operator<(const Identity& other) const {
+ // Shorter sets go first.
+ if (this->size() < other.size()) {
+ return true;
+ }
+ if (this->size() > other.size()) {
+ return false;
+ }
+ for (auto lit = this->begin(), rit = other.begin(); lit != this->end();
+ ++lit, ++rit) {
+ if (*lit < *rit) {
+ return true;
+ }
+ if (*lit > *rit) {
+ return false;
+ }
+ }
+ return false; // Equal.
+}
+
+bool Subgraph::Identity::operator==(const Identity& other) const {
+ if (this->size() != other.size()) {
+ return false;
+ }
+ for (auto lit = this->begin(), rit = other.begin(); lit != this->end();
+ ++lit, ++rit) {
+ if (*lit != *rit) {
+ return false;
+ }
+ }
+ return true; // Equal.
+}
+
+size_t Subgraph::Identity::Hash() const {
+ std::hash<const GenNode*> hasher;
+ size_t result = 0;
+ for (auto ptr : *this) {
+ CombineHash(hasher(ptr), &result);
+ }
+ return result;
+}
+
+string Subgraph::Dump() {
+ // TODO(babkin): this is simplified for now.
+ std::vector<string> nodes;
+ for (const auto& n : id_) {
+ if (specific_) {
+ nodes.emplace_back(absl::StrFormat("%s(%s)", n->opcode(), n->name()));
+ } else {
+ nodes.emplace_back(n->opcode());
+ }
+ }
+ std::sort(nodes.begin(), nodes.end());
+
+ return absl::StrFormat("%d: ", collation_count_) + absl::StrJoin(nodes, ", ");
+}
+
+void Subgraph::ExtractForSignature(SigNodeMap* result) {
+ // Mapping of nodes from the original graph to the new one.
+ SigNode::TranslationMap full_to_new;
+
+ for (auto node : id_) {
+ auto newnode_ref = absl::make_unique<SigNode>(node->node_def());
+ auto newnode = newnode_ref.get();
+ (*result)[node->name()] = std::move(newnode_ref);
+ full_to_new[node] = newnode;
+ }
+
+ for (const auto& mapping : full_to_new) {
+ mapping.second->CopyLinks(*mapping.first, full_to_new);
+ }
+}
+
+//=== Subgraph
+
+Subgraph::Subgraph(const Identity& parent_id, GenNode* add_node)
+ : id_(parent_id) {
+ id_.insert(add_node);
+ hash_ = id_.Hash();
+}
+
+//=== SubgraphIterator
+
+SubgraphIterator::SubgraphIterator(const Subgraph::Identity* id)
+ : id_(id), id_it_(id_->begin()) {
+ if (!id_->empty()) {
+ link_map_it_ = (*id_it_)->links().begin();
+ // In case if the node has no links.
+ while (link_map_it_ == (*id_it_)->links().end()) {
+ if (++id_it_ == id_->end()) {
+ return;
+ }
+ link_map_it_ = (*id_it_)->links().begin();
+ }
+ link_idx_ = 0;
+ // The LinkTargetVector should never be empty but just in case safeguard
+ // against that too.
+ PropagateNext();
+ }
+}
+
+bool SubgraphIterator::Next() {
+ if (AtEnd()) {
+ return false;
+ }
+ ++link_idx_;
+ return PropagateNext();
+}
+
+bool SubgraphIterator::NextIfSamePort() {
+ if (AtEnd()) {
+ return false;
+ }
+ if (link_idx_ + 1 < link_map_it_->second.size()) {
+ ++link_idx_;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+void SubgraphIterator::SkipPort() {
+ if (AtEnd()) {
+ return;
+ }
+ link_idx_ = link_map_it_->second.size() - 1;
+}
+
+void SubgraphIterator::SkipNode() {
+ if (AtEnd()) {
+ return;
+ }
+ for (auto next = link_map_it_; next != (*id_it_)->links().end(); ++next) {
+ link_map_it_ = next;
+ }
+ link_idx_ = link_map_it_->second.size() - 1;
+}
+
+bool SubgraphIterator::PropagateNext() {
+ // Loops are used to skip over the empty entries.
+ while (link_idx_ >= link_map_it_->second.size()) {
+ ++link_map_it_;
+ while (link_map_it_ == (*id_it_)->links().end()) {
+ if (++id_it_ == id_->end()) {
+ return false;
+ }
+ link_map_it_ = (*id_it_)->links().begin();
+ }
+ link_idx_ = 0;
+ }
+ return true;
+}
+
+bool SubgraphIterator::operator==(const SubgraphIterator& other) const {
+ if (id_ != other.id_) {
+ return false;
+ }
+ if (id_it_ != other.id_it_) {
+ return false;
+ }
+ // When AtEnd(), the rest of the fields are not valid.
+ if (AtEnd()) {
+ return true;
+ }
+ if (link_map_it_ != other.link_map_it_) {
+ return false;
+ }
+ if (link_idx_ != other.link_idx_) {
+ return false;
+ }
+ return true;
+}
+
+//=== SubgraphPtrSet
+
+Subgraph* SubgraphPtrSet::ExtendParent(const Subgraph::Identity& parent_id,
+ GenNode* node) {
+ if (parent_id.find(node) != parent_id.end()) {
+ // This was another link to the node that is already in the parent.
+ return nullptr;
+ }
+
+ // Constructing an object just to check that an equivalent one is already
+ // present is kind of ugly but storing the references rather than the objects
+ // in the set avoids the need to make the object copyable.
+ auto sg = absl::make_unique<Subgraph>(parent_id, node);
+ if (find(sg) != end()) {
+ // This subgraph was already found by extending from a different path.
+ return nullptr;
+ }
+
+ Subgraph* ptr = sg.get();
+ insert(std::move(sg));
+ return ptr;
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph.h b/tensorflow/core/grappler/graph_analyzer/subgraph.h
new file mode 100644
index 0000000000..4de31d5dfa
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/subgraph.h
@@ -0,0 +1,189 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
+
+#include <initializer_list>
+#include <set>
+
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/map_tools.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+// The description of a single subgraph for processing.
+class Subgraph {
+ public:
+ // Identity of a single subgraph as a set of nodes.
+ class Identity : public gtl::FlatSet<const GenNode*> {
+ public:
+ using InitializerList = std::initializer_list<GenNode*>;
+
+ Identity() = default;
+ Identity(InitializerList init);
+ bool operator<(const Identity& other) const;
+ bool operator==(const Identity& other) const;
+
+ // Compute the hash.
+ size_t Hash() const;
+ };
+
+ explicit Subgraph(Identity id) : id_(std::move(id)), hash_(id_.Hash()) {}
+
+ // Construct by extending the parent identity with an extra node.
+ Subgraph(const Identity& parent_id, GenNode* add_node);
+
+ Subgraph() = delete;
+ Subgraph(const Subgraph& other) = delete;
+ void operator=(const Subgraph& other) = delete;
+
+ // Order for building sets of subgraphs.
+ bool operator<(const Subgraph& other) const { return this->id_ < other.id_; }
+ // Support for hashed sets.
+ bool operator==(const Subgraph& other) const {
+ return this->id_ == other.id_;
+ }
+ size_t Hash() const { return hash_; }
+
+ // Dump the subgraph information to a string.
+ string Dump();
+
+ // Extract this subgraph into a separate graph representation for signature
+ // building, that includes only the links between the nodes in the subgraph
+ // and drops all the external links. The result map should be clear before the
+ // call.
+ void ExtractForSignature(SigNodeMap* result);
+
+ const Identity& id() const { return id_; }
+ bool specific() const { return specific_; }
+ void SetSpecific(bool value) { specific_ = value; }
+ int32_t collation_count() const { return collation_count_; }
+ void AddCollation(int32_t n = 1) { collation_count_ += n; }
+ void ResetCollation() { collation_count_ = 1; }
+ void MergeCollation(const Subgraph& other) {
+ collation_count_ += other.collation_count_;
+ }
+
+ private:
+ // Identity also serves as the list of nodes. It never changes throughout the
+ // life of subgraph.
+ Identity id_;
+ size_t hash_; // Cached from the identity.
+ // Whether the dump should include the specific names of the nodes. The
+ // non-specific (i.e. generic) subgraphs represent a collation of multiple
+ // subgraphs.
+ bool specific_ = true;
+ // How many collated subgraphs are represented by this subgraph.
+ int32_t collation_count_ = 1;
+};
+
+// Iteration of all links in a subgraph. This is more like Java iterators than
+// the normal C++ iterators. It's simpler this way and there seems to be no
+// major reason to make it a proper C++ iterator.
+class SubgraphIterator {
+ public:
+ // Obviously an iterator is valid only until the original object
+ // gets destroyed.
+ explicit SubgraphIterator(const Subgraph::Identity* id);
+ explicit SubgraphIterator(const Subgraph* sg) : SubgraphIterator(&sg->id()) {}
+
+ // Check whether the built-in iterator is at the end.
+ bool AtEnd() const { return id_it_ == id_->end(); }
+
+ // Get the neighbor at the current iterator.
+ // MUST NOT be called when AtEnd();
+ const GenNode::LinkTarget& GetNeighbor() const {
+ return link_map_it_->second[link_idx_];
+ }
+
+ // Get the node at the current iterator.
+ // MUST NOT be called when AtEnd();
+ const GenNode* GetNode() const { return *id_it_; }
+
+ // Get the port leading to the neighbor at the current iterator.
+ // MUST NOT be called when AtEnd();
+ GenNode::Port GetPort() const { return link_map_it_->first; }
+
+ // Increases the iterator.
+ // Returns true if NOT AtEnd() after increasing the iterator.
+ // Safe to call if already AtEnd().
+ bool Next();
+
+ // If there are more links at the same port, increases the iterator and
+ // returns true. Otherwise leaves the iterator unchanged and returns false.
+ bool NextIfSamePort();
+
+ // Increases the iterator directly to the last position on the current port
+ // (or if already there then doesn't increase). Equivalent to calling
+ // NextIfSamePort() while it returns true, but faster.
+ // Safe to call if already AtEnd().
+ void SkipPort();
+
+ // Increases the iterator directly to the last position on the current node.
+ // Safe to call if already AtEnd().
+ void SkipNode();
+
+ // Returns true if the iterators are exactly the same.
+ bool operator==(const SubgraphIterator& other) const;
+ bool operator!=(const SubgraphIterator& other) const {
+ return !(*this == other);
+ }
+
+ private:
+ // After link_idx_ has been increased, make sure that it points to the
+ // next valid element (or end) by increasing the higher levels of iteration if
+ // needed.
+ // Returns true if NOT AtEnd() after increasing the iterator.
+ // NOT safe to call if already AtEnd().
+ bool PropagateNext();
+
+ // Identity of the subgraph being iterated over.
+ const Subgraph::Identity* id_;
+
+ // The current position, allowing to iterate through the links (see the
+ // reasoning for it in the public section).
+ //
+ // (1) Iterator of the nodes in the subgraph.
+ Subgraph::Identity::const_iterator id_it_;
+ // (2) Iterator in the link map of the node.
+ GenNode::LinkMap::const_iterator link_map_it_;
+ // (3) Index in the vector of the links.
+ int32_t link_idx_;
+};
+
+// A convenient way to store subgraphs: in a set of unique_ptrs. This way the
+// addresses of subgraph objects will stay stable, and the objects themselves
+// won't be copied.
+class SubgraphPtrSet
+ : public std::unordered_set<std::unique_ptr<Subgraph>,
+ HashAtPtr<std::unique_ptr<Subgraph>>,
+ EqAtPtr<std::unique_ptr<Subgraph>>> {
+ public:
+ // Attempts to extend the set by adding a new subgraph that gets created by
+ // adding one node to the parent subgraph. If such a subgraph already exists,
+ // returns nullptr, otherwise returns the pointer to the new subgraph.
+ Subgraph* ExtendParent(const Subgraph::Identity& parent_id, GenNode* node);
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc b/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc
new file mode 100644
index 0000000000..0f90dc8f0d
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc
@@ -0,0 +1,348 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Ne;
+
+TEST(SubgraphTest, Comparison) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node2");
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ ASSERT_THAT(gn1, Ne(nullptr));
+ ASSERT_THAT(gn2, Ne(nullptr));
+
+ Subgraph::Identity id1;
+ Subgraph::Identity id2;
+
+ id1.insert(gn1);
+ id2.insert(gn2);
+
+ Subgraph sg1(id1);
+ Subgraph sg2(id2);
+
+ EXPECT_TRUE(id1 == sg1.id());
+ EXPECT_TRUE(id2 == sg2.id());
+
+ EXPECT_THAT(sg1 < sg2, Eq(id1 < id2));
+}
+
+TEST(SubgraphTest, EmptyIteration) {
+ NodeDef node1 = MakeNodeConst("node1");
+ auto gn1 = absl::make_unique<GenNode>(&node1);
+ Subgraph::Identity id1;
+ id1.insert(gn1.get());
+ Subgraph sg1(id1);
+ SubgraphIterator sit(&sg1);
+
+ EXPECT_TRUE(sit.AtEnd());
+ EXPECT_FALSE(sit.Next());
+ EXPECT_TRUE(sit.AtEnd());
+
+ SubgraphIterator sit2(&sg1);
+ EXPECT_TRUE(sit == sit2);
+}
+
+TEST(SubgraphTest, Iteration) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ auto node3 = graph.add_node();
+ *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ node3->add_input("^node3"); // The control link goes back to self.
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id;
+ id.insert(map["node3"].get());
+ Subgraph sg(id);
+
+ // node3 has 2 incoming data links, 2 outgoing data , 1 control incoming, 1
+ // control outgoing = total of 6
+ {
+ SubgraphIterator sit(&sg);
+ EXPECT_FALSE(sit.AtEnd()); // 1
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 2
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 3
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 4
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 5
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 6
+ EXPECT_FALSE(sit.Next());
+ EXPECT_TRUE(sit.AtEnd());
+ }
+
+ // Now get the values out. And more equality testing along the way.
+ {
+ SubgraphIterator sit(&sg);
+ SubgraphIterator sit2(&sg);
+ std::vector<string> links;
+ for (; !sit.AtEnd(); sit.Next()) {
+ EXPECT_TRUE(sit == sit2);
+ sit2.Next();
+ EXPECT_FALSE(sit == sit2);
+
+ links.push_back(absl::StrFormat("[%s,%s,%s]", string(sit.GetPort()),
+ sit.GetNeighbor().node->name(),
+ string(sit.GetNeighbor().port)));
+ }
+ EXPECT_TRUE(sit == sit2);
+
+ std::sort(links.begin(), links.end());
+ // clang-format off
+ EXPECT_THAT(links, ElementsAre(
+ "[i0,node1,o0]",
+ "[i1,node2,o0]",
+ "[iC,node3,oC]",
+ "[o0,node2,i1]",
+ "[o1,node2,i0]",
+ "[oC,node3,iC]"
+ ));
+ // clang-format on
+ }
+}
+
+TEST(SubgraphTest, IterationSamePort) {
+ GraphDef graph;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3", "node3");
+ (*graph.add_node()) = MakeNodeAddN("node3", "node1", "node2");
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id;
+ id.insert(map["node3"].get());
+ Subgraph sg(id);
+
+ int total_links = 0;
+ for (SubgraphIterator sit(&sg); !sit.AtEnd(); sit.Next()) {
+ ++total_links;
+ }
+
+ // Initialize the port as control, which doesn't occur in this graph.
+ GenNode::Port last_port(false, -1);
+ int steps_total_same_port = 0;
+ int steps_with_same_port = 0;
+ for (SubgraphIterator sit(&sg); !sit.AtEnd(); sit.Next()) {
+ GenNode::Port new_port = sit.GetPort();
+ EXPECT_THAT(last_port.Encoded(), Ne(new_port.Encoded()))
+ << "At step " << steps_total_same_port;
+ last_port = new_port;
+
+ ++steps_total_same_port;
+
+ SubgraphIterator sit2(sit);
+ sit2.SkipPort();
+
+ while (sit.NextIfSamePort()) {
+ new_port = sit.GetPort();
+ EXPECT_THAT(last_port.Encoded(), Eq(new_port.Encoded()))
+ << "At step " << steps_total_same_port;
+ ++steps_total_same_port;
+ ++steps_with_same_port;
+ }
+
+ EXPECT_TRUE(sit == sit2);
+ }
+
+ EXPECT_THAT(steps_total_same_port, Eq(total_links));
+ // There is one 2-way input and one 2-way output.
+ EXPECT_THAT(steps_with_same_port, Eq(2));
+}
+
+TEST(SubgraphTest, IterationSameNode) {
+ GraphDef graph;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3", "node3");
+ (*graph.add_node()) = MakeNodeAddN("node3", "node1", "node2");
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id;
+ id.insert(map["node3"].get());
+ Subgraph sg(id);
+
+ const GenNode* last_node = nullptr;
+ SubgraphIterator sit(&sg);
+ while (!sit.AtEnd()) {
+ const GenNode* new_node = sit.GetNode();
+
+ EXPECT_THAT(new_node, Ne(last_node)) << "At node " << new_node->name();
+
+ SubgraphIterator sit2(sit);
+ sit2.SkipNode();
+
+ ASSERT_FALSE(sit2.AtEnd());
+ EXPECT_THAT(sit2.GetNode(), Eq(new_node))
+ << "At expected node " << new_node->name() << ", got "
+ << sit2.GetNode()->name();
+
+ while (sit != sit2 && !sit.AtEnd()) {
+ sit.Next();
+ }
+
+ ASSERT_FALSE(sit.AtEnd());
+ EXPECT_THAT(sit.GetNode(), Eq(new_node))
+ << "At expected node " << new_node->name() << ", got "
+ << sit2.GetNode()->name();
+
+ sit.Next();
+
+ last_node = new_node;
+ }
+
+ // Check that it doesn't fail if already at end.
+ sit.SkipNode();
+ EXPECT_TRUE(sit.AtEnd());
+}
+
+TEST(SubgraphTest, ExtendSet) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ auto node3 = graph.add_node();
+ *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ node3->add_input("^node3"); // The control link goes back to self.
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node2"), Ne(map.end()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id_empty;
+
+ Subgraph::Identity id3;
+ id3.insert(map["node3"].get());
+
+ Subgraph::Identity id23 = id3;
+ id23.insert(map["node2"].get());
+
+ Subgraph* sg;
+ SubgraphPtrSet set;
+
+ // Extend an empty identity.
+ sg = set.ExtendParent(id_empty, map["node3"].get());
+ EXPECT_THAT(set.size(), Eq(1));
+ ASSERT_THAT(sg, Ne(nullptr));
+ EXPECT_TRUE(sg->id() == id3);
+
+ // Extend with a node that is already in the parent.
+ sg = set.ExtendParent(id3, map["node3"].get());
+ EXPECT_THAT(set.size(), Eq(1));
+ EXPECT_THAT(sg, Eq(nullptr));
+
+ // Extend to a 2-node subgraph.
+ sg = set.ExtendParent(id3, map["node2"].get());
+ EXPECT_THAT(set.size(), Eq(2));
+ ASSERT_THAT(sg, Ne(nullptr));
+ EXPECT_TRUE(sg->id() == id23);
+
+ // The second insert of the same node gets ignored.
+ sg = set.ExtendParent(id3, map["node2"].get());
+ EXPECT_THAT(set.size(), Eq(2));
+ EXPECT_THAT(sg, Eq(nullptr));
+}
+
+TEST(SubgraphTest, ExtractForSignature) {
+ GraphDef graph;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ auto node3 = graph.add_node();
+ *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ node3->add_input("^node1");
+ node3->add_input("^node2");
+ node3->add_input("^node3"); // The control link goes back to self.
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node1"), Ne(map.end()));
+ ASSERT_THAT(map.find("node2"), Ne(map.end()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id;
+ id.insert(map["node1"].get());
+ id.insert(map["node3"].get());
+
+ Subgraph sg(id);
+
+ SigNodeMap map2;
+ sg.ExtractForSignature(&map2);
+ ASSERT_THAT(map2.find("node1"), Ne(map2.end()));
+ ASSERT_THAT(map2.find("node2"), Eq(map2.end()));
+ ASSERT_THAT(map2.find("node3"), Ne(map2.end()));
+
+ // clang-format off
+ EXPECT_THAT(DumpLinkHashMap(map2["node1"]->hash_to_link()), ElementsAre(
+ "oC:iC: node3",
+ "o0:i0: node3"
+ ));
+ EXPECT_THAT(DumpHashedPeerVector(map2["node1"]->hashed_peers()), ElementsAre(
+ "node3",
+ "node3"
+ ));
+ EXPECT_THAT(DumpLinkHashMap(map2["node3"]->hash_to_link()), ElementsAre(
+ "oC:iC: node3",
+ "iC:oC: node1, node3",
+ "i0:o0: node1"
+ ));
+ EXPECT_THAT(DumpHashedPeerVector(map2["node3"]->hashed_peers()), ElementsAre(
+ "node3",
+ "node1",
+ "node3",
+ "node1"
+ ));
+ // clang-format on
+}
+
+} // end namespace
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/test_tools.cc b/tensorflow/core/grappler/graph_analyzer/test_tools.cc
new file mode 100644
index 0000000000..fc9495bc7d
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/test_tools.cc
@@ -0,0 +1,296 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+
+//=== Helper methods to construct the nodes.
+
+NodeDef MakeNodeConst(const string& name) {
+ NodeDef n;
+ n.set_name(name);
+ n.set_op("Const");
+ return n;
+}
+
+NodeDef MakeNode2Arg(const string& name, const string& opcode,
+ const string& arg1, const string& arg2) {
+ NodeDef n;
+ n.set_name(name);
+ n.set_op(opcode);
+ n.add_input(arg1);
+ n.add_input(arg2);
+ return n;
+}
+
+NodeDef MakeNode4Arg(const string& name, const string& opcode,
+ const string& arg1, const string& arg2, const string& arg3,
+ const string& arg4) {
+ NodeDef n;
+ n.set_name(name);
+ n.set_op(opcode);
+ n.add_input(arg1);
+ n.add_input(arg2);
+ n.add_input(arg3);
+ n.add_input(arg4);
+ return n;
+}
+
+// Not really a 2-argument but convenient to construct.
+NodeDef MakeNodeShapeN(const string& name, const string& arg1,
+ const string& arg2) {
+ // This opcode is multi-input but not commutative.
+ return MakeNode2Arg(name, "ShapeN", arg1, arg2);
+}
+
+// Not really a 2-argument but convenient to construct.
+NodeDef MakeNodeIdentityN(const string& name, const string& arg1,
+ const string& arg2) {
+ // The argument is of a list type.
+ return MakeNode2Arg(name, "IdentityN", arg1, arg2);
+}
+
+NodeDef MakeNodeQuantizedConcat(const string& name, const string& arg1,
+ const string& arg2, const string& arg3,
+ const string& arg4) {
+ // This opcode has multiple multi-inputs.
+ return MakeNode4Arg(name, "QuantizedConcat", arg1, arg2, arg3, arg4);
+}
+
+//=== Helper methods for analysing the structures.
+
+std::vector<string> DumpLinkMap(const GenNode::LinkMap& link_map) {
+ // This will order the entries first.
+ std::map<string, string> ordered;
+ for (const auto& link : link_map) {
+ string key = string(link.first);
+
+ // Order the other sides too. They may be repeating, so store them
+ // in a multiset.
+ std::multiset<string> others;
+ for (const auto& other : link.second) {
+ others.emplace(
+ absl::StrFormat("%s[%s]", other.node->name(), string(other.port)));
+ }
+ ordered[key] = absl::StrJoin(others, ", ");
+ }
+ // Now dump the result in a predictable order.
+ std::vector<string> result;
+ result.reserve(ordered.size());
+ for (const auto& link : ordered) {
+ result.emplace_back(link.first + ": " + link.second);
+ }
+ return result;
+}
+
+std::vector<string> DumpLinkHashMap(const SigNode::LinkHashMap& link_hash_map) {
+ // The entries in this map are ordered by hash value which might change
+ // at any point. Re-order them by the link tag.
+ std::map<SigNode::LinkTag, size_t> tags;
+ for (const auto& entry : link_hash_map) {
+ tags[entry.second.tag] = entry.first;
+ }
+
+ std::vector<string> result;
+ for (const auto& id : tags) {
+ // For predictability, the nodes need to be sorted.
+ std::vector<string> nodes;
+ for (const auto& peer : link_hash_map.at(id.second).peers) {
+ nodes.emplace_back(peer->name());
+ }
+ std::sort(nodes.begin(), nodes.end());
+ result.emplace_back(string(id.first.local) + ":" + string(id.first.remote) +
+ ": " + absl::StrJoin(nodes, ", "));
+ }
+ return result;
+}
+
+std::vector<string> DumpHashedPeerVector(
+ const SigNode::HashedPeerVector& hashed_peers) {
+ std::vector<string> result;
+
+ // Each subset of nodes with the same hash has to be sorted by name.
+ // Other than that, the vector is already ordered by full tags.
+ size_t last_hash = 0;
+ // Index, since iterators may get invalidated on append.
+ size_t subset_start = 0;
+
+ for (const auto& entry : hashed_peers) {
+ if (entry.link_hash != last_hash) {
+ std::sort(result.begin() + subset_start, result.end());
+ subset_start = result.size();
+ }
+ result.emplace_back(entry.peer->name());
+ }
+ std::sort(result.begin() + subset_start, result.end());
+
+ return result;
+}
+
+TestGraphs::TestGraphs() {
+ {
+ GraphDef& graph = graph_3n_self_control_;
+ // The topology includes a loop and a link to self.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ auto node3 = graph.add_node();
+ *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ node3->add_input("^node3"); // The control link goes back to self.
+ }
+ {
+ GraphDef& graph = graph_multi_input_;
+ // The topology includes a loop and a link to self.
+ (*graph.add_node()) = MakeNodeConst("const1_1");
+ (*graph.add_node()) = MakeNodeConst("const1_2");
+ (*graph.add_node()) = MakeNodeAddN("add1", "const1_1", "const1_2");
+
+ (*graph.add_node()) = MakeNodeConst("const2_1");
+ (*graph.add_node()) = MakeNodeConst("const2_2");
+ (*graph.add_node()) = MakeNodeConst("const2_3");
+
+ auto add2 = graph.add_node();
+ *add2 = MakeNodeAddN("add2", "const2_1", "const2_2");
+ // The 3rd node is connected twice, to 4 links total.
+ add2->add_input("const2_3");
+ add2->add_input("const2_3");
+
+ (*graph.add_node()) = MakeNodeSub("sub", "add1", "add2");
+ }
+ {
+ GraphDef& graph = graph_all_or_none_;
+ // The topology includes a loop and a link to self.
+ (*graph.add_node()) = MakeNodeConst("const1_1");
+ (*graph.add_node()) = MakeNodeConst("const1_2");
+ auto pass1 = graph.add_node();
+ *pass1 = MakeNodeIdentityN("pass1", "const1_1", "const1_2");
+
+ (*graph.add_node()) = MakeNodeConst("const2_1");
+ (*graph.add_node()) = MakeNodeConst("const2_2");
+ (*graph.add_node()) = MakeNodeConst("const2_3");
+
+ auto pass2 = graph.add_node();
+ *pass2 = MakeNodeIdentityN("pass2", "const2_1", "const2_2");
+ // The 3rd node is connected twice, to 4 links total.
+ pass2->add_input("const2_3");
+ pass2->add_input("const2_3");
+
+ // Add the control links, they get handled separately than the normal
+ // links.
+ pass1->add_input("^const2_1");
+ pass1->add_input("^const2_2");
+ pass1->add_input("^const2_3");
+
+ (*graph.add_node()) = MakeNodeSub("sub", "pass1", "pass2");
+ }
+ {
+ GraphDef& graph = graph_circular_onedir_;
+ (*graph.add_node()) = MakeNodeMul("node1", "node5", "node5");
+ (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
+ (*graph.add_node()) = MakeNodeMul("node3", "node2", "node2");
+ (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
+ (*graph.add_node()) = MakeNodeMul("node5", "node4", "node4");
+ }
+ {
+ GraphDef& graph = graph_circular_bidir_;
+ // The left and right links are intentionally mixed up.
+ (*graph.add_node()) = MakeNodeMul("node1", "node5", "node2");
+ (*graph.add_node()) = MakeNodeMul("node2", "node3", "node1");
+ (*graph.add_node()) = MakeNodeMul("node3", "node2", "node4");
+ (*graph.add_node()) = MakeNodeMul("node4", "node5", "node3");
+ (*graph.add_node()) = MakeNodeMul("node5", "node4", "node1");
+ }
+ {
+ GraphDef& graph = graph_linear_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
+ (*graph.add_node()) = MakeNodeMul("node3", "node2", "node2");
+ (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
+ (*graph.add_node()) = MakeNodeMul("node5", "node4", "node4");
+ }
+ {
+ GraphDef& graph = graph_cross_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
+ (*graph.add_node()) = MakeNodeConst("node3");
+ (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
+ (*graph.add_node()) = MakeNodeConst("node5");
+ (*graph.add_node()) = MakeNodeMul("node6", "node5", "node5");
+ (*graph.add_node()) = MakeNodeConst("node7");
+ (*graph.add_node()) = MakeNodeMul("node8", "node7", "node7");
+
+ auto center = graph.add_node();
+ *center = MakeNodeMul("node9", "node2", "node4");
+ center->add_input("node6");
+ center->add_input("node8");
+ }
+ {
+ GraphDef& graph = graph_small_cross_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node2");
+ (*graph.add_node()) = MakeNodeConst("node3");
+ (*graph.add_node()) = MakeNodeConst("node4");
+
+ auto center = graph.add_node();
+ *center = MakeNodeMul("node5", "node1", "node2");
+ center->add_input("node3");
+ center->add_input("node4");
+ }
+ {
+ GraphDef& graph = graph_for_link_order_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node2");
+ (*graph.add_node()) = MakeNodeConst("node3");
+ (*graph.add_node()) = MakeNodeConst("node4");
+
+ // One group of equivalent links.
+ auto center = graph.add_node();
+ *center = MakeNodeMul("node5", "node1", "node2");
+ center->add_input("node3");
+ center->add_input("node4");
+
+ // Multiple groups, separated by unique links.
+ auto center2 = graph.add_node();
+ *center2 = MakeNodeMul("node6", "node1", "node2");
+ center2->add_input("node2:1");
+ center2->add_input("node3:2");
+ center2->add_input("node4:2");
+ center2->add_input("node4:3");
+ }
+ {
+ GraphDef& graph = graph_sun_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node2");
+ (*graph.add_node()) = MakeNodeConst("node3");
+ (*graph.add_node()) = MakeNodeConst("node4");
+ (*graph.add_node()) = MakeNodeConst("node5");
+ (*graph.add_node()) = MakeNodeSub("node6", "node1", "node10");
+ (*graph.add_node()) = MakeNodeSub("node7", "node2", "node6");
+ (*graph.add_node()) = MakeNodeSub("node8", "node3", "node7");
+ (*graph.add_node()) = MakeNodeSub("node9", "node4", "node8");
+ (*graph.add_node()) = MakeNodeSub("node10", "node5", "node9");
+ }
+}
+
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/test_tools.h b/tensorflow/core/grappler/graph_analyzer/test_tools.h
new file mode 100644
index 0000000000..98e269d57e
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/test_tools.h
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+#include "tensorflow/core/grappler/op_types.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+
+//=== Helper methods to construct the nodes.
+
+NodeDef MakeNodeConst(const string& name);
+
+NodeDef MakeNode2Arg(const string& name, const string& opcode,
+ const string& arg1, const string& arg2);
+
+NodeDef MakeNode4Arg(const string& name, const string& opcode,
+ const string& arg1, const string& arg2, const string& arg3,
+ const string& arg4);
+
+inline NodeDef MakeNodeMul(const string& name, const string& arg1,
+ const string& arg2) {
+ return MakeNode2Arg(name, "Mul", arg1, arg2);
+}
+
+// Not really a 2-argument but convenient to construct.
+inline NodeDef MakeNodeAddN(const string& name, const string& arg1,
+ const string& arg2) {
+ return MakeNode2Arg(name, "AddN", arg1, arg2);
+}
+
+inline NodeDef MakeNodeSub(const string& name, const string& arg1,
+ const string& arg2) {
+ return MakeNode2Arg(name, "Sub", arg1, arg2);
+}
+
+// Has 2 honest outputs.
+inline NodeDef MakeNodeBroadcastGradientArgs(const string& name,
+ const string& arg1,
+ const string& arg2) {
+ return MakeNode2Arg(name, "BroadcastGradientArgs", arg1, arg2);
+}
+
+NodeDef MakeNodeShapeN(const string& name, const string& arg1,
+ const string& arg2);
+
+NodeDef MakeNodeIdentityN(const string& name, const string& arg1,
+ const string& arg2);
+
+NodeDef MakeNodeQuantizedConcat(const string& name, const string& arg1,
+ const string& arg2, const string& arg3,
+ const string& arg4);
+
+//=== A container of pre-constructed graphs.
+
+class TestGraphs {
+ public:
+ TestGraphs();
+
+ // Graph with 3 nodes and a control link to self (which is not valid in
+ // reality but adds excitement to the tests).
+ GraphDef graph_3n_self_control_;
+ // Graph that has the multi-input links.
+ GraphDef graph_multi_input_;
+ // Graph that has the all-or-none nodes.
+ GraphDef graph_all_or_none_;
+ // All the nodes are connected in a circle that goes in one direction.
+ GraphDef graph_circular_onedir_;
+ // All the nodes are connected in a circle that goes in both directions.
+ GraphDef graph_circular_bidir_;
+ // The nodes are connected in a line.
+ GraphDef graph_linear_;
+ // The nodes are connected in a cross shape.
+ GraphDef graph_cross_;
+ GraphDef graph_small_cross_;
+ // For testing the ordering of links at the end of signature generation,
+ // a variation of a cross.
+ GraphDef graph_for_link_order_;
+ // Sun-shaped, a ring with "rays".
+ GraphDef graph_sun_;
+};
+
+//=== Helper methods for analysing the structures.
+
+std::vector<string> DumpLinkMap(const GenNode::LinkMap& link_map);
+
+// Also checks for the consistency of hash values.
+std::vector<string> DumpLinkHashMap(const SigNode::LinkHashMap& link_hash_map);
+
+std::vector<string> DumpHashedPeerVector(
+ const SigNode::HashedPeerVector& hashed_peers);
+
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index f62a927925..5af6437c56 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3777,6 +3777,7 @@ tf_py_wrap_cc(
"framework/python_op_gen.i",
"grappler/cluster.i",
"grappler/cost_analyzer.i",
+ "grappler/graph_analyzer.i",
"grappler/item.i",
"grappler/model_analyzer.i",
"grappler/tf_optimizer.i",
@@ -3835,6 +3836,7 @@ tf_py_wrap_cc(
"//tensorflow/core/grappler/clusters:single_machine",
"//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/costs:graph_memory",
+ "//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
"//tensorflow/core:lib",
"//tensorflow/core:reader_base",
@@ -5536,6 +5538,18 @@ py_test(
],
)
+py_binary(
+ name = "graph_analyzer",
+ srcs = [
+ "grappler/graph_analyzer.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework_for_generated_wrappers",
+ ":pywrap_tensorflow_internal",
+ ],
+)
+
pyx_library(
name = "framework_fast_tensor_util",
srcs = ["framework/fast_tensor_util.pyx"],
diff --git a/tensorflow/python/grappler/graph_analyzer.i b/tensorflow/python/grappler/graph_analyzer.i
new file mode 100644
index 0000000000..cc7b5358eb
--- /dev/null
+++ b/tensorflow/python/grappler/graph_analyzer.i
@@ -0,0 +1,26 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%{
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h"
+%}
+
+%{
+void GraphAnalyzer(const string& file_path, int n) {
+ tensorflow::grappler::graph_analyzer::GraphAnalyzerTool(file_path, n);
+}
+%}
+
+void GraphAnalyzer(const string& file_path, int n);
diff --git a/tensorflow/python/grappler/graph_analyzer.py b/tensorflow/python/grappler/graph_analyzer.py
new file mode 100644
index 0000000000..ec5544e38e
--- /dev/null
+++ b/tensorflow/python/grappler/graph_analyzer.py
@@ -0,0 +1,46 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""A tool that finds all subgraphs of a given size in a TF graph.
+
+The subgraph patterns are sorted by occurrence, and only the transitive fanin
+part of the graph with regard to the fetch nodes is considered.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+from tensorflow.python import pywrap_tensorflow as tf_wrap
+from tensorflow.python.platform import app
+
+
+def main(_):
+ tf_wrap.GraphAnalyzer(FLAGS.input, FLAGS.n)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--input",
+ type=str,
+ default=None,
+ help="Input file path for a TensorFlow MetaGraphDef.")
+ parser.add_argument(
+ "--n", type=int, default=None, help="The size of the subgraphs.")
+ FLAGS, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 26e8acd897..39174fa589 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -54,4 +54,5 @@ limitations under the License.
%include "tensorflow/python/grappler/item.i"
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"
+%include "tensorflow/python/grappler/graph_analyzer.i"
%include "tensorflow/python/grappler/model_analyzer.i"