diff options
Diffstat (limited to 'tensorflow/core/grappler/graph_analyzer/subgraph.cc')
-rw-r--r-- | tensorflow/core/grappler/graph_analyzer/subgraph.cc | 235 |
1 files changed, 235 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph.cc b/tensorflow/core/grappler/graph_analyzer/subgraph.cc new file mode 100644 index 0000000000..28a91e0f84 --- /dev/null +++ b/tensorflow/core/grappler/graph_analyzer/subgraph.cc @@ -0,0 +1,235 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/graph_analyzer/subgraph.h" + +#include <functional> + +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +//=== Subgraph::Identity + +Subgraph::Identity::Identity(InitializerList init) { + for (auto element : init) { + insert(element); + } +} + +bool Subgraph::Identity::operator<(const Identity& other) const { + // Shorter sets go first. + if (this->size() < other.size()) { + return true; + } + if (this->size() > other.size()) { + return false; + } + for (auto lit = this->begin(), rit = other.begin(); lit != this->end(); + ++lit, ++rit) { + if (*lit < *rit) { + return true; + } + if (*lit > *rit) { + return false; + } + } + return false; // Equal. +} + +bool Subgraph::Identity::operator==(const Identity& other) const { + if (this->size() != other.size()) { + return false; + } + for (auto lit = this->begin(), rit = other.begin(); lit != this->end(); + ++lit, ++rit) { + if (*lit != *rit) { + return false; + } + } + return true; // Equal. +} + +size_t Subgraph::Identity::Hash() const { + std::hash<const GenNode*> hasher; + size_t result = 0; + for (auto ptr : *this) { + CombineHash(hasher(ptr), &result); + } + return result; +} + +string Subgraph::Dump() { + // TODO(babkin): this is simplified for now. + std::vector<string> nodes; + for (const auto& n : id_) { + if (specific_) { + nodes.emplace_back(absl::StrFormat("%s(%s)", n->opcode(), n->name())); + } else { + nodes.emplace_back(n->opcode()); + } + } + std::sort(nodes.begin(), nodes.end()); + + return absl::StrFormat("%d: ", collation_count_) + absl::StrJoin(nodes, ", "); +} + +void Subgraph::ExtractForSignature(SigNodeMap* result) { + // Mapping of nodes from the original graph to the new one. + SigNode::TranslationMap full_to_new; + + for (auto node : id_) { + auto newnode_ref = absl::make_unique<SigNode>(node->node_def()); + auto newnode = newnode_ref.get(); + (*result)[node->name()] = std::move(newnode_ref); + full_to_new[node] = newnode; + } + + for (const auto& mapping : full_to_new) { + mapping.second->CopyLinks(*mapping.first, full_to_new); + } +} + +//=== Subgraph + +Subgraph::Subgraph(const Identity& parent_id, GenNode* add_node) + : id_(parent_id) { + id_.insert(add_node); + hash_ = id_.Hash(); +} + +//=== SubgraphIterator + +SubgraphIterator::SubgraphIterator(const Subgraph::Identity* id) + : id_(id), id_it_(id_->begin()) { + if (!id_->empty()) { + link_map_it_ = (*id_it_)->links().begin(); + // In case if the node has no links. + while (link_map_it_ == (*id_it_)->links().end()) { + if (++id_it_ == id_->end()) { + return; + } + link_map_it_ = (*id_it_)->links().begin(); + } + link_idx_ = 0; + // The LinkTargetVector should never be empty but just in case safeguard + // against that too. + PropagateNext(); + } +} + +bool SubgraphIterator::Next() { + if (AtEnd()) { + return false; + } + ++link_idx_; + return PropagateNext(); +} + +bool SubgraphIterator::NextIfSamePort() { + if (AtEnd()) { + return false; + } + if (link_idx_ + 1 < link_map_it_->second.size()) { + ++link_idx_; + return true; + } else { + return false; + } +} + +void SubgraphIterator::SkipPort() { + if (AtEnd()) { + return; + } + link_idx_ = link_map_it_->second.size() - 1; +} + +void SubgraphIterator::SkipNode() { + if (AtEnd()) { + return; + } + for (auto next = link_map_it_; next != (*id_it_)->links().end(); ++next) { + link_map_it_ = next; + } + link_idx_ = link_map_it_->second.size() - 1; +} + +bool SubgraphIterator::PropagateNext() { + // Loops are used to skip over the empty entries. + while (link_idx_ >= link_map_it_->second.size()) { + ++link_map_it_; + while (link_map_it_ == (*id_it_)->links().end()) { + if (++id_it_ == id_->end()) { + return false; + } + link_map_it_ = (*id_it_)->links().begin(); + } + link_idx_ = 0; + } + return true; +} + +bool SubgraphIterator::operator==(const SubgraphIterator& other) const { + if (id_ != other.id_) { + return false; + } + if (id_it_ != other.id_it_) { + return false; + } + // When AtEnd(), the rest of the fields are not valid. + if (AtEnd()) { + return true; + } + if (link_map_it_ != other.link_map_it_) { + return false; + } + if (link_idx_ != other.link_idx_) { + return false; + } + return true; +} + +//=== SubgraphPtrSet + +Subgraph* SubgraphPtrSet::ExtendParent(const Subgraph::Identity& parent_id, + GenNode* node) { + if (parent_id.find(node) != parent_id.end()) { + // This was another link to the node that is already in the parent. + return nullptr; + } + + // Constructing an object just to check that an equivalent one is already + // present is kind of ugly but storing the references rather than the objects + // in the set avoids the need to make the object copyable. + auto sg = absl::make_unique<Subgraph>(parent_id, node); + if (find(sg) != end()) { + // This subgraph was already found by extending from a different path. + return nullptr; + } + + Subgraph* ptr = sg.get(); + insert(std::move(sg)); + return ptr; +} + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow |