aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/virtual_scheduler.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-17 14:36:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-17 14:43:22 -0700
commit76d27f74541befd883b88f185c88f4d920c2beee (patch)
tree4f530fa3cc2702cf3a96314c9ac7d2ad5fe23a89 /tensorflow/core/grappler/costs/virtual_scheduler.cc
parent02f87fee25552e220c8295b58ab8e58b6fbe598b (diff)
Extend Grappler API to accept feed and fetch node lists.
PiperOrigin-RevId: 165630567
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc23
1 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 6b0b869df5..88d5156b9e 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -119,6 +120,12 @@ Status VirtualScheduler::Init() {
const auto& graph = grappler_item_->graph;
const auto& fetch_nodes = grappler_item_->fetch;
+ std::set<string> feed_nodes;
+ for (const auto& f : grappler_item_->feed) {
+ auto iter_and_inserted_flag = feed_nodes.insert(f.first);
+ QCHECK(iter_and_inserted_flag.second)
+ << "Duplicate feed node found: " << f.first;
+ }
// Get the nodes that would run to output fetch_nodes.
std::vector<const NodeDef*> nodes =
@@ -193,12 +200,21 @@ Status VirtualScheduler::Init() {
}
}
- if (curr_node->input().empty()) {
- // Node without input: ready at time 0.
+ // Special case: given feed nodes are ready at time 0.
+ const bool given_as_feed =
+ feed_nodes.find(curr_node->name()) != feed_nodes.end();
+
+ // Default case: node without inputs are ready at time 0.
+ const bool has_no_inputs = curr_node->input().empty();
+
+ if (given_as_feed || has_no_inputs) {
curr_node_state.time_ready = Costs::Duration();
ready_nodes_->AddNode(curr_node);
+ VLOG(1) << "Added ready node: " << curr_node->name();
}
+ feed_nodes.erase(curr_node->name());
+
if (IsPersistentNode(curr_node)) {
auto& device_state = device_[curr_node_device];
for (int port_num = 0;
@@ -213,6 +229,9 @@ Status VirtualScheduler::Init() {
return Status(error::UNAVAILABLE, "No ready nodes in the graph.");
}
+ CHECK(feed_nodes.empty()) << "Some feed nodes were not found in the graph: "
+ << str_util::Join(feed_nodes, ",");
+
initialized_ = true;
return Status::OK();
}