diff options
author | 2018-06-18 15:23:36 -0700 | |
---|---|---|
committer | 2018-06-18 15:26:27 -0700 | |
commit | 209662bac4a3e04ae359939f67ab892456453b92 (patch) | |
tree | 1b5c3ec1cb287973f578e886f3661042cf77098e /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | |
parent | c26ba8f104cd6efd16080ada5f6414baa1f4e372 (diff) |
Fix bug in RemoveIdempotent optimizer stage.
Minor cleanup in RemoveIdentityTranspose.
PiperOrigin-RevId: 201069367
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 45 |
1 files changed, 21 insertions, 24 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index d518685216..0d69e0dde3 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1083,14 +1083,6 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* node, string* simplified_node_name) override { TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); - NodeDef* tail = node; - // TODO(rmlarsen): Enable after debugging breakage in Bayesflow. - if (ctx().opt_level == RewriterConfig::AGGRESSIVE) { - tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, - *ctx().nodes_to_preserve); - } - NodeDef* first_transpose; - TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose)); NodeDef* node_perm; TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm)); @@ -1099,7 +1091,21 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { } std::vector<int64> node_perm_values; TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values)); - if (first_transpose->op() == node->op()) { + + // Remove simple identity transposes. + if (IsIdentityPermutation(node_perm_values)) { + *simplified_node_name = node->input(0); + return Status::OK(); + } + + NodeDef* tail = node; + tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, + *ctx().nodes_to_preserve); + NodeDef* first_transpose; + TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose)); + + if (first_transpose->op() == node->op() && + NumNonControlOutputs(*first_transpose, *ctx().node_map) == 1) { // Remove pairs of transposes that cancel each other. NodeDef* first_transpose_perm; TF_RETURN_IF_ERROR( @@ -1124,11 +1130,6 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { *simplified_node_name = node->input(0); } } - } else { - // Remove simple identity transposes. - if (IsIdentityPermutation(node_perm_values)) { - *simplified_node_name = node->input(0); - } } return Status::OK(); } @@ -1722,19 +1723,15 @@ class RemoveIdempotentStage : public ArithmeticOptimizerStage { ~RemoveIdempotentStage() override = default; bool IsSupported(const NodeDef* node) const override { - return IsIdempotent(*node) && !IsInPreserveSet(*node); + return node->input_size() == 1 && IsIdempotent(*node) && + !IsInPreserveSet(*node); } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { NodeDef* input; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); - auto root_scope_and_name = ParseNodeScopeAndName(node->name()); - const string new_name = OptimizedNodeName(root_scope_and_name); - if (input->op() == node->op() && input->device() == node->device() && - IsIdempotent(*input) && !ctx().node_map->NodeExists(new_name)) { - NodeDef* new_input_node = AddCopyNode(new_name, input); - ForwardControlDependencies(new_input_node, {node}); - *simplified_node_name = new_input_node->name(); + if (input->op() == node->op() && input->device() == node->device()) { + *simplified_node_name = node->input(0); } return Status::OK(); } @@ -2901,7 +2898,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext); if (options_.minimize_broadcasts && can_use_shapes) pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext); - if (options_.remove_identity_transpose && can_use_shapes) + if (options_.remove_identity_transpose) pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext); if (options_.remove_involution) pipeline.AddStage<RemoveInvolution>(ctx, ctx_ext); @@ -2909,7 +2906,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext); if (options_.remove_redundant_cast) pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext); - if (options_.remove_redundant_reshape) + if (options_.remove_redundant_reshape && can_use_shapes) pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext); if (options_.remove_negation) pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext); |