aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/virtual_scheduler.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-10-12 11:25:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-12 11:29:39 -0700
commit4241b86dc8da0f8ba23cb832c090469635bf09a9 (patch)
treed0a863fe29e3c0a29c0d79c370f71840da639bc6 /tensorflow/core/grappler/costs/virtual_scheduler.cc
parente975d947929c3f396ec536a086e1fb3756efa2e4 (diff)
Updated the virtual scheduler to use legal names when inserting Send/Recv nodes in the graph.
PiperOrigin-RevId: 171986401
Diffstat (limited to 'tensorflow/core/grappler/costs/virtual_scheduler.cc')
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc27
1 files changed, 17 insertions, 10 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 1ae6fac8c8..d5625ae58f 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -310,11 +310,18 @@ string VirtualScheduler::DeviceName(const NodeDef* node) const {
return placer_.get_canonical_device_name(*node);
}
+string VirtualScheduler::SanitizedDeviceName(const NodeDef* node) const {
+ // Replace the ":" characters that may be present in the device name with "_".
+ // This makes it possible to then use the resulting string in a node name.
+ return str_util::StringReplace(placer_.get_canonical_device_name(*node), ":",
+ "_", true);
+}
+
string VirtualScheduler::ChannelDeviceName(const NodeDef* from,
const NodeDef* to) const {
CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
- return kChannelDevice + ": from " + DeviceName(from) + " to " +
- DeviceName(to);
+ return kChannelDevice + "_from_" + SanitizedDeviceName(from) + "_to_" +
+ SanitizedDeviceName(to);
}
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
@@ -335,15 +342,15 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
auto input_node_port_num = NodePosition(input_name);
string src_name;
if (input_node_port_num >= 0) {
- src_name = strings::StrCat(from->name(), ":", input_node_port_num);
+ src_name = strings::StrCat(from->name(), "_", input_node_port_num);
} else {
- src_name = strings::StrCat(from->name(), ":minus1");
+ src_name = strings::StrCat(from->name(), "_minus1");
}
// _Send op.
auto* send = new NodeDef();
- send->set_name("Send " + src_name + " from " + DeviceName(from) + " to " +
- DeviceName(to));
+ send->set_name("Send_" + src_name + "_from_" + SanitizedDeviceName(from) +
+ "_to_" + SanitizedDeviceName(to));
send->set_op("_Send");
send->add_input(from->name());
send->set_device(ChannelDeviceName(from, to));
@@ -354,7 +361,7 @@ std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
// _Recv op.
auto* recv = new NodeDef();
- recv->set_name("Recv " + src_name + " on " + DeviceName(to));
+ recv->set_name("Recv_" + src_name + "_on_" + SanitizedDeviceName(to));
recv->set_op("_Recv");
recv->add_input(send->name());
recv->set_device(DeviceName(to));
@@ -500,8 +507,8 @@ Costs& VirtualScheduler::FindOrCreateZero(const string& op_name,
bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
// Update graph_costs_ and per-op costs.
graph_costs_ = CombineCosts(graph_costs_, node_costs);
- const auto* node = ready_nodes_->GetCurrNode();
- const auto& op_name = node->op();
+ const NodeDef* node = ready_nodes_->GetCurrNode();
+ const string& op_name = node->op();
// Also keep track of op counts and times per op (with their shapes).
OpContext op_context = GetCurrNode();
@@ -651,7 +658,7 @@ Costs VirtualScheduler::Summary() const {
<< ", num_nodes = " << state.nodes_executed.size()
<< ", execution_time = " << state.GetCurrTime().count()
<< ", memory usage: "
- << "persistenst = "
+ << "persistent = "
<< strings::HumanReadableNumBytes(persistent_memory_usage)
<< ", peak = "
<< strings::HumanReadableNumBytes(state.max_memory_usage)