diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-28 22:11:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-28 22:13:59 -0700 |
commit | 6874e1ef40c4189d96c105227f60b507953f95d3 (patch) | |
tree | ca8c6d0f04aec4c2a32c99582084116606f9adaa | |
parent | 50839154899377f89367d851f6d1e2c45fcd6431 (diff) |
Convert exp(x-1) into expm1(x).
PiperOrigin-RevId: 202598404
5 files changed, 168 insertions, 49 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index bdeb5c66fc..653b088b1d 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -161,6 +161,8 @@ bool IsExit(const NodeDef& node) { return op == "Exit" || op == "RefExit"; } +bool IsExp(const NodeDef& node) { return node.op() == "Exp"; } + bool IsFill(const NodeDef& node) { return node.op() == "Fill"; } bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 2de7d8cc9a..94439265c9 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -60,6 +60,7 @@ bool IsEluGrad(const NodeDef& node); bool IsEnter(const NodeDef& node); bool IsEqual(const NodeDef& node); bool IsExit(const NodeDef& node); +bool IsExp(const NodeDef& node); bool IsFill(const NodeDef& node); bool IsFloorDiv(const NodeDef& node); bool IsFloorMod(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index d8c5d09c4d..72ca3c3fa2 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -178,6 +178,42 @@ NodeDef* GetTailOfIdempotentChain( is_idempotent_non_branching); } +// GetElementUnexhaustive tries to get the value of an element in a tensor and +// turn it into complex128 type. It only check for a limited number of data +// types, so it's unexhaustive. +bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes, + complex128* element) { + if (dtypes.find(t.dtype()) == dtypes.end()) return false; + switch (t.dtype()) { + case DT_BFLOAT16: + *element = complex128(t.flat<bfloat16>()(i)); + return true; + case DT_HALF: + *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0); + return true; + case DT_INT32: + *element = complex128(t.flat<int32>()(i)); + return true; + case DT_INT64: + *element = complex128(t.flat<int64>()(i)); + return true; + case DT_FLOAT: + *element = complex128(t.flat<float>()(i)); + return true; + case DT_DOUBLE: + *element = complex128(t.flat<double>()(i)); + return true; + case DT_COMPLEX64: + *element = complex128(t.flat<complex64>()(i)); + return true; + case DT_COMPLEX128: + *element = t.flat<complex128>()(i); + return true; + default: + return false; + } +} + // Graph optimizer context extension specific to ArithmeticOptimizer. struct ArithmeticOptimizerContext { explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify) @@ -2361,7 +2397,13 @@ class ConvertPowStage : public ArithmeticOptimizerStage { complex128 prev, curr; for (int i = 0; i < pow.NumElements(); ++i) { - TF_RETURN_IF_ERROR(GetElement(pow, i, &curr)); + if (!GetElementUnexhaustive(pow, i, + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_COMPLEX128}, + &curr)) { + // input data type is not supported by Pow. Skip. + return Status::OK(); + } if (i != 0 && curr != prev) { // pow has different values on different elements. Skip. return Status::OK(); @@ -2432,31 +2474,6 @@ class ConvertPowStage : public ArithmeticOptimizerStage { } private: - Status GetElement(const Tensor& t, int i, complex128* element) { - switch (t.dtype()) { - case DT_INT32: - *element = complex128(t.flat<int32>()(i)); - return Status::OK(); - case DT_INT64: - *element = complex128(t.flat<int64>()(i)); - return Status::OK(); - case DT_FLOAT: - *element = complex128(t.flat<float>()(i)); - return Status::OK(); - case DT_DOUBLE: - *element = complex128(t.flat<double>()(i)); - return Status::OK(); - case DT_COMPLEX64: - *element = complex128(t.flat<complex64>()(i)); - return Status::OK(); - case DT_COMPLEX128: - *element = t.flat<complex128>()(i); - return Status::OK(); - default: - return errors::InvalidArgument("Invalid data type: ", t.dtype()); - } - } - Status SetElementToOne(int i, Tensor* t) { switch (t->dtype()) { case DT_INT32: @@ -2544,7 +2561,10 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { } complex128 element; for (int k = 0; k < constant.NumElements(); ++k) { - if (!GetElement(constant, k, &element)) { + if (!GetElementUnexhaustive(constant, k, + {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_COMPLEX128}, + &element)) { // input data type is not supported by log1p. Skip. return Status::OK(); } @@ -2569,30 +2589,81 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { } return Status::OK(); } +}; - bool GetElement(const Tensor& t, int i, complex128* element) { - switch (t.dtype()) { - case DT_BFLOAT16: - *element = complex128(t.flat<bfloat16>()(i)); - return true; - case DT_HALF: - *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0); - return true; - case DT_FLOAT: - *element = complex128(t.flat<float>()(i)); - return true; - case DT_DOUBLE: - *element = complex128(t.flat<double>()(i)); - return true; - case DT_COMPLEX64: - *element = complex128(t.flat<complex64>()(i)); - return true; - case DT_COMPLEX128: - *element = t.flat<complex128>()(i); - return true; - default: - return false; +class ConvertExpm1Stage : public ArithmeticOptimizerStage { + public: + explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {} + ~ConvertExpm1Stage() override = default; + + bool IsSupported(const NodeDef* node) const override { return IsExp(*node); } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + if (!IsSub(*input)) { + return Status::OK(); } + + if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) { + return Status::OK(); + } + + const auto& t = + ctx().graph_properties->GetInputProperties(input->name())[0]; + const auto& c = + ctx().graph_properties->GetInputProperties(input->name())[1]; + for (int k = 0; k < c.shape().dim_size(); ++k) { + // Skip if c shape is not fully determined. + if (c.shape().dim(k).size() < 0) { + return Status::OK(); + } + } + TensorShapeProto broadcast_shape; + if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) { + return Status::OK(); + } + if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) { + // skip if the non-constant tensor doesn't have the same shape after + // broadcast. + return Status::OK(); + } + if (TensorShape::IsValid(c.shape()) && c.has_value()) { + Tensor constant(c.dtype(), c.shape()); + if (!constant.FromProto(c.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + c.value().DebugString()); + } + complex128 element; + for (int k = 0; k < constant.NumElements(); ++k) { + if (!GetElementUnexhaustive(constant, k, + {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_COMPLEX128}, + &element)) { + // input data type is not supported by expm1. Skip. + return Status::OK(); + } + if (element != complex128(1)) { + // current element is not 1. Skip. + return Status::OK(); + } + } + NodeDef *x, *y; + TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &x)); + TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &y)); + node->set_op("Expm1"); + node->set_input(0, input->input(0)); + node->add_input(AsControlDependency(y->name())); + ForwardControlDependencies(node, {input}); + + AddToOptimizationQueue(node); + AddToOptimizationQueue(input); + AddToOptimizationQueue(x); + AddToOptimizationQueue(y); + } + return Status::OK(); } }; @@ -2928,6 +2999,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext); if (options_.optimize_max_or_min_of_monotonic) pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext); + if (options_.convert_expm1) + pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 824ef35ef6..45a5f65b81 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -77,6 +77,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool simplify_aggregation = true; bool convert_pow = true; bool convert_log1p = true; + bool convert_expm1 = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index d0e6b04679..3f6c04a5b5 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -274,6 +274,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.optimize_max_or_min_of_monotonic = true; } + + void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_expm1 = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -2533,6 +2538,43 @@ TEST_F(ArithmeticOptimizerTest, Log1p) { CompareGraphs(want, got); } +TEST_F(ArithmeticOptimizerTest, Expm1) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2}); + auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2}); + auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2}); + auto s12 = ops::Sub(s.WithOpName("s12").WithControlDependencies(x3), x1, x2); + auto s23 = ops::Sub(s.WithOpName("s23"), x2, x3); + Output out1 = ops::Exp(s.WithOpName("out1"), s12); + Output out2 = ops::Exp(s.WithOpName("out2"), s23); + + GrapplerItem item; + item.fetch = {"out1", "out2"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(2, tensors_expected.size()); + + GraphDef got; + ArithmeticOptimizer optimizer; + EnableOnlyExpm1(&optimizer); + OptimizeAndPrune(&optimizer, &item, &got); + auto tensors = EvaluateNodes(got, item.fetch); + EXPECT_EQ(2, tensors.size()); + + GraphDef want; + AddNode("x1", "Const", {}, {}, &want); + AddNode("x2", "Const", {}, {}, &want); + AddNode("x3", "Const", {}, {}, &want); + AddNode("s23", "Sub", {"x2", "x3"}, {}, &want); + AddNode("out1", "Expm1", + {"x1", AsControlDependency("x2"), AsControlDependency("x3")}, {}, + &want); + AddNode("out2", "Exp", {"s23"}, {}, &want); + + CompareGraphs(want, got); +} + TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); |