aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_grad.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-15 15:40:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-15 15:45:49 -0800
commit6731e21360ac2fd2fa4999a12e87ae72acfcb2d2 (patch)
treedcc2c62ff497cd37bb9819f59111a8a7073296ea /tensorflow/python/ops/math_grad.py
parentffe08d38d747735eafe741467ce9092bcf126fa9 (diff)
Add Expm1 Op.
Change: 142198004
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r--tensorflow/python/ops/math_grad.py143
1 files changed, 78 insertions, 65 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index e4d42a4311..c7fb2c22e3 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Gradients for operators defined in math_ops.py."""
from __future__ import absolute_import
from __future__ import division
@@ -40,8 +39,8 @@ def _SumGrad(op, grad):
"""Gradient for Sum."""
# Fast path for when reducing to a scalar and ndims is known: adds only
# Reshape and Tile ops (and possibly a Shape).
- if (op.inputs[0].get_shape().ndims is not None and op.inputs[1].op.type ==
- "Const"):
+ if (op.inputs[0].get_shape().ndims is not None and
+ op.inputs[1].op.type == "Const"):
rank = op.inputs[0].get_shape().ndims
axes = tensor_util.MakeNdarray(op.inputs[1].op.get_attr("value"))
if np.array_equal(axes, np.arange(rank)): # Reduce all dims.
@@ -73,8 +72,7 @@ def _MinOrMaxGrad(op, grad):
# then the gradient will be divided between them.
indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype)
num_selected = array_ops.reshape(
- math_ops.reduce_sum(indicators, op.inputs[1]),
- output_shape_kept_dims)
+ math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims)
return [math_ops.div(indicators, num_selected) * grad, None]
@@ -96,8 +94,8 @@ def _MeanGrad(op, grad):
sum_grad = _SumGrad(op, grad)[0]
input_shape = array_ops.shape(op.inputs[0])
output_shape = array_ops.shape(op.outputs[0])
- factor = _safe_shape_div(math_ops.reduce_prod(input_shape),
- math_ops.reduce_prod(output_shape))
+ factor = _safe_shape_div(
+ math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape))
return sum_grad / math_ops.cast(factor, sum_grad.dtype), None
@@ -170,42 +168,36 @@ def _SparseSegmentSumGrad(op, grad):
"""Gradient for SparseSegmentSum."""
input_rows = array_ops.shape(op.inputs[0])[0]
return (math_ops.unsorted_segment_sum(
- array_ops.gather(grad, op.inputs[2]),
- op.inputs[1], input_rows), None, None)
+ array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None,
+ None)
@ops.RegisterGradient("SparseSegmentMean")
def _SparseSegmentMeanGrad(op, grad):
"""Gradient for SparseSegmentMean."""
dim0 = array_ops.shape(op.inputs[0])[0]
- return (math_ops.sparse_segment_mean_grad(grad,
- op.inputs[1],
- op.inputs[2],
- dim0),
- None, None)
+ return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
+ dim0), None, None)
@ops.RegisterGradient("SparseSegmentSqrtN")
def _SparseSegmentSqrtNGrad(op, grad):
"""Gradient for SparseSegmentSqrtN."""
dim0 = array_ops.shape(op.inputs[0])[0]
- return (math_ops.sparse_segment_sqrt_n_grad(grad,
- op.inputs[1],
- op.inputs[2],
- dim0),
- None, None)
+ return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
+ dim0), None, None)
def _SegmentMinOrMaxGrad(op, grad):
"""Gradient for SegmentMin and SegmentMax. Both share the same code."""
- zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
- dtype=op.inputs[0].dtype)
+ zeros = array_ops.zeros(
+ array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype)
# Get the number of selected (minimum or maximum) elements in each segment.
gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
- num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype),
- op.inputs[1])
+ num_selected = math_ops.segment_sum(
+ math_ops.cast(is_selected, grad.dtype), op.inputs[1])
# Compute the gradient for each segment. The gradient for the ith segment is
# divided evenly among the selected elements in that segment.
@@ -337,6 +329,16 @@ def _ExpGrad(op, grad):
return grad * y
+@ops.RegisterGradient("Expm1")
+def _Expm1Grad(op, grad):
+ """Returns grad * exp(x)."""
+ x = op.inputs[0]
+ with ops.control_dependencies([grad.op]):
+ x = math_ops.conj(x)
+ y = math_ops.exp(x)
+ return grad * y
+
+
@ops.RegisterGradient("Log")
def _LogGrad(op, grad):
"""Returns grad * (1/x)."""
@@ -388,8 +390,8 @@ def _ErfGrad(op, grad):
def _ErfcGrad(op, grad):
"""Returns -grad * 2/sqrt(pi) * exp(-x**2)."""
x = op.inputs[0]
- minus_two_over_root_pi = constant_op.constant(-2 / np.sqrt(np.pi),
- dtype=grad.dtype)
+ minus_two_over_root_pi = constant_op.constant(
+ -2 / np.sqrt(np.pi), dtype=grad.dtype)
with ops.control_dependencies([grad.op]):
x = math_ops.conj(x)
return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x))
@@ -425,7 +427,7 @@ def _IgammaGrad(op, grad):
# Perform operations in log space before summing, because Gamma(a)
# and Gamma'(a) can grow large.
- partial_x = math_ops.exp(-x + (a-1) * math_ops.log(x) - math_ops.lgamma(a))
+ partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - math_ops.lgamma(a))
return (None,
array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
@@ -625,8 +627,9 @@ def _DivGrad(op, grad):
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(math_ops.reduce_sum(math_ops.div(grad, y), rx), sx),
- array_ops.reshape(math_ops.reduce_sum(
- grad * math_ops.div(-x, math_ops.square(y)), ry), sy))
+ array_ops.reshape(
+ math_ops.reduce_sum(grad * math_ops.div(-x, math_ops.square(y)),
+ ry), sy))
@ops.RegisterGradient("FloorDiv")
@@ -652,10 +655,11 @@ def _RealDivGrad(op, grad):
# pylint: enable=protected-access
x = math_ops.conj(x)
y = math_ops.conj(y)
- return (array_ops.reshape(math_ops.reduce_sum(
- math_ops.realdiv(grad, y), rx), sx),
- array_ops.reshape(math_ops.reduce_sum(
- grad * math_ops.realdiv(-x, math_ops.square(y)), ry), sy))
+ return (array_ops.reshape(
+ math_ops.reduce_sum(math_ops.realdiv(grad, y), rx),
+ sx), array_ops.reshape(
+ math_ops.reduce_sum(grad * math_ops.realdiv(-x, math_ops.square(y)),
+ ry), sy))
@ops.RegisterGradient("Pow")
@@ -680,8 +684,7 @@ def _PowGrad(op, grad):
else:
# There's no sensible real value to return if x < 0, so return 0
log_x = array_ops.where(x > 0, math_ops.log(x), array_ops.zeros_like(x))
- gy = array_ops.reshape(
- math_ops.reduce_sum(grad * z * log_x, ry), sy)
+ gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy)
return gx, gy
@@ -760,19 +763,21 @@ def _MatMulGrad(op, grad):
t_a = op.get_attr("transpose_a")
t_b = op.get_attr("transpose_b")
if not t_a and not t_b:
- return (math_ops.matmul(grad, op.inputs[1], transpose_b=True),
- math_ops.matmul(op.inputs[0], grad, transpose_a=True))
+ return (math_ops.matmul(
+ grad, op.inputs[1], transpose_b=True), math_ops.matmul(
+ op.inputs[0], grad, transpose_a=True))
elif not t_a and t_b:
- return (math_ops.matmul(grad, op.inputs[1]),
- math_ops.matmul(grad, op.inputs[0], transpose_a=True))
+ return (math_ops.matmul(grad, op.inputs[1]), math_ops.matmul(
+ grad, op.inputs[0], transpose_a=True))
elif t_a and not t_b:
- return (math_ops.matmul(op.inputs[1], grad, transpose_b=True),
+ return (math_ops.matmul(
+ op.inputs[1], grad, transpose_b=True),
math_ops.matmul(op.inputs[0], grad))
elif t_a and t_b:
- return (math_ops.matmul(op.inputs[1], grad, transpose_a=True,
- transpose_b=True),
- math_ops.matmul(grad, op.inputs[0], transpose_a=True,
- transpose_b=True))
+ return (math_ops.matmul(
+ op.inputs[1], grad, transpose_a=True, transpose_b=True),
+ math_ops.matmul(
+ grad, op.inputs[0], transpose_a=True, transpose_b=True))
@ops.RegisterGradient("SparseMatMul")
@@ -787,8 +792,8 @@ def _SparseMatMulGrad(op, grad):
# Use heuristic to figure out if grad might be sparse
grad: (grad.op.type == "ReluGrad")
}
- def _SparseMatMul(t1, t2, out_dtype,
- transpose_a=False, transpose_b=False):
+
+ def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):
"""Helper function to create SparseMatMul op."""
assert t1 in is_sparse and t2 in is_sparse
@@ -797,11 +802,13 @@ def _SparseMatMulGrad(op, grad):
if transpose_b:
t2 = array_ops.transpose(t2)
transpose_b = False
- prod = math_ops.matmul(t1, t2,
- transpose_a=transpose_a,
- transpose_b=transpose_b,
- a_is_sparse=t1_sparse,
- b_is_sparse=t2_sparse)
+ prod = math_ops.matmul(
+ t1,
+ t2,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b,
+ a_is_sparse=t1_sparse,
+ b_is_sparse=t2_sparse)
if prod.dtype != out_dtype:
prod = math_ops.cast(prod, out_dtype)
return prod
@@ -809,19 +816,21 @@ def _SparseMatMulGrad(op, grad):
dtype_a = op.inputs[0].dtype
dtype_b = op.inputs[1].dtype
if not t_a and not t_b:
- return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True),
- _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True))
+ return (_SparseMatMul(
+ grad, op.inputs[1], dtype_a, transpose_b=True), _SparseMatMul(
+ op.inputs[0], grad, dtype_b, transpose_a=True))
elif not t_a and t_b:
- return (_SparseMatMul(grad, op.inputs[1], dtype_a),
- _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True))
+ return (_SparseMatMul(grad, op.inputs[1], dtype_a), _SparseMatMul(
+ grad, op.inputs[0], dtype_b, transpose_a=True))
elif t_a and not t_b:
- return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True),
+ return (_SparseMatMul(
+ op.inputs[1], grad, dtype_a, transpose_b=True),
_SparseMatMul(op.inputs[0], grad, dtype_b))
elif t_a and t_b:
- return (_SparseMatMul(op.inputs[1], grad, dtype_a,
- transpose_a=True, transpose_b=True),
- _SparseMatMul(grad, op.inputs[0], dtype_b,
- transpose_a=True, transpose_b=True))
+ return (_SparseMatMul(
+ op.inputs[1], grad, dtype_a, transpose_a=True,
+ transpose_b=True), _SparseMatMul(
+ grad, op.inputs[0], dtype_b, transpose_a=True, transpose_b=True))
@ops.RegisterGradient("Floor")
@@ -908,8 +917,10 @@ def _ComplexAbsGrad(op, grad):
@ops.RegisterGradient("Cast")
def _CastGrad(op, grad):
- t = [dtypes.float16, dtypes.float32, dtypes.float64,
- dtypes.bfloat16, dtypes.complex64, dtypes.complex128]
+ t = [
+ dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16,
+ dtypes.complex64, dtypes.complex128
+ ]
src_type = op.inputs[0].dtype.base_dtype
dst_type = grad.dtype.base_dtype
if src_type in t and dst_type in t:
@@ -972,8 +983,10 @@ def _CumsumGrad(op, grad):
axis = op.inputs[1]
exclusive = op.get_attr("exclusive")
reverse = op.get_attr("reverse")
- return [math_ops.cumsum(grad, axis, exclusive=exclusive,
- reverse=not reverse), None]
+ return [
+ math_ops.cumsum(
+ grad, axis, exclusive=exclusive, reverse=not reverse), None
+ ]
@ops.RegisterGradient("Cumprod")
@@ -985,6 +998,6 @@ def _CumprodGrad(op, grad):
# TODO This fails when x contains 0 and should be fixed
prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse)
- out = math_ops.cumsum(prod * grad, axis, exclusive=exclusive,
- reverse=not reverse)
+ out = math_ops.cumsum(
+ prod * grad, axis, exclusive=exclusive, reverse=not reverse)
return [out / x, None]