diff options
author | 2016-01-25 11:39:16 -0800 | |
---|---|---|
committer | 2016-01-25 13:19:24 -0800 | |
commit | 71184628900752e0602c93332918061c7c337f7a (patch) | |
tree | cc1fee1ac302b25dbda324cf6bd47d53f1a52155 /tensorflow/core/graph/subgraph.cc | |
parent | 668b2a7667921db344b9725f7909ead0eb1f7c6b (diff) |
Added constant folding optimization pass.
- Graph* -> Graph* pass
- Creates a local executor and executes a copy of the constant "slice" of the
original graph, and replaces nodes in original graph with constant nodes.
Change: 112971745
Diffstat (limited to 'tensorflow/core/graph/subgraph.cc')
-rw-r--r-- | tensorflow/core/graph/subgraph.cc | 109 |
1 files changed, 49 insertions, 60 deletions
diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index 092bc6757f..570d73f78c 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -44,8 +44,6 @@ namespace tensorflow { namespace { -typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> NameIndex; - // Rewrite graph by replacing the output tensors specified in // "fed_outputs" with special feed nodes for each specified output // tensor, and removing any nodes that are now disconnected from the @@ -57,7 +55,7 @@ typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> NameIndex; // state). static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, const gtl::ArraySlice<string>& fed_outputs, - NameIndex* name_index) { + subgraph::NameIndex* name_index) { for (const string& t : fed_outputs) { TensorId id(ParseTensorName(t)); @@ -121,18 +119,54 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, return Status::OK(); } -// Augment "*g" by adding special "fetch" nodes that connect to the -// tensor outputs specified in "fetch_outputs" to retrieve the output -// of the tensors. The new nodes added are set up to execute on -// "client_device_name", and are returned in "*fetch_nodes". -// -// Return true on success. On error, return false and sets *error to -// an appropriate error message (and *g is left in an indeterminate -// state). -static Status FetchOutputs(Graph* g, const DeviceAttributes& device_info, - const gtl::ArraySlice<string>& fetch_outputs, - NameIndex* name_index, - std::vector<Node*>* fetch_nodes) { +static bool AddNodeToTargets(const string& node_or_tensor_name, + const subgraph::NameIndex& name_index, + std::unordered_set<const Node*>* targets) { + TensorId id = ParseTensorName(node_or_tensor_name); + auto iter = name_index.find(id.first); + if (iter == name_index.end()) { + return false; + } + const Node* n = iter->second; + if (n->name() != node_or_tensor_name) { + return false; + } + + targets->insert(n); + return true; +} + +static Status PruneForTargets(Graph* g, const subgraph::NameIndex& name_index, + const std::vector<Node*>& fetch_nodes, + const gtl::ArraySlice<string>& target_nodes) { + string not_found; + std::unordered_set<const Node*> targets; + for (Node* n : fetch_nodes) { + if (!AddNodeToTargets(n->name(), name_index, &targets)) { + strings::StrAppend(¬_found, n->name(), " "); + } + } + for (const string& s : target_nodes) { + if (!AddNodeToTargets(s, name_index, &targets)) { + strings::StrAppend(¬_found, s, " "); + } + } + if (!not_found.empty()) { + return errors::NotFound("PruneForTargets: Some target nodes not found: ", + not_found); + } + PruneForReverseReachability(g, targets); + + return Status::OK(); +} + +} // namespace + +namespace subgraph { + +Status FetchOutputs(Graph* g, const DeviceAttributes& device_info, + const gtl::ArraySlice<string>& fetch_outputs, + NameIndex* name_index, std::vector<Node*>* fetch_nodes) { fetch_nodes->clear(); for (const string& t : fetch_outputs) { // Parse t into node_name and output_index. @@ -188,51 +222,6 @@ static Status FetchOutputs(Graph* g, const DeviceAttributes& device_info, return Status::OK(); } -static bool AddNodeToTargets(const string& node_or_tensor_name, - const NameIndex& name_index, - std::unordered_set<const Node*>* targets) { - TensorId id = ParseTensorName(node_or_tensor_name); - auto iter = name_index.find(id.first); - if (iter == name_index.end()) { - return false; - } - const Node* n = iter->second; - if (n->name() != node_or_tensor_name) { - return false; - } - - targets->insert(n); - return true; -} - -static Status PruneForTargets(Graph* g, const NameIndex& name_index, - const std::vector<Node*>& fetch_nodes, - const gtl::ArraySlice<string>& target_nodes) { - string not_found; - std::unordered_set<const Node*> targets; - for (Node* n : fetch_nodes) { - if (!AddNodeToTargets(n->name(), name_index, &targets)) { - strings::StrAppend(¬_found, n->name(), " "); - } - } - for (const string& s : target_nodes) { - if (!AddNodeToTargets(s, name_index, &targets)) { - strings::StrAppend(¬_found, s, " "); - } - } - if (!not_found.empty()) { - return errors::NotFound("PruneForTargets: Some target nodes not found: ", - not_found); - } - PruneForReverseReachability(g, targets); - - return Status::OK(); -} - -} // namespace - -namespace subgraph { - Status RewriteGraphForExecution( Graph* g, const gtl::ArraySlice<string>& fed_outputs, const gtl::ArraySlice<string>& fetch_outputs, |