aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils.cc
diff options
context:
space:
mode:
authorGravatar Piotr Padlewski <prazek@google.com>2018-07-25 13:37:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 13:41:05 -0700
commit45aa4e968ff743eea697c9fbc8778c33b2926a6a (patch)
treedebf0da8cc7a3094d2b847b2be1b40b9fc2fe9ec /tensorflow/core/grappler/utils.cc
parentf0b189f3f24f5032642ff71338c1be66fbd446b5 (diff)
MutableGraphView and other graph utils
MutableGraphView was implemented so that the view could be updated when new nodes are added or connections changed. The current passes do not require it only because they do not do any optimization on already optimized nodes, but optimizations like MapFusion require it. PiperOrigin-RevId: 206046420
Diffstat (limited to 'tensorflow/core/grappler/utils.cc')
-rw-r--r--tensorflow/core/grappler/utils.cc42
1 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index c8e63f95e1..153785d3b4 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/grappler/utils.h"
+
#include <memory>
+#include <queue>
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -21,7 +24,6 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -354,13 +356,51 @@ void DedupControlInputs(NodeDef* node) {
}
namespace {
+
+template <typename UniqueContainer>
+void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete,
+ GraphDef* graph) {
+ static_assert(std::is_same<typename UniqueContainer::value_type, int>::value,
+ "Need to pass container of ints");
+
+ int last = graph->node_size() - 1;
+ for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) {
+ const int index = *it;
+ graph->mutable_node()->SwapElements(index, last);
+ last--;
+ }
+ graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size());
+}
+
template <typename T>
inline void STLSortAndRemoveDuplicates(T* v) {
std::sort(v->begin(), v->end());
v->erase(std::unique(v->begin(), v->end()), v->end());
}
+
} // namespace
+void EraseNodesFromGraph(const std::set<int>& nodes_to_delete,
+ GraphDef* graph) {
+ EraseNodesFromGraphImpl(nodes_to_delete, graph);
+}
+
+void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) {
+ STLSortAndRemoveDuplicates(&nodes_to_delete);
+ EraseNodesFromGraphImpl(nodes_to_delete, graph);
+}
+
+void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
+ GraphDef* graph) {
+ std::vector<int> nodes_idx_to_delete;
+ nodes_idx_to_delete.reserve(nodes_to_delete.size());
+ for (int i = 0; i < graph->node_size(); ++i) {
+ if (nodes_to_delete.count(graph->node(i).name()))
+ nodes_idx_to_delete.push_back(i);
+ }
+ EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
+}
+
Status SimpleGraphView::Initialize(
const GraphDef& graph,
const std::vector<std::pair<const NodeDef*, const NodeDef*>>*