diff options
Diffstat (limited to 'tensorflow/core/grappler/utils.cc')
-rw-r--r-- | tensorflow/core/grappler/utils.cc | 42 |
1 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index c8e63f95e1..153785d3b4 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/grappler/utils.h" + #include <memory> +#include <queue> #include <vector> #include "tensorflow/core/framework/attr_value.pb.h" @@ -21,7 +24,6 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -354,13 +356,51 @@ void DedupControlInputs(NodeDef* node) { } namespace { + +template <typename UniqueContainer> +void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete, + GraphDef* graph) { + static_assert(std::is_same<typename UniqueContainer::value_type, int>::value, + "Need to pass container of ints"); + + int last = graph->node_size() - 1; + for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) { + const int index = *it; + graph->mutable_node()->SwapElements(index, last); + last--; + } + graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size()); +} + template <typename T> inline void STLSortAndRemoveDuplicates(T* v) { std::sort(v->begin(), v->end()); v->erase(std::unique(v->begin(), v->end()), v->end()); } + } // namespace +void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, + GraphDef* graph) { + EraseNodesFromGraphImpl(nodes_to_delete, graph); +} + +void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) { + STLSortAndRemoveDuplicates(&nodes_to_delete); + EraseNodesFromGraphImpl(nodes_to_delete, graph); +} + +void EraseNodesFromGraph(const std::set<string>& nodes_to_delete, + GraphDef* graph) { + std::vector<int> nodes_idx_to_delete; + nodes_idx_to_delete.reserve(nodes_to_delete.size()); + for (int i = 0; i < graph->node_size(); ++i) { + if (nodes_to_delete.count(graph->node(i).name())) + nodes_idx_to_delete.push_back(i); + } + EraseNodesFromGraphImpl(nodes_idx_to_delete, graph); +} + Status SimpleGraphView::Initialize( const GraphDef& graph, const std::vector<std::pair<const NodeDef*, const NodeDef*>>* |