diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-21 13:24:35 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-21 13:30:19 -0800 |
commit | ae6ffcadafcd83f3488ceb3f47a670f5c6ea45cd (patch) | |
tree | 07c50d45a08f7cb97a0a5db4861f433cea0b23c0 /tensorflow/core/grappler/costs/virtual_scheduler.cc | |
parent | 6419fd98883cd051213f0daeaea465728cf7a27c (diff) |
In VirtualScheduler, if there is a Recv without a Send, handle the Recv as an
initially ready node.
PiperOrigin-RevId: 186509851
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 14b4ed7507..b9a80fbff2 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -366,8 +366,16 @@ Status VirtualScheduler::Init() { std::vector<string> inputs; if (IsRecv(*curr_node)) { const auto& attr = curr_node->attr(); - const NodeDef* send = name_to_send[attr.at("tensor_name").s()]; - inputs = {send->name()}; + if (attr.count("tensor_name")) { + const auto& send_node_name = attr.at("tensor_name").s(); + auto it = name_to_send.find(send_node_name); + // If there is a _Send associated with the curr_node (_Recv), add it as + // input. + if (it != name_to_send.end()) { + const NodeDef* send = it->second; + inputs = {send->name()}; + } + } } else { for (const string& input : curr_node->input()) { inputs.push_back(input); @@ -426,9 +434,11 @@ Status VirtualScheduler::Init() { feed_nodes.find(curr_node->name()) != feed_nodes.end(); // Default case: node without inputs are ready at time 0. - const bool has_no_inputs = curr_node->input().empty(); + // Note that we check inputs vector which may be different to + // curr_node->input(); e.g., we add Send as input to Recv. + const bool has_no_inputs = inputs.empty(); - if (!IsRecv(*curr_node) && (given_as_feed || has_no_inputs)) { + if (given_as_feed || has_no_inputs) { curr_node_state.time_ready = Costs::Duration(); ready_nodes_->AddNode(curr_node); VLOG(3) << "Added ready node: " << curr_node->name(); |