diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-07-24 12:23:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-24 12:27:37 -0700 |
commit | 1b578268f5407b4f40c4226b7d00b08e948a4b09 (patch) | |
tree | 65f644e80f01d4e7dce65111a0e164f2b8f5a30b /tensorflow/core/grappler/costs/virtual_scheduler.cc | |
parent | 62bced8280075c4964033cb8b8e43b9855d655e1 (diff) |
Fix _Recv op caching for multi-output port ops in VirtualScheduler.
PiperOrigin-RevId: 162970793
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 52 |
1 files changed, 45 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 1dad81ed90..6b0b869df5 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -56,6 +56,32 @@ Costs CombineCosts(const Costs& left, const Costs& right) { << " max_per_op_streaming=" << result.max_per_op_streaming; return result; } + +// Key to the cached _Recv ops map, and its hash and predicate structures. +struct RecvNodeDescriptor { + const NodeDef* node; + const int port_num; + const string& device; + + RecvNodeDescriptor(const NodeDef* node_, const int port_num_, + const string& device_) + : node(node_), port_num(port_num_), device(device_) {} +}; + +struct RecvNodeDescritorHash { + std::size_t operator()(const RecvNodeDescriptor& recv_node) const { + return std::hash<const NodeDef*>()(recv_node.node) ^ + std::hash<int>()(recv_node.port_num) ^ + std::hash<string>()(recv_node.device); + } +}; + +struct RecvNodeDescriptorEqual { + bool operator()(const RecvNodeDescriptor& a, + const RecvNodeDescriptor& b) const { + return a.node == b.node && a.port_num == b.port_num && a.device == b.device; + } +}; } // namespace VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item, @@ -109,6 +135,11 @@ Status VirtualScheduler::Init() { name_to_node[node->name()] = node; } + // To reuse _Recv ops. + std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescritorHash, + RecvNodeDescriptorEqual> + cached_recv_nodes; + // Build node_map; for each node, create its NodeState and connect its inputs // and outputs. for (const auto* curr_node : nodes) { @@ -131,12 +162,13 @@ Status VirtualScheduler::Init() { auto& input_node_state = GetNodeStateOrCreateIt(input_node); input_node_state.outputs[input_node_port_num].push_back(curr_node); } else { - if (cached_recv_nodes_.count(input_node) > 0 && - cached_recv_nodes_[input_node].count(curr_node_device) > 0) { + RecvNodeDescriptor recv_node(input_node, input_node_port_num, + curr_node_device); + auto it = cached_recv_nodes.find(recv_node); + if (it != cached_recv_nodes.end()) { // Different device, but found an already-cached copy (a _Recv op); // connect the _Recv to curr_node. - const auto* recv_op = - cached_recv_nodes_[input_node][curr_node_device]; + const NodeDef* recv_op = it->second; // recv_op's output port is hard-coded to zero. curr_node_state.inputs.push_back(std::make_pair(recv_op, 0)); auto& input_node_state = node_map_.at(recv_op); @@ -156,7 +188,7 @@ Status VirtualScheduler::Init() { input_node_state.outputs[input_node_port_num].push_back(send); // Cache the _Recv op for future use. - cached_recv_nodes_[input_node][curr_node_device] = recv; + cached_recv_nodes[recv_node] = recv; } } } @@ -269,10 +301,16 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv( // input names, attrs, etc. auto input_node_port_num = NodePosition(input_name); + string src_name; + if (input_node_port_num >= 0) { + src_name = strings::StrCat(from->name(), ":", input_node_port_num); + } else { + src_name = strings::StrCat(from->name(), ":minus1"); + } // _Send op. auto* send = new NodeDef(); - send->set_name("Send " + from->name() + " from " + DeviceName(from) + " to " + + send->set_name("Send " + src_name + " from " + DeviceName(from) + " to " + DeviceName(to)); send->set_op("_Send"); send->add_input(from->name()); @@ -284,7 +322,7 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv( // _Recv op. auto* recv = new NodeDef(); - recv->set_name("Recv " + from->name() + " on " + DeviceName(to)); + recv->set_name("Recv " + src_name + " on " + DeviceName(to)); recv->set_op("_Recv"); recv->add_input(send->name()); recv->set_device(DeviceName(to)); |