aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-14 17:12:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 17:15:41 -0700
commit9e4cbaf3a3a3bfca913bebdcfc082265c7a13ad6 (patch)
treec8fcc2f493b5a6f7d38a9b2c036ecffd5105ac37
parent261ab05537885556f92d7322017ddf73ea5a7357 (diff)
Convert log(x+1) to log1p(x).
PiperOrigin-RevId: 200645461
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc115
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc42
5 files changed, 161 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 2a47a4c495..2227904dbf 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -193,6 +193,8 @@ bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
+bool IsLog(const NodeDef& node) { return node.op() == "Log"; }
+
bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index e7f39981c0..7110a9c63d 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -74,6 +74,7 @@ bool IsImag(const NodeDef& node);
bool IsInvGrad(const NodeDef& node);
bool IsLess(const NodeDef& node);
bool IsLessEqual(const NodeDef& node);
+bool IsLog(const NodeDef& node);
bool IsLogicalAnd(const NodeDef& node);
bool IsLogicalNot(const NodeDef& node);
bool IsLogicalOr(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index c41b152d21..9d500f8f54 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2487,6 +2487,119 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
};
+class ConvertLog1pStage : public ArithmeticOptimizerStage {
+ public:
+ explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
+ ~ConvertLog1pStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+ if (!IsAdd(*input)) {
+ return Status::OK();
+ }
+
+ if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
+ return Status::OK();
+ }
+
+ bool modified = false;
+ TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
+ if (!modified) {
+ TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
+ }
+ if (modified) {
+ *simplified_node_name = node->name();
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j,
+ bool* modified) {
+ const auto& t =
+ ctx().graph_properties->GetInputProperties(input->name())[i];
+ for (int k = 0; k < t.shape().dim_size(); ++k) {
+ // Skip if t shape is not fully determined.
+ if (t.shape().dim(k).size() < 0) {
+ return Status::OK();
+ }
+ }
+ const auto& c =
+ ctx().graph_properties->GetInputProperties(input->name())[j];
+ TensorShapeProto broadcast_shape;
+ if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
+ return errors::InvalidArgument("Cannot get broadcast shape for: ",
+ t.DebugString(), " and ", c.DebugString());
+ }
+ if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
+ // skip if the non-constant tensor doesn't have the same shape after
+ // broadcast.
+ return Status::OK();
+ }
+ if (TensorShape::IsValid(t.shape()) && t.has_value()) {
+ Tensor tensor(t.dtype(), t.shape());
+ if (!tensor.FromProto(t.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ t.value().DebugString());
+ }
+ complex128 element;
+ for (int k = 0; k < tensor.NumElements(); ++k) {
+ if (!GetElement(tensor, k, &element)) {
+ // input data type is not supported by log1p. Skip.
+ return Status::OK();
+ }
+ if (element != complex128(1)) {
+ // current element is not 1. Skip.
+ return Status::OK();
+ }
+ }
+ NodeDef *x, *y;
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x));
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y));
+ node->set_op("Log1p");
+ node->set_input(0, y->name());
+ node->add_input(AsControlDependency(x->name()));
+ ForwardControlDependencies(node, {input});
+
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(x);
+ AddToOptimizationQueue(y);
+ *modified = true;
+ }
+ return Status::OK();
+ }
+
+ bool GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_BFLOAT16:
+ *element = complex128(t.flat<bfloat16>()(i));
+ return true;
+ case DT_HALF:
+ *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
+ return true;
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return true;
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return true;
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return true;
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return true;
+ default:
+ return false;
+ }
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -2763,6 +2876,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.remove_idempotent)
pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
+ if (options_.convert_log1p)
+ pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 40c5e9fc56..9a6081dcd8 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -75,6 +75,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool replace_mul_with_square = true;
bool simplify_aggregation = true;
bool convert_pow = true;
+ bool convert_log1p = true;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index fe70c7db5c..177c237fe7 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -264,6 +264,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.simplify_aggregation = true;
}
+
+ void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.convert_log1p = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -2486,6 +2491,43 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
CompareGraphs(want, got);
}
+TEST_F(ArithmeticOptimizerTest, Log1p) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto x1 = ops::Const(s.WithOpName("x1"), {1.0f, 1.0f}, {1, 2});
+ auto x2 = ops::Const(s.WithOpName("x2"), {2.0f, 2.0f}, {1, 2});
+ auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
+ auto a12 = ops::Add(s.WithOpName("a12").WithControlDependencies(x3), x1, x2);
+ auto a23 = ops::Add(s.WithOpName("a23"), x2, x3);
+ Output out1 = ops::Log(s.WithOpName("out1"), a12);
+ Output out2 = ops::Log(s.WithOpName("out2"), a23);
+
+ GrapplerItem item;
+ item.fetch = {"out1", "out2"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(2, tensors_expected.size());
+
+ GraphDef got;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyLog1p(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &got);
+ auto tensors = EvaluateNodes(got, item.fetch);
+ EXPECT_EQ(2, tensors.size());
+
+ GraphDef want;
+ AddNode("x1", "Const", {}, {}, &want);
+ AddNode("x2", "Const", {}, {}, &want);
+ AddNode("x3", "Const", {}, {}, &want);
+ AddNode("a23", "Add", {"x2", "x3"}, {}, &want);
+ AddNode("out1", "Log1p",
+ {"x2", AsControlDependency("x1"), AsControlDependency("x3")}, {},
+ &want);
+ AddNode("out2", "Log", {"a23"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();