diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-05 17:49:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-05 17:53:44 -0700 |
commit | 2366bd07dd3fc0e82f34f92deeebdc9cb87649de (patch) | |
tree | d7701870bc31fb6112352b80dc11923492ed0549 | |
parent | 5105350be955422169de1f22bb99f928c1f4c2ae (diff) |
Automated g4 rollback of changelist 197562826
PiperOrigin-RevId: 199388675
3 files changed, 209 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 44a14ef7eb..51110b4bda 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2334,6 +2334,156 @@ class SimplifyAggregation : public ArithmeticOptimizerStage { } }; +class ConvertPowStage : public ArithmeticOptimizerStage { + public: + explicit ConvertPowStage(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {} + + bool IsSupported(const NodeDef* node) const override { + return IsPow(*node) && + ctx().graph_properties->GetInputProperties(node->name()).size() == 2; + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1]; + for (int i = 0; i < p.shape().dim_size(); ++i) { + if (p.shape().dim(i).size() < 0) { + // skip if p is is not fully defined. + return Status::OK(); + } + } + if (TensorShape::IsValid(p.shape()) && p.has_value()) { + Tensor pow(p.dtype(), p.shape()); + if (!pow.FromProto(p.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + p.value().DebugString()); + } + + complex128 prev, curr; + for (int i = 0; i < pow.NumElements(); ++i) { + TF_RETURN_IF_ERROR(GetElement(pow, i, &curr)); + if (i != 0 && curr != prev) { + // pow has different values on different elements. Skip. + return Status::OK(); + } + prev = curr; + } + NodeDef *x, *y; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x)); + TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); + if (curr == complex128(2, 0)) { + node->set_op("Square"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(1, 0)) { + node->set_op("Identity"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(0.5, 0)) { + node->set_op("Sqrt"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(0, 0)) { + const auto& b = + ctx().graph_properties->GetInputProperties(node->name())[0]; + for (int i = 0; i < b.shape().dim_size(); ++i) { + if (b.shape().dim(i).size() < 0) { + // skip if b is is not fully defined. + return Status::OK(); + } + } + if (TensorShape::IsValid(b.shape()) && b.has_value()) { + Tensor base(b.dtype(), b.shape()); + if (!base.FromProto(b.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + b.value().DebugString()); + } + node->set_op("Const"); + Tensor c(base.dtype(), base.shape()); + for (int i = 0; i < c.NumElements(); ++i) { + TF_RETURN_IF_ERROR(SetElementToOne(i, &c)); + } + (*node->mutable_attr())["dtype"].set_type(base.dtype()); + c.AsProtoTensorContent( + (*node->mutable_attr())["value"].mutable_tensor()); + node->mutable_attr()->erase("T"); + node->set_input(0, AsControlDependency(x->name())); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(x); + AddToOptimizationQueue(y); + } + } else if (curr == complex128(-0.5, 0)) { + node->set_op("Rsqrt"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(-1, 0)) { + node->set_op("Reciprocal"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } + } + return Status::OK(); + } + + private: + Status GetElement(const Tensor& t, int i, complex128* element) { + switch (t.dtype()) { + case DT_INT32: + *element = complex128(t.flat<int32>()(i)); + return Status::OK(); + case DT_INT64: + *element = complex128(t.flat<int64>()(i)); + return Status::OK(); + case DT_FLOAT: + *element = complex128(t.flat<float>()(i)); + return Status::OK(); + case DT_DOUBLE: + *element = complex128(t.flat<double>()(i)); + return Status::OK(); + case DT_COMPLEX64: + *element = complex128(t.flat<complex64>()(i)); + return Status::OK(); + case DT_COMPLEX128: + *element = t.flat<complex128>()(i); + return Status::OK(); + default: + return errors::InvalidArgument("Invalid data type: ", t.dtype()); + } + } + + Status SetElementToOne(int i, Tensor* t) { + switch (t->dtype()) { + case DT_INT32: + t->flat<int32>()(i) = 1; + return Status::OK(); + case DT_INT64: + t->flat<int64>()(i) = 1L; + return Status::OK(); + case DT_FLOAT: + t->flat<float>()(i) = 1.0f; + return Status::OK(); + case DT_DOUBLE: + t->flat<double>()(i) = 1.0; + return Status::OK(); + case DT_COMPLEX64: + t->flat<complex64>()(i) = complex64(1); + return Status::OK(); + case DT_COMPLEX128: + t->flat<complex128>()(i) = complex128(1); + return Status::OK(); + default: + return errors::InvalidArgument("Invalid data type: ", t->dtype()); + } + } +}; + } // namespace class UniqueNodes { @@ -2608,6 +2758,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext); if (options_.remove_idempotent) pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext); + if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index f37458eba4..40c5e9fc56 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -74,6 +74,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool reorder_cast_and_transpose = true; bool replace_mul_with_square = true; bool simplify_aggregation = true; + bool convert_pow = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 8083b6051f..ff96cb6480 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -245,6 +245,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true; } + void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_pow = true; + } + void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.remove_idempotent = true; @@ -2429,6 +2434,58 @@ TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) { } } +TEST_F(ArithmeticOptimizerTest, ConvertPow) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2}); + auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2}); + auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2}); + auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2}); + auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2}); + auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2}); + auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2}); + Output out2 = ops::Pow(s.WithOpName("out2"), x, y2); + Output out1 = ops::Pow(s.WithOpName("out1"), x, y1); + Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5); + Output out0 = ops::Pow(s.WithOpName("out0"), x, y0); + Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5); + Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1); + Output out = ops::Pow(s.WithOpName("out"), x, y); + + GrapplerItem item; + item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(7, tensors_expected.size()); + + GraphDef got; + ArithmeticOptimizer optimizer; + EnableOnlyConvertPow(&optimizer); + OptimizeAndPrune(&optimizer, &item, &got); + auto tensors = EvaluateNodes(got, item.fetch); + EXPECT_EQ(7, tensors.size()); + + GraphDef want; + AddNode("x", "Const", {}, {}, &want); + AddNode("y2", "Const", {}, {}, &want); + AddNode("y1", "Const", {}, {}, &want); + AddNode("y.5", "Const", {}, {}, &want); + AddNode("y0", "Const", {}, {}, &want); + AddNode("y_.5", "Const", {}, {}, &want); + AddNode("y_1", "Const", {}, {}, &want); + AddNode("y", "Const", {}, {}, &want); + AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want); + AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want); + AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want); + AddNode("out0", "Const", + {AsControlDependency("x"), AsControlDependency("y0")}, {}, &want); + AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want); + AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want); + AddNode("out", "Pow", {"x", "y"}, {}, &want); + + CompareGraphs(want, got); +} + TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); |