aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-29 11:55:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-29 11:58:17 -0700
commit7be843c68f6c97f9d885c56b026bfe564402d072 (patch)
tree6744f120150c09fba5271f5597016f5b3321d3d5
parent7b394717dccce7d4b252e6d935a7c32ed5daff6d (diff)
Extracts the 'remove split or splitv nodes' optimization into its own method.
PiperOrigin-RevId: 198432976
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc23
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h4
2 files changed, 20 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 90862b0dc5..1ea916a250 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1645,13 +1645,7 @@ Status ConstantFolding::SimplifyGraph(bool use_shape_info,
Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
GraphDef* optimized_graph,
GraphProperties* properties) {
- if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
- ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
- return Status::OK();
- }
-
- if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
- ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+ if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) {
return Status::OK();
}
@@ -1792,6 +1786,21 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
return Status::OK();
}
+bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
+ GraphDef* optimized_graph,
+ NodeDef* node) {
+ if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
+ ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
+ return true;
+ }
+
+ if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
+ ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+ return true;
+ }
+ return false;
+}
+
Status ConstantFolding::RemoveShuffleOrTranspose(
const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 43cabb484c..b42d5f201e 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -205,6 +205,10 @@ class ConstantFolding : public GraphOptimizer {
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node,
bool* success);
+
+ // Removes Split or SplitV node if possible.
+ bool RemoveSplitOrSplitV(const GraphProperties& properties,
+ GraphDef* optimized_graph, NodeDef* node);
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;