aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/constant_folding.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-01-27 16:54:47 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-27 23:07:04 -0800
commitbcd9722be4250b8584e4fe5bc4f60b8793cf87d0 (patch)
treea096d66cd49c45f582dfcd5e4b9fe9ca1ea63b99 /tensorflow/core/common_runtime/constant_folding.cc
parent5515a4977ba9c461cfa7d07ce5e5cda8348baf46 (diff)
Fixed constant folding to handle nodes with multiple outputs.
Change: 113215834
Diffstat (limited to 'tensorflow/core/common_runtime/constant_folding.cc')
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc117
1 files changed, 55 insertions, 62 deletions
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 434b6eee70..d5c2eb756f 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include <algorithm>
-#include <atomic>
#include <set>
#include <unordered_map>
#include <vector>
@@ -23,13 +22,11 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
-#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
@@ -60,18 +57,15 @@ void FindConstantFoldableNodes(const Graph* graph, ConstantFoldingOptions opts,
ReverseDFS(*graph, nullptr,
[&nodes, &node_set, &internal_node_inserted, opts](Node* n) {
if (n->IsConstant()) {
- // Constants with no control inputs (except from _SOURCE node)
- // are definitely constant foldable.
- if (n->in_edges().size() == 0 ||
- (n->in_edges().size() == 1 &&
- (*n->in_edges().begin())->src()->IsSource())) {
- node_set.insert(n);
- nodes.push_back(n);
- }
+ // Constants are definitely constant foldable
+ node_set.insert(n);
+ nodes.push_back(n);
} else if (IsConstantFoldable(n, opts.consider)) {
// Check whether the set of this node's in_nodes is completely
- // included in the set of constant foldable nodes. If true,
- // then this node is also constant foldable.
+ // included in
+ // the set of constant foldable nodes. If true, then this nodes
+ // is also
+ // constant foldable.
bool all_parents_constant = n->num_inputs() > 0;
for (const Node* parent : n->in_nodes()) {
if (node_set.count(parent) == 0) {
@@ -92,16 +86,14 @@ void FindConstantFoldableNodes(const Graph* graph, ConstantFoldingOptions opts,
}
}
-typedef std::pair<Node*, int> NodeAndOutput;
-
// Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g'
// will contain copies of the nodes in 'nodes'. In addition, if there is an edge
// going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in
-// 'nodes', then 'tensors_to_fetch' will contain the mapping from the
-// corresponding copy of 'n' and the edge number in 'g' to 'n'.
+// 'nodes', then 'nodes_to_fetch' will contain the mapping from the
+// corresponding copy of 'n' in 'g' to 'n'.
Graph* GetConstantGraph(const Graph* orig_graph,
const std::vector<Node*>& nodes,
- std::map<NodeAndOutput, Node*>* tensors_to_fetch) {
+ std::unordered_map<Node*, Node*>* nodes_to_fetch) {
Graph* constant_graph = new Graph(orig_graph->op_registry());
std::unordered_map<Node*, Node*> node_map;
std::set<Node*> already_added;
@@ -115,51 +107,45 @@ Graph* GetConstantGraph(const Graph* orig_graph,
already_added.insert(added);
for (const Edge* in_edge : n->in_edges()) {
Node* in = in_edge->src();
- CHECK_GT(node_map.count(in), 0) << n->DebugString() << " <-"
- << in->DebugString();
- CHECK_GT(already_added.count(node_map[in]), 0) << in->DebugString();
+ CHECK_GT(node_map.count(in), 0);
+ CHECK_GT(already_added.count(node_map[in]), 0);
constant_graph->AddEdge(node_map[in], in_edge->src_output(), added,
in_edge->dst_input());
}
}
for (auto const& added_nodes : node_map) {
+ bool should_fetch = false;
for (const Edge* out_edge : added_nodes.first->out_edges()) {
if (node_map.count(out_edge->dst()) == 0) {
- tensors_to_fetch->insert(
- {{added_nodes.second, out_edge->src_output()}, added_nodes.first});
+ should_fetch = true;
+ break;
+ }
}
+ if (should_fetch) {
+ nodes_to_fetch->insert({added_nodes.second, added_nodes.first});
}
}
return constant_graph;
}
-int64 UniqueConstantId() {
- static std::atomic_int_fast64_t id;
- return id.fetch_add(1);
-}
-
-void ReplaceTensorWithConstant(Graph* graph, NodeAndOutput tensor,
- const Tensor& constant) {
- Node* n = tensor.first;
- std::vector<const Edge*> edges_to_remove;
+void ReplaceNodeWithConstant(Graph* graph, Node* n, const Tensor& constant) {
+ std::vector<std::tuple<int, Node*, int>> old_edges;
for (const Edge* out_edge : n->out_edges()) {
- if (out_edge->src_output() == tensor.second) {
- edges_to_remove.push_back(out_edge);
- }
+ old_edges.push_back(std::make_tuple(out_edge->src_output(), out_edge->dst(),
+ out_edge->dst_input()));
}
string node_name = n->name();
+ graph->RemoveNode(n);
Node* constant_node;
- TF_CHECK_OK(NodeBuilder(strings::StrCat(graph->NewName(node_name), "__cf__",
- UniqueConstantId()),
- "Const")
+ TF_CHECK_OK(NodeBuilder(graph->NewName(node_name), "Const")
.Attr("dtype", constant.dtype())
.Attr("value", constant)
.Finalize(graph, &constant_node));
- for (auto edge : edges_to_remove) {
- graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
- graph->RemoveEdge(edge);
+ for (auto edge : old_edges) {
+ graph->AddEdge(constant_node, std::get<0>(edge), std::get<1>(edge),
+ std::get<2>(edge));
}
}
@@ -232,7 +218,6 @@ class SimpleRendezvous : public Rendezvous {
} // namespace
bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) {
- DumpGraph("Before", graph);
Device* device = GetCPUDevice();
thread::ThreadPool* thread_pool = GetThreadPool();
if (!device || !thread_pool) {
@@ -248,12 +233,11 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) {
return false;
}
- std::map<NodeAndOutput, Node*> tensors_to_fetch;
+ std::unordered_map<Node*, Node*> nodes_to_fetch;
Graph* constant_graph =
- GetConstantGraph(graph, constant_foldable_nodes, &tensors_to_fetch);
- DumpGraph("Constant graph", constant_graph);
+ GetConstantGraph(graph, constant_foldable_nodes, &nodes_to_fetch);
- if (tensors_to_fetch.empty()) {
+ if (nodes_to_fetch.empty()) {
VLOG(1) << "No constant nodes found that feed into the original graph.";
delete constant_graph;
return false;
@@ -268,23 +252,21 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) {
}
std::vector<Node*> fetch_nodes;
- std::vector<string> tensors_to_fetch_names;
- std::vector<NodeAndOutput> tensors_to_replace;
- for (auto n : tensors_to_fetch) {
- tensors_to_fetch_names.push_back(
- strings::StrCat(n.first.first->name(), ":", n.first.second));
- tensors_to_replace.push_back({n.second, n.first.second});
+ std::vector<string> nodes_to_fetch_names;
+ std::vector<Node*> nodes_to_replace;
+ for (auto n : nodes_to_fetch) {
+ nodes_to_fetch_names.push_back(n.first->name());
+ nodes_to_replace.push_back(n.second);
}
// For nodes that need to be fetched back from the constant_graph, attach Send
// nodes.
if (!subgraph::FetchOutputs(constant_graph, device->attributes(),
- tensors_to_fetch_names, &name_index, &fetch_nodes)
+ nodes_to_fetch_names, &name_index, &fetch_nodes)
.ok()) {
- VLOG(1) << "Could not fetch constants";
return false;
}
- CHECK_EQ(fetch_nodes.size(), tensors_to_fetch.size());
+ CHECK_EQ(fetch_nodes.size(), nodes_to_fetch.size());
// Create the local executor and the Rendezvous for fetching back the
// constants.
@@ -329,7 +311,17 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) {
}
executor_done.WaitForNotification();
- // Fetch the constant tensors and replace the corresponding tensors in the
+ // Keep track of the nodes that will be orphaned once the internal nodes have
+ // been constant folded and replaced, so we can delete them later.
+ std::set<Node*> replaced_nodes_set(nodes_to_replace.begin(),
+ nodes_to_replace.end());
+ std::vector<Node*> to_delete;
+ for (Node* n : constant_foldable_nodes) {
+ if (replaced_nodes_set.count(n) == 0) {
+ to_delete.push_back(n);
+ }
+ }
+ // Fetch the constant nodes and replace the corresponding nodes in the
// original graph with those constants.
for (size_t c = 0; c < fetch_nodes.size(); ++c) {
Tensor output;
@@ -344,14 +336,15 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, Graph* graph) {
if (!s.ok() || is_dead) {
return c > 0;
}
- VLOG(1) << "Replacing " << tensors_to_replace[c].first->DebugString()
- << " :: " << tensors_to_replace[c].second << " with constant "
- << output.DebugString();
- ReplaceTensorWithConstant(graph, tensors_to_replace[c], output);
+ VLOG(1) << "Replacing " << nodes_to_replace[c]->DebugString()
+ << " with constant " << output.DebugString();
+ ReplaceNodeWithConstant(graph, nodes_to_replace[c], output);
}
- DumpGraph("After", graph);
-
+ // Delete the orphaned nodes in the original graph.
+ for (Node* n : to_delete) {
+ graph->RemoveNode(n);
+ }
return true;
}