aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-14 17:12:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 17:15:41 -0700
commit9e4cbaf3a3a3bfca913bebdcfc082265c7a13ad6 (patch)
treec8fcc2f493b5a6f7d38a9b2c036ecffd5105ac37 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
parent261ab05537885556f92d7322017ddf73ea5a7357 (diff)
Convert log(x+1) to log1p(x).
PiperOrigin-RevId: 200645461
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc115
1 files changed, 115 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index c41b152d21..9d500f8f54 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2487,6 +2487,119 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
};
+class ConvertLog1pStage : public ArithmeticOptimizerStage {
+ public:
+ explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
+ ~ConvertLog1pStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+ if (!IsAdd(*input)) {
+ return Status::OK();
+ }
+
+ if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
+ return Status::OK();
+ }
+
+ bool modified = false;
+ TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
+ if (!modified) {
+ TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
+ }
+ if (modified) {
+ *simplified_node_name = node->name();
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j,
+ bool* modified) {
+ const auto& t =
+ ctx().graph_properties->GetInputProperties(input->name())[i];
+ for (int k = 0; k < t.shape().dim_size(); ++k) {
+ // Skip if t shape is not fully determined.
+ if (t.shape().dim(k).size() < 0) {
+ return Status::OK();
+ }
+ }
+ const auto& c =
+ ctx().graph_properties->GetInputProperties(input->name())[j];
+ TensorShapeProto broadcast_shape;
+ if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
+ return errors::InvalidArgument("Cannot get broadcast shape for: ",
+ t.DebugString(), " and ", c.DebugString());
+ }
+ if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
+ // skip if the non-constant tensor doesn't have the same shape after
+ // broadcast.
+ return Status::OK();
+ }
+ if (TensorShape::IsValid(t.shape()) && t.has_value()) {
+ Tensor tensor(t.dtype(), t.shape());
+ if (!tensor.FromProto(t.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ t.value().DebugString());
+ }
+ complex128 element;
+ for (int k = 0; k < tensor.NumElements(); ++k) {
+ if (!GetElement(tensor, k, &element)) {
+ // input data type is not supported by log1p. Skip.
+ return Status::OK();
+ }
+ if (element != complex128(1)) {
+ // current element is not 1. Skip.
+ return Status::OK();
+ }
+ }
+ NodeDef *x, *y;
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x));
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y));
+ node->set_op("Log1p");
+ node->set_input(0, y->name());
+ node->add_input(AsControlDependency(x->name()));
+ ForwardControlDependencies(node, {input});
+
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(x);
+ AddToOptimizationQueue(y);
+ *modified = true;
+ }
+ return Status::OK();
+ }
+
+ bool GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_BFLOAT16:
+ *element = complex128(t.flat<bfloat16>()(i));
+ return true;
+ case DT_HALF:
+ *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
+ return true;
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return true;
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return true;
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return true;
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return true;
+ default:
+ return false;
+ }
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -2763,6 +2876,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.remove_idempotent)
pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
+ if (options_.convert_log1p)
+ pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");