aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2018-06-05 12:19:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 12:22:39 -0700
commit2b5f598fbd822f911ad305ae1e57325aefd50826 (patch)
tree30ced01eceaa62a99ea7908688df5f79bf4c46d6 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parent920df27282b3f5d03d79f54ef05cea305c2a30d7 (diff)
Move ReplaceMulWithSquare to a separate optimizer stage.
PiperOrigin-RevId: 199338297
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc68
1 files changed, 45 insertions, 23 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 400af82627..561930f858 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2079,6 +2079,49 @@ class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
}
};
+// Replace Mul node with identical inputs with a Square.
+class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
+ public:
+ explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {}
+ ~ReplaceMulWithSquare() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsMul(*node) && node->input(0) == node->input(1);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const NodeScopeAndName mul = ParseNodeScopeAndName(node->name());
+ const string optimized_node_name = OptimizedNodeName(mul);
+ if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
+
+ const DataType type = GetDataTypeFromAttr(*node, "T");
+ bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
+
+ string task;
+ string device;
+ bool is_on_cpu =
+ DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
+ str_util::StrContains(device, DEVICE_CPU);
+
+ if (!is_complex || is_on_cpu) {
+ NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
+ new_square_node->set_op("Square");
+ for (int i = 1; i < new_square_node->input_size(); ++i) {
+ new_square_node->set_input(i - 1, new_square_node->input(i));
+ }
+ new_square_node->mutable_input()->RemoveLast();
+ for (const string& input : new_square_node->input()) {
+ ctx().node_map->AddOutput(NodeName(input), new_square_node->name());
+ }
+ *simplified_node_name = new_square_node->name();
+ }
+
+ return Status::OK();
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -2331,29 +2374,6 @@ void ArithmeticOptimizer::ForwardControlDependencies(
// ArithmeticOptimizerStage
string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) {
- if (node->op() == "Mul" && node->input(0) == node->input(1) &&
- !OptimizedNodeExists(*node, "square")) {
- const DataType type = GetDataTypeFromAttr(*node, "T");
- bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
- string dontcare;
- string device;
- bool is_on_cpu =
- DeviceNameUtils::SplitDeviceName(node->device(), &dontcare, &device) &&
- str_util::StrContains(device, DEVICE_CPU);
- if (!is_complex || is_on_cpu) {
- NodeDef* new_square_node = AddNode(*node, "square", /*copy_node=*/true);
- new_square_node->set_op("Square");
- for (int i = 1; i < new_square_node->input_size(); ++i) {
- new_square_node->set_input(i - 1, new_square_node->input(i));
- }
- new_square_node->mutable_input()->RemoveLast();
- for (const string& input : new_square_node->input()) {
- node_map_->AddOutput(NodeName(input), new_square_node->name());
- }
- return new_square_node->name();
- }
- }
-
if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) {
// Discard aggregate nodes with a single input and no control dependencies.
if (node->input_size() == 1) {
@@ -2528,6 +2548,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext);
if (options_.remove_negation)
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
+ if (options_.replace_mul_with_square)
+ pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
if (options_.remove_logical_not)
pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
if (options_.reorder_cast_and_transpose)