aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/subgraph.cc
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2016-01-25 11:39:16 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-25 13:19:24 -0800
commit71184628900752e0602c93332918061c7c337f7a (patch)
treecc1fee1ac302b25dbda324cf6bd47d53f1a52155 /tensorflow/core/graph/subgraph.cc
parent668b2a7667921db344b9725f7909ead0eb1f7c6b (diff)
Added constant folding optimization pass.
- Graph* -> Graph* pass - Creates a local executor and executes a copy of the constant "slice" of the original graph, and replaces nodes in original graph with constant nodes. Change: 112971745
Diffstat (limited to 'tensorflow/core/graph/subgraph.cc')
-rw-r--r--tensorflow/core/graph/subgraph.cc109
1 files changed, 49 insertions, 60 deletions
diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc
index 092bc6757f..570d73f78c 100644
--- a/tensorflow/core/graph/subgraph.cc
+++ b/tensorflow/core/graph/subgraph.cc
@@ -44,8 +44,6 @@ namespace tensorflow {
namespace {
-typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> NameIndex;
-
// Rewrite graph by replacing the output tensors specified in
// "fed_outputs" with special feed nodes for each specified output
// tensor, and removing any nodes that are now disconnected from the
@@ -57,7 +55,7 @@ typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> NameIndex;
// state).
static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
const gtl::ArraySlice<string>& fed_outputs,
- NameIndex* name_index) {
+ subgraph::NameIndex* name_index) {
for (const string& t : fed_outputs) {
TensorId id(ParseTensorName(t));
@@ -121,18 +119,54 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
return Status::OK();
}
-// Augment "*g" by adding special "fetch" nodes that connect to the
-// tensor outputs specified in "fetch_outputs" to retrieve the output
-// of the tensors. The new nodes added are set up to execute on
-// "client_device_name", and are returned in "*fetch_nodes".
-//
-// Return true on success. On error, return false and sets *error to
-// an appropriate error message (and *g is left in an indeterminate
-// state).
-static Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
- const gtl::ArraySlice<string>& fetch_outputs,
- NameIndex* name_index,
- std::vector<Node*>* fetch_nodes) {
+static bool AddNodeToTargets(const string& node_or_tensor_name,
+ const subgraph::NameIndex& name_index,
+ std::unordered_set<const Node*>* targets) {
+ TensorId id = ParseTensorName(node_or_tensor_name);
+ auto iter = name_index.find(id.first);
+ if (iter == name_index.end()) {
+ return false;
+ }
+ const Node* n = iter->second;
+ if (n->name() != node_or_tensor_name) {
+ return false;
+ }
+
+ targets->insert(n);
+ return true;
+}
+
+static Status PruneForTargets(Graph* g, const subgraph::NameIndex& name_index,
+ const std::vector<Node*>& fetch_nodes,
+ const gtl::ArraySlice<string>& target_nodes) {
+ string not_found;
+ std::unordered_set<const Node*> targets;
+ for (Node* n : fetch_nodes) {
+ if (!AddNodeToTargets(n->name(), name_index, &targets)) {
+ strings::StrAppend(&not_found, n->name(), " ");
+ }
+ }
+ for (const string& s : target_nodes) {
+ if (!AddNodeToTargets(s, name_index, &targets)) {
+ strings::StrAppend(&not_found, s, " ");
+ }
+ }
+ if (!not_found.empty()) {
+ return errors::NotFound("PruneForTargets: Some target nodes not found: ",
+ not_found);
+ }
+ PruneForReverseReachability(g, targets);
+
+ return Status::OK();
+}
+
+} // namespace
+
+namespace subgraph {
+
+Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
+ const gtl::ArraySlice<string>& fetch_outputs,
+ NameIndex* name_index, std::vector<Node*>* fetch_nodes) {
fetch_nodes->clear();
for (const string& t : fetch_outputs) {
// Parse t into node_name and output_index.
@@ -188,51 +222,6 @@ static Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
return Status::OK();
}
-static bool AddNodeToTargets(const string& node_or_tensor_name,
- const NameIndex& name_index,
- std::unordered_set<const Node*>* targets) {
- TensorId id = ParseTensorName(node_or_tensor_name);
- auto iter = name_index.find(id.first);
- if (iter == name_index.end()) {
- return false;
- }
- const Node* n = iter->second;
- if (n->name() != node_or_tensor_name) {
- return false;
- }
-
- targets->insert(n);
- return true;
-}
-
-static Status PruneForTargets(Graph* g, const NameIndex& name_index,
- const std::vector<Node*>& fetch_nodes,
- const gtl::ArraySlice<string>& target_nodes) {
- string not_found;
- std::unordered_set<const Node*> targets;
- for (Node* n : fetch_nodes) {
- if (!AddNodeToTargets(n->name(), name_index, &targets)) {
- strings::StrAppend(&not_found, n->name(), " ");
- }
- }
- for (const string& s : target_nodes) {
- if (!AddNodeToTargets(s, name_index, &targets)) {
- strings::StrAppend(&not_found, s, " ");
- }
- }
- if (!not_found.empty()) {
- return errors::NotFound("PruneForTargets: Some target nodes not found: ",
- not_found);
- }
- PruneForReverseReachability(g, targets);
-
- return Status::OK();
-}
-
-} // namespace
-
-namespace subgraph {
-
Status RewriteGraphForExecution(
Graph* g, const gtl::ArraySlice<string>& fed_outputs,
const gtl::ArraySlice<string>& fetch_outputs,