aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-31 14:01:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-31 14:03:45 -0700
commit395428bcaf02c9a9e8067083993d7e6b5afdc0a6 (patch)
tree028e3a7c9922edab67e89cddeaec37f90ac1bec7 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parentd3b5b07e7810782c3760468312f9cace10b89073 (diff)
Move RemodeRedundantReshape optimization to a separate stage.
PiperOrigin-RevId: 198775276
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc114
1 files changed, 61 insertions, 53 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index e7f385cbd6..0edea16aac 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -196,22 +196,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) {
bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
-// Returns whether `reshape` is an identity op. The tensor that `reshape`
-// reshapes is the `output_pos`-th output of node `input`.
-bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
- const int output_pos,
- const GraphProperties& graph_properties) {
- const std::vector<OpInfo::TensorProperties>& reshape_props =
- graph_properties.GetOutputProperties(reshape.name());
- const std::vector<OpInfo::TensorProperties>& input_props =
- graph_properties.GetOutputProperties(input.name());
- if (reshape_props.empty() || input_props.size() <= output_pos) {
- return false;
- }
-
- return ShapesSymbolicallyEqual(input_props[output_pos], reshape_props[0]);
-}
-
NodeDef* GetTailOfValuePreservingChain(
const NodeDef& node, const NodeMap& node_map,
const std::unordered_set<string>& nodes_to_preserve) {
@@ -1823,6 +1807,65 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
}
};
+// Bypass redundant reshape nodes:
+//
+// Reshape Reshape <-+
+// ^ |
+// | |
+// Reshape becomes Reshape |
+// ^ |
+// | |
+// input input ---+
+class RemoveRedundantReshape : public ArithmeticOptimizerStage {
+ public:
+ explicit RemoveRedundantReshape(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("RemoveRedundantReshape", ctx, ctx_ext) {}
+ ~RemoveRedundantReshape() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsReshape(*node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+
+ // 1. Bypass reshape followed by reshape.
+ if (IsReshape(*input) && !HasControlInputs(*input)) {
+ node->set_input(0, input->input(0));
+ ctx().node_map->UpdateInput(node->name(), input->name(), input->input(0));
+ *simplified_node_name = node->name();
+ AddToOptimizationQueue(node);
+ return Status::OK();
+ }
+
+ // 2. If the reshape is a no-op, forward its input to its consumers, unless
+ // it anchors a control dependency since we want to make sure that control
+ // dependency is triggered.
+ if (ReshapeIsIdentity(*node) && !HasControlInputs(*node)) {
+ *simplified_node_name = node->input(0);
+ return Status::OK();
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ // Returns whether `reshape` is an identity op.
+ bool ReshapeIsIdentity(const NodeDef& reshape) {
+ OpInfo::TensorProperties reshape_props;
+ OpInfo::TensorProperties input_props;
+
+ if (!GetTensorProperties(reshape.name(), &reshape_props).ok() ||
+ !GetTensorProperties(reshape.input(0), &input_props).ok()) {
+ return false;
+ }
+
+ return ShapesSymbolicallyEqual(input_props.shape(), reshape_props.shape());
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -2076,43 +2119,6 @@ void ArithmeticOptimizer::ForwardControlDependencies(
string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) {
- if (node->op() == "Reshape") {
- // Reshape
- // ^
- // |
- // Reshape
- // ^
- // |
- // input
- //
- // becomes
- //
- // Reshape <-+
- // |
- // Reshape |
- // ^ |
- // | |
- // input ---+
- NodeDef* reshape = const_cast<NodeDef*>(node);
- int output_pos = 0;
- string input_node_name = ParseNodeName(reshape->input(0), &output_pos);
- const NodeDef* input = node_map_->GetNode(input_node_name);
- if (input->op() == "Reshape" && !HasControlInputs(*input)) {
- reshape->set_input(0, input->input(0));
- node_map_->UpdateInput(reshape->name(), input->name(), input->input(0));
- nodes_to_simplify->PushBack(reshape);
- return reshape->name();
- }
-
- // If the reshape is a no-op, forward its input to its consumers, unless it
- // anchors a control dependency since we want to make sure that control
- // dependency is triggered.
- if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_) &&
- !HasControlInputs(*reshape)) {
- return reshape->input(0);
- }
- }
-
if (node->op() == "Transpose") {
// Reorder Cast and Transpose if beneficial.
//
@@ -2450,6 +2456,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
if (options_.remove_redundant_cast)
pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
+ if (options_.remove_redundant_reshape)
+ pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext);
if (options_.remove_negation)
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
if (options_.remove_logical_not)