aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/math_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/math_ops.cc')
-rw-r--r--tensorflow/core/ops/math_ops.cc21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 0f9ee4942a..b220a2d2d6 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -238,6 +238,13 @@ tf.complex_abs(x) ==> [5.25594902, 6.60492229]
.Attr("T: {half, float, double, complex64, complex128}") \
.SetShapeFn(OpShapeInferenceFn(shape_inference::UnchangedShape))
+#define UNARY_GRADIENT_COMPLEX() \
+ Input("x: T") \
+ .Input("y: T") \
+ .Output("z: T") \
+ .Attr("T: {half, float, double, complex64, complex128}") \
+ .SetShapeFn(OpShapeInferenceFn(shape_inference::UnchangedShape))
+
REGISTER_OP("Neg")
.UNARY()
.Doc(R"doc(
@@ -292,6 +299,13 @@ REGISTER_OP("Tanh")
Computes hyperbolic tangent of `x` element-wise.
)doc");
+REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX().Doc(R"doc(
+Computes the gradient for the tanh of `x` wrt its input.
+
+Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy`
+is the corresponding input gradient.
+)doc");
+
REGISTER_OP("Lgamma")
.UNARY_REAL()
.Doc(R"doc(
@@ -325,6 +339,13 @@ Computes sigmoid of `x` element-wise.
Specifically, `y = 1 / (1 + exp(-x))`.
)doc");
+REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX().Doc(R"doc(
+Computes the gradient of the sigmoid of `x` wrt its input.
+
+Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
+`dy` is the corresponding input gradient.
+)doc");
+
REGISTER_OP("Sin")
.UNARY_COMPLEX()
.Doc(R"doc(