aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/utils.h')
-rw-r--r--tensorflow/core/grappler/utils.h10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 1c6fef59ea..b297caa8d4 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -209,6 +209,13 @@ void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
+void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph);
+
+void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);
+
+void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
+ GraphDef* graph);
+
class SimpleGraphView {
public:
// Build a graph view for the specified graphdef.
@@ -237,6 +244,9 @@ class SimpleGraphView {
DCHECK(it != name_to_index_.end());
return it == name_to_index_.end() ? -1 : it->second;
}
+ inline const NodeDef& node(int node_idx) const {
+ return graph_->node(node_idx);
+ }
inline const string& node_name(int node_idx) const {
return index_to_name_[node_idx];
}