aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-08 11:20:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-08 11:28:11 -0800
commit3b25be3081d9fa1ab6976334c1a2c0f6f8d0d1a7 (patch)
treeffc0a594be806851d027d0b5df32c5c401396982 /tensorflow/tools/graph_transforms
parent40f0cf009641ef0d827729bb01e8ed50d97fd109 (diff)
Preserving order when removing nodes.
PiperOrigin-RevId: 185023366
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather.cc12
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc
index 214ec721e2..701e350fc3 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc
@@ -212,6 +212,14 @@ Status RemoveInputAtIndex(NodeDef* n, int index) {
return Status::OK();
}
+Status RemoveNodeAtIndex(GraphDef* g, int index) {
+ for (int i = index; i < g->node_size() - 1; i++) {
+ g->mutable_node()->SwapElements(i, i + 1);
+ }
+ g->mutable_node()->RemoveLast();
+ return Status::OK();
+}
+
Status SparsifyGatherInternal(
const GraphDef& input_graph_def,
const std::unique_ptr<std::unordered_map<string, string> >&
@@ -493,9 +501,7 @@ Status SparsifyGatherInternal(
removed_node_names.push_back(parsed_input);
}
}
- replaced_graph_def.mutable_node()->SwapElements(
- i, replaced_graph_def.node_size() - 1);
- replaced_graph_def.mutable_node()->RemoveLast();
+ TF_RETURN_IF_ERROR(RemoveNodeAtIndex(&replaced_graph_def, i));
continue;
}
int j = 0;