diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-10-12 11:25:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-12 11:29:39 -0700 |
commit | 4241b86dc8da0f8ba23cb832c090469635bf09a9 (patch) | |
tree | d0a863fe29e3c0a29c0d79c370f71840da639bc6 /tensorflow/core/grappler/costs/virtual_scheduler.cc | |
parent | e975d947929c3f396ec536a086e1fb3756efa2e4 (diff) |
Updated the virtual scheduler to use legal names when inserting Send/Recv nodes in the graph.
PiperOrigin-RevId: 171986401
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 1ae6fac8c8..d5625ae58f 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -310,11 +310,18 @@ string VirtualScheduler::DeviceName(const NodeDef* node) const { return placer_.get_canonical_device_name(*node); } +string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const { + // Replace the ":" characters that may be present in the device name with "_". + // This makes it possible to then use the resulting string in a node name. + return str_util::StringReplace(placer_.get_canonical_device_name(*node), ":", + "_", true); +} + string VirtualScheduler::ChannelDeviceName(const NodeDef* from, const NodeDef* to) const { CHECK(!initialized_) << "ChannelDeviceName is called after Init()."; - return kChannelDevice + ": from " + DeviceName(from) + " to " + - DeviceName(to); + return kChannelDevice + "_from_" + SanitizedDeviceName(from) + "_to_" + + SanitizedDeviceName(to); } std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv( @@ -335,15 +342,15 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv( auto input_node_port_num = NodePosition(input_name); string src_name; if (input_node_port_num >= 0) { - src_name = strings::StrCat(from->name(), ":", input_node_port_num); + src_name = strings::StrCat(from->name(), "_", input_node_port_num); } else { - src_name = strings::StrCat(from->name(), ":minus1"); + src_name = strings::StrCat(from->name(), "_minus1"); } // _Send op. auto* send = new NodeDef(); - send->set_name("Send " + src_name + " from " + DeviceName(from) + " to " + - DeviceName(to)); + send->set_name("Send_" + src_name + "_from_" + SanitizedDeviceName(from) + + "_to_" + SanitizedDeviceName(to)); send->set_op("_Send"); send->add_input(from->name()); send->set_device(ChannelDeviceName(from, to)); @@ -354,7 +361,7 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv( // _Recv op. auto* recv = new NodeDef(); - recv->set_name("Recv " + src_name + " on " + DeviceName(to)); + recv->set_name("Recv_" + src_name + "_on_" + SanitizedDeviceName(to)); recv->set_op("_Recv"); recv->add_input(send->name()); recv->set_device(DeviceName(to)); @@ -500,8 +507,8 @@ Costs& VirtualScheduler::FindOrCreateZero(const string& op_name, bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { // Update graph_costs_ and per-op costs. graph_costs_ = CombineCosts(graph_costs_, node_costs); - const auto* node = ready_nodes_->GetCurrNode(); - const auto& op_name = node->op(); + const NodeDef* node = ready_nodes_->GetCurrNode(); + const string& op_name = node->op(); // Also keep track of op counts and times per op (with their shapes). OpContext op_context = GetCurrNode(); @@ -651,7 +658,7 @@ Costs VirtualScheduler::Summary() const { << ", num_nodes = " << state.nodes_executed.size() << ", execution_time = " << state.GetCurrTime().count() << ", memory usage: " - << "persistenst = " + << "persistent = " << strings::HumanReadableNumBytes(persistent_memory_usage) << ", peak = " << strings::HumanReadableNumBytes(state.max_memory_usage) |