aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-18 15:23:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 15:26:27 -0700
commit209662bac4a3e04ae359939f67ab892456453b92 (patch)
tree1b5c3ec1cb287973f578e886f3661042cf77098e /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parentc26ba8f104cd6efd16080ada5f6414baa1f4e372 (diff)
Fix bug in RemoveIdempotent optimizer stage.
Minor cleanup in RemoveIdentityTranspose. PiperOrigin-RevId: 201069367
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc45
1 files changed, 21 insertions, 24 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index d518685216..0d69e0dde3 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1083,14 +1083,6 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
- NodeDef* tail = node;
- // TODO(rmlarsen): Enable after debugging breakage in Bayesflow.
- if (ctx().opt_level == RewriterConfig::AGGRESSIVE) {
- tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
- *ctx().nodes_to_preserve);
- }
- NodeDef* first_transpose;
- TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
NodeDef* node_perm;
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
@@ -1099,7 +1091,21 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
}
std::vector<int64> node_perm_values;
TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
- if (first_transpose->op() == node->op()) {
+
+ // Remove simple identity transposes.
+ if (IsIdentityPermutation(node_perm_values)) {
+ *simplified_node_name = node->input(0);
+ return Status::OK();
+ }
+
+ NodeDef* tail = node;
+ tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
+ *ctx().nodes_to_preserve);
+ NodeDef* first_transpose;
+ TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
+
+ if (first_transpose->op() == node->op() &&
+ NumNonControlOutputs(*first_transpose, *ctx().node_map) == 1) {
// Remove pairs of transposes that cancel each other.
NodeDef* first_transpose_perm;
TF_RETURN_IF_ERROR(
@@ -1124,11 +1130,6 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
*simplified_node_name = node->input(0);
}
}
- } else {
- // Remove simple identity transposes.
- if (IsIdentityPermutation(node_perm_values)) {
- *simplified_node_name = node->input(0);
- }
}
return Status::OK();
}
@@ -1722,19 +1723,15 @@ class RemoveIdempotentStage : public ArithmeticOptimizerStage {
~RemoveIdempotentStage() override = default;
bool IsSupported(const NodeDef* node) const override {
- return IsIdempotent(*node) && !IsInPreserveSet(*node);
+ return node->input_size() == 1 && IsIdempotent(*node) &&
+ !IsInPreserveSet(*node);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
- auto root_scope_and_name = ParseNodeScopeAndName(node->name());
- const string new_name = OptimizedNodeName(root_scope_and_name);
- if (input->op() == node->op() && input->device() == node->device() &&
- IsIdempotent(*input) && !ctx().node_map->NodeExists(new_name)) {
- NodeDef* new_input_node = AddCopyNode(new_name, input);
- ForwardControlDependencies(new_input_node, {node});
- *simplified_node_name = new_input_node->name();
+ if (input->op() == node->op() && input->device() == node->device()) {
+ *simplified_node_name = node->input(0);
}
return Status::OK();
}
@@ -2901,7 +2898,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
if (options_.minimize_broadcasts && can_use_shapes)
pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
- if (options_.remove_identity_transpose && can_use_shapes)
+ if (options_.remove_identity_transpose)
pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
if (options_.remove_involution)
pipeline.AddStage<RemoveInvolution>(ctx, ctx_ext);
@@ -2909,7 +2906,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
if (options_.remove_redundant_cast)
pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
- if (options_.remove_redundant_reshape)
+ if (options_.remove_redundant_reshape && can_use_shapes)
pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext);
if (options_.remove_negation)
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);