aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/virtual_scheduler.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-21 14:08:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-21 14:20:02 -0700
commit56c4856f61dd9b42181803722b40ffe80c1297a8 (patch)
tree3c72e46ff5062da225908ab533afcaa9c08e8886 /tensorflow/core/grappler/costs/virtual_scheduler.cc
parentcbfd50ff0f01e1825922230a8bc6e5766da98dd7 (diff)
Fix _Recv op caching for multi-output port ops in VirtualScheduler.
PiperOrigin-RevId: 162782896
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc48
1 files changed, 40 insertions, 8 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 15ebef188f..8b51bb9096 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -56,6 +56,28 @@ Costs CombineCosts(const Costs& left, const Costs& right) {
<< " max_per_op_streaming=" << result.max_per_op_streaming;
return result;
}
+
+// Key to the cached _Recv ops map, and its hash and predicate structures.
+struct RecvNodeDescriptor {
+ const NodeDef* node;
+ const int port_num;
+ const string& device;
+};
+
+struct RecvNodeDescritorHash {
+ std::size_t operator()(const RecvNodeDescriptor& recv_node) const {
+ return std::hash<const NodeDef*>()(recv_node.node) ^
+ std::hash<int>()(recv_node.port_num) ^
+ std::hash<string>()(recv_node.device);
+ }
+};
+
+struct RecvNodeDescriptorEqual {
+ bool operator()(const RecvNodeDescriptor& a,
+ const RecvNodeDescriptor& b) const {
+ return a.node == b.node && a.port_num == b.port_num && a.device == b.device;
+ }
+};
} // namespace
VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item,
@@ -109,6 +131,11 @@ Status VirtualScheduler::Init() {
name_to_node[node->name()] = node;
}
+ // To reuse _Recv ops.
+ std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescritorHash,
+ RecvNodeDescriptorEqual>
+ cached_recv_nodes;
+
// Build node_map; for each node, create its NodeState and connect its inputs
// and outputs.
for (const auto* curr_node : nodes) {
@@ -131,12 +158,14 @@ Status VirtualScheduler::Init() {
auto& input_node_state = GetNodeStateOrCreateIt(input_node);
input_node_state.outputs[input_node_port_num].push_back(curr_node);
} else {
- if (cached_recv_nodes_.count(input_node) > 0 &&
- cached_recv_nodes_[input_node].count(curr_node_device) > 0) {
+ RecvNodeDescriptor recv_node = {.node = input_node,
+ .port_num = input_node_port_num,
+ .device = curr_node_device};
+ auto it = cached_recv_nodes.find(recv_node);
+ if (it != cached_recv_nodes.end()) {
// Different device, but found an already-cached copy (a _Recv op);
// connect the _Recv to curr_node.
- const auto* recv_op =
- cached_recv_nodes_[input_node][curr_node_device];
+ const NodeDef* recv_op = it->second;
// recv_op's output port is hard-coded to zero.
curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
auto& input_node_state = node_map_.at(recv_op);
@@ -156,7 +185,7 @@ Status VirtualScheduler::Init() {
input_node_state.outputs[input_node_port_num].push_back(send);
// Cache the _Recv op for future use.
- cached_recv_nodes_[input_node][curr_node_device] = recv;
+ cached_recv_nodes[recv_node] = recv;
}
}
}
@@ -269,11 +298,13 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
// input names, attrs, etc.
auto input_node_port_num = NodePosition(input_name);
+ const string port_num_string =
+ input_node_port_num >= 0 ? std::to_string(input_node_port_num) : "minus1";
// _Send op.
auto* send = new NodeDef();
- send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " +
- DeviceName(to));
+ send->set_name("Send " + from->name() + ":" + port_num_string + " from " +
+ DeviceName(from) + " to " + DeviceName(to));
send->set_op("_Send");
send->add_input(from->name());
send->set_device(ChannelDeviceName(from, to));
@@ -284,7 +315,8 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
// _Recv op.
auto* recv = new NodeDef();
- recv->set_name("Recv " + from->name() + " on " + DeviceName(to));
+ recv->set_name("Recv " + from->name() + ":" + port_num_string + " on " +
+ DeviceName(to));
recv->set_op("_Recv");
recv->add_input(send->name());
recv->set_device(DeviceName(to));