diff options
author | 2018-06-14 17:12:51 -0700 | |
---|---|---|
committer | 2018-06-14 17:15:41 -0700 | |
commit | 9e4cbaf3a3a3bfca913bebdcfc082265c7a13ad6 (patch) | |
tree | c8fcc2f493b5a6f7d38a9b2c036ecffd5105ac37 /tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | |
parent | 261ab05537885556f92d7322017ddf73ea5a7357 (diff) |
Convert log(x+1) to log1p(x).
PiperOrigin-RevId: 200645461
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index c41b152d21..9d500f8f54 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2487,6 +2487,119 @@ class ConvertPowStage : public ArithmeticOptimizerStage { } }; +class ConvertLog1pStage : public ArithmeticOptimizerStage { + public: + explicit ConvertLog1pStage(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {} + ~ConvertLog1pStage() override = default; + + bool IsSupported(const NodeDef* node) const override { return IsLog(*node); } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + if (!IsAdd(*input)) { + return Status::OK(); + } + + if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) { + return Status::OK(); + } + + bool modified = false; + TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified)); + if (!modified) { + TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified)); + } + if (modified) { + *simplified_node_name = node->name(); + } + return Status::OK(); + } + + private: + Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j, + bool* modified) { + const auto& t = + ctx().graph_properties->GetInputProperties(input->name())[i]; + for (int k = 0; k < t.shape().dim_size(); ++k) { + // Skip if t shape is not fully determined. + if (t.shape().dim(k).size() < 0) { + return Status::OK(); + } + } + const auto& c = + ctx().graph_properties->GetInputProperties(input->name())[j]; + TensorShapeProto broadcast_shape; + if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) { + return errors::InvalidArgument("Cannot get broadcast shape for: ", + t.DebugString(), " and ", c.DebugString()); + } + if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) { + // skip if the non-constant tensor doesn't have the same shape after + // broadcast. + return Status::OK(); + } + if (TensorShape::IsValid(t.shape()) && t.has_value()) { + Tensor tensor(t.dtype(), t.shape()); + if (!tensor.FromProto(t.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + t.value().DebugString()); + } + complex128 element; + for (int k = 0; k < tensor.NumElements(); ++k) { + if (!GetElement(tensor, k, &element)) { + // input data type is not supported by log1p. Skip. + return Status::OK(); + } + if (element != complex128(1)) { + // current element is not 1. Skip. + return Status::OK(); + } + } + NodeDef *x, *y; + TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x)); + TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y)); + node->set_op("Log1p"); + node->set_input(0, y->name()); + node->add_input(AsControlDependency(x->name())); + ForwardControlDependencies(node, {input}); + + AddToOptimizationQueue(node); + AddToOptimizationQueue(x); + AddToOptimizationQueue(y); + *modified = true; + } + return Status::OK(); + } + + bool GetElement(const Tensor& t, int i, complex128* element) { + switch (t.dtype()) { + case DT_BFLOAT16: + *element = complex128(t.flat<bfloat16>()(i)); + return true; + case DT_HALF: + *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0); + return true; + case DT_FLOAT: + *element = complex128(t.flat<float>()(i)); + return true; + case DT_DOUBLE: + *element = complex128(t.flat<double>()(i)); + return true; + case DT_COMPLEX64: + *element = complex128(t.flat<complex64>()(i)); + return true; + case DT_COMPLEX128: + *element = t.flat<complex128>()(i); + return true; + default: + return false; + } + } +}; + } // namespace class UniqueNodes { @@ -2763,6 +2876,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { if (options_.remove_idempotent) pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext); if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext); + if (options_.convert_log1p) + pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); |