aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-05 17:49:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 17:53:44 -0700
commit2366bd07dd3fc0e82f34f92deeebdc9cb87649de (patch)
treed7701870bc31fb6112352b80dc11923492ed0549
parent5105350be955422169de1f22bb99f928c1f4c2ae (diff)
Automated g4 rollback of changelist 197562826
PiperOrigin-RevId: 199388675
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc151
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc57
3 files changed, 209 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 44a14ef7eb..51110b4bda 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2334,6 +2334,156 @@ class SimplifyAggregation : public ArithmeticOptimizerStage {
}
};
+class ConvertPowStage : public ArithmeticOptimizerStage {
+ public:
+ explicit ConvertPowStage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsPow(*node) &&
+ ctx().graph_properties->GetInputProperties(node->name()).size() == 2;
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
+ for (int i = 0; i < p.shape().dim_size(); ++i) {
+ if (p.shape().dim(i).size() < 0) {
+ // skip if p is is not fully defined.
+ return Status::OK();
+ }
+ }
+ if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+ Tensor pow(p.dtype(), p.shape());
+ if (!pow.FromProto(p.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ p.value().DebugString());
+ }
+
+ complex128 prev, curr;
+ for (int i = 0; i < pow.NumElements(); ++i) {
+ TF_RETURN_IF_ERROR(GetElement(pow, i, &curr));
+ if (i != 0 && curr != prev) {
+ // pow has different values on different elements. Skip.
+ return Status::OK();
+ }
+ prev = curr;
+ }
+ NodeDef *x, *y;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+ if (curr == complex128(2, 0)) {
+ node->set_op("Square");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(1, 0)) {
+ node->set_op("Identity");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(0.5, 0)) {
+ node->set_op("Sqrt");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(0, 0)) {
+ const auto& b =
+ ctx().graph_properties->GetInputProperties(node->name())[0];
+ for (int i = 0; i < b.shape().dim_size(); ++i) {
+ if (b.shape().dim(i).size() < 0) {
+ // skip if b is is not fully defined.
+ return Status::OK();
+ }
+ }
+ if (TensorShape::IsValid(b.shape()) && b.has_value()) {
+ Tensor base(b.dtype(), b.shape());
+ if (!base.FromProto(b.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ b.value().DebugString());
+ }
+ node->set_op("Const");
+ Tensor c(base.dtype(), base.shape());
+ for (int i = 0; i < c.NumElements(); ++i) {
+ TF_RETURN_IF_ERROR(SetElementToOne(i, &c));
+ }
+ (*node->mutable_attr())["dtype"].set_type(base.dtype());
+ c.AsProtoTensorContent(
+ (*node->mutable_attr())["value"].mutable_tensor());
+ node->mutable_attr()->erase("T");
+ node->set_input(0, AsControlDependency(x->name()));
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(x);
+ AddToOptimizationQueue(y);
+ }
+ } else if (curr == complex128(-0.5, 0)) {
+ node->set_op("Rsqrt");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(-1, 0)) {
+ node->set_op("Reciprocal");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_INT32:
+ *element = complex128(t.flat<int32>()(i));
+ return Status::OK();
+ case DT_INT64:
+ *element = complex128(t.flat<int64>()(i));
+ return Status::OK();
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return Status::OK();
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return Status::OK();
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return Status::OK();
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Invalid data type: ", t.dtype());
+ }
+ }
+
+ Status SetElementToOne(int i, Tensor* t) {
+ switch (t->dtype()) {
+ case DT_INT32:
+ t->flat<int32>()(i) = 1;
+ return Status::OK();
+ case DT_INT64:
+ t->flat<int64>()(i) = 1L;
+ return Status::OK();
+ case DT_FLOAT:
+ t->flat<float>()(i) = 1.0f;
+ return Status::OK();
+ case DT_DOUBLE:
+ t->flat<double>()(i) = 1.0;
+ return Status::OK();
+ case DT_COMPLEX64:
+ t->flat<complex64>()(i) = complex64(1);
+ return Status::OK();
+ case DT_COMPLEX128:
+ t->flat<complex128>()(i) = complex128(1);
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Invalid data type: ", t->dtype());
+ }
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -2608,6 +2758,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
if (options_.remove_idempotent)
pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
+ if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index f37458eba4..40c5e9fc56 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -74,6 +74,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool reorder_cast_and_transpose = true;
bool replace_mul_with_square = true;
bool simplify_aggregation = true;
+ bool convert_pow = true;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 8083b6051f..ff96cb6480 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -245,6 +245,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
}
+ void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.convert_pow = true;
+ }
+
void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_idempotent = true;
@@ -2429,6 +2434,58 @@ TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
}
}
+TEST_F(ArithmeticOptimizerTest, ConvertPow) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
+ auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
+ auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
+ auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
+ auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
+ auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
+ auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
+ Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
+ Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
+ Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
+ Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
+ Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
+ Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
+ Output out = ops::Pow(s.WithOpName("out"), x, y);
+
+ GrapplerItem item;
+ item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(7, tensors_expected.size());
+
+ GraphDef got;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyConvertPow(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &got);
+ auto tensors = EvaluateNodes(got, item.fetch);
+ EXPECT_EQ(7, tensors.size());
+
+ GraphDef want;
+ AddNode("x", "Const", {}, {}, &want);
+ AddNode("y2", "Const", {}, {}, &want);
+ AddNode("y1", "Const", {}, {}, &want);
+ AddNode("y.5", "Const", {}, {}, &want);
+ AddNode("y0", "Const", {}, {}, &want);
+ AddNode("y_.5", "Const", {}, {}, &want);
+ AddNode("y_1", "Const", {}, {}, &want);
+ AddNode("y", "Const", {}, {}, &want);
+ AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
+ AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
+ AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
+ AddNode("out0", "Const",
+ {AsControlDependency("x"), AsControlDependency("y0")}, {}, &want);
+ AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
+ AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
+ AddNode("out", "Pow", {"x", "y"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();