diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-01-27 16:54:47 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2016-01-27 23:07:04 -0800 |
commit | bcd9722be4250b8584e4fe5bc4f60b8793cf87d0 (patch) | |
tree | a096d66cd49c45f582dfcd5e4b9fe9ca1ea63b99 /tensorflow/core/common_runtime/constant_folding.cc | |
parent | 5515a4977ba9c461cfa7d07ce5e5cda8348baf46 (diff) |
Fixed constant folding to handle nodes with multiple outputs.
Change: 113215834
Diffstat (limited to 'tensorflow/core/common_runtime/constant_folding.cc')
-rw-r--r-- | tensorflow/core/common_runtime/constant_folding.cc | 117 |
1 files changed, 55 insertions, 62 deletions
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 434b6eee70..d5c2eb756f 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include <algorithm> -#include <atomic> #include <set> #include <unordered_map> #include <vector> @@ -23,13 +22,11 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" -#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { @@ -60,18 +57,15 @@ void FindConstantFoldableNodes(const Graph* graph, ConstantFoldingOptions opts, ReverseDFS(*graph, nullptr, [&nodes, &node_set, &internal_node_inserted, opts](Node* n) { if (n->IsConstant()) { - // Constants with no control inputs (except from _SOURCE node) - // are definitely constant foldable. - if (n->in_edges().size() == 0 || - (n->in_edges().size() == 1 && - (*n->in_edges().begin())->src()->IsSource())) { - node_set.insert(n); - nodes.push_back(n); - } + // Constants are definitely constant foldable + node_set.insert(n); + nodes.push_back(n); } else if (IsConstantFoldable(n, opts.consider)) { // Check whether the set of this node's in_nodes is completely - // included in the set of constant foldable nodes. If true, - // then this node is also constant foldable. + // included in + // the set of constant foldable nodes. If true, then this nodes + // is also + // constant foldable. bool all_parents_constant = n->num_inputs() > 0; for (const Node* parent : n->in_nodes()) { if (node_set.count(parent) == 0) { @@ -92,16 +86,14 @@ void FindConstantFoldableNodes(const Graph* graph, ConstantFoldingOptions opts, } } -typedef std::pair<Node*, int> NodeAndOutput; - // Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g' // will contain copies of the nodes in 'nodes'. In addition, if there is an edge // going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in -// 'nodes', then 'tensors_to_fetch' will contain the mapping from the -// corresponding copy of 'n' and the edge number in 'g' to 'n'. +// 'nodes', then 'nodes_to_fetch' will contain the mapping from the +// corresponding copy of 'n' in 'g' to 'n'. Graph* GetConstantGraph(const Graph* orig_graph, const std::vector<Node*>& nodes, - std::map<NodeAndOutput, Node*>* tensors_to_fetch) { + std::unordered_map<Node*, Node*>* nodes_to_fetch) { Graph* constant_graph = new Graph(orig_graph->op_registry()); std::unordered_map<Node*, Node*> node_map; std::set<Node*> already_added; @@ -115,51 +107,45 @@ Graph* GetConstantGraph(const Graph* orig_graph, already_added.insert(added); for (const Edge* in_edge : n->in_edges()) { Node* in = in_edge->src(); - CHECK_GT(node_map.count(in), 0) << n->DebugString() << " <-" - << in->DebugString(); - CHECK_GT(already_added.count(node_map[in]), 0) << in->DebugString(); + CHECK_GT(node_map.count(in), 0); + CHECK_GT(already_added.count(node_map[in]), 0); constant_graph->AddEdge(node_map[in], in_edge->src_output(), added, in_edge->dst_input()); } } for (auto const& added_nodes : node_map) { + bool should_fetch = false; for (const Edge* out_edge : added_nodes.first->out_edges()) { if (node_map.count(out_edge->dst()) == 0) { - tensors_to_fetch->insert( - {{added_nodes.second, out_edge->src_output()}, added_nodes.first}); + should_fetch = true; + break; + } } + if (should_fetch) { + nodes_to_fetch->insert({added_nodes.second, added_nodes.first}); } } return constant_graph; } -int64 UniqueConstantId() { - static std::atomic_int_fast64_t id; - return id.fetch_add(1); -} - -void ReplaceTensorWithConstant(Graph* graph, NodeAndOutput tensor, - const Tensor& constant) { - Node* n = tensor.first; - std::vector<const Edge*> edges_to_remove; +void ReplaceNodeWithConstant(Graph* graph, Node* n, const Tensor& constant) { + std::vector<std::tuple<int, Node*, int>> old_edges; for (const Edge* out_edge : n->out_edges()) { - if (out_edge->src_output() == tensor.second) { - edges_to_remove.push_back(out_edge); - } + old_edges.push_back(std::make_tuple(out_edge->src_output(), out_edge->dst(), + out_edge->dst_input())); } string node_name = n->name(); + graph->RemoveNode(n); Node* constant_node; - TF_CHECK_OK(NodeBuilder(strings::StrCat(graph->NewName(node_name), "__cf__", - UniqueConstantId()), - "Const") + TF_CHECK_OK(NodeBuilder(graph->NewName(node_name), "Const") .Attr("dtype", constant.dtype()) .Attr("value", constant) .Finalize(graph, &constant_node)); - for (auto edge : edges_to_remove) { - graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input()); - graph->RemoveEdge(edge); + for (auto edge : old_edges) { + graph->AddEdge(constant_node, std::get<0>(edge), std::get<1>(edge), + std::get<2>(edge)); } } @@ -232,7 +218,6 @@ class SimpleRendezvous : public Rendezvous { } // namespace bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) { - DumpGraph("Before", graph); Device* device = GetCPUDevice(); thread::ThreadPool* thread_pool = GetThreadPool(); if (!device || !thread_pool) { @@ -248,12 +233,11 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) { return false; } - std::map<NodeAndOutput, Node*> tensors_to_fetch; + std::unordered_map<Node*, Node*> nodes_to_fetch; Graph* constant_graph = - GetConstantGraph(graph, constant_foldable_nodes, &tensors_to_fetch); - DumpGraph("Constant graph", constant_graph); + GetConstantGraph(graph, constant_foldable_nodes, &nodes_to_fetch); - if (tensors_to_fetch.empty()) { + if (nodes_to_fetch.empty()) { VLOG(1) << "No constant nodes found that feed into the original graph."; delete constant_graph; return false; @@ -268,23 +252,21 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) { } std::vector<Node*> fetch_nodes; - std::vector<string> tensors_to_fetch_names; - std::vector<NodeAndOutput> tensors_to_replace; - for (auto n : tensors_to_fetch) { - tensors_to_fetch_names.push_back( - strings::StrCat(n.first.first->name(), ":", n.first.second)); - tensors_to_replace.push_back({n.second, n.first.second}); + std::vector<string> nodes_to_fetch_names; + std::vector<Node*> nodes_to_replace; + for (auto n : nodes_to_fetch) { + nodes_to_fetch_names.push_back(n.first->name()); + nodes_to_replace.push_back(n.second); } // For nodes that need to be fetched back from the constant_graph, attach Send // nodes. if (!subgraph::FetchOutputs(constant_graph, device->attributes(), - tensors_to_fetch_names, &name_index, &fetch_nodes) + nodes_to_fetch_names, &name_index, &fetch_nodes) .ok()) { - VLOG(1) << "Could not fetch constants"; return false; } - CHECK_EQ(fetch_nodes.size(), tensors_to_fetch.size()); + CHECK_EQ(fetch_nodes.size(), nodes_to_fetch.size()); // Create the local executor and the Rendezvous for fetching back the // constants. @@ -329,7 +311,17 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) { } executor_done.WaitForNotification(); - // Fetch the constant tensors and replace the corresponding tensors in the + // Keep track of the nodes that will be orphaned once the internal nodes have + // been constant folded and replaced, so we can delete them later. + std::set<Node*> replaced_nodes_set(nodes_to_replace.begin(), + nodes_to_replace.end()); + std::vector<Node*> to_delete; + for (Node* n : constant_foldable_nodes) { + if (replaced_nodes_set.count(n) == 0) { + to_delete.push_back(n); + } + } + // Fetch the constant nodes and replace the corresponding nodes in the // original graph with those constants. for (size_t c = 0; c < fetch_nodes.size(); ++c) { Tensor output; @@ -344,14 +336,15 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) { if (!s.ok() || is_dead) { return c > 0; } - VLOG(1) << "Replacing " << tensors_to_replace[c].first->DebugString() - << " :: " << tensors_to_replace[c].second << " with constant " - << output.DebugString(); - ReplaceTensorWithConstant(graph, tensors_to_replace[c], output); + VLOG(1) << "Replacing " << nodes_to_replace[c]->DebugString() + << " with constant " << output.DebugString(); + ReplaceNodeWithConstant(graph, nodes_to_replace[c], output); } - DumpGraph("After", graph); - + // Delete the orphaned nodes in the original graph. + for (Node* n : to_delete) { + graph->RemoveNode(n); + } return true; } |