diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-08-17 14:36:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-17 14:43:22 -0700 |
commit | 76d27f74541befd883b88f185c88f4d920c2beee (patch) | |
tree | 4f530fa3cc2702cf3a96314c9ac7d2ad5fe23a89 /tensorflow/core/grappler/costs/virtual_scheduler.cc | |
parent | 02f87fee25552e220c8295b58ab8e58b6fbe598b (diff) |
Extend Grappler API to accept feed and fetch node lists.
PiperOrigin-RevId: 165630567
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 6b0b869df5..88d5156b9e 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -119,6 +120,12 @@ Status VirtualScheduler::Init() { const auto& graph = grappler_item_->graph; const auto& fetch_nodes = grappler_item_->fetch; + std::set<string> feed_nodes; + for (const auto& f : grappler_item_->feed) { + auto iter_and_inserted_flag = feed_nodes.insert(f.first); + QCHECK(iter_and_inserted_flag.second) + << "Duplicate feed node found: " << f.first; + } // Get the nodes that would run to output fetch_nodes. std::vector<const NodeDef*> nodes = @@ -193,12 +200,21 @@ Status VirtualScheduler::Init() { } } - if (curr_node->input().empty()) { - // Node without input: ready at time 0. + // Special case: given feed nodes are ready at time 0. + const bool given_as_feed = + feed_nodes.find(curr_node->name()) != feed_nodes.end(); + + // Default case: node without inputs are ready at time 0. + const bool has_no_inputs = curr_node->input().empty(); + + if (given_as_feed || has_no_inputs) { curr_node_state.time_ready = Costs::Duration(); ready_nodes_->AddNode(curr_node); + VLOG(1) << "Added ready node: " << curr_node->name(); } + feed_nodes.erase(curr_node->name()); + if (IsPersistentNode(curr_node)) { auto& device_state = device_[curr_node_device]; for (int port_num = 0; @@ -213,6 +229,9 @@ Status VirtualScheduler::Init() { return Status(error::UNAVAILABLE, "No ready nodes in the graph."); } + CHECK(feed_nodes.empty()) << "Some feed nodes were not found in the graph: " + << str_util::Join(feed_nodes, ","); + initialized_ = true; return Status::OK(); } |