diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-09 10:43:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-09 10:47:17 -0700 |
commit | 6b51853e3ab388af8f56685450f3b6fa5eb54ced (patch) | |
tree | 987c2a5d840d56498af6be18109e87b8d37eca51 | |
parent | f83a382e87ca09e8f688515a9549c81d0f46554a (diff) |
Automated rollback of commit 6874e1ef40c4189d96c105227f60b507953f95d3
PiperOrigin-RevId: 203790544
5 files changed, 49 insertions, 168 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 653b088b1d..bdeb5c66fc 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -161,8 +161,6 @@ bool IsExit(const NodeDef& node) { return op == "Exit" || op == "RefExit"; } -bool IsExp(const NodeDef& node) { return node.op() == "Exp"; } - bool IsFill(const NodeDef& node) { return node.op() == "Fill"; } bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 94439265c9..2de7d8cc9a 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -60,7 +60,6 @@ bool IsEluGrad(const NodeDef& node); bool IsEnter(const NodeDef& node); bool IsEqual(const NodeDef& node); bool IsExit(const NodeDef& node); -bool IsExp(const NodeDef& node); bool IsFill(const NodeDef& node); bool IsFloorDiv(const NodeDef& node); bool IsFloorMod(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index b7369c7b4a..97862d1ed0 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -178,42 +178,6 @@ NodeDef* GetTailOfIdempotentChain( is_idempotent_non_branching); } -// GetElementUnexhaustive tries to get the value of an element in a tensor and -// turn it into complex128 type. It only check for a limited number of data -// types, so it's unexhaustive. -bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes, - complex128* element) { - if (dtypes.find(t.dtype()) == dtypes.end()) return false; - switch (t.dtype()) { - case DT_BFLOAT16: - *element = complex128(t.flat<bfloat16>()(i)); - return true; - case DT_HALF: - *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0); - return true; - case DT_INT32: - *element = complex128(t.flat<int32>()(i)); - return true; - case DT_INT64: - *element = complex128(t.flat<int64>()(i)); - return true; - case DT_FLOAT: - *element = complex128(t.flat<float>()(i)); - return true; - case DT_DOUBLE: - *element = complex128(t.flat<double>()(i)); - return true; - case DT_COMPLEX64: - *element = complex128(t.flat<complex64>()(i)); - return true; - case DT_COMPLEX128: - *element = t.flat<complex128>()(i); - return true; - default: - return false; - } -} - // Graph optimizer context extension specific to ArithmeticOptimizer. struct ArithmeticOptimizerContext { explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify) @@ -2397,13 +2361,7 @@ class ConvertPowStage : public ArithmeticOptimizerStage { complex128 prev, curr; for (int i = 0; i < pow.NumElements(); ++i) { - if (!GetElementUnexhaustive(pow, i, - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_COMPLEX128}, - &curr)) { - // input data type is not supported by Pow. Skip. - return Status::OK(); - } + TF_RETURN_IF_ERROR(GetElement(pow, i, &curr)); if (i != 0 && curr != prev) { // pow has different values on different elements. Skip. return Status::OK(); @@ -2474,6 +2432,31 @@ class ConvertPowStage : public ArithmeticOptimizerStage { } private: + Status GetElement(const Tensor& t, int i, complex128* element) { + switch (t.dtype()) { + case DT_INT32: + *element = complex128(t.flat<int32>()(i)); + return Status::OK(); + case DT_INT64: + *element = complex128(t.flat<int64>()(i)); + return Status::OK(); + case DT_FLOAT: + *element = complex128(t.flat<float>()(i)); + return Status::OK(); + case DT_DOUBLE: + *element = complex128(t.flat<double>()(i)); + return Status::OK(); + case DT_COMPLEX64: + *element = complex128(t.flat<complex64>()(i)); + return Status::OK(); + case DT_COMPLEX128: + *element = t.flat<complex128>()(i); + return Status::OK(); + default: + return errors::InvalidArgument("Invalid data type: ", t.dtype()); + } + } + Status SetElementToOne(int i, Tensor* t) { switch (t->dtype()) { case DT_INT32: @@ -2561,10 +2544,7 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { } complex128 element; for (int k = 0; k < constant.NumElements(); ++k) { - if (!GetElementUnexhaustive(constant, k, - {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_COMPLEX128}, - &element)) { + if (!GetElement(constant, k, &element)) { // input data type is not supported by log1p. Skip. return Status::OK(); } @@ -2589,81 +2569,30 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { } return Status::OK(); } -}; -class ConvertExpm1Stage : public ArithmeticOptimizerStage { - public: - explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx, - const ArithmeticOptimizerContext& ctx_ext) - : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {} - ~ConvertExpm1Stage() override = default; - - bool IsSupported(const NodeDef* node) const override { return IsExp(*node); } - - Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - NodeDef* input; - TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); - if (!IsSub(*input)) { - return Status::OK(); - } - - if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) { - return Status::OK(); - } - - const auto& t = - ctx().graph_properties->GetInputProperties(input->name())[0]; - const auto& c = - ctx().graph_properties->GetInputProperties(input->name())[1]; - for (int k = 0; k < c.shape().dim_size(); ++k) { - // Skip if c shape is not fully determined. - if (c.shape().dim(k).size() < 0) { - return Status::OK(); - } - } - TensorShapeProto broadcast_shape; - if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) { - return Status::OK(); - } - if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) { - // skip if the non-constant tensor doesn't have the same shape after - // broadcast. - return Status::OK(); - } - if (TensorShape::IsValid(c.shape()) && c.has_value()) { - Tensor constant(c.dtype(), c.shape()); - if (!constant.FromProto(c.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - c.value().DebugString()); - } - complex128 element; - for (int k = 0; k < constant.NumElements(); ++k) { - if (!GetElementUnexhaustive(constant, k, - {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_COMPLEX128}, - &element)) { - // input data type is not supported by expm1. Skip. - return Status::OK(); - } - if (element != complex128(1)) { - // current element is not 1. Skip. - return Status::OK(); - } - } - NodeDef *x, *y; - TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &x)); - TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &y)); - node->set_op("Expm1"); - node->set_input(0, input->input(0)); - node->add_input(AsControlDependency(y->name())); - ForwardControlDependencies(node, {input}); - - AddToOptimizationQueue(node); - AddToOptimizationQueue(input); - AddToOptimizationQueue(x); - AddToOptimizationQueue(y); + bool GetElement(const Tensor& t, int i, complex128* element) { + switch (t.dtype()) { + case DT_BFLOAT16: + *element = complex128(t.flat<bfloat16>()(i)); + return true; + case DT_HALF: + *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0); + return true; + case DT_FLOAT: + *element = complex128(t.flat<float>()(i)); + return true; + case DT_DOUBLE: + *element = complex128(t.flat<double>()(i)); + return true; + case DT_COMPLEX64: + *element = complex128(t.flat<complex64>()(i)); + return true; + case DT_COMPLEX128: + *element = t.flat<complex128>()(i); + return true; + default: + return false; } - return Status::OK(); } }; @@ -3165,8 +3094,6 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext); if (options_.optimize_max_or_min_of_monotonic) pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext); - if (options_.convert_expm1) - pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext); if (options_.unary_ops_composition) pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 551c3652bf..00c02d19bd 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -77,7 +77,6 @@ class ArithmeticOptimizer : public GraphOptimizer { bool simplify_aggregation = true; bool convert_pow = true; bool convert_log1p = true; - bool convert_expm1 = true; bool unary_ops_composition = true; // Choose which arithmetic optimizer stages will be enabled for a given diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 54fdc01adb..c387b00303 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -279,11 +279,6 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.optimize_max_or_min_of_monotonic = true; } - void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.convert_expm1 = true; - } - void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.unary_ops_composition = true; @@ -2547,43 +2542,6 @@ TEST_F(ArithmeticOptimizerTest, Log1p) { CompareGraphs(want, got); } -TEST_F(ArithmeticOptimizerTest, Expm1) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - - auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2}); - auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2}); - auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2}); - auto s12 = ops::Sub(s.WithOpName("s12").WithControlDependencies(x3), x1, x2); - auto s23 = ops::Sub(s.WithOpName("s23"), x2, x3); - Output out1 = ops::Exp(s.WithOpName("out1"), s12); - Output out2 = ops::Exp(s.WithOpName("out2"), s23); - - GrapplerItem item; - item.fetch = {"out1", "out2"}; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - EXPECT_EQ(2, tensors_expected.size()); - - GraphDef got; - ArithmeticOptimizer optimizer; - EnableOnlyExpm1(&optimizer); - OptimizeAndPrune(&optimizer, &item, &got); - auto tensors = EvaluateNodes(got, item.fetch); - EXPECT_EQ(2, tensors.size()); - - GraphDef want; - AddNode("x1", "Const", {}, {}, &want); - AddNode("x2", "Const", {}, {}, &want); - AddNode("x3", "Const", {}, {}, &want); - AddNode("s23", "Sub", {"x2", "x3"}, {}, &want); - AddNode("out1", "Expm1", - {"x1", AsControlDependency("x2"), AsControlDependency("x3")}, {}, - &want); - AddNode("out2", "Exp", {"s23"}, {}, &want); - - CompareGraphs(want, got); -} - TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); |