diff options
24 files changed, 5544 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/graph_analyzer/BUILD b/tensorflow/core/grappler/graph_analyzer/BUILD new file mode 100644 index 0000000000..d56a08d3c8 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/BUILD @@ -0,0 +1,139 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "graph_analyzer_lib", + srcs = [ + "gen_node.cc", + "graph_analyzer.cc", + "sig_node.cc", + "subgraph.cc", + ], + hdrs = [ + "gen_node.h", + "graph_analyzer.h", + "hash_tools.h", + "map_tools.h", + "sig_node.h", + "subgraph.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "graph_analyzer_tool", + srcs = ["graph_analyzer_tool.cc"], + hdrs = ["graph_analyzer_tool.h"], + visibility = ["//visibility:public"], + deps = [ + ":graph_analyzer_lib", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core/grappler:grappler_item", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "test_tools_lib", + testonly = 1, + srcs = [ + "test_tools.cc", + ], + hdrs = [ + "test_tools.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_analyzer_lib", + "//tensorflow/core:framework", + "//tensorflow/core:tensorflow", + "//tensorflow/core/grappler:op_types", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +tf_cc_test( + name = "hash_tools_test", + testonly = 1, + srcs = [ + "hash_tools_test.cc", + ], + deps = [ + ":graph_analyzer_lib", + "@com_google_googletest//:gtest_main", + ], +) + +tf_cc_test( + name = "gen_node_test", + testonly = 1, + srcs = [ + "gen_node_test.cc", + ], + deps = [ + ":graph_analyzer_lib", + ":test_tools_lib", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + +tf_cc_test( + name = "sig_node_test", + testonly = 1, + srcs = [ + "sig_node_test.cc", + ], + deps = [ + ":graph_analyzer_lib", + ":test_tools_lib", + "//tensorflow/core/grappler:utils", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest_main", + ], +) + +tf_cc_test( + name = "graph_analyzer_test", + testonly = 1, + srcs = [ + "graph_analyzer_test.cc", + ], + deps = [ + ":graph_analyzer_lib", + ":test_tools_lib", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + +tf_cc_test( + name = "subgraph_test", + testonly = 1, + srcs = [ + "subgraph_test.cc", + ], + deps = [ + ":graph_analyzer_lib", + ":test_tools_lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node.cc b/tensorflow/core/grappler/graph_analyzer/gen_node.cc new file mode 100644 index 0000000000..f8c15fd50e --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/gen_node.cc @@ -0,0 +1,148 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/graph_analyzer/gen_node.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/utils.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +GenNode::GenNode(const NodeDef* node) : node_(node), op_(nullptr) {} + +Status GenNode::BuildGraphInMap(const GraphDef& source, GenNodeMap* map) { + for (const auto& n : source.node()) { + const string& name = n.name(); + if (map->find(name) != map->end()) { + // This error code looks more meaningful than ALREADY_EXISTS. + return Status(error::INVALID_ARGUMENT, + "Duplicate node name '" + name + "'."); + } + (*map)[name] = absl::make_unique<GenNode>(&n); + } + // Now parse the links. + for (const auto& mapit : *map) { + Status st = mapit.second->ParseInputs(map); + if (!st.ok()) { + return st; + } + } + return Status::OK(); +} + +Status GenNode::ParseInputs(const GenNodeMap* map) { + all_inputs_or_none_ = false; + Status st = OpRegistry::Global()->LookUpOpDef(opcode(), &op_); + if (!st.ok()) { + return Status( + error::INVALID_ARGUMENT, + absl::StrFormat("Node '%s' contains an undefined operation '%s': %s", + name(), opcode(), st.error_message())); + } + + int n_inputs = node_->input_size(); + + int n_named_inputs = op_->input_arg_size(); + + int n_multi_inputs = 0; + for (const auto& inarg : op_->input_arg()) { + if (!inarg.number_attr().empty() || !inarg.type_list_attr().empty()) { + ++n_multi_inputs; + } + } + bool is_commutative = grappler::IsCommutative(*node_); + + if (n_multi_inputs > 1 || (n_multi_inputs > 0 && n_named_inputs > 1)) { + // Can't handle more than one multi-input at a time. + // And can't handle the commutativeness of only some arguments + // rather than all of them. + is_commutative = false; + } + + if (is_commutative) { + // If truly commutative, can treat all the inputs as one multi-input. + // It's possible to just treat the commutative nodes as AllInputsOrNone + // but (1) this way is a bit more efficient and (2) I want to preserve this + // more efficient code path that does all-or-none by a single input and + // perhaps extend its use in the future. + n_named_inputs = 1; + all_inputs_or_none_ = false; + } else if (n_multi_inputs > 0) { + all_inputs_or_none_ = true; + } + + for (int i = 0; i < n_inputs; ++i) { + int other_position; + string other_name = ParseNodeName(node_->input(i), &other_position); + auto other_it = map->find(other_name); + if (other_it == map->end()) { + return Status( + error::INVALID_ARGUMENT, + absl::StrFormat( + "Node '%s' input %d refers to a non-existing node '%s'.", name(), + i, other_name)); + } + GenNode* other_node = other_it->second.get(); + + int this_position = other_position < 0 ? -1 : (is_commutative ? 0 : i); + + if (this_position >= 0 && n_multi_inputs == 0 && + this_position >= n_named_inputs) { + return Status( + error::INVALID_ARGUMENT, + absl::StrFormat( + "Node '%s' has a non-control input from '%s' at index %d but its " + "operation '%s' defines only %d inputs.", + name(), other_name, this_position, op_->name(), n_named_inputs)); + } + + Port this_port(/*inbound=*/true, this_position); + Port other_port(/*inbound=*/false, other_position); + + links_[this_port].emplace_back(LinkTarget(other_node, other_port)); + other_node->links_[other_port].emplace_back(LinkTarget(this, this_port)); + } + return Status::OK(); +} + +bool GenNode::IsMultiInput(Port port) const { + if (!port.IsInbound()) { + return false; + } + auto it = links_.find(port); + if (it == links_.end()) { + return false; // Shouldn't happen. + } + return (it->second.size() > 1); +} + +GenNode::Port::operator string() const { + string result = this->IsInbound() ? "i" : "o"; + if (this->IsControl()) { + result.append("C"); + } else { + result.append(absl::StrFormat("%d", this->Id())); + } + return result; +} + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node.h b/tensorflow/core/grappler/graph_analyzer/gen_node.h new file mode 100644 index 0000000000..faec9ecad8 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/gen_node.h @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_ + +#include <map> +#include <memory> +#include <vector> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +class GenNode; + +// To find nodes by name. +using GenNodeMap = std::unordered_map<string, std::unique_ptr<GenNode>>; + +// One node in the graph, in the form convenient for traversal and generation of +// subgraphs. It refers to the original NodeDef protobuf for most information +// and adds the extra enrichment. +// +// The graph building is 2-stage: first match a GenNode with each NodeDef and +// collect them into a map that finds them by name, then process the map, +// deep-parse the underlying NodeDefs and connect the GenNodes together. +class GenNode { + public: + // Will keep the pointer, so the underlying object must not be deleted while + // GenNode is alive. + explicit GenNode(const NodeDef* node); + + // Access wrappers. + const string& name() const { return node_->name(); } + const string& opcode() const { return node_->op(); } + const NodeDef* node_def() const { return node_; } + + // Parse the inputs of this node and update the map accordingly, creating the + // links (i.e. edges, connections between nodes) in itself and in the nodes + // it's linked to (the map itself is unchanged, only the nodes in it are + // updated). + Status ParseInputs(const GenNodeMap* map); + + // Does the full 2-stage build of the graph. The map should be initially + // empty. The map keeps pointers to the nodes in source, so the source must + // not be destroyed before the map. + static Status BuildGraphInMap(const GraphDef& source, GenNodeMap* map); + + // The enrichment that constitutes the point of this class. + + // Representation of a connection on a node. + class Port { + public: + // A port may be inbound or outbound. + // Negative ids (canonically -1) mean a control port. + Port(bool inbound, int32_t id) : value_(id << 1) { + if (inbound) { + value_ |= 1; + } + } + Port(const Port&) = default; + Port& operator=(const Port&) = default; + + bool IsInbound() const { return (value_ & 0x1); } + + bool IsControl() const { return (value_ < 0); } + + int32_t Id() const { + // Arithmetic shift preserves the sign. + return (value_ >> 1); + } + + // Integer type used to represent the encoded port value. + using IntPort = int32_t; + + // Returns the encoded form of this port, so that it can be used + // as various map indexes. + IntPort Encoded() const { return value_; } + + static Port Decode(IntPort encoded) { return Port(encoded); } + + bool operator==(const Port& other) const { return value_ == other.value_; } + bool operator<(const Port& other) const { return value_ < other.value_; } + + struct Hasher { + size_t operator()(const Port& port) const noexcept { + return hasher(port.Encoded()); + } + std::hash<int32_t> hasher; + }; + + // Convenient for printing. I've really wanted it to be implicit but + // ClangTidy insists on making it explicit. + explicit operator string() const; + + private: + explicit Port(IntPort value) : value_(value) {} + + IntPort value_; + }; + + struct LinkTarget { + GenNode* node; // Node where this link points. + Port port; // Port on the remote side of this link. + + LinkTarget(GenNode* a_node, Port a_port) : node(a_node), port(a_port) {} + }; + // All the links that are connected to the same port of this node + // are collected in one vector. A link is an edge of the graph that connects + // 2 nodes. Each of the connected nodes has its own perspective on the link, + // seeing its local port, remote port and the remote node. The direction of + // the link is encoded in the ports, one port is always incoming and another + // one outgoing. + using LinkTargetVector = std::vector<LinkTarget>; + // Both inputs and outputs are stored in the same map. + using LinkMap = std::unordered_map<Port, LinkTargetVector, Port::Hasher>; + + // Access to the link map. + const LinkMap& links() const { return links_; } + + // Check whether the port is an input (including the controls) with multiple + // connections. Such inputs get handled in a special way when building the + // subgraphs, in an "all or nothing" fashion. + bool IsMultiInput(Port port) const; + + // When building the subgraphs, must include either all non-control inputs of + // this node into the subgraph or none of them. This happens when at least one + // of the inputs is a multi-input (or if the opcode is commutative, thus + // treating all the inputs as one multi-input). + bool AllInputsOrNone() const { return all_inputs_or_none_; } + + private: + const NodeDef* node_; + // Becomes valid only after ParseInputs(). + const OpDef* op_; + + // The opcode has a complicated structure of input args, with multi-input args + // that are not commutative. This means that to make sense, the subgraphs that + // include this node must also include either all its inputs or none of them. + bool all_inputs_or_none_ = false; + + LinkMap links_; +}; + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_ diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc b/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc new file mode 100644 index 0000000000..d77daf7849 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc @@ -0,0 +1,491 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/graph_analyzer/gen_node.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "absl/memory/memory.h" +#include "tensorflow/core/grappler/graph_analyzer/test_tools.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { +namespace test { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Ne; + +TEST(GenNodeTest, Port) { + { + GenNode::Port p(true, 100); + EXPECT_THAT(p.IsInbound(), Eq(true)); + EXPECT_THAT(p.IsControl(), Eq(false)); + EXPECT_THAT(p.Id(), Eq(100)); + GenNode::Port p2 = GenNode::Port::Decode(p.Encoded()); + EXPECT_THAT(p2.IsInbound(), Eq(true)); + EXPECT_THAT(p2.IsControl(), Eq(false)); + EXPECT_THAT(p2.Id(), Eq(100)); + } + { + GenNode::Port p(false, 0); + EXPECT_THAT(p.IsInbound(), Eq(false)); + EXPECT_THAT(p.IsControl(), Eq(false)); + EXPECT_THAT(p.Id(), Eq(0)); + GenNode::Port p2 = GenNode::Port::Decode(p.Encoded()); + EXPECT_THAT(p2.IsInbound(), Eq(false)); + EXPECT_THAT(p2.IsControl(), Eq(false)); + EXPECT_THAT(p2.Id(), Eq(0)); + } + { + GenNode::Port p(true, -100); + EXPECT_THAT(p.IsInbound(), Eq(true)); + EXPECT_THAT(p.IsControl(), Eq(true)); + EXPECT_THAT(p.Id(), Eq(-100)); + GenNode::Port p2 = GenNode::Port::Decode(p.Encoded()); + EXPECT_THAT(p2.IsInbound(), Eq(true)); + EXPECT_THAT(p2.IsControl(), Eq(true)); + EXPECT_THAT(p2.Id(), Eq(-100)); + } + { + GenNode::Port p(false, -1); + EXPECT_THAT(p.IsInbound(), Eq(false)); + EXPECT_THAT(p.IsControl(), Eq(true)); + EXPECT_THAT(p.Id(), Eq(-1)); + GenNode::Port p2 = GenNode::Port::Decode(p.Encoded()); + EXPECT_THAT(p2.IsInbound(), Eq(false)); + EXPECT_THAT(p2.IsControl(), Eq(true)); + EXPECT_THAT(p2.Id(), Eq(-1)); + } +} + +TEST(GenNodeTest, ParseNodeNoInputs) { + GenNodeMap map; + NodeDef node1 = MakeNodeConst("node1"); + map["node1"] = absl::make_unique<GenNode>(&node1); + + auto gn1 = map["node1"].get(); + ASSERT_THAT(gn1->ParseInputs(&map), Eq(Status::OK())); + EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre()); +} + +// A general operation, and a control link. +TEST(GenNodeTest, ParseNodeWithControl) { + GenNodeMap map; + + NodeDef node1 = MakeNodeConst("node1"); + map["node1"] = absl::make_unique<GenNode>(&node1); + + NodeDef node2 = MakeNodeConst("node2"); + map["node2"] = absl::make_unique<GenNode>(&node2); + + NodeDef node3 = MakeNodeSub("node3", "node1", "node2"); + node3.add_input("^node1"); // The control link. + node3.add_input("^node2"); // The control link. + map["node3"] = absl::make_unique<GenNode>(&node3); + + auto gn1 = map["node1"].get(); + auto gn2 = map["node2"].get(); + auto gn3 = map["node3"].get(); + ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK())); + // clang-format off + EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( + "o0: node3[i0]", + "oC: node3[iC]" + )); + EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre( + "o0: node3[i1]", + "oC: node3[iC]" + )); + EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre( + "i0: node1[o0]", + "i1: node2[o0]", + "iC: node1[oC], node2[oC]" + )); + // clang-format on + + EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(false)); + + // This is a multi-control-input. + EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, -1)), Eq(true)); + + EXPECT_FALSE(gn1->AllInputsOrNone()); + EXPECT_FALSE(gn2->AllInputsOrNone()); + EXPECT_FALSE(gn3->AllInputsOrNone()); +} + +// Commutative nodes are treated as having a single input, +// because their inputs are equivalent. +TEST(GenNodeTest, ParseNodeCommutative) { + GenNodeMap map; + + NodeDef node1 = MakeNodeConst("node1"); + map["node1"] = absl::make_unique<GenNode>(&node1); + + NodeDef node2 = MakeNodeConst("node2"); + map["node2"] = absl::make_unique<GenNode>(&node2); + + // TODO(babkin): grappler::IsCommutative() should return true for Add but + // apparently doesn't. So use Mul in the meantime. + NodeDef node3 = MakeNodeMul("node3", "node1", "node2"); + map["node3"] = absl::make_unique<GenNode>(&node3); + + auto gn1 = map["node1"].get(); + auto gn2 = map["node2"].get(); + auto gn3 = map["node3"].get(); + ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK())); + // clang-format off + EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( + "o0: node3[i0]" + )); + EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre( + "o0: node3[i0]" + )); + EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre( + "i0: node1[o0], node2[o0]" + )); + // clang-format on + + EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(true)); + + EXPECT_FALSE(gn3->AllInputsOrNone()); +} + +TEST(GenNodeTest, ParseNodeMultiInputCommutative) { + GenNodeMap map; + + NodeDef node1 = MakeNodeConst("node1"); + map["node1"] = absl::make_unique<GenNode>(&node1); + + NodeDef node2 = MakeNodeConst("node2"); + map["node2"] = absl::make_unique<GenNode>(&node2); + + NodeDef node3 = MakeNodeAddN("node3", "node1", "node2"); + map["node3"] = absl::make_unique<GenNode>(&node3); + + auto gn1 = map["node1"].get(); + auto gn2 = map["node2"].get(); + auto gn3 = map["node3"].get(); + ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK())); + // clang-format off + EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( + "o0: node3[i0]" + )); + EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre( + "o0: node3[i0]" + )); + EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre( + "i0: node1[o0], node2[o0]" + )); + // clang-format on + + // This is a multi-output. + EXPECT_THAT(gn2->IsMultiInput(GenNode::Port(false, 0)), Eq(false)); + // This is a multi-input. + EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(true)); + + EXPECT_FALSE(gn3->AllInputsOrNone()); +} + +TEST(GenNodeTest, ParseNodeMultiInputNotCommutative) { + GenNodeMap map; + + NodeDef node1 = MakeNodeConst("node1"); + map["node1"] = absl::make_unique<GenNode>(&node1); + + NodeDef node2 = MakeNodeConst("node2"); + map["node2"] = absl::make_unique<GenNode>(&node2); + + NodeDef node3 = MakeNodeShapeN("node3", "node1", "node2"); + map["node3"] = absl::make_unique<GenNode>(&node3); + + auto gn1 = map["node1"].get(); + auto gn2 = map["node2"].get(); + auto gn3 = map["node3"].get(); + ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK())); + // clang-format off + EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( + "o0: node3[i0]" + )); + EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre( + "o0: node3[i1]" + )); + EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre( + "i0: node1[o0]", + "i1: node2[o0]" + )); + // clang-format on + + // Non-commutative multi-input doesn't count. + EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(false)); + EXPECT_TRUE(gn3->AllInputsOrNone()); +} + +TEST(GenNodeTest, ParseNodeMultiInputList) { + GenNodeMap map; + + NodeDef node1 = MakeNodeConst("node1"); + map["node1"] = absl::make_unique<GenNode>(&node1); + + NodeDef node2 = MakeNodeConst("node2"); + map["node2"] = absl::make_unique<GenNode>(&node2); + + NodeDef node3 = MakeNodeIdentityN("node3", "node1", "node2"); + map["node3"] = absl::make_unique<GenNode>(&node3); + + auto gn1 = map["node1"].get(); + auto gn2 = map["node2"].get(); + auto gn3 = map["node3"].get(); + ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK())); + // clang-format off + EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( + "o0: node3[i0]" + )); + EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre( + "o0: node3[i1]" + )); + EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre( + "i0: node1[o0]", + "i1: node2[o0]" + )); + // clang-format on + + // Non-commutative multi-input doesn't count. + EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(false)); + EXPECT_TRUE(gn3->AllInputsOrNone()); +} + +TEST(GenNodeTest, ParseNodeMultiMultiInput) { + GenNodeMap map; + + NodeDef node1 = MakeNodeConst("node1"); + map["node1"] = absl::make_unique<GenNode>(&node1); + + NodeDef node2 = MakeNodeConst("node2"); + map["node2"] = absl::make_unique<GenNode>(&node2); + + NodeDef node3 = MakeNodeConst("node3"); + map["node3"] = absl::make_unique<GenNode>(&node3); + + NodeDef node4 = MakeNodeConst("node4"); + map["node4"] = absl::make_unique<GenNode>(&node4); + + NodeDef node5 = + MakeNodeQuantizedConcat("node5", "node1", "node2", "node3", "node4"); + map["node5"] = absl::make_unique<GenNode>(&node5); + + auto gn1 = map["node1"].get(); + auto gn2 = map["node2"].get(); + auto gn3 = map["node3"].get(); + auto gn4 = map["node4"].get(); + auto gn5 = map["node5"].get(); + ASSERT_THAT(gn5->ParseInputs(&map), Eq(Status::OK())); + // clang-format off + EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( + "o0: node5[i0]" + )); + EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre( + "o0: node5[i1]" + )); + EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre( + "o0: node5[i2]" + )); + EXPECT_THAT(DumpLinkMap(gn4->links()), ElementsAre( + "o0: node5[i3]" + )); + EXPECT_THAT(DumpLinkMap(gn5->links()), ElementsAre( + "i0: node1[o0]", + "i1: node2[o0]", + "i2: node3[o0]", + "i3: node4[o0]" + )); + // clang-format on + + // Non-commutative multi-input doesn't count. + EXPECT_THAT(gn5->IsMultiInput(GenNode::Port(true, 1)), Eq(false)); + EXPECT_THAT(gn5->IsMultiInput(GenNode::Port(true, 2)), Eq(false)); + EXPECT_TRUE(gn5->AllInputsOrNone()); +} + +TEST(GenNodeTest, ParseNodeMultiOutput) { + GenNodeMap map; + + NodeDef node1 = MakeNodeConst("node1"); + map["node1"] = absl::make_unique<GenNode>(&node1); + + NodeDef node2 = MakeNodeConst("node2"); + map["node2"] = absl::make_unique<GenNode>(&node2); + + NodeDef node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2"); + map["node3"] = absl::make_unique<GenNode>(&node3); + + NodeDef node4 = MakeNodeSub("node4", "node3:1", "node3:0"); + map["node4"] = absl::make_unique<GenNode>(&node4); + + auto gn4 = map["node4"].get(); + ASSERT_THAT(gn4->ParseInputs(&map), Eq(Status::OK())); + // clang-format off + EXPECT_THAT(DumpLinkMap(gn4->links()), ElementsAre( + "i0: node3[o1]", + "i1: node3[o0]" + )); + // clang-format on +} + +TEST(GenNodeTest, ParseNodeUndefinedOp) { + GenNodeMap map; + NodeDef node1; + node1.set_name("node1"); + node1.set_op("Zzzx"); + + map["node1"] = absl::make_unique<GenNode>(&node1); + + const OpDef* opdef; + Status nested_error = OpRegistry::Global()->LookUpOpDef("Zzzx", &opdef); + + auto gn = map["node1"].get(); + ASSERT_THAT( + gn->ParseInputs(&map), + Eq(Status(error::INVALID_ARGUMENT, + "Node 'node1' contains an undefined operation 'Zzzx': " + + nested_error.error_message()))); +} + +TEST(GenNodeTest, ParseNodeUnexpectedInputs) { + GenNodeMap map; + + NodeDef node1 = MakeNodeConst("node1"); + map["node1"] = absl::make_unique<GenNode>(&node1); + node1.add_input("node1"); + + auto gn1 = map["node1"].get(); + EXPECT_THAT(gn1->ParseInputs(&map), + Eq(Status(error::INVALID_ARGUMENT, + "Node 'node1' has a non-control " + "input from 'node1' at index 0 but its operation " + "'Const' defines only 0 inputs."))); + + NodeDef node2 = MakeNodeConst("node2"); + map["node2"] = absl::make_unique<GenNode>(&node2); + + NodeDef node3 = MakeNodeSub("node3", "node1", "node2"); + map["node3"] = absl::make_unique<GenNode>(&node3); + node3.add_input("node1"); + + auto gn3 = map["node3"].get(); + EXPECT_THAT(gn3->ParseInputs(&map), + Eq(Status(error::INVALID_ARGUMENT, + "Node 'node3' has a non-control " + "input from 'node1' at index 2 but its operation " + "'Sub' defines only 2 inputs."))); +} + +// Even if an opcode defines no inputs, the node may still accept the control +// inputs. +TEST(GenNodeTest, ParseNodeControlInputsAlwaysOk) { + GenNodeMap map; + NodeDef node1 = MakeNodeConst("node1"); + map["node1"] = absl::make_unique<GenNode>(&node1); + node1.add_input("^node1"); + auto gn1 = map["node1"].get(); + ASSERT_THAT(gn1->ParseInputs(&map), Eq(Status::OK())); + // clang-format off + EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre( + "iC: node1[oC]", + "oC: node1[iC]" + )); + // clang-format on +} + +TEST(GenNodeTest, ParseNodeInvalidInput) { + GenNodeMap map; + NodeDef node1 = MakeNodeAddN("node1", "node2", "node3"); + map["node1"] = absl::make_unique<GenNode>(&node1); + node1.add_input("node1"); + auto gn1 = map["node1"].get(); + ASSERT_THAT( + gn1->ParseInputs(&map), + Eq(Status( + error::INVALID_ARGUMENT, + "Node 'node1' input 0 refers to a non-existing node 'node2'."))); +} + +TEST(GenNodeTest, BuildGraphInMap) { + GraphDef graph; + // A topology with a loop. + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0"); + (*graph.add_node()) = + MakeNodeBroadcastGradientArgs("node3", "node1", "node2"); + + GenNodeMap map; + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK())); + ASSERT_THAT(map.find("node1"), Ne(map.end())); + ASSERT_THAT(map.find("node2"), Ne(map.end())); + ASSERT_THAT(map.find("node3"), Ne(map.end())); + + EXPECT_THAT(map["node1"]->name(), Eq("node1")); + EXPECT_THAT(map["node2"]->name(), Eq("node2")); + EXPECT_THAT(map["node3"]->name(), Eq("node3")); + + // clang-format off + EXPECT_THAT(DumpLinkMap(map["node1"]->links()), ElementsAre( + "o0: node3[i0]" + )); + EXPECT_THAT(DumpLinkMap(map["node2"]->links()), ElementsAre( + "i0: node3[o1]", + "i1: node3[o0]", + "o0: node3[i1]" + )); + EXPECT_THAT(DumpLinkMap(map["node3"]->links()), ElementsAre( + "i0: node1[o0]", + "i1: node2[o0]", + "o0: node2[i1]", + "o1: node2[i0]" + )); + // clang-format on +} + +TEST(GenNodeTest, BuildGraphInMapDuplicateNode) { + GraphDef graph; + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeConst("node1"); + GenNodeMap map; + ASSERT_THAT( + GenNode::BuildGraphInMap(graph, &map), + Eq(Status(error::INVALID_ARGUMENT, "Duplicate node name 'node1'."))); +} + +TEST(GenNodeTest, BuildGraphInMapParseError) { + GraphDef graph; + // A topology with a loop. + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0"); + + GenNodeMap map; + ASSERT_THAT( + GenNode::BuildGraphInMap(graph, &map), + Eq(Status( + error::INVALID_ARGUMENT, + "Node 'node2' input 0 refers to a non-existing node 'node3'."))); +} + +} // end namespace +} // end namespace test +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc new file mode 100644 index 0000000000..f3796fcf86 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc @@ -0,0 +1,341 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <deque> +#include <iostream> + +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "tensorflow/core/grappler/graph_analyzer/gen_node.h" +#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h" +#include "tensorflow/core/grappler/graph_analyzer/sig_node.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +GraphAnalyzer::GraphAnalyzer(const GraphDef& graph, int subgraph_size) + : graph_(graph), subgraph_size_(subgraph_size) {} + +GraphAnalyzer::~GraphAnalyzer() {} + +Status GraphAnalyzer::Run() { + // The signature computation code would detect this too, but better + // to report it up front than spend time computing all the graphs first. + if (subgraph_size_ > Signature::kMaxGraphSize) { + return Status(error::INVALID_ARGUMENT, + absl::StrFormat("Subgraphs of %d nodes are not supported, " + "the maximal supported node count is %d.", + subgraph_size_, Signature::kMaxGraphSize)); + } + + Status st = BuildMap(); + if (!st.ok()) { + return st; + } + + FindSubgraphs(); + DropInvalidSubgraphs(); + st = CollateResult(); + if (!st.ok()) { + return st; + } + + return Status::OK(); +} + +Status GraphAnalyzer::BuildMap() { + nodes_.clear(); + return GenNode::BuildGraphInMap(graph_, &nodes_); +} + +void GraphAnalyzer::FindSubgraphs() { + result_.clear(); + + if (subgraph_size_ < 1) { + return; + } + + partial_.clear(); + todo_.clear(); // Just in case. + + // Start with all subgraphs of size 1. + const Subgraph::Identity empty_parent; + for (const auto& node : nodes_) { + if (subgraph_size_ == 1) { + result_.ExtendParent(empty_parent, node.second.get()); + } else { + // At this point ExtendParent() is guaranteed to not return nullptr. + todo_.push_back(partial_.ExtendParent(empty_parent, node.second.get())); + } + } + + // Then extend the subgraphs until no more extensions are possible. + while (!todo_.empty()) { + ExtendSubgraph(todo_.front()); + todo_.pop_front(); + } + + partial_.clear(); +} + +void GraphAnalyzer::ExtendSubgraph(Subgraph* parent) { + bool will_complete = (parent->id().size() + 1 == subgraph_size_); + SubgraphPtrSet& sg_set = will_complete ? result_ : partial_; + + const GenNode* last_all_or_none_node = nullptr; + for (SubgraphIterator sit(parent); !sit.AtEnd(); sit.Next()) { + const GenNode* node = sit.GetNode(); + GenNode::Port port = sit.GetPort(); + const GenNode::LinkTarget& neighbor = sit.GetNeighbor(); + + if (node->AllInputsOrNone() && port.IsInbound() && !port.IsControl()) { + if (node != last_all_or_none_node) { + ExtendSubgraphAllOrNone(parent, node); + last_all_or_none_node = node; + } + sit.SkipPort(); + } else if (neighbor.node->AllInputsOrNone() && !port.IsInbound() && + !port.IsControl()) { + if (parent->id().find(neighbor.node) == parent->id().end()) { + // Not added yet. + ExtendSubgraphAllOrNone(parent, neighbor.node); + } + } else if (node->IsMultiInput(port)) { + ExtendSubgraphPortAllOrNone(parent, node, port); + sit.SkipPort(); + } else if (neighbor.node->IsMultiInput(neighbor.port)) { + // Would need to add all inputs of the neighbor node at this port at + // once. + if (parent->id().find(neighbor.node) != parent->id().end()) { + continue; // Already added. + } + ExtendSubgraphPortAllOrNone(parent, neighbor.node, neighbor.port); + } else { + Subgraph* sg = sg_set.ExtendParent(parent->id(), neighbor.node); + if (!will_complete && sg != nullptr) { + todo_.push_back(sg); + } + } + } +} + +void GraphAnalyzer::ExtendSubgraphAllOrNone(Subgraph* parent, + const GenNode* node) { + Subgraph::Identity id = parent->id(); + id.insert(node); + + auto range_end = node->links().end(); + + for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) { + auto port = nbit->first; + if (!port.IsInbound() || port.IsControl()) { + continue; + } + + // Since there might be multiple links to the same nodes, + // have to add all links one-by-one to check whether the subgraph + // would grow too large. But if it does grow too large, there is no + // point in growing it more, can just skip over the rest of the links. + for (const auto& link : nbit->second) { + id.insert(link.node); + if (id.size() > subgraph_size_) { + return; // Too big. + } + } + } + + AddExtendedSubgraph(parent, id); +} + +void GraphAnalyzer::ExtendSubgraphPortAllOrNone(Subgraph* parent, + const GenNode* node, + GenNode::Port port) { + auto nbit = node->links().find(port); + if (nbit == node->links().end()) { + return; // Should never happen. + } + + Subgraph::Identity id = parent->id(); + id.insert(node); + + // Since there might be multiple links to the same nodes, + // have to add all links one-by-one to check whether the subgraph + // would grow too large. But if it does grow too large, there is no + // point in growing it more, can just skip over the rest of the links. + for (const auto& link : nbit->second) { + id.insert(link.node); + if (id.size() > subgraph_size_) { + return; // Too big. + } + } + + AddExtendedSubgraph(parent, id); +} + +void GraphAnalyzer::AddExtendedSubgraph(Subgraph* parent, + const Subgraph::Identity& id) { + if (id.size() == parent->id().size()) { + return; // Nothing new was added. + } + + auto sg = absl::make_unique<Subgraph>(id); + SubgraphPtrSet& spec_sg_set = + (id.size() == subgraph_size_) ? result_ : partial_; + if (spec_sg_set.find(sg) != spec_sg_set.end()) { + // This subgraph was already found by extending from a different path. + return; + } + + if (id.size() != subgraph_size_) { + todo_.push_back(sg.get()); + } + spec_sg_set.insert(std::move(sg)); +} + +void GraphAnalyzer::DropInvalidSubgraphs() { + auto resit = result_.begin(); + while (resit != result_.end()) { + if (HasInvalidMultiInputs(resit->get())) { + auto delit = resit; + ++resit; + result_.erase(delit); + } else { + ++resit; + } + } +} + +bool GraphAnalyzer::HasInvalidMultiInputs(Subgraph* sg) { + // Do the all-or-none-input nodes. + for (auto const& node : sg->id()) { + if (!node->AllInputsOrNone()) { + continue; + } + + bool anyIn = false; + bool anyOut = false; + + auto range_end = node->links().end(); + for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) { + auto port = nbit->first; + if (!port.IsInbound() || port.IsControl()) { + continue; + } + + // Since there might be multiple links to the same nodes, + // have to add all links one-by-one to check whether the subgraph + // would grow too large. But if it does grow too large, there is no + // point in growing it more, can just skip over the rest of the links. + for (const auto& link : nbit->second) { + if (sg->id().find(link.node) == sg->id().end()) { + anyOut = true; + } else { + anyIn = true; + } + } + } + + if (anyIn && anyOut) { + return true; + } + } + + // Do the multi-input ports. + for (SubgraphIterator sit(sg); !sit.AtEnd(); sit.Next()) { + if (sit.GetNode()->IsMultiInput(sit.GetPort())) { + bool anyIn = false; + bool anyOut = false; + do { + GenNode* peer = sit.GetNeighbor().node; + if (sg->id().find(peer) == sg->id().end()) { + anyOut = true; + } else { + anyIn = true; + } + } while (sit.NextIfSamePort()); + + if (anyIn && anyOut) { + return true; + } + } + } + return false; +} + +Status GraphAnalyzer::CollateResult() { + ordered_collation_.clear(); + collation_map_.clear(); + + // Collate by the signatures of the graphs. + for (const auto& it : result_) { + auto sig = absl::make_unique<Signature>(); + it->ExtractForSignature(&sig->map); + Status status = sig->Compute(); + if (!status.ok()) { + return status; + } + + auto& coll_entry = collation_map_[sig.get()]; + if (coll_entry.sig == nullptr) { + coll_entry.sig = std::move(sig); + } + ++coll_entry.count; + } + + // Then order them by the count. + for (auto& entry : collation_map_) { + ordered_collation_.insert(&entry.second); + } + + result_.clear(); // Not needed after collation. + + return Status::OK(); +} + +std::vector<string> GraphAnalyzer::DumpRawSubgraphs() { + std::vector<string> result; + for (const auto& it : result_) { + result.emplace_back(it->Dump()); + } + return result; +} + +std::vector<string> GraphAnalyzer::DumpSubgraphs() { + std::vector<string> result; + for (auto ptr : ordered_collation_) { + result.emplace_back( + absl::StrFormat("%d %s", ptr->count, ptr->sig->ToString())); + } + return result; +} + +Status GraphAnalyzer::OutputSubgraphs() { + size_t total = 0; + for (auto ptr : ordered_collation_) { + std::cout << ptr->count << ' ' << ptr->sig->ToString() << '\n'; + total += ptr->count; + } + std::cout << "Total: " << total << '\n'; + if (std::cout.fail()) { + return Status(error::DATA_LOSS, "Failed to write to stdout"); + } else { + return Status::OK(); + } +} + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h new file mode 100644 index 0000000000..26d38a4931 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h @@ -0,0 +1,154 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_ + +#include <deque> +#include <vector> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/graph_analyzer/map_tools.h" +#include "tensorflow/core/grappler/graph_analyzer/sig_node.h" +#include "tensorflow/core/grappler/graph_analyzer/subgraph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +namespace test { +class GraphAnalyzerTest; +} // end namespace test + +// Finds all the subgraphs of a given size and groups them by equivalence. +class GraphAnalyzer { + public: + // Makes a copy of the graph. + GraphAnalyzer(const GraphDef& graph, int subgraph_size); + + virtual ~GraphAnalyzer(); + + // Performs the analysis and collects the subgraphs. + Status Run(); + + // Returns the subgraphs found in Run() printed to text. + std::vector<string> DumpSubgraphs(); + + // Prints the subgraphs found in Run() to stdout. + Status OutputSubgraphs(); + + // TODO(babkin): add a way to extract the subgraphs as direct data + // structures and as protobufs, and to write protobufs to a RecordIO. + + private: + GraphAnalyzer() = delete; + GraphAnalyzer(const GraphAnalyzer&) = delete; + void operator=(const GraphAnalyzer&) = delete; + + friend class tensorflow::grappler::graph_analyzer::test::GraphAnalyzerTest; + + // Builds the map of nodes from the original graph definition. + Status BuildMap(); + + // Using nodes_, finds all the subgraphs of size subgraph_size_ and places + // them into result_. + void FindSubgraphs(); + + // Deletes from result_ the unacceptable subgraphs. Those include the + // subgraphs where not all the inputs at a multi-input port are included (this + // could happen if some of these inputs were reached and included through + // different paths). + void DropInvalidSubgraphs(); + + // Deletes from result_ duplicate entries of equivalent topology. + Status CollateResult(); + + // Returns the raw subgraphs found in FindSubgraphs() printed to text. + std::vector<string> DumpRawSubgraphs(); + + // Finds and adds appropriately to either partial_ or result_ all the + // subgraphs that can be created by extending the parent subgraph by one node. + // Ignores the duplicates. + void ExtendSubgraph(Subgraph* parent); + + // Extends the parent subgraph by adding another node (if it wasn't already + // added) and all its non-control inputs in the link map range at once. + // If the subgraph would grow over subgraph_size_, it gets ignored. + void ExtendSubgraphAllOrNone(Subgraph* parent, const GenNode* node); + // Same but adds one specific inbound port (even control) all-or-none. + void ExtendSubgraphPortAllOrNone(Subgraph* parent, const GenNode* node, + GenNode::Port port); + // The common final step called by ExtendSubgraph*AllOrNone() methods. + void AddExtendedSubgraph(Subgraph* parent, const Subgraph::Identity& id); + + // Returns true if this subgraph has any multi-inputs that aren't all-in or + // all-out. + bool HasInvalidMultiInputs(Subgraph* sg); + + // Graph to run the analysis on. + GraphDef graph_; + int subgraph_size_; + + // The enriched graph of parsed nodes and connections. + GenNodeMap nodes_; + // The resulting set of subgraphs. + SubgraphPtrSet result_; + // The subgraphs of partial size, stored while finding the result. + SubgraphPtrSet partial_; + // The subgraphs of partial size (stored in partial_) that are still waiting + // to be extended. + // + // TODO(babkin): This is rather simple-minded, each subgraph is examined from + // scratch, which means that all its internal links get iterated too. But it's + // OK for the small subgraphs. This can be improved by keeping not just + // subgraphs but iterators on the list, each of them having the list not-yet + // examined nodes (and the link position of the next link to be examined for + // the first node). This would add extra constant overhead, so the break-even + // subgraph size is not clear yet. + std::deque<Subgraph*> todo_; + + // The collation map by signature is designed to allow the removal of entries + // and moving of the signature references from the keys of this map to the + // outside world. Must be careful at inserting and removal: make sure that + // when a new entry is inserted, its signature reference gets populated with + // the same data as the key of the map, and that if a reference is moved out, + // the map entry gets removed before that reference gets destroyed. + struct CollationEntry { + std::shared_ptr<Signature> sig; + size_t count = 0; + }; + using CollationMap = + std::unordered_map<Signature*, CollationEntry, HashAtPtr<Signature*>, + EqAtPtr<Signature*> >; + CollationMap collation_map_; + + // The entries are owned by collation_map_, so must be removed from + // ordered_collation_ before removing them from collation_map_. + struct ReverseLessByCount { + bool operator()(CollationEntry* left, CollationEntry* right) { + return left->count > right->count; // Reverse order. + } + }; + using CollationOrderByCount = + std::multiset<CollationEntry*, ReverseLessByCount>; + CollationOrderByCount ordered_collation_; +}; + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_ diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc new file mode 100644 index 0000000000..e94c472056 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc @@ -0,0 +1,569 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h" + +#include <algorithm> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "absl/memory/memory.h" +#include "tensorflow/core/grappler/graph_analyzer/test_tools.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { +namespace test { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Ne; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +class GraphAnalyzerTest : public ::testing::Test, protected TestGraphs { + protected: + Status BuildMap() { return gran_->BuildMap(); } + + void FindSubgraphs() { gran_->FindSubgraphs(); } + + void DropInvalidSubgraphs() { gran_->DropInvalidSubgraphs(); } + + Status CollateResult() { return gran_->CollateResult(); } + + void ExtendSubgraph(Subgraph* parent) { gran_->ExtendSubgraph(parent); } + + void ExtendSubgraphPortAllOrNone(Subgraph* parent, GenNode* node, + GenNode::Port port) { + gran_->ExtendSubgraphPortAllOrNone(parent, node, port); + } + + void ExtendSubgraphAllOrNone(Subgraph* parent, GenNode* node) { + gran_->ExtendSubgraphAllOrNone(parent, node); + } + + std::vector<string> DumpRawSubgraphs() { return gran_->DumpRawSubgraphs(); } + + std::vector<string> DumpPartials() { + std::vector<string> result; + for (const auto& it : gran_->partial_) { + result.emplace_back(it->Dump()); + } + return result; + } + + const GenNodeMap& GetNodes() { return gran_->nodes_; } + + GenNode* GetNode(const string& name) { return gran_->nodes_.at(name).get(); } + + SubgraphPtrSet& GetResult() { return gran_->result_; } + SubgraphPtrSet& GetPartial() { return gran_->partial_; } + std::deque<Subgraph*>& GetTodo() { return gran_->todo_; } + + // Gets initialized by a particular test from a suitable GraphDef. + std::unique_ptr<GraphAnalyzer> gran_; +}; + +TEST_F(GraphAnalyzerTest, BuildMap) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 1); + Status st = BuildMap(); + EXPECT_THAT(st, Eq(Status::OK())); + + auto& map = GetNodes(); + EXPECT_THAT(map.find("node1"), Ne(map.end())); + EXPECT_THAT(map.find("node2"), Ne(map.end())); + EXPECT_THAT(map.find("node3"), Ne(map.end())); +} + +TEST_F(GraphAnalyzerTest, BuildMapError) { + // A duplicate node. + (*graph_3n_self_control_.add_node()) = MakeNodeConst("node1"); + gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 1); + Status st = BuildMap(); + ASSERT_THAT( + st, Eq(Status(error::INVALID_ARGUMENT, "Duplicate node name 'node1'."))); +} + +TEST_F(GraphAnalyzerTest, FindSubgraphs0) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 0); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + FindSubgraphs(); + auto& subgraphs = GetResult(); + EXPECT_THAT(subgraphs, SizeIs(0)); + EXPECT_THAT(DumpRawSubgraphs(), ElementsAre()); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +TEST_F(GraphAnalyzerTest, FindSubgraphs1) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 1); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + FindSubgraphs(); + auto& subgraphs = GetResult(); + EXPECT_THAT(subgraphs, SizeIs(3)); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre( + "1: BroadcastGradientArgs(node3)", + "1: Const(node1)", + "1: Sub(node2)" + )); + // clang-format on + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +// The required subgraphs are larger than the graph. +TEST_F(GraphAnalyzerTest, FindSubgraphsTooLarge) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 4); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + FindSubgraphs(); + EXPECT_THAT(DumpRawSubgraphs(), ElementsAre()); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +//=== + +// Successfully propagate backwards through a multi-input link, +// with the base (currently-extending) node already in the graph. +TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsBaseIn) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")})); + + ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"), + GenNode::Port(true, 0)); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre( + "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)" + )); + // clang-format on + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +// Successfully propagate backwards through a multi-input link, +// with the base (currently-extending) node not in the graph yet. +TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsBaseOut) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto parent = absl::make_unique<Subgraph>(Subgraph::Identity()); + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")})); + + ExtendSubgraphPortAllOrNone(parent.get(), GetNode("add2"), + GenNode::Port(true, 0)); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre( + "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)" + )); + // clang-format on + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +// Successfully propagate backwards through a multi-input link, +// where the target subgraph size is larger. +TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsIncomplete) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 5); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")})); + + ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"), + GenNode::Port(true, 0)); + + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre()); + // clang-format off + EXPECT_THAT(DumpPartials(), UnorderedElementsAre( + "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)" + )); + // clang-format on + EXPECT_THAT(GetTodo(), SizeIs(1)); +} + +// Propagate backwards through a multi-input link, finding that the +// resulting subgraph would be too large. +TEST_F(GraphAnalyzerTest, MultiInputTooLargeBackwards) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 3); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")})); + + ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"), + GenNode::Port(true, 0)); + + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre()); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +// Propagate backwards through a multi-input link, finding that nothing +// would be added to the parent subgraph. +TEST_F(GraphAnalyzerTest, MultiInputNothingAddedBackwards) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = absl::make_unique<Subgraph>( + Subgraph::Identity({GetNode("add2"), GetNode("const2_1"), + GetNode("const2_2"), GetNode("const2_3")})); + + ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"), + GenNode::Port(true, 0)); + + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre()); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +// Successfully propagate forwards through a multi-input link, +// with the base (currently-extending) node not in the subgraph yet. +TEST_F(GraphAnalyzerTest, MultiInputSuccessForwardsBaseOut) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")})); + + ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"), + GenNode::Port(true, 0)); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre( + "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)" + )); + // clang-format on + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +// Successfully propagate backwards through a multi-input link. +TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsFull) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")})); + + ExtendSubgraph(root.get()); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre( + "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)" + )); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre( + "1: AddN(add2), Sub(sub)" + )); + // clang-format on + EXPECT_THAT(GetTodo(), SizeIs(1)); +} + +// Successfully propagate forwards through a multi-input link. +TEST_F(GraphAnalyzerTest, MultiInputSuccessForwardsFull) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")})); + + ExtendSubgraph(root.get()); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre( + "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)" + )); + // clang-format on + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +TEST_F(GraphAnalyzerTest, DropInvalidSubgraphsMulti) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 3); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + // A good one, multi-input is all-in. + GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({ + GetNode("const1_1"), + GetNode("const1_2"), + GetNode("add1"), + }))); + // A good one, multi-input is all-out + GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({ + GetNode("add1"), + GetNode("add2"), + GetNode("sub"), + }))); + // A bad one, multi-input is partially in. + GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({ + GetNode("const1_1"), + GetNode("add1"), + GetNode("sub"), + }))); + // A bad one, multi-input is partially in. + GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({ + GetNode("add2"), + GetNode("const2_1"), + GetNode("const2_2"), + }))); + + DropInvalidSubgraphs(); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre( + "1: AddN(add1), AddN(add2), Sub(sub)", + "1: AddN(add1), Const(const1_1), Const(const1_2)" + )); + // clang-format on + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +//=== + +// Successfully propagate backwards through a multi-input link, +// with the base (currently-extending) node already in the graph. +TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwards) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass2")})); + + ExtendSubgraphAllOrNone(root.get(), GetNode("pass2")); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre( + "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)" + )); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + // clang-format on + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +// Successfully propagate backwards through a multi-input link, +// but no control links propagate. It also tests the situation +// where the target subgraph size is larger. +TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwardsNoControl) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 5); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass1")})); + + ExtendSubgraphAllOrNone(root.get(), GetNode("pass1")); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre()); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre( + "1: Const(const1_1), Const(const1_2), IdentityN(pass1)" + )); + // clang-format on + EXPECT_THAT(GetTodo(), SizeIs(1)); +} + +// The control links propagate separately as all-or-none, even on the nodes +// that are all-or-none for the normal inputs. +TEST_F(GraphAnalyzerTest, AllOrNoneInputSeparateControl) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 5); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass1")})); + + ExtendSubgraphPortAllOrNone(root.get(), GetNode("pass1"), + GenNode::Port(true, -1)); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre()); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre( + "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass1)" + )); + // clang-format on + EXPECT_THAT(GetTodo(), SizeIs(1)); +} + +// Propagate backwards from all-or-none-input node, finding that the +// resulting subgraph would be too large. +TEST_F(GraphAnalyzerTest, AllOrNoneInputTooLargeBackwards) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 3); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass2")})); + + ExtendSubgraphAllOrNone(root.get(), GetNode("pass2")); + + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre()); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +// Propagate backwards from all-or-none-input node, finding that nothing +// would be added to the parent subgraph. +TEST_F(GraphAnalyzerTest, AllOrNoneInputNothingAddedBackwards) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = absl::make_unique<Subgraph>( + Subgraph::Identity({GetNode("pass2"), GetNode("const2_1"), + GetNode("const2_2"), GetNode("const2_3")})); + + ExtendSubgraphAllOrNone(root.get(), GetNode("pass2")); + + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre()); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +// Successfully propagate forwards to all-or-none-input node, +// with the base (currently-extending) node not in the subgraph yet. +TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessForwardsBaseOut) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")})); + + ExtendSubgraphAllOrNone(root.get(), GetNode("pass2")); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre( + "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)" + )); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + // clang-format on + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +// Successfully propagate backwards from all-or-none-input node. +TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwardsFull) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass2")})); + + ExtendSubgraph(root.get()); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre( + "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)" + )); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre( + "1: IdentityN(pass2), Sub(sub)" + )); + // clang-format on + EXPECT_THAT(GetTodo(), SizeIs(1)); +} + +// Successfully propagate forwards to all-or-none-input node. This includes +// both all-or-none-input for the normal inputs, and multi-input by the +// control path. +TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessForwardsFull) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + auto root = + absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")})); + + ExtendSubgraph(root.get()); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre( + "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)", + "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass1)" + )); + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + // clang-format on + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +TEST_F(GraphAnalyzerTest, DropInvalidSubgraphsAllOrNone) { + gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 3); + Status st = BuildMap(); + ASSERT_THAT(st, Eq(Status::OK())); + + // A good one, all-or-none is all-in. + GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({ + GetNode("const1_1"), + GetNode("const1_2"), + GetNode("pass1"), + }))); + // A good one, all-or-none is all-out + GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({ + GetNode("pass1"), + GetNode("pass2"), + GetNode("sub"), + }))); + // A bad one, all-or-none is partially in. + GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({ + GetNode("const1_1"), + GetNode("pass1"), + GetNode("sub"), + }))); + // A bad one, all-or-none is partially in. + GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({ + GetNode("pass2"), + GetNode("const2_1"), + GetNode("const2_2"), + }))); + + DropInvalidSubgraphs(); + + // clang-format off + EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre( + "1: IdentityN(pass1), IdentityN(pass2), Sub(sub)", + "1: Const(const1_1), Const(const1_2), IdentityN(pass1)" + )); + // clang-format on + EXPECT_THAT(DumpPartials(), UnorderedElementsAre()); + EXPECT_THAT(GetTodo(), SizeIs(0)); +} + +} // end namespace test +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc new file mode 100644 index 0000000000..924ca11e61 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +// Dies on failure. +static void LoadModel(const string& filename, + tensorflow::MetaGraphDef* metagraph) { + LOG(INFO) << "Loading model from " << filename; + Status st; + st = ReadBinaryProto(Env::Default(), filename, metagraph); + if (!st.ok()) { + LOG(WARNING) << "Failed to read a binary metagraph: " << st; + st = ReadTextProto(Env::Default(), filename, metagraph); + if (!st.ok()) { + LOG(FATAL) << "Failed to read a text metagraph: " << st; + } + } +} + +// Prune the graph to only keep the transitive fanin part with respect to a set +// of train ops (if provided). +void MaybePruneGraph(const tensorflow::MetaGraphDef& metagraph, + tensorflow::GraphDef* graph) { + std::vector<string> fetch_nodes; + for (const auto& fetch : + metagraph.collection_def().at("train_op").node_list().value()) { + LOG(INFO) << "Fetch node: " << fetch; + fetch_nodes.push_back(fetch); + } + if (fetch_nodes.empty()) { + *graph = metagraph.graph_def(); + } else { + std::vector<const tensorflow::NodeDef*> fanin_nodes = + tensorflow::grappler::ComputeTransitiveFanin(metagraph.graph_def(), + fetch_nodes); + for (const tensorflow::NodeDef* node : fanin_nodes) { + *(graph->add_node()) = *node; + } + LOG(INFO) << "Pruned " + << metagraph.graph_def().node_size() - graph->node_size() + << " nodes. Original graph size: " + << metagraph.graph_def().node_size() + << ". New graph size: " << graph->node_size() << "."; + } +} + +void GraphAnalyzerTool(const string& file_name, int n) { + if (n < 1) { + LOG(FATAL) << "Invalid subgraph size " << n << ", must be at least 1"; + } + + tensorflow::MetaGraphDef metagraph; + LoadModel(file_name, &metagraph); + tensorflow::GraphDef graph; + MaybePruneGraph(metagraph, &graph); + tensorflow::grappler::graph_analyzer::GraphAnalyzer analyzer(graph, n); + LOG(INFO) << "Running the analysis"; + tensorflow::Status st = analyzer.Run(); + if (!st.ok()) { + LOG(FATAL) << "Analysis failed: " << st; + } + + LOG(INFO) << "Printing the result"; + st = analyzer.OutputSubgraphs(); + if (!st.ok()) { + LOG(FATAL) << "Failed to print the result: " << st; + } + + LOG(INFO) << "Completed"; +} + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h new file mode 100644 index 0000000000..5a91fe7dc8 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_ + +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +void GraphAnalyzerTool(const string& file_name, int n); + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_ diff --git a/tensorflow/core/grappler/graph_analyzer/hash_tools.h b/tensorflow/core/grappler/graph_analyzer/hash_tools.h new file mode 100644 index 0000000000..b0e79f9a68 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/hash_tools.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_ + +#include <cstddef> + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +// Unfortunately, std::hash provides no way to combine hashes, so everyone +// is copying boost::hash_combine. This is a version that follows Google's +// guidelines on the arguments, and contains only the combination, without +// hashing. +inline void CombineHash(size_t from, size_t* to) { + *to ^= from + 0x9e3779b9 + (*to << 6) + (*to >> 2); +} + +// Combine two hashes in such a way that the order of combination doesn't matter +// (so it's really both commutative and associative). The result is not a very +// high-quality hash but can be used in case if the order of sub-elements must +// not matter in the following comparison. An alternative would be to sort the +// hashes of the sub-elements and then combine them normally in the sorted +// order. +inline void CombineHashCommutative(size_t from, size_t* to) { + *to = *to + from + 0x9e3779b9; +} + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_ diff --git a/tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc b/tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc new file mode 100644 index 0000000000..b5e9ce6b8e --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { +namespace test { +namespace { + +using ::testing::Eq; + +TEST(HashToolsTest, CombineHashCommutative) { + size_t a = 0; + size_t b = 999; + + size_t c = a; + CombineHashCommutative(b, &c); + + size_t d = b; + CombineHashCommutative(a, &d); + + EXPECT_THAT(c, Eq(d)); +} + +} // namespace +} // end namespace test +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_analyzer/map_tools.h b/tensorflow/core/grappler/graph_analyzer/map_tools.h new file mode 100644 index 0000000000..584062c5f2 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/map_tools.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_ + +#include <functional> + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +// Helpers for building maps of pointers. + +template <typename Ptr> +struct LessAtPtr : std::binary_function<Ptr, Ptr, bool> { + bool operator()(const Ptr& x, const Ptr& y) const { return *x < *y; } +}; + +template <typename Ptr> +struct EqAtPtr : std::binary_function<Ptr, Ptr, bool> { + bool operator()(const Ptr& x, const Ptr& y) const { return *x == *y; } +}; + +template <typename Ptr> +struct HashAtPtr : std::unary_function<Ptr, size_t> { + size_t operator()(const Ptr& x) const { return x->Hash(); } +}; + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_ diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node.cc b/tensorflow/core/grappler/graph_analyzer/sig_node.cc new file mode 100644 index 0000000000..b5cca6a512 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/sig_node.cc @@ -0,0 +1,453 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/graph_analyzer/sig_node.h" + +#include <algorithm> + +#include "absl/strings/str_format.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +static constexpr bool debug = false; + +//=== SigNode + +SigNode::SigNode(const NodeDef* node) : node_(node) {} + +void SigNode::CopyLinks(const GenNode& from, const TranslationMap& map) { + hash_to_link_.clear(); + hashed_peers_.clear(); + + std::map<LinkTag, Link> link_map; + CopyLinksPass1(from, map, &link_map); + CopyLinksPass2(&link_map); +} + +void SigNode::CopyLinksPass1(const GenNode& from, const TranslationMap& map, + std::map<LinkTag, Link>* link_map) { + LinkTag::Hasher link_hasher; + + for (const auto& entry : from.links()) { + for (const auto& target : entry.second) { + auto nodeit = map.find(target.node); + if (nodeit == map.end()) { + // Node is not in the subgraph, ignore. + continue; + } + + LinkTag tag(entry.first, target.port); + size_t hval = link_hasher(tag); + + // This instantiates the entry if it was not present. + Link& map_entry = (*link_map)[tag]; + if (map_entry.peers.empty()) { + map_entry.tag = tag; + map_entry.unique_hash = hval; + } + map_entry.peers.push_back(nodeit->second); + } + } +} + +void SigNode::CopyLinksPass2(std::map<LinkTag, Link>* link_map) { + for (auto& entry : *link_map) { + Link* hl_entry_ptr = &hash_to_link_[entry.second.unique_hash]; + // In case of a conflict, rehash. This should almost never happen. + // Because the order of iteration is predictable, the rehashed values + // will also be predictable. + while (!hl_entry_ptr->peers.empty()) { + CombineHash(1, &entry.second.unique_hash); + hl_entry_ptr = &hash_to_link_[entry.second.unique_hash]; + } + + for (const auto& peer : entry.second.peers) { + hashed_peers_.emplace_back(HashedPeer(entry.second.unique_hash, peer)); + } + + hl_entry_ptr->tag = entry.second.tag; + hl_entry_ptr->unique_hash = entry.second.unique_hash; + hl_entry_ptr->peers.swap(entry.second.peers); + } +} + +void SigNode::ComputeTopoHash0() { + topo_hash_.clear(); + last_hashed_nodes_ = next_hashed_nodes_ = node_mask_; + + // TODO(babkin): include the attrbutes too, as an option. + size_t hval = std::hash<string>()(opcode()); + + // Getting the topology of the links in to the hash early should get more + // conflicts resolved early. + for (const auto& entry : hashed_peers_) { + CombineHash(entry.link_hash, &hval); + } + + topo_hash_.push_back(hval); +} + +void SigNode::ComputeTopoHash(int distance) { + // The new starting point. + next_hashed_nodes_ = last_hashed_nodes_; + if (debug) { + LOG(INFO) << "DEBUG node " << name() << " mask=" << std::hex + << next_hashed_nodes_; + } + + if (hash_is_final_) { + return; + } + + CHECK(topo_hash_.size() == distance); + + int prev = distance - 1; + + // Start with own's local topology hash. This value is stable, so + // if the hashes of the surrounding nodes don't change on the following + // distances, the hash of this node won't change either. + size_t hval = topo_hash_[0]; + + if (!hashed_peers_.empty()) { + size_t last_link_hash = hashed_peers_[0].link_hash; + size_t comm_hash = 0; + + for (const auto& entry : hashed_peers_) { + if (entry.link_hash != last_link_hash) { + CombineHash(last_link_hash, &hval); + CombineHash(comm_hash, &hval); + comm_hash = 0; + last_link_hash = entry.link_hash; + } + + // The links in the same vector are commutative, so combine their + // hashes in a commutative way. + CombineHashCommutative(entry.peer->GetTopoHash(prev), &comm_hash); + next_hashed_nodes_ |= entry.peer->last_hashed_nodes_; + if (debug) { + LOG(INFO) << "DEBUG node " << name() << " += " << entry.peer->name() + << " mask=" << std::hex << next_hashed_nodes_; + } + } + + // The last commutative group. + CombineHash(last_link_hash, &hval); + CombineHash(comm_hash, &hval); + } + + topo_hash_.push_back(hval); +} + +size_t SigNode::GetTopoHash(int distance) const { + CHECK(!topo_hash_.empty()); + if (distance >= topo_hash_.size()) { + CHECK(hash_is_final_); + return topo_hash_.back(); + } else { + return topo_hash_[distance]; + } +} + +bool SigNode::operator==(const SigNode& other) const { + // TODO(babkin): add attributes too. + if (opcode() != other.opcode()) { + return false; + } + + // Normally the caller is expected to compare the nodes + // at the same rank in different graphs, but just in case... + if (unique_rank_ != other.unique_rank_) { + return false; + } + + if (hashed_peers_.size() != other.hashed_peers_.size()) { + return false; + } + + for (auto it1 = hashed_peers_.begin(), it2 = other.hashed_peers_.begin(); + it1 != hashed_peers_.end(); ++it1, ++it2) { + // TODO(babkin): might compare the actual values too + // but the hash is probably just as good. + if (it1->link_hash != it2->link_hash) { + return false; + } + if (it1->peer->unique_rank_ != it2->peer->unique_rank_) { + return false; + } + } + + return true; +} + +//=== Signature + +constexpr int Signature::kMaxGraphSize; + +string Signature::ToString() const { + string result; + for (size_t n = 0; n < nodes.size(); ++n) { + // TODO(babkin): add attributes too. + result += absl::StrFormat("%d:%s", n, nodes[n]->opcode()); + for (const auto& entry : nodes[n]->hashed_peers_) { + const auto& link = nodes[n]->hash_to_link_[entry.link_hash]; + + // The link entries are already sorted, by tags and then by the + // node ranks. + if (link.tag.local.IsInbound()) { + result += + absl::StrFormat("[%s:%s:%d]", string(link.tag.local), + string(link.tag.remote), entry.peer->unique_rank_); + } + } + result.push_back(','); + } + return result; +} + +Status Signature::Compute() { + if (map.size() > kMaxGraphSize) { + return Status( + error::INVALID_ARGUMENT, + absl::StrFormat( + "A graph of %d nodes is too big for signature computation, " + "the maximal supported node count is %d.", + map.size(), kMaxGraphSize)); + } + + // The value that will be assigned next as the unique node id. + // This also means that all the entries in nodes at indexes less than this + // have been finalized and don't need to be touched any more. + size_t next_node_id = 0; + + sig_short = 0; + sig_full.resize(0); // Keep the storage. + + // The main signature generation. + PrepareNodes(); + FindUniqueHashes(&next_node_id); + while (next_node_id < map.size()) { + ComputeOneRound(next_node_id); + FindUniqueHashes(&next_node_id); + } + + OrderLinks(); + + return Status::OK(); +} + +void Signature::PrepareNodes() { + nodes.resize(0); // Keep the storage. + + // Initialize the nodes. + int64_t mask = 1; + for (const auto& entry : map) { + SigNode* node = entry.second.get(); + node->last_hashed_nodes_ = node->node_mask_ = mask; + mask <<= 1; + node->unique_rank_ = ~0; + node->hash_is_final_ = false; + node->ComputeTopoHash0(); + if (node->GetHighTopoHash() <= map.size()) { + // Would conflict with one of the reserved values. + node->ReHighTopoHash(); + } + + // The initial order is random. + nodes.emplace_back(node); + } +} + +void Signature::FindUniqueHashes(size_t* next_node_id_p) { + // Start by sorting by the hash value. + std::sort(nodes.begin() + *next_node_id_p, nodes.end(), + SigNode::NodeOrderLess()); + + // At each call, if no nodes have unique hashes, one node that has a + // non-unique (shared) hash can be made unique by assigning a unique id. + // This node gets picked predictably by taking the last node. + // TODO(babkin): Technically, more than one node can be unshared, + // as long as their last_hashed_nodes_ overlap only by the nodes that + // already had the assigned ids before the current round. But it's not clear + // yet, how often would this beneficial, because it looks like for many + // subgraphs unsharing one node should be enough to untangle them. This + // would need more measurement before implementing. + bool found_unique = false; + for (size_t n = *next_node_id_p; n < nodes.size(); ++n) { + size_t cur_hash = nodes[n]->GetHighTopoHash(); + if (n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash) { + // A sequence of nodes sharing the same hash. Skip over it. + // TODO(babkin): check here for the arbitrary hash conflicts and resolve + // them. + for (++n; + n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash; + ++n) { + } + if (found_unique || n != nodes.size() - 1) { + // Either some unique nodes have already been found, or this is + // not the last chance, keep trying to find the unique nodes. + continue; + } + // Here we're at the last node and haven't found any unique ones. + // So fall through and make this last node unique. + } + + found_unique = true; + size_t id = (*next_node_id_p)++; + nodes[n]->unique_rank_ = id; + + size_t last_hash = nodes[n]->GetHighTopoHash(); + CombineHash(last_hash, &sig_short); + sig_full.push_back(last_hash); + + // Take the hash at 0 and mix the unique rank into it. After that it will + // stay fixed. + nodes[n]->topo_hash_.resize(1); + nodes[n]->topo_hash_[0] = id + 1; // Avoid the value of 0. + + nodes[n]->hash_is_final_ = true; + nodes[n]->last_hashed_nodes_ = nodes[n]->node_mask_; + if (n != id) { + std::swap(nodes[id], nodes[n]); + } + } +} + +void Signature::ComputeOneRound(size_t next_node_id) { + // Reset the state of the nodes. + int debug_i = 0; + for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) { + auto node = *it; + // The hash at distance 0 never changes, so preserve it. + node->topo_hash_.resize(1); + node->last_hashed_nodes_ = node->node_mask_; + node->hash_is_final_ = false; + if (debug) { + LOG(INFO) << "DEBUG distance=" << 0 << " node " << debug_i++ << " " + << node->name() << " mask=" << std::hex + << node->last_hashed_nodes_; + } + } + + bool stop = false; + // The distance can reach up to nodes.size()+1, to include not only all the + // nodes but also all the redundant paths. + for (int distance = 1; !stop; ++distance) { + for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) { + auto node = *it; + if (node->hash_is_final_) { + continue; + } + node->ComputeTopoHash(distance); + if (node->GetHighTopoHash() <= nodes.size()) { + // Would conflict with one of the reserved values. + node->ReHighTopoHash(); + } + } + + // Will be looking for the indications to not stop. + stop = true; + + debug_i = 0; + // The bitmasks get moved after all the hash computations are done. + for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) { + auto node = *it; + if (debug) { + LOG(INFO) << "DEBUG distance=" << distance << " node " << debug_i++ + << " " << node->name() << " oldmask=" << std::hex + << node->last_hashed_nodes_ << " mask=" << std::hex + << node->next_hashed_nodes_; + } + if (node->last_hashed_nodes_ == node->next_hashed_nodes_) { + // Stopped growing, this part of the graph must be fully + // surrounded by nodes that already have the unique ids. + node->hash_is_final_ = true; + } else { + node->last_hashed_nodes_ = node->next_hashed_nodes_; + stop = false; + } + } + } +} + +void Signature::OrderLinks() { + for (const auto& node : nodes) { + if (node->hashed_peers_.empty()) { + continue; + } + + size_t cur_link_hash = node->hashed_peers_[0].link_hash + 1; + int first_idx = -1; + + int idx; + for (idx = 0; idx < node->hashed_peers_.size(); ++idx) { + auto& entry = node->hashed_peers_[idx]; + if (entry.link_hash == cur_link_hash) { + continue; + } + if (idx - first_idx > 1) { + // Need to sort. + std::sort(node->hashed_peers_.begin() + first_idx, + node->hashed_peers_.begin() + idx, + SigNode::HashedPeer::LessByRank()); + } + + cur_link_hash = entry.link_hash; + first_idx = idx; + } + if (idx - first_idx > 1) { + // Sort the last bunch. + std::sort(node->hashed_peers_.begin() + first_idx, + node->hashed_peers_.begin() + idx, + SigNode::HashedPeer::LessByRank()); + } + } +} + +bool Signature::operator==(const Signature& other) const { + // Tries to find the differences as early as possible by + // comparing the hashes first. + + if (sig_short != other.sig_short) { + return false; + } + if (sig_full.size() != other.sig_full.size()) { + return false; + } + + for (auto it1 = sig_full.begin(), it2 = other.sig_full.begin(); + it1 != sig_full.end(); ++it1, ++it2) { + if (*it1 != *it2) { + return false; + } + } + + if (nodes.size() != other.nodes.size()) { + return false; + } + for (auto it1 = nodes.begin(), it2 = other.nodes.begin(); it1 != nodes.end(); + ++it1, ++it2) { + if (**it1 != **it2) { + return false; + } + } + + return true; +} + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node.h b/tensorflow/core/grappler/graph_analyzer/sig_node.h new file mode 100644 index 0000000000..45c0ed3162 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/sig_node.h @@ -0,0 +1,304 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_ + +#include <map> +#include <memory> +#include <vector> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/graph_analyzer/gen_node.h" +#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +namespace test { +class SigBaseTest; +} // end namespace test + +class SigNode; + +// To find nodes by name. Having the map ordered makes the tests easier, +// and it isn't used in production code often enough to get any win from +// using an unordered map. +using SigNodeMap = std::map<string, std::unique_ptr<SigNode>>; + +// One node in the graph, in the form convenient for generation of the signature +// of the graph, and comparison of two (sub)graphs for equivalence. It refers to +// the original NodeDef protobuf for most information and adds the extra +// enrichment. +// +// The graph building is 2-stage: first match a SigNode with each NodeDef and +// collect them into a map that finds them by name, then process the map, +// deep-parse the underlying NodeDefs and connect the SigNodes together. +class SigNode { + public: + friend struct Signature; + + // Will keep the pointer to the underlying NodeDef, so that + // underlying object must not be deleted while SigNode is alive. + explicit SigNode(const NodeDef* node); + + // Access wrappers. + const string& name() const { return node_->name(); } + const string& opcode() const { return node_->op(); } + const NodeDef* node_def() const { return node_; } + + // For extraction of subgraphs into a separate SigNodeMap, copies the links + // that point inside the subgraph from a full-graph SigNode to a subgraph + // SigNode. The translation map defines the subgraph and gives the mapping + // from the nodes in the full graph to the matching nodes in subgraph. + using TranslationMap = + std::unordered_map<const GenNode* /*full_graph*/, SigNode* /*subgraph*/>; + void CopyLinks(const GenNode& from, const TranslationMap& map); + + // A link is an edge of the graph that connects 2 nodes. Each of the connected + // nodes has its own perspective on the link, seeing its local port, remote + // port and the remote node. The direction of the link is encoded in the + // ports, one port is always incoming and another one outgoing. + // + // The link tag here contains both ports of the link viewed from the + // perspective of this node; consisting of both the local port (i.e. at this + // node) and remote port (i.e. on the other node), the local one going first. + struct LinkTag { + struct Hasher { + size_t operator()(const LinkTag& tag) const noexcept { + size_t hval = port_hasher(tag.local); + CombineHash(port_hasher(tag.remote), &hval); + return hval; + } + GenNode::Port::Hasher port_hasher; + }; + + LinkTag(GenNode::Port a_local, GenNode::Port a_remote) + : local(a_local), remote(a_remote) {} + + // The default constructor is used for the default values in maps. + // (false, 99) is an arbitrary value that makes the uninitialized + // links easy to tell when debugging (they should never happen). + LinkTag() : local(false, 99), remote(false, 99) {} + + // Port of the link on the local node. + GenNode::Port local; + // Port of the link on the remote node. + GenNode::Port remote; + + bool operator==(const LinkTag& other) const { + return local == other.local && remote == other.remote; + } + bool operator<(const LinkTag& other) const { + return local < other.local || + (local == other.local && remote < other.remote); + } + }; + + // Since the signature logic doesn't differentiate between the links + // with the same tag (other than by the "peer" nodes on their other ends), + // all the links with the same tag are grouped into a single structure. + struct Link { + LinkTag tag; + size_t unique_hash; // Hash of the tag after conflict resolution. + // The remote node(s) on the other side on the link(s). + using PeerVector = std::vector<SigNode*>; + PeerVector peers; + }; + + // A way to look up the link description by its hash. + using LinkHashMap = std::map<size_t, Link>; + const LinkHashMap& hash_to_link() const { return hash_to_link_; } + + // The enumeration of all the peer nodes in a predictable order. + // Before the signature generation, only the link values determine the + // order, after the signature generation the entries at the same + // links get further sorted by their peer node ranks. + struct HashedPeer { + HashedPeer(size_t l, SigNode* p) : link_hash(l), peer(p) {} + + struct LessByRank { + bool operator()(const SigNode::HashedPeer& left, + const SigNode::HashedPeer& right) { + return left.peer->unique_rank_ < right.peer->unique_rank_; + } + }; + + size_t link_hash; + SigNode* peer; + }; + using HashedPeerVector = std::vector<HashedPeer>; + const HashedPeerVector& hashed_peers() const { return hashed_peers_; } + + // Compares two nodes in two different graphs for equivalence (two nodes in + // the same graph would never be equivalent). Expects that the signatures of + // the graphs have already been computed, so unique_rank_ is filled in and + // the hashed_peers_ properly ordered. + bool operator==(const SigNode& other) const; + + bool operator!=(const SigNode& other) const { return !(*this == other); } + + private: + friend class test::SigBaseTest; + + // The CopyLinks code is split into 2 parts for testability. + // The first pass builds a map ordered by LinkTag for predictability. + void CopyLinksPass1(const GenNode& from, const TranslationMap& map, + std::map<LinkTag, Link>* link_map); + // The second pass converts to the map by hash value, + // resolves any hash conflicts, and builds the hashed peer vector. + void CopyLinksPass2(std::map<LinkTag, Link>* link_map); + + // Computes the topological hash at distance 0. Resets the topo_hash_ vector + // and hashed_nodes_; + void ComputeTopoHash0(); + + // Compute the topological has at the given distance. The hashes for all the + // lower distances must be already computed for all the nodes in the graph. + // Also computes next_hashed_nodes_ from last_hashed_nodes_. + void ComputeTopoHash(int distance); + + // Get the hash value for a particular distance. It must be previously + // computed. + size_t GetTopoHash(int distance) const; + + // The the hash value for the highest computed distance. It must be previously + // computed. + size_t GetHighTopoHash() const { + CHECK(!topo_hash_.empty()); + return topo_hash_.back(); + } + + // Rehash the topmost hash, to avoid conflicts. + void ReHighTopoHash() { + CHECK(!topo_hash_.empty()); + CombineHash(1, &topo_hash_.back()); + } + + // Ordering by node order and highest available hash (it must be + // previously computed). + struct NodeOrderLess { + bool operator()(const SigNode* left, const SigNode* right) { + return left->topo_hash_.back() < right->topo_hash_.back(); + } + }; + + private: + const NodeDef* node_; + + // The bitmap mask with 1 bit set that represents this node in the set + // during the computation of the signature. + uint64_t node_mask_ = 0; + + // The code that populates this map makes sure that there are no hash + // conflicts, rehashing if necessary. + LinkHashMap hash_to_link_; + + // The enumeration of all the direct peers in the predictable order (which + // happens to be the order ot their link tags, but the order of the hashes + // would do too). It is used for the quick enumeration during the signature + // computation. After the signature building is completed, the entries that + // have the same link tag get further sorted in the order of the ranks of + // their nodes. + HashedPeerVector hashed_peers_; + + // The unique rank represents the order in which the node will be included + // into the signature. It gets assigned in order either when the topo_hash_ of + // this node becomes unique in the graph, or when the nodes are completely + // equivalent, one of them is picked at random to assign the next rank, and + // then the rest of the nodes attempt to disambiguate based on that + // information. + size_t unique_rank_ = ~0; + // When hash_is_final_ is set, the topo_has_ vector stops growing, and the + // last value from it is used for all the further hashes. + bool hash_is_final_ = false; + // The hashes that include the topology of the nodes up to the distance N. The + // hash for distance 0 is produced from the attributes of this node itself and + // its general connectivity properties but no information about the + // neighboring nodes. The hash for distance D+1 is build from hashes at level + // D of this node and of all its immediate neighbors. The neighbors that are + // connected by equivalent links are included in a commutative way. + std::vector<size_t> topo_hash_; + // The set of nodes that got included into the computation of the + // last topo_hash_ entry. + uint64_t last_hashed_nodes_ = 0; + // The next set of nodes that gets used for the current topo_hash entry. + uint64_t next_hashed_nodes_ = 0; +}; + +// Signature of a graph. The computation is intertwined with the private methods +// of SigNode, so keeping both in the same file looks more convenient. +struct Signature { + friend class test::SigBaseTest; + + // Maximal size of the graphs for which the signature can be computed. + // Changing this constant won't magically add the support for a larger size, + // the rest of implementation would have to be extended. The value of 64 is + // driven by the size of a bitset in an uint64_t, and should be enough for our + // purposes, while having a high efficiency of implementation. + static constexpr int kMaxGraphSize = 64; + + // Using the map, computes the rest of the fields of a signature. + // Returns an error is the graph is too big. + Status Compute(); + + // Convert the computed signature to a string representation. + string ToString() const; + + SigNodeMap map; // The nodes in the graph, accessible by name. + size_t sig_short = 0; // Hash of the signature, for the quick equality check. + // The full signature: hashes of the nodes in a predictable order. + std::vector<size_t> sig_full; + // The nodes in the same order as they go in the signature. + std::vector<SigNode*> nodes; + + // For building the unordered maps. + size_t Hash() const { return sig_short; } + + // Returns true if the graphs are equivalent. The signature must be already + // computed. + bool operator==(const Signature& other) const; + + private: + // Populates the nodes vector from the map and initializes the state of the + // nodes for the signature computation. + void PrepareNodes(); + + // Finds the nodes with the hashes that are unique and assigns the unique ids + // to them. If there are nodes with non-unique hashes, exactly one node from + // the first such sequence (in the order of hash values) will be picked and + // assigned a unique id. Assumes that the nodes[0...(next_node_id-1)] have + // been already assigned the unique ids. Advances next_node_id by at least 1. + void FindUniqueHashes(size_t* next_node_id_p); + + // One round of the signature computation. Assumes that the + // nodes[0...(next_node_id-1)] have been already assigned the fixed + // positions, and thus computes the hashes only for the remaining nodes. + void ComputeOneRound(size_t next_node_id); + + // Additional ordering of the hashed_peers_ links in the nodes, so that they + // can be compared and printed in a predictable order. + void OrderLinks(); +}; + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_ diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc b/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc new file mode 100644 index 0000000000..4c6a9ba9e0 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc @@ -0,0 +1,1235 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/graph_analyzer/sig_node.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "tensorflow/core/grappler/graph_analyzer/subgraph.h" +#include "tensorflow/core/grappler/graph_analyzer/test_tools.h" +#include "tensorflow/core/grappler/utils.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { +namespace test { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Gt; +using ::testing::Ne; +using ::testing::SizeIs; + +//=== + +TEST(SigNodeLinkTag, Compare) { + SigNode::LinkTag a(GenNode::Port(false, 1), GenNode::Port(false, 2)); + SigNode::LinkTag b(GenNode::Port(false, 1), GenNode::Port(false, 2)); + SigNode::LinkTag c(GenNode::Port(false, 2), GenNode::Port(false, 1)); + SigNode::LinkTag d(GenNode::Port(false, 1), GenNode::Port(false, 3)); + SigNode::LinkTag e(GenNode::Port(false, 2), GenNode::Port(false, 2)); + + EXPECT_TRUE(a == b); + EXPECT_FALSE(a == c); + EXPECT_FALSE(a == e); + + EXPECT_FALSE(a < b); + EXPECT_FALSE(b < a); + + EXPECT_TRUE(a < c); + EXPECT_FALSE(c < a); + + EXPECT_TRUE(a < d); + EXPECT_FALSE(d < a); +} + +//=== + +class SigBaseTest : public ::testing::Test, protected TestGraphs { + protected: + void BuildSigMap(const GraphDef& graph) { + gen_map_.clear(); + sig_.map.clear(); + CHECK(GenNode::BuildGraphInMap(graph, &gen_map_).ok()); + Subgraph::Identity id; + for (const auto& entry : gen_map_) { + id.insert(entry.second.get()); + } + Subgraph sg(id); + sg.ExtractForSignature(&sig_.map); + } + + static void CopyLinksPass2( + std::map<SigNode::LinkTag, SigNode::Link>* link_map, SigNode* node) { + node->CopyLinksPass2(link_map); + } + + static void ComputeTopoHash0(SigNode* node) { node->ComputeTopoHash0(); } + + static void ComputeTopoHash(int distance, SigNode* node) { + node->ComputeTopoHash(distance); + } + + static size_t GetTopoHash(int distance, SigNode* node) { + return node->GetTopoHash(distance); + } + + static size_t GetHighTopoHash(SigNode* node) { + return node->GetHighTopoHash(); + } + + static void ReHighTopoHash(SigNode* node) { node->ReHighTopoHash(); } + + static SigNode::HashedPeerVector& RefHashedPeers(SigNode* node) { + return node->hashed_peers_; + } + static size_t& RefUniqueRank(SigNode* node) { return node->unique_rank_; } + static bool& RefHashIsFinal(SigNode* node) { return node->hash_is_final_; } + static std::vector<size_t>& RefTopoHash(SigNode* node) { + return node->topo_hash_; + } + static uint64_t& RefNodeMask(SigNode* node) { return node->node_mask_; } + static uint64_t& RefLastHashedNodes(SigNode* node) { + return node->last_hashed_nodes_; + } + static uint64_t& RefNextHashedNodes(SigNode* node) { + return node->next_hashed_nodes_; + } + + static void PrepareNodes(Signature* signature) { signature->PrepareNodes(); } + + static void FindUniqueHashes(size_t* next_node_id_p, Signature* signature) { + signature->FindUniqueHashes(next_node_id_p); + } + + static void ComputeOneRound(size_t next_node_id, Signature* signature) { + signature->ComputeOneRound(next_node_id); + } + + static void OrderLinks(Signature* signature) { signature->OrderLinks(); } + + // These get initialized in BuildSigMap(). + GenNodeMap gen_map_; + Signature sig_; +}; + +//=== + +class SigNodeTest : public SigBaseTest {}; + +// Tests that the duplicate hashes get resolved by rehashing. +TEST_F(SigNodeTest, DuplicateHash) { + NodeDef node1 = MakeNodeConst("node1"); + NodeDef node2 = MakeNodeConst("node2"); + NodeDef node3 = MakeNodeShapeN("node3", "node1", "node2"); + + SigNode sn1(&node1); + SigNode sn2(&node2); + SigNode sn3(&node3); + + constexpr size_t kSameHash = 999; + + SigNode::Link link1; + link1.tag = SigNode::LinkTag(GenNode::Port(true, 0), GenNode::Port(false, 0)); + link1.unique_hash = kSameHash; + link1.peers.emplace_back(&sn1); + + SigNode::Link link2; + link2.tag = SigNode::LinkTag(GenNode::Port(true, 1), GenNode::Port(false, 0)); + link2.unique_hash = kSameHash; + link2.peers.emplace_back(&sn2); + + SigNode::Link link3; + link3.tag = SigNode::LinkTag(GenNode::Port(true, 2), GenNode::Port(false, 0)); + link3.unique_hash = kSameHash; + link3.peers.emplace_back(&sn3); + + std::map<SigNode::LinkTag, SigNode::Link> link_map; + link_map[link1.tag] = link1; + link_map[link2.tag] = link2; + link_map[link3.tag] = link3; + + CopyLinksPass2(&link_map, &sn3); + auto& hl = sn3.hash_to_link(); + EXPECT_THAT(hl, SizeIs(3)); + + // Check that the hashes are self_consistent, and put the entries into + // another map with a known order. + std::map<SigNode::LinkTag, SigNode::Link> rehashed; + auto hlit = hl.begin(); + ASSERT_THAT(hlit, Ne(hl.end())); + EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first)); + rehashed[hlit->second.tag] = hlit->second; + ++hlit; + ASSERT_THAT(hlit, Ne(hl.end())); + EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first)); + rehashed[hlit->second.tag] = hlit->second; + ++hlit; + ASSERT_THAT(hlit, Ne(hl.end())); + EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first)); + rehashed[hlit->second.tag] = hlit->second; + + // Just in case. + ASSERT_THAT(rehashed, SizeIs(3)); + + auto rhit = rehashed.begin(); + ASSERT_THAT(rhit, Ne(rehashed.end())); + EXPECT_TRUE(rhit->second.tag == link1.tag); + EXPECT_THAT(rhit->second.unique_hash, Eq(kSameHash)); + EXPECT_THAT(rhit->second.peers, ElementsAre(&sn1)); + + ++rhit; + ASSERT_THAT(rhit, Ne(rehashed.end())); + EXPECT_TRUE(rhit->second.tag == link2.tag); + // This hash must be rehashed. + EXPECT_THAT(rhit->second.unique_hash, Ne(kSameHash)); + size_t hash2 = rhit->second.unique_hash; + EXPECT_THAT(rhit->second.peers, ElementsAre(&sn2)); + + ++rhit; + ASSERT_THAT(rhit, Ne(rehashed.end())); + EXPECT_TRUE(rhit->second.tag == link3.tag); + // This hash must be rehashed. + EXPECT_THAT(rhit->second.unique_hash, Ne(kSameHash)); + EXPECT_THAT(rhit->second.unique_hash, Ne(hash2)); + size_t hash3 = rhit->second.unique_hash; + EXPECT_THAT(rhit->second.peers, ElementsAre(&sn3)); + + auto& peers = sn3.hashed_peers(); + EXPECT_THAT(peers, SizeIs(3)); + + auto peerit = peers.begin(); + ASSERT_THAT(peerit, Ne(peers.end())); + EXPECT_THAT(peerit->link_hash, Eq(kSameHash)); + EXPECT_THAT(peerit->peer, Eq(&sn1)); + + ++peerit; + ASSERT_THAT(peerit, Ne(peers.end())); + EXPECT_THAT(peerit->link_hash, Eq(hash2)); + EXPECT_THAT(peerit->peer, Eq(&sn2)); + + ++peerit; + ASSERT_THAT(peerit, Ne(peers.end())); + EXPECT_THAT(peerit->link_hash, Eq(hash3)); + EXPECT_THAT(peerit->peer, Eq(&sn3)); +} + +// The full CopyLinks() is tested in (SubgraphTest, ExtractForSignature). + +TEST_F(SigNodeTest, GetTopoHash) { + NodeDef node1 = MakeNodeConst("node1"); + SigNode sn1(&node1); + + // Fake some hash values. + RefTopoHash(&sn1).emplace_back(123); + RefTopoHash(&sn1).emplace_back(456); + + EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123)); + EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456)); + + RefHashIsFinal(&sn1) = true; + + EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123)); + EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456)); + EXPECT_THAT(GetTopoHash(2, &sn1), Eq(456)); + + EXPECT_THAT(GetHighTopoHash(&sn1), Eq(456)); +} + +TEST_F(SigNodeTest, ReTopoHash) { + NodeDef node1 = MakeNodeConst("node1"); + SigNode sn1(&node1); + + // Fake some hash values. + RefTopoHash(&sn1).emplace_back(123); + RefTopoHash(&sn1).emplace_back(456); + + EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123)); + EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456)); + + ReHighTopoHash(&sn1); + + size_t expected_hash = 456; + CombineHash(1, &expected_hash); + + EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123)); + EXPECT_THAT(GetTopoHash(1, &sn1), Eq(expected_hash)); +} + +TEST_F(SigNodeTest, ComputeTopoHash0) { + NodeDef node1 = MakeNodeConst("node1"); + SigNode sn1(&node1); + + // Fake a topology. + RefUniqueRank(&sn1) = 10; + RefNodeMask(&sn1) = 0x02; + + RefTopoHash(&sn1).emplace_back(123); + RefTopoHash(&sn1).emplace_back(456); + + // Fake a state. + RefLastHashedNodes(&sn1) = 0xFF; + RefNextHashedNodes(&sn1) = 0xFF; + + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(1, nullptr)); + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(1, nullptr)); + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(2, nullptr)); + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(3, nullptr)); + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(3, nullptr)); + + // Run the test. + ComputeTopoHash0(&sn1); + + EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x02)); + EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x02)); + EXPECT_THAT(RefTopoHash(&sn1), SizeIs(1)); + + size_t exp_hval = std::hash<string>()(sn1.opcode()); + CombineHash(1, &exp_hval); + CombineHash(1, &exp_hval); + CombineHash(2, &exp_hval); + CombineHash(3, &exp_hval); + CombineHash(3, &exp_hval); + + EXPECT_THAT(GetTopoHash(0, &sn1), Eq(exp_hval)); +} + +TEST_F(SigNodeTest, ComputeTopoHashNotFinal) { + NodeDef node1 = MakeNodeConst("node1"); + SigNode sn1(&node1); + NodeDef node2 = MakeNodeConst("node2"); + SigNode sn2(&node2); + NodeDef node3 = MakeNodeConst("node3"); + SigNode sn3(&node3); + + // Fake a topology. + RefUniqueRank(&sn1) = 0; + RefNodeMask(&sn1) = 0x01; + RefUniqueRank(&sn2) = 0; + RefNodeMask(&sn2) = 0x02; + RefUniqueRank(&sn3) = 0; + RefNodeMask(&sn3) = 0x04; + + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn2)); + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn3)); + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(20, &sn2)); + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn3)); + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn2)); + + // Fake a state. + RefTopoHash(&sn1).emplace_back(123); + RefTopoHash(&sn1).emplace_back(321); + + RefTopoHash(&sn2).emplace_back(456); + RefTopoHash(&sn2).emplace_back(654); + + RefTopoHash(&sn3).emplace_back(789); + RefTopoHash(&sn3).emplace_back(987); + + // These values are not realistic in the way that they don't include the bits + // from the mask of nodes themselves, but that's the point of this test: only + // the previous nodes' node sets are used in the computation, not their own + // masks directly. + RefLastHashedNodes(&sn1) = 0x8; + RefLastHashedNodes(&sn2) = 0x10; + RefLastHashedNodes(&sn3) = 0x20; + + // A scratch value to get overwritten. + RefNextHashedNodes(&sn1) = 0x100; + + ComputeTopoHash(2, &sn1); + + EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x8)); // Unchanged. + EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x38)); + + // This computes the hash form the explicit numbers above. + size_t exp_hash = 123; // The 0th hash is the starting point. + size_t comm_hash; + + comm_hash = 0; + CombineHashCommutative(654, &comm_hash); + CombineHashCommutative(987, &comm_hash); + + CombineHash(10, &exp_hash); + CombineHash(comm_hash, &exp_hash); + + comm_hash = 0; + CombineHashCommutative(654, &comm_hash); + + CombineHash(20, &exp_hash); + CombineHash(comm_hash, &exp_hash); + + comm_hash = 0; + CombineHashCommutative(654, &comm_hash); + CombineHashCommutative(987, &comm_hash); + + CombineHash(30, &exp_hash); + CombineHash(comm_hash, &exp_hash); + + EXPECT_THAT(GetTopoHash(2, &sn1), Eq(exp_hash)); + EXPECT_THAT(RefTopoHash(&sn1), SizeIs(3)); +} + +TEST_F(SigNodeTest, ComputeTopoHashFinal) { + NodeDef node1 = MakeNodeConst("node1"); + SigNode sn1(&node1); + NodeDef node2 = MakeNodeConst("node2"); + SigNode sn2(&node2); + NodeDef node3 = MakeNodeConst("node3"); + SigNode sn3(&node3); + + // Fake a topology - same as for ComputeTopoHashNotFinal. + RefUniqueRank(&sn1) = 0; + RefNodeMask(&sn1) = 0x01; + RefUniqueRank(&sn2) = 0; + RefNodeMask(&sn2) = 0x02; + RefUniqueRank(&sn3) = 0; + RefNodeMask(&sn3) = 0x04; + + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn2)); + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn3)); + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(20, &sn2)); + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn3)); + RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn2)); + + // Fake a state - mostly same as for ComputeTopoHashNotFinal. + RefTopoHash(&sn1).emplace_back(123); + RefTopoHash(&sn1).emplace_back(321); + + RefTopoHash(&sn2).emplace_back(456); + RefTopoHash(&sn2).emplace_back(654); + + RefTopoHash(&sn3).emplace_back(789); + RefTopoHash(&sn3).emplace_back(987); + + // These values are not realistic in the way that they don't include the bits + // from the mask of nodes themselves, but that's the point of this test: only + // the previous nodes' node sets are used in the computation, not their own + // masks directly. + RefLastHashedNodes(&sn1) = 0x8; + RefLastHashedNodes(&sn2) = 0x10; + RefLastHashedNodes(&sn3) = 0x20; + + // A scratch value to get overwritten. + RefNextHashedNodes(&sn1) = 0x100; + + // This is the difference in configuration. + RefHashIsFinal(&sn1) = true; + + ComputeTopoHash(2, &sn1); + + EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x8)); // Unchanged. + EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x8)); + EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2)); + EXPECT_THAT(GetTopoHash(2, &sn1), Eq(321)); +} + +TEST_F(SigNodeTest, EqualsOpcode) { + NodeDef node1 = MakeNodeConst("node1"); + SigNode sn1(&node1); + + NodeDef node2 = MakeNodeConst("node2"); + SigNode sn2(&node2); + + EXPECT_TRUE(sn1 == sn2); + EXPECT_FALSE(sn1 != sn2); + + node2.set_op("Mul"); + + EXPECT_TRUE(sn1 != sn2); + EXPECT_FALSE(sn1 == sn2); +} + +TEST_F(SigNodeTest, EqualsRank) { + NodeDef node1 = MakeNodeConst("node1"); + SigNode sn1(&node1); + + NodeDef node2 = MakeNodeConst("node2"); + SigNode sn2(&node2); + + EXPECT_TRUE(sn1 == sn2); + EXPECT_FALSE(sn1 != sn2); + + RefUniqueRank(&sn1) = 1; + RefUniqueRank(&sn2) = 2; + + EXPECT_TRUE(sn1 != sn2); + EXPECT_FALSE(sn1 == sn2); +} + +// Checks that if the nodes have a different number of links, +// they will be considered unequal. +TEST_F(SigNodeTest, EqualsLinkSize) { + GraphDef graph1; + (*graph1.add_node()) = MakeNodeConst("node1"); + (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1"); + + GenNodeMap gen_map1; + ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(Status::OK())); + + Subgraph::Identity id1; + id1.insert(gen_map1["node1"].get()); + id1.insert(gen_map1["node2"].get()); + Subgraph sg1(id1); + + SigNodeMap sig_map1; + sg1.ExtractForSignature(&sig_map1); + + GraphDef graph2; + (*graph2.add_node()) = MakeNodeConst("node1"); + // The difference between graph1 and graph2: one more input. + auto node22 = graph2.add_node(); + *node22 = MakeNodeMul("node2", "node1", "node1"); + node22->add_input("node2"); + + GenNodeMap gen_map2; + ASSERT_THAT(GenNode::BuildGraphInMap(graph2, &gen_map2), Eq(Status::OK())); + + Subgraph::Identity id2; + id2.insert(gen_map2["node1"].get()); + id2.insert(gen_map2["node2"].get()); + Subgraph sg2(id2); + + SigNodeMap sig_map2; + sg2.ExtractForSignature(&sig_map2); + + EXPECT_TRUE(*sig_map1["node1"] == *sig_map2["node1"]); + EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]); + EXPECT_FALSE(*sig_map2["node2"] == *sig_map1["node2"]); +} + +TEST_F(SigNodeTest, EqualsLinks) { + // Start with 2 copies of the same graph. + GraphDef graph1; + (*graph1.add_node()) = MakeNodeConst("node1"); + (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1"); + + GenNodeMap gen_map1; + ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(Status::OK())); + + Subgraph::Identity id1; + id1.insert(gen_map1["node1"].get()); + id1.insert(gen_map1["node2"].get()); + Subgraph sg1(id1); + + SigNodeMap sig_map1; + sg1.ExtractForSignature(&sig_map1); + + GenNodeMap gen_map2; + ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map2), Eq(Status::OK())); + + Subgraph::Identity id2; + id2.insert(gen_map2["node1"].get()); + id2.insert(gen_map2["node2"].get()); + Subgraph sg2(id2); + + SigNodeMap sig_map2; + sg2.ExtractForSignature(&sig_map2); + + EXPECT_TRUE(*sig_map1["node1"] == *sig_map2["node1"]); + EXPECT_TRUE(*sig_map1["node2"] == *sig_map2["node2"]); + + // Alter the link hash of one of the nodes. + SigNode* sn2 = sig_map2["node2"].get(); + ++RefHashedPeers(sn2)[0].link_hash; + + EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]); + + // Restore back. + --RefHashedPeers(sn2)[0].link_hash; + EXPECT_TRUE(*sig_map1["node2"] == *sig_map2["node2"]); + + // Alter the unique rank of a referenced node. + ++RefUniqueRank(sig_map2["node1"].get()); + + EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]); +} + +//=== + +class SignatureTest : public SigBaseTest { + protected: + // Initializeds the state used to generate the permutations of a given size. + static void InitPermutation(size_t size, + std::vector<size_t>* plain_permutation, + std::vector<size_t>* countdown) { + plain_permutation->clear(); + countdown->clear(); + for (size_t i = 0; i < size; ++i) { + plain_permutation->emplace_back(i); + countdown->emplace_back(size - 1 - i); + } + } + + // Builds a permutation guided by the count-down value. + static void BuildPermutation(const std::vector<size_t>& plain_permutation, + const std::vector<size_t>& countdown, + std::vector<size_t>* result) { + *result = plain_permutation; + for (int i = 0; i < result->size(); ++i) { + std::swap((*result)[i], (*result)[i + countdown[i]]); + } + } + + // Returns false when the count-down is finished. + static bool CountDown(std::vector<size_t>* countdown) { + // The last position always contains 0, so skip it. + int pos; + for (pos = countdown->size() - 2; pos >= 0; --pos) { + if ((*countdown)[pos] > 0) { + --(*countdown)[pos]; + break; + } + (*countdown)[pos] = (countdown->size() - 1 - pos); + } + + return pos >= 0; + } + + // Permutes the nodes every which way and checks that all the signatures + // produced are the same. This is reasonable for the graphs up to the + // size 5, maybe 6 at the stretch. After that the number of permutation grows + // huge and the test becomes very slow. + void TestGraphEveryWay(const GraphDef& graph) { + size_t graph_size = graph.node_size(); + + gen_map_.clear(); + sig_.map.clear(); + Status result = GenNode::BuildGraphInMap(graph, &gen_map_); + ASSERT_THAT(result, Eq(Status::OK())); + Subgraph::Identity id; + for (const auto& entry : gen_map_) { + id.insert(entry.second.get()); + } + Subgraph sg(id); + sg.ExtractForSignature(&sig_.map); + + std::vector<size_t> plain_permutation; + std::vector<size_t> countdown; + InitPermutation(graph_size, &plain_permutation, &countdown); + + std::set<string> signatures; + std::vector<size_t> permutation; + do { + BuildPermutation(plain_permutation, countdown, &permutation); + + constexpr bool kDebugPermutation = false; + if (kDebugPermutation) { + string p; + for (int i = 0; i < permutation.size(); ++i) { + p.push_back('0' + permutation[i]); + } + LOG(INFO) << "Permutation: " << p; + } + + std::vector<std::unique_ptr<SigNode>> hold(graph_size); + int idx; + + // Permute the nodes. + sig_.nodes.clear(); + idx = 0; + if (kDebugPermutation) { + LOG(INFO) << " nodes before permutation:"; + } + for (auto& entry : sig_.map) { + if (kDebugPermutation) { + LOG(INFO) << " " << entry.second.get(); + } + hold[idx++] = std::move(entry.second); + } + idx = 0; + if (kDebugPermutation) { + LOG(INFO) << " nodes after permutation:"; + } + for (auto& entry : sig_.map) { + entry.second = std::move(hold[permutation[idx++]]); + if (kDebugPermutation) { + LOG(INFO) << " " << entry.second.get(); + } + // This is used to order the links per permutation. + sig_.nodes.emplace_back(entry.second.get()); + RefUniqueRank(entry.second.get()) = idx; + } + // Order the links with the same tags per permutation. + OrderLinks(&sig_); + + // The test as such. + ASSERT_THAT(sig_.Compute(), Eq(Status::OK())); + + signatures.insert(sig_.ToString()); + + EXPECT_THAT(sig_.sig_full, SizeIs(graph_size)); + size_t hval = 0; + for (size_t ih : sig_.sig_full) { + // The space 1..graph_size is reserved. + EXPECT_THAT(ih, Gt(graph_size)); + CombineHash(ih, &hval); + } + EXPECT_THAT(sig_.sig_short, Eq(hval)); + + // Un-permute the nodes for the next iteration. + idx = 0; + for (auto& entry : sig_.map) { + hold[permutation[idx++]] = std::move(entry.second); + } + idx = 0; + if (kDebugPermutation) { + LOG(INFO) << " nodes after un-permutation:"; + } + for (auto& entry : sig_.map) { + entry.second = std::move(hold[idx++]); + if (kDebugPermutation) { + LOG(INFO) << " " << entry.second.get(); + } + } + } while (CountDown(&countdown)); + + for (const auto& s : signatures) { + LOG(INFO) << "Signature: " << s; + } + + // All the permutations should produce the same signature. + EXPECT_THAT(signatures, SizeIs(1)); + } +}; + +TEST_F(SignatureTest, PrepareNodes) { + NodeDef node1 = MakeNodeConst("node1"); + sig_.map["node1"] = absl::make_unique<SigNode>(&node1); + NodeDef node2 = MakeNodeConst("node2"); + sig_.map["node2"] = absl::make_unique<SigNode>(&node2); + NodeDef node3 = MakeNodeConst("node3"); + sig_.map["node3"] = absl::make_unique<SigNode>(&node3); + + PrepareNodes(&sig_); + + ASSERT_THAT(sig_.nodes, SizeIs(3)); + + int idx = 0; + for (const auto& entry : sig_.map) { + EXPECT_THAT(RefNodeMask(entry.second.get()), Eq(1 << idx)) + << " at index " << idx; + EXPECT_THAT(RefUniqueRank(entry.second.get()), Eq(static_cast<size_t>(~0))) + << " at index " << idx; + EXPECT_THAT(RefHashIsFinal(entry.second.get()), false) + << " at index " << idx; + EXPECT_THAT(RefTopoHash(entry.second.get()), SizeIs(1)) + << " at index " << idx; + ++idx; + } +} + +TEST_F(SignatureTest, FindUniqueHashesAllDifferent) { + NodeDef node1 = MakeNodeConst("node1"); + SigNode sn1(&node1); + NodeDef node2 = MakeNodeConst("node2"); + SigNode sn2(&node2); + NodeDef node3 = MakeNodeConst("node3"); + SigNode sn3(&node3); + NodeDef node4 = MakeNodeConst("node4"); + SigNode sn4(&node4); + + // The last values in the arrays values go in the backwards order. + RefTopoHash(&sn1).emplace_back(100); + RefTopoHash(&sn1).emplace_back(900); + + RefTopoHash(&sn2).emplace_back(200); + RefTopoHash(&sn2).emplace_back(800); + + RefTopoHash(&sn3).emplace_back(300); + RefTopoHash(&sn3).emplace_back(700); + + RefTopoHash(&sn4).emplace_back(400); + RefTopoHash(&sn4).emplace_back(600); + + sig_.nodes.emplace_back(&sn1); + sig_.nodes.emplace_back(&sn2); + sig_.nodes.emplace_back(&sn3); + sig_.nodes.emplace_back(&sn4); + + size_t next = 1; // Skips over sn1. + + FindUniqueHashes(&next, &sig_); + EXPECT_THAT(next, Eq(4)); + + EXPECT_THAT(sig_.nodes[0], Eq(&sn1)); + // The nodes after first one get sorted by the high hash. + EXPECT_THAT(sig_.nodes[1], Eq(&sn4)); + EXPECT_THAT(sig_.nodes[2], Eq(&sn3)); + EXPECT_THAT(sig_.nodes[3], Eq(&sn2)); + + EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false)); + // Nodes that get finalized are marked as such. + EXPECT_THAT(RefHashIsFinal(&sn2), Eq(true)); + EXPECT_THAT(RefHashIsFinal(&sn3), Eq(true)); + EXPECT_THAT(RefHashIsFinal(&sn4), Eq(true)); + + EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2)); + ASSERT_THAT(RefTopoHash(&sn2), SizeIs(1)); + ASSERT_THAT(RefTopoHash(&sn3), SizeIs(1)); + ASSERT_THAT(RefTopoHash(&sn4), SizeIs(1)); + + EXPECT_THAT(RefTopoHash(&sn2)[0], Eq(4)); + EXPECT_THAT(RefTopoHash(&sn3)[0], Eq(3)); + EXPECT_THAT(RefTopoHash(&sn4)[0], Eq(2)); + + EXPECT_THAT(sig_.sig_full, ElementsAre(600, 700, 800)); + + size_t exp_short_hash = 0; + CombineHash(600, &exp_short_hash); + CombineHash(700, &exp_short_hash); + CombineHash(800, &exp_short_hash); + EXPECT_THAT(sig_.sig_short, Eq(exp_short_hash)); +} + +TEST_F(SignatureTest, FindUniqueHashesDuplicatesExceptOne) { + NodeDef node1 = MakeNodeConst("node1"); + SigNode sn1(&node1); + NodeDef node2 = MakeNodeConst("node2"); + SigNode sn2(&node2); + NodeDef node3 = MakeNodeConst("node3"); + SigNode sn3(&node3); + NodeDef node4 = MakeNodeConst("node4"); + SigNode sn4(&node4); + NodeDef node5 = MakeNodeConst("node5"); + SigNode sn5(&node5); + + RefTopoHash(&sn1).emplace_back(100); + RefTopoHash(&sn1).emplace_back(600); + + RefTopoHash(&sn2).emplace_back(200); + RefTopoHash(&sn2).emplace_back(600); + + RefTopoHash(&sn3).emplace_back(300); + RefTopoHash(&sn3).emplace_back(700); + + RefTopoHash(&sn4).emplace_back(400); + RefTopoHash(&sn4).emplace_back(800); + + RefTopoHash(&sn5).emplace_back(500); + RefTopoHash(&sn5).emplace_back(800); + + sig_.nodes.emplace_back(&sn1); + sig_.nodes.emplace_back(&sn2); + sig_.nodes.emplace_back(&sn3); + sig_.nodes.emplace_back(&sn4); + sig_.nodes.emplace_back(&sn5); + + size_t next = 0; + + FindUniqueHashes(&next, &sig_); + EXPECT_THAT(next, Eq(1)); + + // The unique node goes first. + EXPECT_THAT(sig_.nodes[0], Eq(&sn3)); + + // The rest of the nodes are assumed to be sorted in a stable order. + EXPECT_THAT(sig_.nodes[1], Eq(&sn2)); + // Node 1 gets swapped with node 3. + EXPECT_THAT(sig_.nodes[2], Eq(&sn1)); + EXPECT_THAT(sig_.nodes[3], Eq(&sn4)); + EXPECT_THAT(sig_.nodes[4], Eq(&sn5)); + + EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false)); + EXPECT_THAT(RefHashIsFinal(&sn2), Eq(false)); + EXPECT_THAT(RefHashIsFinal(&sn3), Eq(true)); + EXPECT_THAT(RefHashIsFinal(&sn4), Eq(false)); + EXPECT_THAT(RefHashIsFinal(&sn5), Eq(false)); + + EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2)); + EXPECT_THAT(RefTopoHash(&sn2), SizeIs(2)); + EXPECT_THAT(RefTopoHash(&sn3), SizeIs(1)); + EXPECT_THAT(RefTopoHash(&sn4), SizeIs(2)); + EXPECT_THAT(RefTopoHash(&sn5), SizeIs(2)); + + EXPECT_THAT(RefTopoHash(&sn3)[0], Eq(1)); +} + +TEST_F(SignatureTest, FindUniqueHashesDuplicates) { + NodeDef node1 = MakeNodeConst("node1"); + SigNode sn1(&node1); + NodeDef node2 = MakeNodeConst("node2"); + SigNode sn2(&node2); + NodeDef node3 = MakeNodeConst("node3"); + SigNode sn3(&node3); + NodeDef node4 = MakeNodeConst("node4"); + SigNode sn4(&node4); + NodeDef node5 = MakeNodeConst("node5"); + SigNode sn5(&node5); + + RefTopoHash(&sn1).emplace_back(100); + RefTopoHash(&sn1).emplace_back(600); + + RefTopoHash(&sn2).emplace_back(200); + RefTopoHash(&sn2).emplace_back(600); + + RefTopoHash(&sn3).emplace_back(300); + RefTopoHash(&sn3).emplace_back(700); + + RefTopoHash(&sn4).emplace_back(400); + RefTopoHash(&sn4).emplace_back(700); + + RefTopoHash(&sn5).emplace_back(500); + RefTopoHash(&sn5).emplace_back(700); + + sig_.nodes.emplace_back(&sn1); + sig_.nodes.emplace_back(&sn2); + sig_.nodes.emplace_back(&sn3); + sig_.nodes.emplace_back(&sn4); + sig_.nodes.emplace_back(&sn5); + + size_t next = 0; + + FindUniqueHashes(&next, &sig_); + EXPECT_THAT(next, Eq(1)); + + // The last copy of the last duplicate wins. + EXPECT_THAT(sig_.nodes[0], Eq(&sn5)); + + // The rest of the nodes are assumed to be sorted in a stable order. + // Node 1 gets swapped. + EXPECT_THAT(sig_.nodes[1], Eq(&sn2)); + EXPECT_THAT(sig_.nodes[2], Eq(&sn3)); + EXPECT_THAT(sig_.nodes[3], Eq(&sn4)); + EXPECT_THAT(sig_.nodes[4], Eq(&sn1)); + + EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false)); + EXPECT_THAT(RefHashIsFinal(&sn2), Eq(false)); + EXPECT_THAT(RefHashIsFinal(&sn3), Eq(false)); + EXPECT_THAT(RefHashIsFinal(&sn4), Eq(false)); + EXPECT_THAT(RefHashIsFinal(&sn5), Eq(true)); + + EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2)); + EXPECT_THAT(RefTopoHash(&sn2), SizeIs(2)); + EXPECT_THAT(RefTopoHash(&sn3), SizeIs(2)); + EXPECT_THAT(RefTopoHash(&sn4), SizeIs(2)); + EXPECT_THAT(RefTopoHash(&sn5), SizeIs(1)); + + EXPECT_THAT(RefTopoHash(&sn5)[0], Eq(1)); +} + +// On a circular topology. +TEST_F(SignatureTest, ComputeOneRoundCircular) { + BuildSigMap(graph_circular_onedir_); + PrepareNodes(&sig_); + + ASSERT_THAT(sig_.nodes, SizeIs(5)); + + // This skips FindUniqueHashes() which would pick one node, so that + // all the nodes are equivalent for ComputeOneRound(). + + ComputeOneRound(0, &sig_); + + // All the nodes are the same, so the computed hashes will also be the same. + size_t hval = GetHighTopoHash(sig_.nodes[0]); + for (int i = 0; i < 5; ++i) { + EXPECT_THAT(GetHighTopoHash(sig_.nodes[i]), Eq(hval)) << " at index " << i; + EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i; + EXPECT_THAT(RefLastHashedNodes(sig_.nodes[i]), Eq(0x1F)) + << " at index " << i; + EXPECT_THAT(RefNextHashedNodes(sig_.nodes[i]), Eq(0x1F)) + << " at index " << i; + // The sets of hashed nodes go like this: + // Step 0: self. + // Step 1: self, previous (-1) and next (+1) node. + // Step 2: self, (-1), (-2), (+1), (+2): all 5 nodes in the graph + // Step 3: still all 5 nodes in the graph + EXPECT_THAT(RefTopoHash(sig_.nodes[i]), SizeIs(4)) << " at index " << i; + } +} + +// On a linear topology. +TEST_F(SignatureTest, ComputeOneRoundLinear) { + BuildSigMap(graph_linear_); + PrepareNodes(&sig_); + + ASSERT_THAT(sig_.nodes, SizeIs(5)); + + // This skips FindUniqueHashes() which would pick one node, so that + // all the nodes are equivalent for ComputeOneRound(). + + ComputeOneRound(0, &sig_); + + std::vector<size_t> hash_size; + for (int i = 0; i < 5; ++i) { + EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i; + EXPECT_THAT(RefLastHashedNodes(sig_.nodes[i]), Eq(0x1F)) + << " at index " << i; + EXPECT_THAT(RefNextHashedNodes(sig_.nodes[i]), Eq(0x1F)) + << " at index " << i; + hash_size.emplace_back(RefTopoHash(sig_.nodes[i]).size()); + } + + // The sets of hashed nodes for the central node go like this: + // Step 0: self. + // Step 1: self, previous (-1) and next (+1) node. + // Step 2: self, (-1), (-2), (+1), (+2): all 5 nodes in the graph + // Step 3: still all 5 nodes in the graph + // + // The nodes one step closer to the ends require one more step. The end nodes + // require one more step yet. + std::sort(hash_size.begin(), hash_size.end()); + EXPECT_THAT(hash_size, ElementsAre(4, 5, 5, 6, 6)); +} + +// On a linear topology where the cental node has been already marked as unique +// (yeah, not a very realistic case but tests the situations when the +// disconnected subgraphs get created). +TEST_F(SignatureTest, ComputeOneRoundSplitLinear) { + BuildSigMap(graph_linear_); + PrepareNodes(&sig_); + + ASSERT_THAT(sig_.nodes, SizeIs(5)); + + // This test relies on the order of SigNodeMap imposed on sig_.nodes. + + // The middle node gets separated by moving it to the front. + std::swap(sig_.nodes[0], sig_.nodes[2]); + ASSERT_THAT(RefNodeMask(sig_.nodes[0]), Eq(0x04)); + ASSERT_THAT(RefLastHashedNodes(sig_.nodes[0]), Eq(0x04)); + ASSERT_THAT(RefNextHashedNodes(sig_.nodes[0]), Eq(0x04)); + RefHashIsFinal(sig_.nodes[0]) = true; + + ComputeOneRound(1, &sig_); + + // These should stay unchanged. + EXPECT_THAT(RefLastHashedNodes(sig_.nodes[0]), Eq(0x04)); + EXPECT_THAT(RefNextHashedNodes(sig_.nodes[0]), Eq(0x04)); + + std::vector<size_t> hash_size; + for (int i = 1; i < 5; ++i) { + EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i; + hash_size.emplace_back(RefTopoHash(sig_.nodes[i]).size()); + } + + std::sort(hash_size.begin(), hash_size.end()); + // The end nodes take 4 steps, closer to the center 3 steps. + EXPECT_THAT(hash_size, ElementsAre(3, 3, 4, 4)); + + EXPECT_THAT(RefLastHashedNodes(sig_.nodes[1]), Eq(0x07)); + EXPECT_THAT(RefNextHashedNodes(sig_.nodes[1]), Eq(0x07)); + EXPECT_THAT(RefLastHashedNodes(sig_.nodes[2]), Eq(0x07)); + EXPECT_THAT(RefNextHashedNodes(sig_.nodes[2]), Eq(0x07)); + + EXPECT_THAT(RefLastHashedNodes(sig_.nodes[3]), Eq(0x1C)); + EXPECT_THAT(RefNextHashedNodes(sig_.nodes[3]), Eq(0x1C)); + EXPECT_THAT(RefLastHashedNodes(sig_.nodes[4]), Eq(0x1C)); + EXPECT_THAT(RefNextHashedNodes(sig_.nodes[4]), Eq(0x1C)); +} + +TEST_F(SignatureTest, OrderLinks) { + gen_map_.clear(); + sig_.map.clear(); + Status result = GenNode::BuildGraphInMap(graph_for_link_order_, &gen_map_); + ASSERT_THAT(result, Eq(Status::OK())); + Subgraph::Identity id; + for (const auto& entry : gen_map_) { + id.insert(entry.second.get()); + } + Subgraph sg(id); + sg.ExtractForSignature(&sig_.map); + + // Populate the fake signature and assign the ranks in the backwards order. + for (auto it = sig_.map.rbegin(); it != sig_.map.rend(); ++it) { + auto& entry = *it; + RefUniqueRank(entry.second.get()) = sig_.nodes.size(); + sig_.nodes.emplace_back(entry.second.get()); + } + + // How it was ordered in the original graph. + string before = sig_.ToString(); + // clang-format off + EXPECT_THAT(before, Eq( + "0:Mul[i0:o0:5][i0:o0:4][i0:o1:4][i0:o2:3][i0:o2:2][i0:o3:2]," + "1:Mul[i0:o0:5][i0:o0:4][i0:o0:3][i0:o0:2]," + "2:Const," + "3:Const," + "4:Const," + "5:Const," + )); + // clang-format on + + OrderLinks(&sig_); + + string after = sig_.ToString(); + // clang-format off + EXPECT_THAT(after, Eq( + "0:Mul[i0:o0:4][i0:o0:5][i0:o1:4][i0:o2:2][i0:o2:3][i0:o3:2]," + "1:Mul[i0:o0:2][i0:o0:3][i0:o0:4][i0:o0:5]," + "2:Const," + "3:Const," + "4:Const," + "5:Const," + )); + // clang-format on +} + +TEST_F(SignatureTest, GraphTooBig) { + GraphDef graph; + for (int i = 0; i <= Signature::kMaxGraphSize; ++i) { + (*graph.add_node()) = MakeNodeConst(absl::StrFormat("node%d", i)); + } + + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &gen_map_), Eq(Status::OK())); + + Subgraph::Identity id; + for (const auto& entry : gen_map_) { + id.insert(entry.second.get()); + } + Subgraph sg(id); + sg.ExtractForSignature(&sig_.map); + + ASSERT_THAT(sig_.Compute(), + Eq(Status(error::INVALID_ARGUMENT, + "A graph of 65 nodes is too big for signature " + "computation, the maximal supported node count is " + "64."))); +} + +TEST_F(SignatureTest, ToString) { + BuildSigMap(graph_circular_onedir_); + PrepareNodes(&sig_); + + ASSERT_THAT(sig_.nodes, SizeIs(5)); + + // Fake the works by assigning unique ranks as they go in the initial order. + for (int i = 0; i < 5; ++i) { + RefUniqueRank(sig_.nodes[i]) = i; + RefHashIsFinal(sig_.nodes[i]) = true; + } + + string result = sig_.ToString(); + + // clang-format off + ASSERT_THAT(result, Eq( + "0:Mul[i0:o0:4][i0:o0:4]," + "1:Mul[i0:o0:0][i0:o0:0]," + "2:Mul[i0:o0:1][i0:o0:1]," + "3:Mul[i0:o0:2][i0:o0:2]," + "4:Mul[i0:o0:3][i0:o0:3]," + )); + // clang-format on +} + +// This is a test of the permutation logic itself. +TEST_F(SignatureTest, Permutation) { + std::vector<size_t> plain_permutation; + std::vector<size_t> countdown; + InitPermutation(5, &plain_permutation, &countdown); + + std::set<string> results; + + std::vector<size_t> permutation; + do { + BuildPermutation(plain_permutation, countdown, &permutation); + EXPECT_THAT(permutation, SizeIs(5)); + + string p; + for (int i = 0; i < permutation.size(); ++i) { + p.push_back('0' + permutation[i]); + } + LOG(INFO) << "Permutation: " << p; + results.insert(p); + } while (CountDown(&countdown)); + + EXPECT_THAT(results, SizeIs(5 * 4 * 3 * 2 * 1)); +} + +TEST_F(SignatureTest, ComputeCircularOneDir) { + TestGraphEveryWay(graph_circular_onedir_); +} + +TEST_F(SignatureTest, ComputeCircularBiDir) { + TestGraphEveryWay(graph_circular_bidir_); +} + +TEST_F(SignatureTest, ComputeLinear) { TestGraphEveryWay(graph_linear_); } + +TEST_F(SignatureTest, ComputeMultiInput) { + TestGraphEveryWay(graph_multi_input_); +} + +TEST_F(SignatureTest, ComputeAllOrNone) { + TestGraphEveryWay(graph_all_or_none_); +} + +TEST_F(SignatureTest, ComputeCross) { TestGraphEveryWay(graph_small_cross_); } + +TEST_F(SignatureTest, Equals) { + // Start with 2 copies of the same graph. + GenNodeMap gen_map1; + ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map1), + Eq(Status::OK())); + + Subgraph::Identity id1; + id1.insert(gen_map1["node1"].get()); + id1.insert(gen_map1["node2"].get()); + Subgraph sg1(id1); + + Signature sig1; + sg1.ExtractForSignature(&sig1.map); + ASSERT_THAT(sig1.Compute(), Eq(Status::OK())); + + GenNodeMap gen_map2; + ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map2), + Eq(Status::OK())); + + Subgraph::Identity id2; + id2.insert(gen_map2["node1"].get()); + id2.insert(gen_map2["node2"].get()); + Subgraph sg2(id2); + + Signature sig2; + sg2.ExtractForSignature(&sig2.map); + ASSERT_THAT(sig2.Compute(), Eq(Status::OK())); + + EXPECT_TRUE(sig1 == sig2); + + // Change the short hash. + ++sig2.sig_short; + EXPECT_FALSE(sig1 == sig2); + + // Restore back. + --sig2.sig_short; + EXPECT_TRUE(sig1 == sig2); + + // Change the full hash. + ++sig2.sig_full[0]; + EXPECT_FALSE(sig1 == sig2); + + // Restore back. + --sig2.sig_full[0]; + EXPECT_TRUE(sig1 == sig2); + + // Make the nodes different. + std::swap(sig2.nodes[0], sig2.nodes[1]); + EXPECT_FALSE(sig1 == sig2); + + // Restore back. + std::swap(sig2.nodes[0], sig2.nodes[1]); + EXPECT_TRUE(sig1 == sig2); + + // Different number of nodes. + sig2.nodes.emplace_back(sig2.nodes[0]); + EXPECT_FALSE(sig1 == sig2); + EXPECT_FALSE(sig2 == sig1); +} + +} // end namespace test +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph.cc b/tensorflow/core/grappler/graph_analyzer/subgraph.cc new file mode 100644 index 0000000000..28a91e0f84 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/subgraph.cc @@ -0,0 +1,235 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/graph_analyzer/subgraph.h" + +#include <functional> + +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +//=== Subgraph::Identity + +Subgraph::Identity::Identity(InitializerList init) { + for (auto element : init) { + insert(element); + } +} + +bool Subgraph::Identity::operator<(const Identity& other) const { + // Shorter sets go first. + if (this->size() < other.size()) { + return true; + } + if (this->size() > other.size()) { + return false; + } + for (auto lit = this->begin(), rit = other.begin(); lit != this->end(); + ++lit, ++rit) { + if (*lit < *rit) { + return true; + } + if (*lit > *rit) { + return false; + } + } + return false; // Equal. +} + +bool Subgraph::Identity::operator==(const Identity& other) const { + if (this->size() != other.size()) { + return false; + } + for (auto lit = this->begin(), rit = other.begin(); lit != this->end(); + ++lit, ++rit) { + if (*lit != *rit) { + return false; + } + } + return true; // Equal. +} + +size_t Subgraph::Identity::Hash() const { + std::hash<const GenNode*> hasher; + size_t result = 0; + for (auto ptr : *this) { + CombineHash(hasher(ptr), &result); + } + return result; +} + +string Subgraph::Dump() { + // TODO(babkin): this is simplified for now. + std::vector<string> nodes; + for (const auto& n : id_) { + if (specific_) { + nodes.emplace_back(absl::StrFormat("%s(%s)", n->opcode(), n->name())); + } else { + nodes.emplace_back(n->opcode()); + } + } + std::sort(nodes.begin(), nodes.end()); + + return absl::StrFormat("%d: ", collation_count_) + absl::StrJoin(nodes, ", "); +} + +void Subgraph::ExtractForSignature(SigNodeMap* result) { + // Mapping of nodes from the original graph to the new one. + SigNode::TranslationMap full_to_new; + + for (auto node : id_) { + auto newnode_ref = absl::make_unique<SigNode>(node->node_def()); + auto newnode = newnode_ref.get(); + (*result)[node->name()] = std::move(newnode_ref); + full_to_new[node] = newnode; + } + + for (const auto& mapping : full_to_new) { + mapping.second->CopyLinks(*mapping.first, full_to_new); + } +} + +//=== Subgraph + +Subgraph::Subgraph(const Identity& parent_id, GenNode* add_node) + : id_(parent_id) { + id_.insert(add_node); + hash_ = id_.Hash(); +} + +//=== SubgraphIterator + +SubgraphIterator::SubgraphIterator(const Subgraph::Identity* id) + : id_(id), id_it_(id_->begin()) { + if (!id_->empty()) { + link_map_it_ = (*id_it_)->links().begin(); + // In case if the node has no links. + while (link_map_it_ == (*id_it_)->links().end()) { + if (++id_it_ == id_->end()) { + return; + } + link_map_it_ = (*id_it_)->links().begin(); + } + link_idx_ = 0; + // The LinkTargetVector should never be empty but just in case safeguard + // against that too. + PropagateNext(); + } +} + +bool SubgraphIterator::Next() { + if (AtEnd()) { + return false; + } + ++link_idx_; + return PropagateNext(); +} + +bool SubgraphIterator::NextIfSamePort() { + if (AtEnd()) { + return false; + } + if (link_idx_ + 1 < link_map_it_->second.size()) { + ++link_idx_; + return true; + } else { + return false; + } +} + +void SubgraphIterator::SkipPort() { + if (AtEnd()) { + return; + } + link_idx_ = link_map_it_->second.size() - 1; +} + +void SubgraphIterator::SkipNode() { + if (AtEnd()) { + return; + } + for (auto next = link_map_it_; next != (*id_it_)->links().end(); ++next) { + link_map_it_ = next; + } + link_idx_ = link_map_it_->second.size() - 1; +} + +bool SubgraphIterator::PropagateNext() { + // Loops are used to skip over the empty entries. + while (link_idx_ >= link_map_it_->second.size()) { + ++link_map_it_; + while (link_map_it_ == (*id_it_)->links().end()) { + if (++id_it_ == id_->end()) { + return false; + } + link_map_it_ = (*id_it_)->links().begin(); + } + link_idx_ = 0; + } + return true; +} + +bool SubgraphIterator::operator==(const SubgraphIterator& other) const { + if (id_ != other.id_) { + return false; + } + if (id_it_ != other.id_it_) { + return false; + } + // When AtEnd(), the rest of the fields are not valid. + if (AtEnd()) { + return true; + } + if (link_map_it_ != other.link_map_it_) { + return false; + } + if (link_idx_ != other.link_idx_) { + return false; + } + return true; +} + +//=== SubgraphPtrSet + +Subgraph* SubgraphPtrSet::ExtendParent(const Subgraph::Identity& parent_id, + GenNode* node) { + if (parent_id.find(node) != parent_id.end()) { + // This was another link to the node that is already in the parent. + return nullptr; + } + + // Constructing an object just to check that an equivalent one is already + // present is kind of ugly but storing the references rather than the objects + // in the set avoids the need to make the object copyable. + auto sg = absl::make_unique<Subgraph>(parent_id, node); + if (find(sg) != end()) { + // This subgraph was already found by extending from a different path. + return nullptr; + } + + Subgraph* ptr = sg.get(); + insert(std::move(sg)); + return ptr; +} + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph.h b/tensorflow/core/grappler/graph_analyzer/subgraph.h new file mode 100644 index 0000000000..4de31d5dfa --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/subgraph.h @@ -0,0 +1,189 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_ + +#include <initializer_list> +#include <set> + +#include "tensorflow/core/grappler/graph_analyzer/gen_node.h" +#include "tensorflow/core/grappler/graph_analyzer/map_tools.h" +#include "tensorflow/core/grappler/graph_analyzer/sig_node.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +// The description of a single subgraph for processing. +class Subgraph { + public: + // Identity of a single subgraph as a set of nodes. + class Identity : public gtl::FlatSet<const GenNode*> { + public: + using InitializerList = std::initializer_list<GenNode*>; + + Identity() = default; + Identity(InitializerList init); + bool operator<(const Identity& other) const; + bool operator==(const Identity& other) const; + + // Compute the hash. + size_t Hash() const; + }; + + explicit Subgraph(Identity id) : id_(std::move(id)), hash_(id_.Hash()) {} + + // Construct by extending the parent identity with an extra node. + Subgraph(const Identity& parent_id, GenNode* add_node); + + Subgraph() = delete; + Subgraph(const Subgraph& other) = delete; + void operator=(const Subgraph& other) = delete; + + // Order for building sets of subgraphs. + bool operator<(const Subgraph& other) const { return this->id_ < other.id_; } + // Support for hashed sets. + bool operator==(const Subgraph& other) const { + return this->id_ == other.id_; + } + size_t Hash() const { return hash_; } + + // Dump the subgraph information to a string. + string Dump(); + + // Extract this subgraph into a separate graph representation for signature + // building, that includes only the links between the nodes in the subgraph + // and drops all the external links. The result map should be clear before the + // call. + void ExtractForSignature(SigNodeMap* result); + + const Identity& id() const { return id_; } + bool specific() const { return specific_; } + void SetSpecific(bool value) { specific_ = value; } + int32_t collation_count() const { return collation_count_; } + void AddCollation(int32_t n = 1) { collation_count_ += n; } + void ResetCollation() { collation_count_ = 1; } + void MergeCollation(const Subgraph& other) { + collation_count_ += other.collation_count_; + } + + private: + // Identity also serves as the list of nodes. It never changes throughout the + // life of subgraph. + Identity id_; + size_t hash_; // Cached from the identity. + // Whether the dump should include the specific names of the nodes. The + // non-specific (i.e. generic) subgraphs represent a collation of multiple + // subgraphs. + bool specific_ = true; + // How many collated subgraphs are represented by this subgraph. + int32_t collation_count_ = 1; +}; + +// Iteration of all links in a subgraph. This is more like Java iterators than +// the normal C++ iterators. It's simpler this way and there seems to be no +// major reason to make it a proper C++ iterator. +class SubgraphIterator { + public: + // Obviously an iterator is valid only until the original object + // gets destroyed. + explicit SubgraphIterator(const Subgraph::Identity* id); + explicit SubgraphIterator(const Subgraph* sg) : SubgraphIterator(&sg->id()) {} + + // Check whether the built-in iterator is at the end. + bool AtEnd() const { return id_it_ == id_->end(); } + + // Get the neighbor at the current iterator. + // MUST NOT be called when AtEnd(); + const GenNode::LinkTarget& GetNeighbor() const { + return link_map_it_->second[link_idx_]; + } + + // Get the node at the current iterator. + // MUST NOT be called when AtEnd(); + const GenNode* GetNode() const { return *id_it_; } + + // Get the port leading to the neighbor at the current iterator. + // MUST NOT be called when AtEnd(); + GenNode::Port GetPort() const { return link_map_it_->first; } + + // Increases the iterator. + // Returns true if NOT AtEnd() after increasing the iterator. + // Safe to call if already AtEnd(). + bool Next(); + + // If there are more links at the same port, increases the iterator and + // returns true. Otherwise leaves the iterator unchanged and returns false. + bool NextIfSamePort(); + + // Increases the iterator directly to the last position on the current port + // (or if already there then doesn't increase). Equivalent to calling + // NextIfSamePort() while it returns true, but faster. + // Safe to call if already AtEnd(). + void SkipPort(); + + // Increases the iterator directly to the last position on the current node. + // Safe to call if already AtEnd(). + void SkipNode(); + + // Returns true if the iterators are exactly the same. + bool operator==(const SubgraphIterator& other) const; + bool operator!=(const SubgraphIterator& other) const { + return !(*this == other); + } + + private: + // After link_idx_ has been increased, make sure that it points to the + // next valid element (or end) by increasing the higher levels of iteration if + // needed. + // Returns true if NOT AtEnd() after increasing the iterator. + // NOT safe to call if already AtEnd(). + bool PropagateNext(); + + // Identity of the subgraph being iterated over. + const Subgraph::Identity* id_; + + // The current position, allowing to iterate through the links (see the + // reasoning for it in the public section). + // + // (1) Iterator of the nodes in the subgraph. + Subgraph::Identity::const_iterator id_it_; + // (2) Iterator in the link map of the node. + GenNode::LinkMap::const_iterator link_map_it_; + // (3) Index in the vector of the links. + int32_t link_idx_; +}; + +// A convenient way to store subgraphs: in a set of unique_ptrs. This way the +// addresses of subgraph objects will stay stable, and the objects themselves +// won't be copied. +class SubgraphPtrSet + : public std::unordered_set<std::unique_ptr<Subgraph>, + HashAtPtr<std::unique_ptr<Subgraph>>, + EqAtPtr<std::unique_ptr<Subgraph>>> { + public: + // Attempts to extend the set by adding a new subgraph that gets created by + // adding one node to the parent subgraph. If such a subgraph already exists, + // returns nullptr, otherwise returns the pointer to the new subgraph. + Subgraph* ExtendParent(const Subgraph::Identity& parent_id, GenNode* node); +}; + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_ diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc b/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc new file mode 100644 index 0000000000..0f90dc8f0d --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc @@ -0,0 +1,348 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/graph_analyzer/subgraph.h" + +#include <algorithm> +#include <string> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "tensorflow/core/grappler/graph_analyzer/test_tools.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { +namespace test { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Ne; + +TEST(SubgraphTest, Comparison) { + GraphDef graph; + // A topology with a loop. + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeConst("node2"); + GenNodeMap map; + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK())); + auto gn1 = map["node1"].get(); + auto gn2 = map["node2"].get(); + ASSERT_THAT(gn1, Ne(nullptr)); + ASSERT_THAT(gn2, Ne(nullptr)); + + Subgraph::Identity id1; + Subgraph::Identity id2; + + id1.insert(gn1); + id2.insert(gn2); + + Subgraph sg1(id1); + Subgraph sg2(id2); + + EXPECT_TRUE(id1 == sg1.id()); + EXPECT_TRUE(id2 == sg2.id()); + + EXPECT_THAT(sg1 < sg2, Eq(id1 < id2)); +} + +TEST(SubgraphTest, EmptyIteration) { + NodeDef node1 = MakeNodeConst("node1"); + auto gn1 = absl::make_unique<GenNode>(&node1); + Subgraph::Identity id1; + id1.insert(gn1.get()); + Subgraph sg1(id1); + SubgraphIterator sit(&sg1); + + EXPECT_TRUE(sit.AtEnd()); + EXPECT_FALSE(sit.Next()); + EXPECT_TRUE(sit.AtEnd()); + + SubgraphIterator sit2(&sg1); + EXPECT_TRUE(sit == sit2); +} + +TEST(SubgraphTest, Iteration) { + GraphDef graph; + // A topology with a loop. + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0"); + auto node3 = graph.add_node(); + *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2"); + node3->add_input("^node3"); // The control link goes back to self. + + GenNodeMap map; + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK())); + ASSERT_THAT(map.find("node3"), Ne(map.end())); + + Subgraph::Identity id; + id.insert(map["node3"].get()); + Subgraph sg(id); + + // node3 has 2 incoming data links, 2 outgoing data , 1 control incoming, 1 + // control outgoing = total of 6 + { + SubgraphIterator sit(&sg); + EXPECT_FALSE(sit.AtEnd()); // 1 + EXPECT_TRUE(sit.Next()); + EXPECT_FALSE(sit.AtEnd()); // 2 + EXPECT_TRUE(sit.Next()); + EXPECT_FALSE(sit.AtEnd()); // 3 + EXPECT_TRUE(sit.Next()); + EXPECT_FALSE(sit.AtEnd()); // 4 + EXPECT_TRUE(sit.Next()); + EXPECT_FALSE(sit.AtEnd()); // 5 + EXPECT_TRUE(sit.Next()); + EXPECT_FALSE(sit.AtEnd()); // 6 + EXPECT_FALSE(sit.Next()); + EXPECT_TRUE(sit.AtEnd()); + } + + // Now get the values out. And more equality testing along the way. + { + SubgraphIterator sit(&sg); + SubgraphIterator sit2(&sg); + std::vector<string> links; + for (; !sit.AtEnd(); sit.Next()) { + EXPECT_TRUE(sit == sit2); + sit2.Next(); + EXPECT_FALSE(sit == sit2); + + links.push_back(absl::StrFormat("[%s,%s,%s]", string(sit.GetPort()), + sit.GetNeighbor().node->name(), + string(sit.GetNeighbor().port))); + } + EXPECT_TRUE(sit == sit2); + + std::sort(links.begin(), links.end()); + // clang-format off + EXPECT_THAT(links, ElementsAre( + "[i0,node1,o0]", + "[i1,node2,o0]", + "[iC,node3,oC]", + "[o0,node2,i1]", + "[o1,node2,i0]", + "[oC,node3,iC]" + )); + // clang-format on + } +} + +TEST(SubgraphTest, IterationSamePort) { + GraphDef graph; + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeSub("node2", "node3", "node3"); + (*graph.add_node()) = MakeNodeAddN("node3", "node1", "node2"); + + GenNodeMap map; + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK())); + ASSERT_THAT(map.find("node3"), Ne(map.end())); + + Subgraph::Identity id; + id.insert(map["node3"].get()); + Subgraph sg(id); + + int total_links = 0; + for (SubgraphIterator sit(&sg); !sit.AtEnd(); sit.Next()) { + ++total_links; + } + + // Initialize the port as control, which doesn't occur in this graph. + GenNode::Port last_port(false, -1); + int steps_total_same_port = 0; + int steps_with_same_port = 0; + for (SubgraphIterator sit(&sg); !sit.AtEnd(); sit.Next()) { + GenNode::Port new_port = sit.GetPort(); + EXPECT_THAT(last_port.Encoded(), Ne(new_port.Encoded())) + << "At step " << steps_total_same_port; + last_port = new_port; + + ++steps_total_same_port; + + SubgraphIterator sit2(sit); + sit2.SkipPort(); + + while (sit.NextIfSamePort()) { + new_port = sit.GetPort(); + EXPECT_THAT(last_port.Encoded(), Eq(new_port.Encoded())) + << "At step " << steps_total_same_port; + ++steps_total_same_port; + ++steps_with_same_port; + } + + EXPECT_TRUE(sit == sit2); + } + + EXPECT_THAT(steps_total_same_port, Eq(total_links)); + // There is one 2-way input and one 2-way output. + EXPECT_THAT(steps_with_same_port, Eq(2)); +} + +TEST(SubgraphTest, IterationSameNode) { + GraphDef graph; + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeSub("node2", "node3", "node3"); + (*graph.add_node()) = MakeNodeAddN("node3", "node1", "node2"); + + GenNodeMap map; + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK())); + ASSERT_THAT(map.find("node3"), Ne(map.end())); + + Subgraph::Identity id; + id.insert(map["node3"].get()); + Subgraph sg(id); + + const GenNode* last_node = nullptr; + SubgraphIterator sit(&sg); + while (!sit.AtEnd()) { + const GenNode* new_node = sit.GetNode(); + + EXPECT_THAT(new_node, Ne(last_node)) << "At node " << new_node->name(); + + SubgraphIterator sit2(sit); + sit2.SkipNode(); + + ASSERT_FALSE(sit2.AtEnd()); + EXPECT_THAT(sit2.GetNode(), Eq(new_node)) + << "At expected node " << new_node->name() << ", got " + << sit2.GetNode()->name(); + + while (sit != sit2 && !sit.AtEnd()) { + sit.Next(); + } + + ASSERT_FALSE(sit.AtEnd()); + EXPECT_THAT(sit.GetNode(), Eq(new_node)) + << "At expected node " << new_node->name() << ", got " + << sit2.GetNode()->name(); + + sit.Next(); + + last_node = new_node; + } + + // Check that it doesn't fail if already at end. + sit.SkipNode(); + EXPECT_TRUE(sit.AtEnd()); +} + +TEST(SubgraphTest, ExtendSet) { + GraphDef graph; + // A topology with a loop. + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0"); + auto node3 = graph.add_node(); + *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2"); + node3->add_input("^node3"); // The control link goes back to self. + + GenNodeMap map; + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK())); + ASSERT_THAT(map.find("node2"), Ne(map.end())); + ASSERT_THAT(map.find("node3"), Ne(map.end())); + + Subgraph::Identity id_empty; + + Subgraph::Identity id3; + id3.insert(map["node3"].get()); + + Subgraph::Identity id23 = id3; + id23.insert(map["node2"].get()); + + Subgraph* sg; + SubgraphPtrSet set; + + // Extend an empty identity. + sg = set.ExtendParent(id_empty, map["node3"].get()); + EXPECT_THAT(set.size(), Eq(1)); + ASSERT_THAT(sg, Ne(nullptr)); + EXPECT_TRUE(sg->id() == id3); + + // Extend with a node that is already in the parent. + sg = set.ExtendParent(id3, map["node3"].get()); + EXPECT_THAT(set.size(), Eq(1)); + EXPECT_THAT(sg, Eq(nullptr)); + + // Extend to a 2-node subgraph. + sg = set.ExtendParent(id3, map["node2"].get()); + EXPECT_THAT(set.size(), Eq(2)); + ASSERT_THAT(sg, Ne(nullptr)); + EXPECT_TRUE(sg->id() == id23); + + // The second insert of the same node gets ignored. + sg = set.ExtendParent(id3, map["node2"].get()); + EXPECT_THAT(set.size(), Eq(2)); + EXPECT_THAT(sg, Eq(nullptr)); +} + +TEST(SubgraphTest, ExtractForSignature) { + GraphDef graph; + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0"); + auto node3 = graph.add_node(); + *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2"); + node3->add_input("^node1"); + node3->add_input("^node2"); + node3->add_input("^node3"); // The control link goes back to self. + + GenNodeMap map; + ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK())); + ASSERT_THAT(map.find("node1"), Ne(map.end())); + ASSERT_THAT(map.find("node2"), Ne(map.end())); + ASSERT_THAT(map.find("node3"), Ne(map.end())); + + Subgraph::Identity id; + id.insert(map["node1"].get()); + id.insert(map["node3"].get()); + + Subgraph sg(id); + + SigNodeMap map2; + sg.ExtractForSignature(&map2); + ASSERT_THAT(map2.find("node1"), Ne(map2.end())); + ASSERT_THAT(map2.find("node2"), Eq(map2.end())); + ASSERT_THAT(map2.find("node3"), Ne(map2.end())); + + // clang-format off + EXPECT_THAT(DumpLinkHashMap(map2["node1"]->hash_to_link()), ElementsAre( + "oC:iC: node3", + "o0:i0: node3" + )); + EXPECT_THAT(DumpHashedPeerVector(map2["node1"]->hashed_peers()), ElementsAre( + "node3", + "node3" + )); + EXPECT_THAT(DumpLinkHashMap(map2["node3"]->hash_to_link()), ElementsAre( + "oC:iC: node3", + "iC:oC: node1, node3", + "i0:o0: node1" + )); + EXPECT_THAT(DumpHashedPeerVector(map2["node3"]->hashed_peers()), ElementsAre( + "node3", + "node1", + "node3", + "node1" + )); + // clang-format on +} + +} // end namespace +} // end namespace test +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_analyzer/test_tools.cc b/tensorflow/core/grappler/graph_analyzer/test_tools.cc new file mode 100644 index 0000000000..fc9495bc7d --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/test_tools.cc @@ -0,0 +1,296 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/graph_analyzer/test_tools.h" + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { +namespace test { + +//=== Helper methods to construct the nodes. + +NodeDef MakeNodeConst(const string& name) { + NodeDef n; + n.set_name(name); + n.set_op("Const"); + return n; +} + +NodeDef MakeNode2Arg(const string& name, const string& opcode, + const string& arg1, const string& arg2) { + NodeDef n; + n.set_name(name); + n.set_op(opcode); + n.add_input(arg1); + n.add_input(arg2); + return n; +} + +NodeDef MakeNode4Arg(const string& name, const string& opcode, + const string& arg1, const string& arg2, const string& arg3, + const string& arg4) { + NodeDef n; + n.set_name(name); + n.set_op(opcode); + n.add_input(arg1); + n.add_input(arg2); + n.add_input(arg3); + n.add_input(arg4); + return n; +} + +// Not really a 2-argument but convenient to construct. +NodeDef MakeNodeShapeN(const string& name, const string& arg1, + const string& arg2) { + // This opcode is multi-input but not commutative. + return MakeNode2Arg(name, "ShapeN", arg1, arg2); +} + +// Not really a 2-argument but convenient to construct. +NodeDef MakeNodeIdentityN(const string& name, const string& arg1, + const string& arg2) { + // The argument is of a list type. + return MakeNode2Arg(name, "IdentityN", arg1, arg2); +} + +NodeDef MakeNodeQuantizedConcat(const string& name, const string& arg1, + const string& arg2, const string& arg3, + const string& arg4) { + // This opcode has multiple multi-inputs. + return MakeNode4Arg(name, "QuantizedConcat", arg1, arg2, arg3, arg4); +} + +//=== Helper methods for analysing the structures. + +std::vector<string> DumpLinkMap(const GenNode::LinkMap& link_map) { + // This will order the entries first. + std::map<string, string> ordered; + for (const auto& link : link_map) { + string key = string(link.first); + + // Order the other sides too. They may be repeating, so store them + // in a multiset. + std::multiset<string> others; + for (const auto& other : link.second) { + others.emplace( + absl::StrFormat("%s[%s]", other.node->name(), string(other.port))); + } + ordered[key] = absl::StrJoin(others, ", "); + } + // Now dump the result in a predictable order. + std::vector<string> result; + result.reserve(ordered.size()); + for (const auto& link : ordered) { + result.emplace_back(link.first + ": " + link.second); + } + return result; +} + +std::vector<string> DumpLinkHashMap(const SigNode::LinkHashMap& link_hash_map) { + // The entries in this map are ordered by hash value which might change + // at any point. Re-order them by the link tag. + std::map<SigNode::LinkTag, size_t> tags; + for (const auto& entry : link_hash_map) { + tags[entry.second.tag] = entry.first; + } + + std::vector<string> result; + for (const auto& id : tags) { + // For predictability, the nodes need to be sorted. + std::vector<string> nodes; + for (const auto& peer : link_hash_map.at(id.second).peers) { + nodes.emplace_back(peer->name()); + } + std::sort(nodes.begin(), nodes.end()); + result.emplace_back(string(id.first.local) + ":" + string(id.first.remote) + + ": " + absl::StrJoin(nodes, ", ")); + } + return result; +} + +std::vector<string> DumpHashedPeerVector( + const SigNode::HashedPeerVector& hashed_peers) { + std::vector<string> result; + + // Each subset of nodes with the same hash has to be sorted by name. + // Other than that, the vector is already ordered by full tags. + size_t last_hash = 0; + // Index, since iterators may get invalidated on append. + size_t subset_start = 0; + + for (const auto& entry : hashed_peers) { + if (entry.link_hash != last_hash) { + std::sort(result.begin() + subset_start, result.end()); + subset_start = result.size(); + } + result.emplace_back(entry.peer->name()); + } + std::sort(result.begin() + subset_start, result.end()); + + return result; +} + +TestGraphs::TestGraphs() { + { + GraphDef& graph = graph_3n_self_control_; + // The topology includes a loop and a link to self. + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0"); + auto node3 = graph.add_node(); + *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2"); + node3->add_input("^node3"); // The control link goes back to self. + } + { + GraphDef& graph = graph_multi_input_; + // The topology includes a loop and a link to self. + (*graph.add_node()) = MakeNodeConst("const1_1"); + (*graph.add_node()) = MakeNodeConst("const1_2"); + (*graph.add_node()) = MakeNodeAddN("add1", "const1_1", "const1_2"); + + (*graph.add_node()) = MakeNodeConst("const2_1"); + (*graph.add_node()) = MakeNodeConst("const2_2"); + (*graph.add_node()) = MakeNodeConst("const2_3"); + + auto add2 = graph.add_node(); + *add2 = MakeNodeAddN("add2", "const2_1", "const2_2"); + // The 3rd node is connected twice, to 4 links total. + add2->add_input("const2_3"); + add2->add_input("const2_3"); + + (*graph.add_node()) = MakeNodeSub("sub", "add1", "add2"); + } + { + GraphDef& graph = graph_all_or_none_; + // The topology includes a loop and a link to self. + (*graph.add_node()) = MakeNodeConst("const1_1"); + (*graph.add_node()) = MakeNodeConst("const1_2"); + auto pass1 = graph.add_node(); + *pass1 = MakeNodeIdentityN("pass1", "const1_1", "const1_2"); + + (*graph.add_node()) = MakeNodeConst("const2_1"); + (*graph.add_node()) = MakeNodeConst("const2_2"); + (*graph.add_node()) = MakeNodeConst("const2_3"); + + auto pass2 = graph.add_node(); + *pass2 = MakeNodeIdentityN("pass2", "const2_1", "const2_2"); + // The 3rd node is connected twice, to 4 links total. + pass2->add_input("const2_3"); + pass2->add_input("const2_3"); + + // Add the control links, they get handled separately than the normal + // links. + pass1->add_input("^const2_1"); + pass1->add_input("^const2_2"); + pass1->add_input("^const2_3"); + + (*graph.add_node()) = MakeNodeSub("sub", "pass1", "pass2"); + } + { + GraphDef& graph = graph_circular_onedir_; + (*graph.add_node()) = MakeNodeMul("node1", "node5", "node5"); + (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1"); + (*graph.add_node()) = MakeNodeMul("node3", "node2", "node2"); + (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3"); + (*graph.add_node()) = MakeNodeMul("node5", "node4", "node4"); + } + { + GraphDef& graph = graph_circular_bidir_; + // The left and right links are intentionally mixed up. + (*graph.add_node()) = MakeNodeMul("node1", "node5", "node2"); + (*graph.add_node()) = MakeNodeMul("node2", "node3", "node1"); + (*graph.add_node()) = MakeNodeMul("node3", "node2", "node4"); + (*graph.add_node()) = MakeNodeMul("node4", "node5", "node3"); + (*graph.add_node()) = MakeNodeMul("node5", "node4", "node1"); + } + { + GraphDef& graph = graph_linear_; + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1"); + (*graph.add_node()) = MakeNodeMul("node3", "node2", "node2"); + (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3"); + (*graph.add_node()) = MakeNodeMul("node5", "node4", "node4"); + } + { + GraphDef& graph = graph_cross_; + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1"); + (*graph.add_node()) = MakeNodeConst("node3"); + (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3"); + (*graph.add_node()) = MakeNodeConst("node5"); + (*graph.add_node()) = MakeNodeMul("node6", "node5", "node5"); + (*graph.add_node()) = MakeNodeConst("node7"); + (*graph.add_node()) = MakeNodeMul("node8", "node7", "node7"); + + auto center = graph.add_node(); + *center = MakeNodeMul("node9", "node2", "node4"); + center->add_input("node6"); + center->add_input("node8"); + } + { + GraphDef& graph = graph_small_cross_; + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeConst("node2"); + (*graph.add_node()) = MakeNodeConst("node3"); + (*graph.add_node()) = MakeNodeConst("node4"); + + auto center = graph.add_node(); + *center = MakeNodeMul("node5", "node1", "node2"); + center->add_input("node3"); + center->add_input("node4"); + } + { + GraphDef& graph = graph_for_link_order_; + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeConst("node2"); + (*graph.add_node()) = MakeNodeConst("node3"); + (*graph.add_node()) = MakeNodeConst("node4"); + + // One group of equivalent links. + auto center = graph.add_node(); + *center = MakeNodeMul("node5", "node1", "node2"); + center->add_input("node3"); + center->add_input("node4"); + + // Multiple groups, separated by unique links. + auto center2 = graph.add_node(); + *center2 = MakeNodeMul("node6", "node1", "node2"); + center2->add_input("node2:1"); + center2->add_input("node3:2"); + center2->add_input("node4:2"); + center2->add_input("node4:3"); + } + { + GraphDef& graph = graph_sun_; + (*graph.add_node()) = MakeNodeConst("node1"); + (*graph.add_node()) = MakeNodeConst("node2"); + (*graph.add_node()) = MakeNodeConst("node3"); + (*graph.add_node()) = MakeNodeConst("node4"); + (*graph.add_node()) = MakeNodeConst("node5"); + (*graph.add_node()) = MakeNodeSub("node6", "node1", "node10"); + (*graph.add_node()) = MakeNodeSub("node7", "node2", "node6"); + (*graph.add_node()) = MakeNodeSub("node8", "node3", "node7"); + (*graph.add_node()) = MakeNodeSub("node9", "node4", "node8"); + (*graph.add_node()) = MakeNodeSub("node10", "node5", "node9"); + } +} + +} // end namespace test +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_analyzer/test_tools.h b/tensorflow/core/grappler/graph_analyzer/test_tools.h new file mode 100644 index 0000000000..98e269d57e --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/test_tools.h @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_ + +#include <string> +#include <vector> + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/graph_analyzer/gen_node.h" +#include "tensorflow/core/grappler/graph_analyzer/sig_node.h" +#include "tensorflow/core/grappler/op_types.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { +namespace test { + +//=== Helper methods to construct the nodes. + +NodeDef MakeNodeConst(const string& name); + +NodeDef MakeNode2Arg(const string& name, const string& opcode, + const string& arg1, const string& arg2); + +NodeDef MakeNode4Arg(const string& name, const string& opcode, + const string& arg1, const string& arg2, const string& arg3, + const string& arg4); + +inline NodeDef MakeNodeMul(const string& name, const string& arg1, + const string& arg2) { + return MakeNode2Arg(name, "Mul", arg1, arg2); +} + +// Not really a 2-argument but convenient to construct. +inline NodeDef MakeNodeAddN(const string& name, const string& arg1, + const string& arg2) { + return MakeNode2Arg(name, "AddN", arg1, arg2); +} + +inline NodeDef MakeNodeSub(const string& name, const string& arg1, + const string& arg2) { + return MakeNode2Arg(name, "Sub", arg1, arg2); +} + +// Has 2 honest outputs. +inline NodeDef MakeNodeBroadcastGradientArgs(const string& name, + const string& arg1, + const string& arg2) { + return MakeNode2Arg(name, "BroadcastGradientArgs", arg1, arg2); +} + +NodeDef MakeNodeShapeN(const string& name, const string& arg1, + const string& arg2); + +NodeDef MakeNodeIdentityN(const string& name, const string& arg1, + const string& arg2); + +NodeDef MakeNodeQuantizedConcat(const string& name, const string& arg1, + const string& arg2, const string& arg3, + const string& arg4); + +//=== A container of pre-constructed graphs. + +class TestGraphs { + public: + TestGraphs(); + + // Graph with 3 nodes and a control link to self (which is not valid in + // reality but adds excitement to the tests). + GraphDef graph_3n_self_control_; + // Graph that has the multi-input links. + GraphDef graph_multi_input_; + // Graph that has the all-or-none nodes. + GraphDef graph_all_or_none_; + // All the nodes are connected in a circle that goes in one direction. + GraphDef graph_circular_onedir_; + // All the nodes are connected in a circle that goes in both directions. + GraphDef graph_circular_bidir_; + // The nodes are connected in a line. + GraphDef graph_linear_; + // The nodes are connected in a cross shape. + GraphDef graph_cross_; + GraphDef graph_small_cross_; + // For testing the ordering of links at the end of signature generation, + // a variation of a cross. + GraphDef graph_for_link_order_; + // Sun-shaped, a ring with "rays". + GraphDef graph_sun_; +}; + +//=== Helper methods for analysing the structures. + +std::vector<string> DumpLinkMap(const GenNode::LinkMap& link_map); + +// Also checks for the consistency of hash values. +std::vector<string> DumpLinkHashMap(const SigNode::LinkHashMap& link_hash_map); + +std::vector<string> DumpHashedPeerVector( + const SigNode::HashedPeerVector& hashed_peers); + +} // end namespace test +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_ diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f62a927925..5af6437c56 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3777,6 +3777,7 @@ tf_py_wrap_cc( "framework/python_op_gen.i", "grappler/cluster.i", "grappler/cost_analyzer.i", + "grappler/graph_analyzer.i", "grappler/item.i", "grappler/model_analyzer.i", "grappler/tf_optimizer.i", @@ -3835,6 +3836,7 @@ tf_py_wrap_cc( "//tensorflow/core/grappler/clusters:single_machine", "//tensorflow/core/grappler/clusters:virtual_cluster", "//tensorflow/core/grappler/costs:graph_memory", + "//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", "//tensorflow/core/grappler/optimizers:meta_optimizer", "//tensorflow/core:lib", "//tensorflow/core:reader_base", @@ -5536,6 +5538,18 @@ py_test( ], ) +py_binary( + name = "graph_analyzer", + srcs = [ + "grappler/graph_analyzer.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":framework_for_generated_wrappers", + ":pywrap_tensorflow_internal", + ], +) + pyx_library( name = "framework_fast_tensor_util", srcs = ["framework/fast_tensor_util.pyx"], diff --git a/tensorflow/python/grappler/graph_analyzer.i b/tensorflow/python/grappler/graph_analyzer.i new file mode 100644 index 0000000000..cc7b5358eb --- /dev/null +++ b/tensorflow/python/grappler/graph_analyzer.i @@ -0,0 +1,26 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%{ +#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h" +%} + +%{ +void GraphAnalyzer(const string& file_path, int n) { + tensorflow::grappler::graph_analyzer::GraphAnalyzerTool(file_path, n); +} +%} + +void GraphAnalyzer(const string& file_path, int n); diff --git a/tensorflow/python/grappler/graph_analyzer.py b/tensorflow/python/grappler/graph_analyzer.py new file mode 100644 index 0000000000..ec5544e38e --- /dev/null +++ b/tensorflow/python/grappler/graph_analyzer.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""A tool that finds all subgraphs of a given size in a TF graph. + +The subgraph patterns are sorted by occurrence, and only the transitive fanin +part of the graph with regard to the fetch nodes is considered. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +from tensorflow.python import pywrap_tensorflow as tf_wrap +from tensorflow.python.platform import app + + +def main(_): + tf_wrap.GraphAnalyzer(FLAGS.input, FLAGS.n) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--input", + type=str, + default=None, + help="Input file path for a TensorFlow MetaGraphDef.") + parser.add_argument( + "--n", type=int, default=None, help="The size of the subgraphs.") + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 26e8acd897..39174fa589 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -54,4 +54,5 @@ limitations under the License. %include "tensorflow/python/grappler/item.i" %include "tensorflow/python/grappler/tf_optimizer.i" %include "tensorflow/python/grappler/cost_analyzer.i" +%include "tensorflow/python/grappler/graph_analyzer.i" %include "tensorflow/python/grappler/model_analyzer.i" |