aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2018-08-27 15:11:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 15:14:54 -0700
commitab8e195d2e0978c21234a5632d4fabf47535eda1 (patch)
tree6145d69d652d9503b9a430d0b2a1a3f167e7139d /tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc
parent599bb5522d622b2f4feaee06b5dd194b04424345 (diff)
Open source graph analyzer.
PiperOrigin-RevId: 210439649
Diffstat (limited to 'tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc')
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc341
1 files changed, 341 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc
new file mode 100644
index 0000000000..f3796fcf86
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc
@@ -0,0 +1,341 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <deque>
+#include <iostream>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+GraphAnalyzer::GraphAnalyzer(const GraphDef& graph, int subgraph_size)
+ : graph_(graph), subgraph_size_(subgraph_size) {}
+
+GraphAnalyzer::~GraphAnalyzer() {}
+
+Status GraphAnalyzer::Run() {
+ // The signature computation code would detect this too, but better
+ // to report it up front than spend time computing all the graphs first.
+ if (subgraph_size_ > Signature::kMaxGraphSize) {
+ return Status(error::INVALID_ARGUMENT,
+ absl::StrFormat("Subgraphs of %d nodes are not supported, "
+ "the maximal supported node count is %d.",
+ subgraph_size_, Signature::kMaxGraphSize));
+ }
+
+ Status st = BuildMap();
+ if (!st.ok()) {
+ return st;
+ }
+
+ FindSubgraphs();
+ DropInvalidSubgraphs();
+ st = CollateResult();
+ if (!st.ok()) {
+ return st;
+ }
+
+ return Status::OK();
+}
+
+Status GraphAnalyzer::BuildMap() {
+ nodes_.clear();
+ return GenNode::BuildGraphInMap(graph_, &nodes_);
+}
+
+void GraphAnalyzer::FindSubgraphs() {
+ result_.clear();
+
+ if (subgraph_size_ < 1) {
+ return;
+ }
+
+ partial_.clear();
+ todo_.clear(); // Just in case.
+
+ // Start with all subgraphs of size 1.
+ const Subgraph::Identity empty_parent;
+ for (const auto& node : nodes_) {
+ if (subgraph_size_ == 1) {
+ result_.ExtendParent(empty_parent, node.second.get());
+ } else {
+ // At this point ExtendParent() is guaranteed to not return nullptr.
+ todo_.push_back(partial_.ExtendParent(empty_parent, node.second.get()));
+ }
+ }
+
+ // Then extend the subgraphs until no more extensions are possible.
+ while (!todo_.empty()) {
+ ExtendSubgraph(todo_.front());
+ todo_.pop_front();
+ }
+
+ partial_.clear();
+}
+
+void GraphAnalyzer::ExtendSubgraph(Subgraph* parent) {
+ bool will_complete = (parent->id().size() + 1 == subgraph_size_);
+ SubgraphPtrSet& sg_set = will_complete ? result_ : partial_;
+
+ const GenNode* last_all_or_none_node = nullptr;
+ for (SubgraphIterator sit(parent); !sit.AtEnd(); sit.Next()) {
+ const GenNode* node = sit.GetNode();
+ GenNode::Port port = sit.GetPort();
+ const GenNode::LinkTarget& neighbor = sit.GetNeighbor();
+
+ if (node->AllInputsOrNone() && port.IsInbound() && !port.IsControl()) {
+ if (node != last_all_or_none_node) {
+ ExtendSubgraphAllOrNone(parent, node);
+ last_all_or_none_node = node;
+ }
+ sit.SkipPort();
+ } else if (neighbor.node->AllInputsOrNone() && !port.IsInbound() &&
+ !port.IsControl()) {
+ if (parent->id().find(neighbor.node) == parent->id().end()) {
+ // Not added yet.
+ ExtendSubgraphAllOrNone(parent, neighbor.node);
+ }
+ } else if (node->IsMultiInput(port)) {
+ ExtendSubgraphPortAllOrNone(parent, node, port);
+ sit.SkipPort();
+ } else if (neighbor.node->IsMultiInput(neighbor.port)) {
+ // Would need to add all inputs of the neighbor node at this port at
+ // once.
+ if (parent->id().find(neighbor.node) != parent->id().end()) {
+ continue; // Already added.
+ }
+ ExtendSubgraphPortAllOrNone(parent, neighbor.node, neighbor.port);
+ } else {
+ Subgraph* sg = sg_set.ExtendParent(parent->id(), neighbor.node);
+ if (!will_complete && sg != nullptr) {
+ todo_.push_back(sg);
+ }
+ }
+ }
+}
+
+void GraphAnalyzer::ExtendSubgraphAllOrNone(Subgraph* parent,
+ const GenNode* node) {
+ Subgraph::Identity id = parent->id();
+ id.insert(node);
+
+ auto range_end = node->links().end();
+
+ for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) {
+ auto port = nbit->first;
+ if (!port.IsInbound() || port.IsControl()) {
+ continue;
+ }
+
+ // Since there might be multiple links to the same nodes,
+ // have to add all links one-by-one to check whether the subgraph
+ // would grow too large. But if it does grow too large, there is no
+ // point in growing it more, can just skip over the rest of the links.
+ for (const auto& link : nbit->second) {
+ id.insert(link.node);
+ if (id.size() > subgraph_size_) {
+ return; // Too big.
+ }
+ }
+ }
+
+ AddExtendedSubgraph(parent, id);
+}
+
+void GraphAnalyzer::ExtendSubgraphPortAllOrNone(Subgraph* parent,
+ const GenNode* node,
+ GenNode::Port port) {
+ auto nbit = node->links().find(port);
+ if (nbit == node->links().end()) {
+ return; // Should never happen.
+ }
+
+ Subgraph::Identity id = parent->id();
+ id.insert(node);
+
+ // Since there might be multiple links to the same nodes,
+ // have to add all links one-by-one to check whether the subgraph
+ // would grow too large. But if it does grow too large, there is no
+ // point in growing it more, can just skip over the rest of the links.
+ for (const auto& link : nbit->second) {
+ id.insert(link.node);
+ if (id.size() > subgraph_size_) {
+ return; // Too big.
+ }
+ }
+
+ AddExtendedSubgraph(parent, id);
+}
+
+void GraphAnalyzer::AddExtendedSubgraph(Subgraph* parent,
+ const Subgraph::Identity& id) {
+ if (id.size() == parent->id().size()) {
+ return; // Nothing new was added.
+ }
+
+ auto sg = absl::make_unique<Subgraph>(id);
+ SubgraphPtrSet& spec_sg_set =
+ (id.size() == subgraph_size_) ? result_ : partial_;
+ if (spec_sg_set.find(sg) != spec_sg_set.end()) {
+ // This subgraph was already found by extending from a different path.
+ return;
+ }
+
+ if (id.size() != subgraph_size_) {
+ todo_.push_back(sg.get());
+ }
+ spec_sg_set.insert(std::move(sg));
+}
+
+void GraphAnalyzer::DropInvalidSubgraphs() {
+ auto resit = result_.begin();
+ while (resit != result_.end()) {
+ if (HasInvalidMultiInputs(resit->get())) {
+ auto delit = resit;
+ ++resit;
+ result_.erase(delit);
+ } else {
+ ++resit;
+ }
+ }
+}
+
+bool GraphAnalyzer::HasInvalidMultiInputs(Subgraph* sg) {
+ // Do the all-or-none-input nodes.
+ for (auto const& node : sg->id()) {
+ if (!node->AllInputsOrNone()) {
+ continue;
+ }
+
+ bool anyIn = false;
+ bool anyOut = false;
+
+ auto range_end = node->links().end();
+ for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) {
+ auto port = nbit->first;
+ if (!port.IsInbound() || port.IsControl()) {
+ continue;
+ }
+
+ // Since there might be multiple links to the same nodes,
+ // have to add all links one-by-one to check whether the subgraph
+ // would grow too large. But if it does grow too large, there is no
+ // point in growing it more, can just skip over the rest of the links.
+ for (const auto& link : nbit->second) {
+ if (sg->id().find(link.node) == sg->id().end()) {
+ anyOut = true;
+ } else {
+ anyIn = true;
+ }
+ }
+ }
+
+ if (anyIn && anyOut) {
+ return true;
+ }
+ }
+
+ // Do the multi-input ports.
+ for (SubgraphIterator sit(sg); !sit.AtEnd(); sit.Next()) {
+ if (sit.GetNode()->IsMultiInput(sit.GetPort())) {
+ bool anyIn = false;
+ bool anyOut = false;
+ do {
+ GenNode* peer = sit.GetNeighbor().node;
+ if (sg->id().find(peer) == sg->id().end()) {
+ anyOut = true;
+ } else {
+ anyIn = true;
+ }
+ } while (sit.NextIfSamePort());
+
+ if (anyIn && anyOut) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+Status GraphAnalyzer::CollateResult() {
+ ordered_collation_.clear();
+ collation_map_.clear();
+
+ // Collate by the signatures of the graphs.
+ for (const auto& it : result_) {
+ auto sig = absl::make_unique<Signature>();
+ it->ExtractForSignature(&sig->map);
+ Status status = sig->Compute();
+ if (!status.ok()) {
+ return status;
+ }
+
+ auto& coll_entry = collation_map_[sig.get()];
+ if (coll_entry.sig == nullptr) {
+ coll_entry.sig = std::move(sig);
+ }
+ ++coll_entry.count;
+ }
+
+ // Then order them by the count.
+ for (auto& entry : collation_map_) {
+ ordered_collation_.insert(&entry.second);
+ }
+
+ result_.clear(); // Not needed after collation.
+
+ return Status::OK();
+}
+
+std::vector<string> GraphAnalyzer::DumpRawSubgraphs() {
+ std::vector<string> result;
+ for (const auto& it : result_) {
+ result.emplace_back(it->Dump());
+ }
+ return result;
+}
+
+std::vector<string> GraphAnalyzer::DumpSubgraphs() {
+ std::vector<string> result;
+ for (auto ptr : ordered_collation_) {
+ result.emplace_back(
+ absl::StrFormat("%d %s", ptr->count, ptr->sig->ToString()));
+ }
+ return result;
+}
+
+Status GraphAnalyzer::OutputSubgraphs() {
+ size_t total = 0;
+ for (auto ptr : ordered_collation_) {
+ std::cout << ptr->count << ' ' << ptr->sig->ToString() << '\n';
+ total += ptr->count;
+ }
+ std::cout << "Total: " << total << '\n';
+ if (std::cout.fail()) {
+ return Status(error::DATA_LOSS, "Failed to write to stdout");
+ } else {
+ return Status::OK();
+ }
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow