diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-05 17:49:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-05 17:53:44 -0700 |
commit | 2366bd07dd3fc0e82f34f92deeebdc9cb87649de (patch) | |
tree | d7701870bc31fb6112352b80dc11923492ed0549 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | |
parent | 5105350be955422169de1f22bb99f928c1f4c2ae (diff) |
Automated g4 rollback of changelist 197562826
PiperOrigin-RevId: 199388675
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 151 |
1 files changed, 151 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 44a14ef7eb..51110b4bda 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2334,6 +2334,156 @@ class SimplifyAggregation : public ArithmeticOptimizerStage { } }; +class ConvertPowStage : public ArithmeticOptimizerStage { + public: + explicit ConvertPowStage(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {} + + bool IsSupported(const NodeDef* node) const override { + return IsPow(*node) && + ctx().graph_properties->GetInputProperties(node->name()).size() == 2; + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1]; + for (int i = 0; i < p.shape().dim_size(); ++i) { + if (p.shape().dim(i).size() < 0) { + // skip if p is is not fully defined. + return Status::OK(); + } + } + if (TensorShape::IsValid(p.shape()) && p.has_value()) { + Tensor pow(p.dtype(), p.shape()); + if (!pow.FromProto(p.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + p.value().DebugString()); + } + + complex128 prev, curr; + for (int i = 0; i < pow.NumElements(); ++i) { + TF_RETURN_IF_ERROR(GetElement(pow, i, &curr)); + if (i != 0 && curr != prev) { + // pow has different values on different elements. Skip. + return Status::OK(); + } + prev = curr; + } + NodeDef *x, *y; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x)); + TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); + if (curr == complex128(2, 0)) { + node->set_op("Square"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(1, 0)) { + node->set_op("Identity"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(0.5, 0)) { + node->set_op("Sqrt"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(0, 0)) { + const auto& b = + ctx().graph_properties->GetInputProperties(node->name())[0]; + for (int i = 0; i < b.shape().dim_size(); ++i) { + if (b.shape().dim(i).size() < 0) { + // skip if b is is not fully defined. + return Status::OK(); + } + } + if (TensorShape::IsValid(b.shape()) && b.has_value()) { + Tensor base(b.dtype(), b.shape()); + if (!base.FromProto(b.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + b.value().DebugString()); + } + node->set_op("Const"); + Tensor c(base.dtype(), base.shape()); + for (int i = 0; i < c.NumElements(); ++i) { + TF_RETURN_IF_ERROR(SetElementToOne(i, &c)); + } + (*node->mutable_attr())["dtype"].set_type(base.dtype()); + c.AsProtoTensorContent( + (*node->mutable_attr())["value"].mutable_tensor()); + node->mutable_attr()->erase("T"); + node->set_input(0, AsControlDependency(x->name())); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(x); + AddToOptimizationQueue(y); + } + } else if (curr == complex128(-0.5, 0)) { + node->set_op("Rsqrt"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(-1, 0)) { + node->set_op("Reciprocal"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } + } + return Status::OK(); + } + + private: + Status GetElement(const Tensor& t, int i, complex128* element) { + switch (t.dtype()) { + case DT_INT32: + *element = complex128(t.flat<int32>()(i)); + return Status::OK(); + case DT_INT64: + *element = complex128(t.flat<int64>()(i)); + return Status::OK(); + case DT_FLOAT: + *element = complex128(t.flat<float>()(i)); + return Status::OK(); + case DT_DOUBLE: + *element = complex128(t.flat<double>()(i)); + return Status::OK(); + case DT_COMPLEX64: + *element = complex128(t.flat<complex64>()(i)); + return Status::OK(); + case DT_COMPLEX128: + *element = t.flat<complex128>()(i); + return Status::OK(); + default: + return errors::InvalidArgument("Invalid data type: ", t.dtype()); + } + } + + Status SetElementToOne(int i, Tensor* t) { + switch (t->dtype()) { + case DT_INT32: + t->flat<int32>()(i) = 1; + return Status::OK(); + case DT_INT64: + t->flat<int64>()(i) = 1L; + return Status::OK(); + case DT_FLOAT: + t->flat<float>()(i) = 1.0f; + return Status::OK(); + case DT_DOUBLE: + t->flat<double>()(i) = 1.0; + return Status::OK(); + case DT_COMPLEX64: + t->flat<complex64>()(i) = complex64(1); + return Status::OK(); + case DT_COMPLEX128: + t->flat<complex128>()(i) = complex128(1); + return Status::OK(); + default: + return errors::InvalidArgument("Invalid data type: ", t->dtype()); + } + } +}; + } // namespace class UniqueNodes { @@ -2608,6 +2758,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext); if (options_.remove_idempotent) pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext); + if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); |