aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/utils.cc')
-rw-r--r--tensorflow/core/grappler/utils.cc42
1 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index c8e63f95e1..153785d3b4 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/grappler/utils.h"
+
#include <memory>
+#include <queue>
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -21,7 +24,6 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -354,13 +356,51 @@ void DedupControlInputs(NodeDef* node) {
}
namespace {
+
+template <typename UniqueContainer>
+void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete,
+ GraphDef* graph) {
+ static_assert(std::is_same<typename UniqueContainer::value_type, int>::value,
+ "Need to pass container of ints");
+
+ int last = graph->node_size() - 1;
+ for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) {
+ const int index = *it;
+ graph->mutable_node()->SwapElements(index, last);
+ last--;
+ }
+ graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size());
+}
+
template <typename T>
inline void STLSortAndRemoveDuplicates(T* v) {
std::sort(v->begin(), v->end());
v->erase(std::unique(v->begin(), v->end()), v->end());
}
+
} // namespace
+void EraseNodesFromGraph(const std::set<int>& nodes_to_delete,
+ GraphDef* graph) {
+ EraseNodesFromGraphImpl(nodes_to_delete, graph);
+}
+
+void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) {
+ STLSortAndRemoveDuplicates(&nodes_to_delete);
+ EraseNodesFromGraphImpl(nodes_to_delete, graph);
+}
+
+void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
+ GraphDef* graph) {
+ std::vector<int> nodes_idx_to_delete;
+ nodes_idx_to_delete.reserve(nodes_to_delete.size());
+ for (int i = 0; i < graph->node_size(); ++i) {
+ if (nodes_to_delete.count(graph->node(i).name()))
+ nodes_idx_to_delete.push_back(i);
+ }
+ EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
+}
+
Status SimpleGraphView::Initialize(
const GraphDef& graph,
const std::vector<std::pair<const NodeDef*, const NodeDef*>>*