aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/virtual_scheduler.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-24 12:23:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-24 12:27:37 -0700
commit1b578268f5407b4f40c4226b7d00b08e948a4b09 (patch)
tree65f644e80f01d4e7dce65111a0e164f2b8f5a30b /tensorflow/core/grappler/costs/virtual_scheduler.cc
parent62bced8280075c4964033cb8b8e43b9855d655e1 (diff)
Fix _Recv op caching for multi-output port ops in VirtualScheduler.
PiperOrigin-RevId: 162970793
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc52
1 files changed, 45 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 1dad81ed90..6b0b869df5 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -56,6 +56,32 @@ Costs CombineCosts(const Costs& left, const Costs& right) {
<< " max_per_op_streaming=" << result.max_per_op_streaming;
return result;
}
+
+// Key to the cached _Recv ops map, and its hash and predicate structures.
+struct RecvNodeDescriptor {
+ const NodeDef* node;
+ const int port_num;
+ const string& device;
+
+ RecvNodeDescriptor(const NodeDef* node_, const int port_num_,
+ const string& device_)
+ : node(node_), port_num(port_num_), device(device_) {}
+};
+
+struct RecvNodeDescritorHash {
+ std::size_t operator()(const RecvNodeDescriptor& recv_node) const {
+ return std::hash<const NodeDef*>()(recv_node.node) ^
+ std::hash<int>()(recv_node.port_num) ^
+ std::hash<string>()(recv_node.device);
+ }
+};
+
+struct RecvNodeDescriptorEqual {
+ bool operator()(const RecvNodeDescriptor& a,
+ const RecvNodeDescriptor& b) const {
+ return a.node == b.node && a.port_num == b.port_num && a.device == b.device;
+ }
+};
} // namespace
VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
@@ -109,6 +135,11 @@ Status VirtualScheduler::Init() {
name_to_node[node->name()] = node;
}
+ // To reuse _Recv ops.
+ std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescritorHash,
+ RecvNodeDescriptorEqual>
+ cached_recv_nodes;
+
// Build node_map; for each node, create its NodeState and connect its inputs
// and outputs.
for (const auto* curr_node : nodes) {
@@ -131,12 +162,13 @@ Status VirtualScheduler::Init() {
auto& input_node_state = GetNodeStateOrCreateIt(input_node);
input_node_state.outputs[input_node_port_num].push_back(curr_node);
} else {
- if (cached_recv_nodes_.count(input_node) > 0 &&
- cached_recv_nodes_[input_node].count(curr_node_device) > 0) {
+ RecvNodeDescriptor recv_node(input_node, input_node_port_num,
+ curr_node_device);
+ auto it = cached_recv_nodes.find(recv_node);
+ if (it != cached_recv_nodes.end()) {
// Different device, but found an already-cached copy (a _Recv op);
// connect the _Recv to curr_node.
- const auto* recv_op =
- cached_recv_nodes_[input_node][curr_node_device];
+ const NodeDef* recv_op = it->second;
// recv_op's output port is hard-coded to zero.
curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
auto& input_node_state = node_map_.at(recv_op);
@@ -156,7 +188,7 @@ Status VirtualScheduler::Init() {
input_node_state.outputs[input_node_port_num].push_back(send);
// Cache the _Recv op for future use.
- cached_recv_nodes_[input_node][curr_node_device] = recv;
+ cached_recv_nodes[recv_node] = recv;
}
}
}
@@ -269,10 +301,16 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
// input names, attrs, etc.
auto input_node_port_num = NodePosition(input_name);
+ string src_name;
+ if (input_node_port_num >= 0) {
+ src_name = strings::StrCat(from->name(), ":", input_node_port_num);
+ } else {
+ src_name = strings::StrCat(from->name(), ":minus1");
+ }
// _Send op.
auto* send = new NodeDef();
- send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " +
+ send->set_name("Send " + src_name + " from " + DeviceName(from) + " to " +
DeviceName(to));
send->set_op("_Send");
send->add_input(from->name());
@@ -284,7 +322,7 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
// _Recv op.
auto* recv = new NodeDef();
- recv->set_name("Recv " + from->name() + " on " + DeviceName(to));
+ recv->set_name("Recv " + src_name + " on " + DeviceName(to));
recv->set_op("_Recv");
recv->add_input(send->name());
recv->set_device(DeviceName(to));