aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-09 10:43:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 10:47:17 -0700
commit6b51853e3ab388af8f56685450f3b6fa5eb54ced (patch)
tree987c2a5d840d56498af6be18109e87b8d37eca51
parentf83a382e87ca09e8f688515a9549c81d0f46554a (diff)
Automated rollback of commit 6874e1ef40c4189d96c105227f60b507953f95d3
PiperOrigin-RevId: 203790544
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc171
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc42
5 files changed, 49 insertions, 168 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 653b088b1d..bdeb5c66fc 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -161,8 +161,6 @@ bool IsExit(const NodeDef& node) {
return op == "Exit" || op == "RefExit";
}
-bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }
-
bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 94439265c9..2de7d8cc9a 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -60,7 +60,6 @@ bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
bool IsExit(const NodeDef& node);
-bool IsExp(const NodeDef& node);
bool IsFill(const NodeDef& node);
bool IsFloorDiv(const NodeDef& node);
bool IsFloorMod(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index b7369c7b4a..97862d1ed0 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -178,42 +178,6 @@ NodeDef* GetTailOfIdempotentChain(
is_idempotent_non_branching);
}
-// GetElementUnexhaustive tries to get the value of an element in a tensor and
-// turn it into complex128 type. It only check for a limited number of data
-// types, so it's unexhaustive.
-bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
- complex128* element) {
- if (dtypes.find(t.dtype()) == dtypes.end()) return false;
- switch (t.dtype()) {
- case DT_BFLOAT16:
- *element = complex128(t.flat<bfloat16>()(i));
- return true;
- case DT_HALF:
- *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
- return true;
- case DT_INT32:
- *element = complex128(t.flat<int32>()(i));
- return true;
- case DT_INT64:
- *element = complex128(t.flat<int64>()(i));
- return true;
- case DT_FLOAT:
- *element = complex128(t.flat<float>()(i));
- return true;
- case DT_DOUBLE:
- *element = complex128(t.flat<double>()(i));
- return true;
- case DT_COMPLEX64:
- *element = complex128(t.flat<complex64>()(i));
- return true;
- case DT_COMPLEX128:
- *element = t.flat<complex128>()(i);
- return true;
- default:
- return false;
- }
-}
-
// Graph optimizer context extension specific to ArithmeticOptimizer.
struct ArithmeticOptimizerContext {
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
@@ -2397,13 +2361,7 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
- if (!GetElementUnexhaustive(pow, i,
- {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &curr)) {
- // input data type is not supported by Pow. Skip.
- return Status::OK();
- }
+ TF_RETURN_IF_ERROR(GetElement(pow, i, &curr));
if (i != 0 && curr != prev) {
// pow has different values on different elements. Skip.
return Status::OK();
@@ -2474,6 +2432,31 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
private:
+ Status GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_INT32:
+ *element = complex128(t.flat<int32>()(i));
+ return Status::OK();
+ case DT_INT64:
+ *element = complex128(t.flat<int64>()(i));
+ return Status::OK();
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return Status::OK();
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return Status::OK();
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return Status::OK();
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Invalid data type: ", t.dtype());
+ }
+ }
+
Status SetElementToOne(int i, Tensor* t) {
switch (t->dtype()) {
case DT_INT32:
@@ -2561,10 +2544,7 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
complex128 element;
for (int k = 0; k < constant.NumElements(); ++k) {
- if (!GetElementUnexhaustive(constant, k,
- {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &element)) {
+ if (!GetElement(constant, k, &element)) {
// input data type is not supported by log1p. Skip.
return Status::OK();
}
@@ -2589,81 +2569,30 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
return Status::OK();
}
-};
-class ConvertExpm1Stage : public ArithmeticOptimizerStage {
- public:
- explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
- const ArithmeticOptimizerContext& ctx_ext)
- : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
- ~ConvertExpm1Stage() override = default;
-
- bool IsSupported(const NodeDef* node) const override { return IsExp(*node); }
-
- Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- NodeDef* input;
- TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
- if (!IsSub(*input)) {
- return Status::OK();
- }
-
- if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
- return Status::OK();
- }
-
- const auto& t =
- ctx().graph_properties->GetInputProperties(input->name())[0];
- const auto& c =
- ctx().graph_properties->GetInputProperties(input->name())[1];
- for (int k = 0; k < c.shape().dim_size(); ++k) {
- // Skip if c shape is not fully determined.
- if (c.shape().dim(k).size() < 0) {
- return Status::OK();
- }
- }
- TensorShapeProto broadcast_shape;
- if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
- return Status::OK();
- }
- if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
- // skip if the non-constant tensor doesn't have the same shape after
- // broadcast.
- return Status::OK();
- }
- if (TensorShape::IsValid(c.shape()) && c.has_value()) {
- Tensor constant(c.dtype(), c.shape());
- if (!constant.FromProto(c.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- c.value().DebugString());
- }
- complex128 element;
- for (int k = 0; k < constant.NumElements(); ++k) {
- if (!GetElementUnexhaustive(constant, k,
- {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &element)) {
- // input data type is not supported by expm1. Skip.
- return Status::OK();
- }
- if (element != complex128(1)) {
- // current element is not 1. Skip.
- return Status::OK();
- }
- }
- NodeDef *x, *y;
- TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &x));
- TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &y));
- node->set_op("Expm1");
- node->set_input(0, input->input(0));
- node->add_input(AsControlDependency(y->name()));
- ForwardControlDependencies(node, {input});
-
- AddToOptimizationQueue(node);
- AddToOptimizationQueue(input);
- AddToOptimizationQueue(x);
- AddToOptimizationQueue(y);
+ bool GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_BFLOAT16:
+ *element = complex128(t.flat<bfloat16>()(i));
+ return true;
+ case DT_HALF:
+ *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
+ return true;
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return true;
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return true;
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return true;
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return true;
+ default:
+ return false;
}
- return Status::OK();
}
};
@@ -3165,8 +3094,6 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
if (options_.optimize_max_or_min_of_monotonic)
pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
- if (options_.convert_expm1)
- pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
if (options_.unary_ops_composition)
pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 551c3652bf..00c02d19bd 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -77,7 +77,6 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool simplify_aggregation = true;
bool convert_pow = true;
bool convert_log1p = true;
- bool convert_expm1 = true;
bool unary_ops_composition = true;
// Choose which arithmetic optimizer stages will be enabled for a given
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 54fdc01adb..c387b00303 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -279,11 +279,6 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.optimize_max_or_min_of_monotonic = true;
}
- void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) {
- DisableAllStages(optimizer);
- optimizer->options_.convert_expm1 = true;
- }
-
void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.unary_ops_composition = true;
@@ -2547,43 +2542,6 @@ TEST_F(ArithmeticOptimizerTest, Log1p) {
CompareGraphs(want, got);
}
-TEST_F(ArithmeticOptimizerTest, Expm1) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-
- auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2});
- auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2});
- auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
- auto s12 = ops::Sub(s.WithOpName("s12").WithControlDependencies(x3), x1, x2);
- auto s23 = ops::Sub(s.WithOpName("s23"), x2, x3);
- Output out1 = ops::Exp(s.WithOpName("out1"), s12);
- Output out2 = ops::Exp(s.WithOpName("out2"), s23);
-
- GrapplerItem item;
- item.fetch = {"out1", "out2"};
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
- EXPECT_EQ(2, tensors_expected.size());
-
- GraphDef got;
- ArithmeticOptimizer optimizer;
- EnableOnlyExpm1(&optimizer);
- OptimizeAndPrune(&optimizer, &item, &got);
- auto tensors = EvaluateNodes(got, item.fetch);
- EXPECT_EQ(2, tensors.size());
-
- GraphDef want;
- AddNode("x1", "Const", {}, {}, &want);
- AddNode("x2", "Const", {}, {}, &want);
- AddNode("x3", "Const", {}, {}, &want);
- AddNode("s23", "Sub", {"x2", "x3"}, {}, &want);
- AddNode("out1", "Expm1",
- {"x1", AsControlDependency("x2"), AsControlDependency("x3")}, {},
- &want);
- AddNode("out2", "Exp", {"s23"}, {}, &want);
-
- CompareGraphs(want, got);
-}
-
TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();