aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-28 22:11:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-28 22:13:59 -0700
commit6874e1ef40c4189d96c105227f60b507953f95d3 (patch)
treeca8c6d0f04aec4c2a32c99582084116606f9adaa /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parent50839154899377f89367d851f6d1e2c45fcd6431 (diff)
Convert exp(x-1) into expm1(x).
PiperOrigin-RevId: 202598404
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc171
1 files changed, 122 insertions, 49 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index d8c5d09c4d..72ca3c3fa2 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -178,6 +178,42 @@ NodeDef* GetTailOfIdempotentChain(
is_idempotent_non_branching);
}
+// GetElementUnexhaustive tries to get the value of an element in a tensor and
+// turn it into complex128 type. It only check for a limited number of data
+// types, so it's unexhaustive.
+bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
+ complex128* element) {
+ if (dtypes.find(t.dtype()) == dtypes.end()) return false;
+ switch (t.dtype()) {
+ case DT_BFLOAT16:
+ *element = complex128(t.flat<bfloat16>()(i));
+ return true;
+ case DT_HALF:
+ *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
+ return true;
+ case DT_INT32:
+ *element = complex128(t.flat<int32>()(i));
+ return true;
+ case DT_INT64:
+ *element = complex128(t.flat<int64>()(i));
+ return true;
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return true;
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return true;
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return true;
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return true;
+ default:
+ return false;
+ }
+}
+
// Graph optimizer context extension specific to ArithmeticOptimizer.
struct ArithmeticOptimizerContext {
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
@@ -2361,7 +2397,13 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
- TF_RETURN_IF_ERROR(GetElement(pow, i, &curr));
+ if (!GetElementUnexhaustive(pow, i,
+ {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
+ DT_COMPLEX64, DT_COMPLEX128},
+ &curr)) {
+ // input data type is not supported by Pow. Skip.
+ return Status::OK();
+ }
if (i != 0 && curr != prev) {
// pow has different values on different elements. Skip.
return Status::OK();
@@ -2432,31 +2474,6 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
private:
- Status GetElement(const Tensor& t, int i, complex128* element) {
- switch (t.dtype()) {
- case DT_INT32:
- *element = complex128(t.flat<int32>()(i));
- return Status::OK();
- case DT_INT64:
- *element = complex128(t.flat<int64>()(i));
- return Status::OK();
- case DT_FLOAT:
- *element = complex128(t.flat<float>()(i));
- return Status::OK();
- case DT_DOUBLE:
- *element = complex128(t.flat<double>()(i));
- return Status::OK();
- case DT_COMPLEX64:
- *element = complex128(t.flat<complex64>()(i));
- return Status::OK();
- case DT_COMPLEX128:
- *element = t.flat<complex128>()(i);
- return Status::OK();
- default:
- return errors::InvalidArgument("Invalid data type: ", t.dtype());
- }
- }
-
Status SetElementToOne(int i, Tensor* t) {
switch (t->dtype()) {
case DT_INT32:
@@ -2544,7 +2561,10 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
complex128 element;
for (int k = 0; k < constant.NumElements(); ++k) {
- if (!GetElement(constant, k, &element)) {
+ if (!GetElementUnexhaustive(constant, k,
+ {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
+ DT_COMPLEX64, DT_COMPLEX128},
+ &element)) {
// input data type is not supported by log1p. Skip.
return Status::OK();
}
@@ -2569,30 +2589,81 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
return Status::OK();
}
+};
- bool GetElement(const Tensor& t, int i, complex128* element) {
- switch (t.dtype()) {
- case DT_BFLOAT16:
- *element = complex128(t.flat<bfloat16>()(i));
- return true;
- case DT_HALF:
- *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
- return true;
- case DT_FLOAT:
- *element = complex128(t.flat<float>()(i));
- return true;
- case DT_DOUBLE:
- *element = complex128(t.flat<double>()(i));
- return true;
- case DT_COMPLEX64:
- *element = complex128(t.flat<complex64>()(i));
- return true;
- case DT_COMPLEX128:
- *element = t.flat<complex128>()(i);
- return true;
- default:
- return false;
+class ConvertExpm1Stage : public ArithmeticOptimizerStage {
+ public:
+ explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
+ ~ConvertExpm1Stage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override { return IsExp(*node); }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+ if (!IsSub(*input)) {
+ return Status::OK();
}
+
+ if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
+ return Status::OK();
+ }
+
+ const auto& t =
+ ctx().graph_properties->GetInputProperties(input->name())[0];
+ const auto& c =
+ ctx().graph_properties->GetInputProperties(input->name())[1];
+ for (int k = 0; k < c.shape().dim_size(); ++k) {
+ // Skip if c shape is not fully determined.
+ if (c.shape().dim(k).size() < 0) {
+ return Status::OK();
+ }
+ }
+ TensorShapeProto broadcast_shape;
+ if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
+ return Status::OK();
+ }
+ if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
+ // skip if the non-constant tensor doesn't have the same shape after
+ // broadcast.
+ return Status::OK();
+ }
+ if (TensorShape::IsValid(c.shape()) && c.has_value()) {
+ Tensor constant(c.dtype(), c.shape());
+ if (!constant.FromProto(c.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ c.value().DebugString());
+ }
+ complex128 element;
+ for (int k = 0; k < constant.NumElements(); ++k) {
+ if (!GetElementUnexhaustive(constant, k,
+ {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
+ DT_COMPLEX64, DT_COMPLEX128},
+ &element)) {
+ // input data type is not supported by expm1. Skip.
+ return Status::OK();
+ }
+ if (element != complex128(1)) {
+ // current element is not 1. Skip.
+ return Status::OK();
+ }
+ }
+ NodeDef *x, *y;
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &x));
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &y));
+ node->set_op("Expm1");
+ node->set_input(0, input->input(0));
+ node->add_input(AsControlDependency(y->name()));
+ ForwardControlDependencies(node, {input});
+
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(input);
+ AddToOptimizationQueue(x);
+ AddToOptimizationQueue(y);
+ }
+ return Status::OK();
}
};
@@ -2928,6 +2999,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
if (options_.optimize_max_or_min_of_monotonic)
pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
+ if (options_.convert_expm1)
+ pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");