aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r--tensorflow/python/ops/math_grad.py34
1 files changed, 33 insertions, 1 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index a0f505e47b..b8266a527d 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -122,8 +122,10 @@ def _ProdGrad(op, grad):
# so we need to cast here. We put all the shape-related ops on CPU to avoid
# copying back and forth, and since listdiff is CPU only.
with ops.device("/cpu:0"):
+ rank = array_ops.rank(op.inputs[0])
+ reduction_indices = (reduction_indices + rank) % rank
reduced = math_ops.cast(reduction_indices, dtypes.int32)
- idx = math_ops.range(0, array_ops.rank(op.inputs[0]))
+ idx = math_ops.range(0, rank)
other, _ = array_ops.setdiff1d(idx, reduced)
perm = array_ops.concat([reduced, other], 0)
reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
@@ -397,6 +399,36 @@ def _TanhGrad(op, grad):
return gen_math_ops._tanh_grad(y, grad)
+@ops.RegisterGradient("Asinh")
+def _AsinhGrad(op, grad):
+ """Returns grad * 1/cosh(y)."""
+ y = op.outputs[0]
+ with ops.control_dependencies([grad.op]):
+ y = math_ops.conj(y)
+ return grad / math_ops.cosh(y)
+
+
+@ops.RegisterGradient("Acosh")
+def _AcoshGrad(op, grad):
+ """Returns grad * 1/sinh(y)."""
+ y = op.outputs[0]
+ with ops.control_dependencies([grad.op]):
+ y = math_ops.conj(y)
+ return grad / math_ops.sinh(y)
+
+
+@ops.RegisterGradient("Atanh")
+def _AtanhGrad(op, grad):
+ """Returns grad * 1/ (1 - x^2)."""
+ x = op.inputs[0]
+ with ops.control_dependencies([grad.op]):
+ x = math_ops.conj(x)
+ x2 = math_ops.square(x)
+ one = constant_op.constant(1, dtype=grad.dtype)
+ inv = math_ops.reciprocal(math_ops.subtract(one, x2))
+ return grad * inv
+
+
@ops.RegisterGradient("TanhGrad")
def _TanhGradGrad(op, grad):
with ops.control_dependencies([grad.op]):