aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-09 10:43:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 10:47:17 -0700
commit6b51853e3ab388af8f56685450f3b6fa5eb54ced (patch)
tree987c2a5d840d56498af6be18109e87b8d37eca51 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parentf83a382e87ca09e8f688515a9549c81d0f46554a (diff)
Automated rollback of commit 6874e1ef40c4189d96c105227f60b507953f95d3
PiperOrigin-RevId: 203790544
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc171
1 files changed, 49 insertions, 122 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index b7369c7b4a..97862d1ed0 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -178,42 +178,6 @@ NodeDef* GetTailOfIdempotentChain(
is_idempotent_non_branching);
}
-// GetElementUnexhaustive tries to get the value of an element in a tensor and
-// turn it into complex128 type. It only check for a limited number of data
-// types, so it's unexhaustive.
-bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
- complex128* element) {
- if (dtypes.find(t.dtype()) == dtypes.end()) return false;
- switch (t.dtype()) {
- case DT_BFLOAT16:
- *element = complex128(t.flat<bfloat16>()(i));
- return true;
- case DT_HALF:
- *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
- return true;
- case DT_INT32:
- *element = complex128(t.flat<int32>()(i));
- return true;
- case DT_INT64:
- *element = complex128(t.flat<int64>()(i));
- return true;
- case DT_FLOAT:
- *element = complex128(t.flat<float>()(i));
- return true;
- case DT_DOUBLE:
- *element = complex128(t.flat<double>()(i));
- return true;
- case DT_COMPLEX64:
- *element = complex128(t.flat<complex64>()(i));
- return true;
- case DT_COMPLEX128:
- *element = t.flat<complex128>()(i);
- return true;
- default:
- return false;
- }
-}
-
// Graph optimizer context extension specific to ArithmeticOptimizer.
struct ArithmeticOptimizerContext {
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
@@ -2397,13 +2361,7 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
- if (!GetElementUnexhaustive(pow, i,
- {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &curr)) {
- // input data type is not supported by Pow. Skip.
- return Status::OK();
- }
+ TF_RETURN_IF_ERROR(GetElement(pow, i, &curr));
if (i != 0 && curr != prev) {
// pow has different values on different elements. Skip.
return Status::OK();
@@ -2474,6 +2432,31 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
private:
+ Status GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_INT32:
+ *element = complex128(t.flat<int32>()(i));
+ return Status::OK();
+ case DT_INT64:
+ *element = complex128(t.flat<int64>()(i));
+ return Status::OK();
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return Status::OK();
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return Status::OK();
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return Status::OK();
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Invalid data type: ", t.dtype());
+ }
+ }
+
Status SetElementToOne(int i, Tensor* t) {
switch (t->dtype()) {
case DT_INT32:
@@ -2561,10 +2544,7 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
complex128 element;
for (int k = 0; k < constant.NumElements(); ++k) {
- if (!GetElementUnexhaustive(constant, k,
- {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &element)) {
+ if (!GetElement(constant, k, &element)) {
// input data type is not supported by log1p. Skip.
return Status::OK();
}
@@ -2589,81 +2569,30 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
return Status::OK();
}
-};
-class ConvertExpm1Stage : public ArithmeticOptimizerStage {
- public:
- explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
- const ArithmeticOptimizerContext& ctx_ext)
- : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
- ~ConvertExpm1Stage() override = default;
-
- bool IsSupported(const NodeDef* node) const override { return IsExp(*node); }
-
- Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- NodeDef* input;
- TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
- if (!IsSub(*input)) {
- return Status::OK();
- }
-
- if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
- return Status::OK();
- }
-
- const auto& t =
- ctx().graph_properties->GetInputProperties(input->name())[0];
- const auto& c =
- ctx().graph_properties->GetInputProperties(input->name())[1];
- for (int k = 0; k < c.shape().dim_size(); ++k) {
- // Skip if c shape is not fully determined.
- if (c.shape().dim(k).size() < 0) {
- return Status::OK();
- }
- }
- TensorShapeProto broadcast_shape;
- if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
- return Status::OK();
- }
- if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
- // skip if the non-constant tensor doesn't have the same shape after
- // broadcast.
- return Status::OK();
- }
- if (TensorShape::IsValid(c.shape()) && c.has_value()) {
- Tensor constant(c.dtype(), c.shape());
- if (!constant.FromProto(c.value())) {
- return errors::InvalidArgument("Cannot parse tensor from proto: ",
- c.value().DebugString());
- }
- complex128 element;
- for (int k = 0; k < constant.NumElements(); ++k) {
- if (!GetElementUnexhaustive(constant, k,
- {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &element)) {
- // input data type is not supported by expm1. Skip.
- return Status::OK();
- }
- if (element != complex128(1)) {
- // current element is not 1. Skip.
- return Status::OK();
- }
- }
- NodeDef *x, *y;
- TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &x));
- TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &y));
- node->set_op("Expm1");
- node->set_input(0, input->input(0));
- node->add_input(AsControlDependency(y->name()));
- ForwardControlDependencies(node, {input});
-
- AddToOptimizationQueue(node);
- AddToOptimizationQueue(input);
- AddToOptimizationQueue(x);
- AddToOptimizationQueue(y);
+ bool GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_BFLOAT16:
+ *element = complex128(t.flat<bfloat16>()(i));
+ return true;
+ case DT_HALF:
+ *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
+ return true;
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return true;
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return true;
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return true;
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return true;
+ default:
+ return false;
}
- return Status::OK();
}
};
@@ -3165,8 +3094,6 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
if (options_.optimize_max_or_min_of_monotonic)
pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
- if (options_.convert_expm1)
- pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
if (options_.unary_ops_composition)
pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);