aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rui Zhao <rzhao@google.com>2018-03-10 00:29:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-10 00:33:46 -0800
commit754dd339c141babf5aeee9495479ff0da380da52 (patch)
treeacc796310a19b2657dc5b79733c40cc967d9de5f
parent2cd50a9fd2900c2bf7e74a7795823254d5383fb4 (diff)
Increment node_ids when merging CostGraphDef.
PiperOrigin-RevId: 188586552
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.cc35
1 files changed, 33 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index 8e236c9ee8..313ef90d81 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -378,10 +378,15 @@ void SingleMachine::MergeCosts(CostGraphDef* graph_costs,
init_costs.node_size() +
queue_costs.node_size());
std::unordered_set<string> nodes_seen;
+ int queue_costs_id_offset = graph_costs->node_size();
for (const auto& node : graph_costs->node()) {
nodes_seen.insert(node.name());
+ if (node.id() >= queue_costs_id_offset) {
+ queue_costs_id_offset = node.id() + 1;
+ }
}
+ int init_costs_id_offset = queue_costs_id_offset + queue_costs.node_size();
// The costs obtained by running the main graph could be more stable than
// the one we get from the queue runners since the queue runners run
// asynchronously.
@@ -389,7 +394,22 @@ void SingleMachine::MergeCosts(CostGraphDef* graph_costs,
if (nodes_seen.find(node.name()) != nodes_seen.end()) {
continue;
}
- graph_costs->add_node()->MergeFrom(node);
+
+ auto* new_node = graph_costs->add_node();
+ new_node->MergeFrom(node);
+
+ new_node->set_id(node.id() + queue_costs_id_offset);
+ if (new_node->id() >= init_costs_id_offset) {
+ init_costs_id_offset = new_node->id() + 1;
+ }
+
+ for (auto& input_info : *new_node->mutable_input_info()) {
+ input_info.set_preceding_node(input_info.preceding_node() +
+ queue_costs_id_offset);
+ }
+ for (auto& control_input : *new_node->mutable_control_input()) {
+ control_input += queue_costs_id_offset;
+ }
}
// Don't overwrite the costs with that generated during initialization since
@@ -398,7 +418,18 @@ void SingleMachine::MergeCosts(CostGraphDef* graph_costs,
if (nodes_seen.find(node.name()) != nodes_seen.end()) {
continue;
}
- graph_costs->add_node()->MergeFrom(node);
+
+ auto* new_node = graph_costs->add_node();
+ new_node->MergeFrom(node);
+
+ new_node->set_id(node.id() + init_costs_id_offset);
+ for (auto& input_info : *new_node->mutable_input_info()) {
+ input_info.set_preceding_node(input_info.preceding_node() +
+ init_costs_id_offset);
+ }
+ for (auto& control_input : *new_node->mutable_control_input()) {
+ control_input += init_costs_id_offset;
+ }
}
}