diff options
author | 2018-05-31 14:01:45 -0700 | |
---|---|---|
committer | 2018-05-31 14:03:45 -0700 | |
commit | 395428bcaf02c9a9e8067083993d7e6b5afdc0a6 (patch) | |
tree | 028e3a7c9922edab67e89cddeaec37f90ac1bec7 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | |
parent | d3b5b07e7810782c3760468312f9cace10b89073 (diff) |
Move RemodeRedundantReshape optimization to a separate stage.
PiperOrigin-RevId: 198775276
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 114 |
1 files changed, 61 insertions, 53 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index e7f385cbd6..0edea16aac 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -196,22 +196,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) { bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } -// Returns whether `reshape` is an identity op. The tensor that `reshape` -// reshapes is the `output_pos`-th output of node `input`. -bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input, - const int output_pos, - const GraphProperties& graph_properties) { - const std::vector<OpInfo::TensorProperties>& reshape_props = - graph_properties.GetOutputProperties(reshape.name()); - const std::vector<OpInfo::TensorProperties>& input_props = - graph_properties.GetOutputProperties(input.name()); - if (reshape_props.empty() || input_props.size() <= output_pos) { - return false; - } - - return ShapesSymbolicallyEqual(input_props[output_pos], reshape_props[0]); -} - NodeDef* GetTailOfValuePreservingChain( const NodeDef& node, const NodeMap& node_map, const std::unordered_set<string>& nodes_to_preserve) { @@ -1823,6 +1807,65 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage { } }; +// Bypass redundant reshape nodes: +// +// Reshape Reshape <-+ +// ^ | +// | | +// Reshape becomes Reshape | +// ^ | +// | | +// input input ---+ +class RemoveRedundantReshape : public ArithmeticOptimizerStage { + public: + explicit RemoveRedundantReshape(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("RemoveRedundantReshape", ctx, ctx_ext) {} + ~RemoveRedundantReshape() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsReshape(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + + // 1. Bypass reshape followed by reshape. + if (IsReshape(*input) && !HasControlInputs(*input)) { + node->set_input(0, input->input(0)); + ctx().node_map->UpdateInput(node->name(), input->name(), input->input(0)); + *simplified_node_name = node->name(); + AddToOptimizationQueue(node); + return Status::OK(); + } + + // 2. If the reshape is a no-op, forward its input to its consumers, unless + // it anchors a control dependency since we want to make sure that control + // dependency is triggered. + if (ReshapeIsIdentity(*node) && !HasControlInputs(*node)) { + *simplified_node_name = node->input(0); + return Status::OK(); + } + + return Status::OK(); + } + + private: + // Returns whether `reshape` is an identity op. + bool ReshapeIsIdentity(const NodeDef& reshape) { + OpInfo::TensorProperties reshape_props; + OpInfo::TensorProperties input_props; + + if (!GetTensorProperties(reshape.name(), &reshape_props).ok() || + !GetTensorProperties(reshape.input(0), &input_props).ok()) { + return false; + } + + return ShapesSymbolicallyEqual(input_props.shape(), reshape_props.shape()); + } +}; + } // namespace class UniqueNodes { @@ -2076,43 +2119,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) { - if (node->op() == "Reshape") { - // Reshape - // ^ - // | - // Reshape - // ^ - // | - // input - // - // becomes - // - // Reshape <-+ - // | - // Reshape | - // ^ | - // | | - // input ---+ - NodeDef* reshape = const_cast<NodeDef*>(node); - int output_pos = 0; - string input_node_name = ParseNodeName(reshape->input(0), &output_pos); - const NodeDef* input = node_map_->GetNode(input_node_name); - if (input->op() == "Reshape" && !HasControlInputs(*input)) { - reshape->set_input(0, input->input(0)); - node_map_->UpdateInput(reshape->name(), input->name(), input->input(0)); - nodes_to_simplify->PushBack(reshape); - return reshape->name(); - } - - // If the reshape is a no-op, forward its input to its consumers, unless it - // anchors a control dependency since we want to make sure that control - // dependency is triggered. - if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_) && - !HasControlInputs(*reshape)) { - return reshape->input(0); - } - } - if (node->op() == "Transpose") { // Reorder Cast and Transpose if beneficial. // @@ -2450,6 +2456,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext); if (options_.remove_redundant_cast) pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext); + if (options_.remove_redundant_reshape) + pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext); if (options_.remove_negation) pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext); if (options_.remove_logical_not) |