diff options
author | 2018-05-29 11:55:17 -0700 | |
---|---|---|
committer | 2018-05-29 11:58:17 -0700 | |
commit | 7be843c68f6c97f9d885c56b026bfe564402d072 (patch) | |
tree | 6744f120150c09fba5271f5597016f5b3321d3d5 | |
parent | 7b394717dccce7d4b252e6d935a7c32ed5daff6d (diff) |
Extracts the 'remove split or splitv nodes' optimization into its own method.
PiperOrigin-RevId: 198432976
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 23 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.h | 4 |
2 files changed, 20 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 90862b0dc5..1ea916a250 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1645,13 +1645,7 @@ Status ConstantFolding::SimplifyGraph(bool use_shape_info, Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, GraphDef* optimized_graph, GraphProperties* properties) { - if (IsSplit(*node) && node->attr().at("num_split").i() == 1) { - ReplaceOperationWithIdentity(1, *properties, node, optimized_graph); - return Status::OK(); - } - - if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) { - ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); + if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) { return Status::OK(); } @@ -1792,6 +1786,21 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, return Status::OK(); } +bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties, + GraphDef* optimized_graph, + NodeDef* node) { + if (IsSplit(*node) && node->attr().at("num_split").i() == 1) { + ReplaceOperationWithIdentity(1, properties, node, optimized_graph); + return true; + } + + if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); + return true; + } + return false; +} + Status ConstantFolding::RemoveShuffleOrTranspose( const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node, bool* success) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 43cabb484c..b42d5f201e 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -205,6 +205,10 @@ class ConstantFolding : public GraphOptimizer { bool use_shape_info, GraphDef* optimized_graph, NodeDef* node, bool* success); + + // Removes Split or SplitV node if possible. + bool RemoveSplitOrSplitV(const GraphProperties& properties, + GraphDef* optimized_graph, NodeDef* node); // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; |