diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-14 17:12:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-14 17:15:41 -0700 |
commit | 9e4cbaf3a3a3bfca913bebdcfc082265c7a13ad6 (patch) | |
tree | c8fcc2f493b5a6f7d38a9b2c036ecffd5105ac37 | |
parent | 261ab05537885556f92d7322017ddf73ea5a7357 (diff) |
Convert log(x+1) to log1p(x).
PiperOrigin-RevId: 200645461
5 files changed, 161 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 2a47a4c495..2227904dbf 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -193,6 +193,8 @@ bool IsLess(const NodeDef& node) { return node.op() == "Less"; } bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; } +bool IsLog(const NodeDef& node) { return node.op() == "Log"; } + bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; } bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index e7f39981c0..7110a9c63d 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -74,6 +74,7 @@ bool IsImag(const NodeDef& node); bool IsInvGrad(const NodeDef& node); bool IsLess(const NodeDef& node); bool IsLessEqual(const NodeDef& node); +bool IsLog(const NodeDef& node); bool IsLogicalAnd(const NodeDef& node); bool IsLogicalNot(const NodeDef& node); bool IsLogicalOr(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index c41b152d21..9d500f8f54 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2487,6 +2487,119 @@ class ConvertPowStage : public ArithmeticOptimizerStage { } }; +class ConvertLog1pStage : public ArithmeticOptimizerStage { + public: + explicit ConvertLog1pStage(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {} + ~ConvertLog1pStage() override = default; + + bool IsSupported(const NodeDef* node) const override { return IsLog(*node); } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + if (!IsAdd(*input)) { + return Status::OK(); + } + + if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) { + return Status::OK(); + } + + bool modified = false; + TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified)); + if (!modified) { + TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified)); + } + if (modified) { + *simplified_node_name = node->name(); + } + return Status::OK(); + } + + private: + Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j, + bool* modified) { + const auto& t = + ctx().graph_properties->GetInputProperties(input->name())[i]; + for (int k = 0; k < t.shape().dim_size(); ++k) { + // Skip if t shape is not fully determined. + if (t.shape().dim(k).size() < 0) { + return Status::OK(); + } + } + const auto& c = + ctx().graph_properties->GetInputProperties(input->name())[j]; + TensorShapeProto broadcast_shape; + if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) { + return errors::InvalidArgument("Cannot get broadcast shape for: ", + t.DebugString(), " and ", c.DebugString()); + } + if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) { + // skip if the non-constant tensor doesn't have the same shape after + // broadcast. + return Status::OK(); + } + if (TensorShape::IsValid(t.shape()) && t.has_value()) { + Tensor tensor(t.dtype(), t.shape()); + if (!tensor.FromProto(t.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + t.value().DebugString()); + } + complex128 element; + for (int k = 0; k < tensor.NumElements(); ++k) { + if (!GetElement(tensor, k, &element)) { + // input data type is not supported by log1p. Skip. + return Status::OK(); + } + if (element != complex128(1)) { + // current element is not 1. Skip. + return Status::OK(); + } + } + NodeDef *x, *y; + TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x)); + TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y)); + node->set_op("Log1p"); + node->set_input(0, y->name()); + node->add_input(AsControlDependency(x->name())); + ForwardControlDependencies(node, {input}); + + AddToOptimizationQueue(node); + AddToOptimizationQueue(x); + AddToOptimizationQueue(y); + *modified = true; + } + return Status::OK(); + } + + bool GetElement(const Tensor& t, int i, complex128* element) { + switch (t.dtype()) { + case DT_BFLOAT16: + *element = complex128(t.flat<bfloat16>()(i)); + return true; + case DT_HALF: + *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0); + return true; + case DT_FLOAT: + *element = complex128(t.flat<float>()(i)); + return true; + case DT_DOUBLE: + *element = complex128(t.flat<double>()(i)); + return true; + case DT_COMPLEX64: + *element = complex128(t.flat<complex64>()(i)); + return true; + case DT_COMPLEX128: + *element = t.flat<complex128>()(i); + return true; + default: + return false; + } + } +}; + } // namespace class UniqueNodes { @@ -2763,6 +2876,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { if (options_.remove_idempotent) pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext); if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext); + if (options_.convert_log1p) + pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 40c5e9fc56..9a6081dcd8 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -75,6 +75,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool replace_mul_with_square = true; bool simplify_aggregation = true; bool convert_pow = true; + bool convert_log1p = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index fe70c7db5c..177c237fe7 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -264,6 +264,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.simplify_aggregation = true; } + + void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_log1p = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -2486,6 +2491,43 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { CompareGraphs(want, got); } +TEST_F(ArithmeticOptimizerTest, Log1p) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto x1 = ops::Const(s.WithOpName("x1"), {1.0f, 1.0f}, {1, 2}); + auto x2 = ops::Const(s.WithOpName("x2"), {2.0f, 2.0f}, {1, 2}); + auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2}); + auto a12 = ops::Add(s.WithOpName("a12").WithControlDependencies(x3), x1, x2); + auto a23 = ops::Add(s.WithOpName("a23"), x2, x3); + Output out1 = ops::Log(s.WithOpName("out1"), a12); + Output out2 = ops::Log(s.WithOpName("out2"), a23); + + GrapplerItem item; + item.fetch = {"out1", "out2"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(2, tensors_expected.size()); + + GraphDef got; + ArithmeticOptimizer optimizer; + EnableOnlyLog1p(&optimizer); + OptimizeAndPrune(&optimizer, &item, &got); + auto tensors = EvaluateNodes(got, item.fetch); + EXPECT_EQ(2, tensors.size()); + + GraphDef want; + AddNode("x1", "Const", {}, {}, &want); + AddNode("x2", "Const", {}, {}, &want); + AddNode("x3", "Const", {}, {}, &want); + AddNode("a23", "Add", {"x2", "x3"}, {}, &want); + AddNode("out1", "Log1p", + {"x2", AsControlDependency("x1"), AsControlDependency("x3")}, {}, + &want); + AddNode("out2", "Log", {"a23"}, {}, &want); + + CompareGraphs(want, got); +} + TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); |