aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-05 17:49:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 17:53:44 -0700
commit2366bd07dd3fc0e82f34f92deeebdc9cb87649de (patch)
treed7701870bc31fb6112352b80dc11923492ed0549 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parent5105350be955422169de1f22bb99f928c1f4c2ae (diff)
Automated g4 rollback of changelist 197562826
PiperOrigin-RevId: 199388675
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc151
1 files changed, 151 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 44a14ef7eb..51110b4bda 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2334,6 +2334,156 @@ class SimplifyAggregation : public ArithmeticOptimizerStage {
}
};
+class ConvertPowStage : public ArithmeticOptimizerStage {
+ public:
+ explicit ConvertPowStage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsPow(*node) &&
+ ctx().graph_properties->GetInputProperties(node->name()).size() == 2;
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
+ for (int i = 0; i < p.shape().dim_size(); ++i) {
+ if (p.shape().dim(i).size() < 0) {
+ // skip if p is is not fully defined.
+ return Status::OK();
+ }
+ }
+ if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+ Tensor pow(p.dtype(), p.shape());
+ if (!pow.FromProto(p.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ p.value().DebugString());
+ }
+
+ complex128 prev, curr;
+ for (int i = 0; i < pow.NumElements(); ++i) {
+ TF_RETURN_IF_ERROR(GetElement(pow, i, &curr));
+ if (i != 0 && curr != prev) {
+ // pow has different values on different elements. Skip.
+ return Status::OK();
+ }
+ prev = curr;
+ }
+ NodeDef *x, *y;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+ if (curr == complex128(2, 0)) {
+ node->set_op("Square");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(1, 0)) {
+ node->set_op("Identity");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(0.5, 0)) {
+ node->set_op("Sqrt");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(0, 0)) {
+ const auto& b =
+ ctx().graph_properties->GetInputProperties(node->name())[0];
+ for (int i = 0; i < b.shape().dim_size(); ++i) {
+ if (b.shape().dim(i).size() < 0) {
+ // skip if b is is not fully defined.
+ return Status::OK();
+ }
+ }
+ if (TensorShape::IsValid(b.shape()) && b.has_value()) {
+ Tensor base(b.dtype(), b.shape());
+ if (!base.FromProto(b.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ b.value().DebugString());
+ }
+ node->set_op("Const");
+ Tensor c(base.dtype(), base.shape());
+ for (int i = 0; i < c.NumElements(); ++i) {
+ TF_RETURN_IF_ERROR(SetElementToOne(i, &c));
+ }
+ (*node->mutable_attr())["dtype"].set_type(base.dtype());
+ c.AsProtoTensorContent(
+ (*node->mutable_attr())["value"].mutable_tensor());
+ node->mutable_attr()->erase("T");
+ node->set_input(0, AsControlDependency(x->name()));
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(x);
+ AddToOptimizationQueue(y);
+ }
+ } else if (curr == complex128(-0.5, 0)) {
+ node->set_op("Rsqrt");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(-1, 0)) {
+ node->set_op("Reciprocal");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_INT32:
+ *element = complex128(t.flat<int32>()(i));
+ return Status::OK();
+ case DT_INT64:
+ *element = complex128(t.flat<int64>()(i));
+ return Status::OK();
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return Status::OK();
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return Status::OK();
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return Status::OK();
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Invalid data type: ", t.dtype());
+ }
+ }
+
+ Status SetElementToOne(int i, Tensor* t) {
+ switch (t->dtype()) {
+ case DT_INT32:
+ t->flat<int32>()(i) = 1;
+ return Status::OK();
+ case DT_INT64:
+ t->flat<int64>()(i) = 1L;
+ return Status::OK();
+ case DT_FLOAT:
+ t->flat<float>()(i) = 1.0f;
+ return Status::OK();
+ case DT_DOUBLE:
+ t->flat<double>()(i) = 1.0;
+ return Status::OK();
+ case DT_COMPLEX64:
+ t->flat<complex64>()(i) = complex64(1);
+ return Status::OK();
+ case DT_COMPLEX128:
+ t->flat<complex128>()(i) = complex128(1);
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Invalid data type: ", t->dtype());
+ }
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -2608,6 +2758,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
if (options_.remove_idempotent)
pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
+ if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");