diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2018-06-05 12:19:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-05 12:22:39 -0700 |
commit | 2b5f598fbd822f911ad305ae1e57325aefd50826 (patch) | |
tree | 30ced01eceaa62a99ea7908688df5f79bf4c46d6 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | |
parent | 920df27282b3f5d03d79f54ef05cea305c2a30d7 (diff) |
Move ReplaceMulWithSquare to a separate optimizer stage.
PiperOrigin-RevId: 199338297
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 68 |
1 files changed, 45 insertions, 23 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 400af82627..561930f858 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2079,6 +2079,49 @@ class FoldMultiplyIntoConv : public ArithmeticOptimizerStage { } }; +// Replace Mul node with identical inputs with a Square. +class ReplaceMulWithSquare : public ArithmeticOptimizerStage { + public: + explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {} + ~ReplaceMulWithSquare() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsMul(*node) && node->input(0) == node->input(1); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const NodeScopeAndName mul = ParseNodeScopeAndName(node->name()); + const string optimized_node_name = OptimizedNodeName(mul); + if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK(); + + const DataType type = GetDataTypeFromAttr(*node, "T"); + bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); + + string task; + string device; + bool is_on_cpu = + DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) && + str_util::StrContains(device, DEVICE_CPU); + + if (!is_complex || is_on_cpu) { + NodeDef* new_square_node = AddCopyNode(optimized_node_name, node); + new_square_node->set_op("Square"); + for (int i = 1; i < new_square_node->input_size(); ++i) { + new_square_node->set_input(i - 1, new_square_node->input(i)); + } + new_square_node->mutable_input()->RemoveLast(); + for (const string& input : new_square_node->input()) { + ctx().node_map->AddOutput(NodeName(input), new_square_node->name()); + } + *simplified_node_name = new_square_node->name(); + } + + return Status::OK(); + } +}; + } // namespace class UniqueNodes { @@ -2331,29 +2374,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( // ArithmeticOptimizerStage string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) { - if (node->op() == "Mul" && node->input(0) == node->input(1) && - !OptimizedNodeExists(*node, "square")) { - const DataType type = GetDataTypeFromAttr(*node, "T"); - bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); - string dontcare; - string device; - bool is_on_cpu = - DeviceNameUtils::SplitDeviceName(node->device(), &dontcare, &device) && - str_util::StrContains(device, DEVICE_CPU); - if (!is_complex || is_on_cpu) { - NodeDef* new_square_node = AddNode(*node, "square", /*copy_node=*/true); - new_square_node->set_op("Square"); - for (int i = 1; i < new_square_node->input_size(); ++i) { - new_square_node->set_input(i - 1, new_square_node->input(i)); - } - new_square_node->mutable_input()->RemoveLast(); - for (const string& input : new_square_node->input()) { - node_map_->AddOutput(NodeName(input), new_square_node->name()); - } - return new_square_node->name(); - } - } - if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) { // Discard aggregate nodes with a single input and no control dependencies. if (node->input_size() == 1) { @@ -2528,6 +2548,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext); if (options_.remove_negation) pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext); + if (options_.replace_mul_with_square) + pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext); if (options_.remove_logical_not) pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext); if (options_.reorder_cast_and_transpose) |