aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/math_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/math_grad.py')
-rw-r--r--tensorflow/python/ops/math_grad.py506
1 files changed, 506 insertions, 0 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
new file mode 100644
index 0000000000..cb808ff5b8
--- /dev/null
+++ b/tensorflow/python/ops/math_grad.py
@@ -0,0 +1,506 @@
+"""Gradients for operators defined in math_ops.py."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+
+
+def _ReductionGradAssist(op):
+ """Reduction grads have much in common, so factor the commonality out."""
+ inp = op.inputs[0] # Example:
+ input_shape = array_ops.shape(inp) # [2, 3, 5, 7]
+ input_rank = array_ops.rank(inp) # 4
+ indices = op.inputs[1] # [1, 2]
+ indices_shape = array_ops.shape(indices) # [2]
+ new_output_shape = data_flow_ops.dynamic_stitch( # [2, 1, 1, 7]
+ [math_ops.range(0, input_rank), # [0, 1, 2, 3]
+ indices], # [1, 2]
+ [input_shape, # [2, 3, 5, 7]
+ array_ops.fill(indices_shape, 1)]) # [1, 1]
+ return inp, new_output_shape, input_shape
+
+
+@ops.RegisterGradient("Sum")
+def _SumGrad(op, grad):
+ """Gradient for Sum."""
+ _, new_output_shape, input_shape = _ReductionGradAssist(op)
+ tile_scaling = input_shape / new_output_shape
+ grad = array_ops.reshape(grad, new_output_shape)
+ return [array_ops.tile(grad, tile_scaling), None]
+
+
+def _MinOrMaxGrad(op, grad):
+ """Gradient for Max or Max. Amazingly it's precisely the same code."""
+ inp, new_output_shape, _ = _ReductionGradAssist(op)
+ y = op.outputs[0]
+ y = array_ops.reshape(y, new_output_shape)
+ grad = array_ops.reshape(grad, new_output_shape)
+ indicators = math_ops.cast(math_ops.equal(y, inp), grad.dtype)
+ return [indicators * grad, None]
+
+
+@ops.RegisterGradient("Max")
+def _MaxGrad(op, grad):
+ """Gradient for Max."""
+ return _MinOrMaxGrad(op, grad)
+
+
+@ops.RegisterGradient("Min")
+def _MinGrad(op, grad):
+ return _MinOrMaxGrad(op, grad)
+
+
+@ops.RegisterGradient("Mean")
+def _MeanGrad(op, grad):
+ """Gradient for Mean."""
+ sum_grad = _SumGrad(op, grad)[0]
+ input_shape = array_ops.shape(op.inputs[0])
+ output_shape = array_ops.shape(op.outputs[0])
+ factor = (math_ops.reduce_prod(input_shape) /
+ math_ops.reduce_prod(output_shape))
+ return sum_grad / math_ops.cast(factor, sum_grad.dtype), None
+
+
+@ops.RegisterGradient("Prod")
+def _ProdGrad(op, grad):
+ """Gradient for Prod."""
+ # TODO(kearnes): this gives NaNs for 0s in the input tensor
+ _, new_output_shape, input_shape = _ReductionGradAssist(op)
+ tile_scaling = input_shape / new_output_shape
+ grad = array_ops.reshape(grad * op.outputs[0], new_output_shape)
+ grad = math_ops.div(array_ops.tile(grad, tile_scaling), op.inputs[0])
+ return grad, None
+
+
+@ops.RegisterGradient("SegmentSum")
+def _SegmentSumGrad(op, grad):
+ """Gradient for SegmentSum."""
+ return array_ops.gather(grad, op.inputs[1]), None
+
+
+@ops.RegisterGradient("SegmentMean")
+def _SegmentMeanGrad(op, grad):
+ """Gradient for SegmentMean."""
+ input_rank = array_ops.rank(op.inputs[0])
+ ones_shape = array_ops.concat(
+ 0, [array_ops.shape(op.inputs[1]),
+ array_ops.fill(array_ops.expand_dims(input_rank - 1, 0), 1)])
+ ones = array_ops.fill(ones_shape,
+ constant_op.constant(1, dtype=grad.dtype))
+ scaled_grad = grad * math_ops.inv(math_ops.segment_sum(ones, op.inputs[1]))
+ return array_ops.gather(scaled_grad, op.inputs[1]), None
+
+
+@ops.RegisterGradient("SparseSegmentSum")
+def _SparseSegmentSumGrad(op, grad):
+ """Gradient for SparseSegmentSum."""
+ input_rows = array_ops.shape(op.inputs[0])[0]
+ return (math_ops.unsorted_segment_sum(
+ array_ops.gather(grad, op.inputs[2]),
+ op.inputs[1], input_rows), None, None)
+
+
+@ops.RegisterGradient("SparseSegmentMean")
+def _SparseSegmentMeanGrad(op, grad):
+ """Gradient for SparseSegmentMean."""
+ dim0 = array_ops.shape(op.inputs[0])[0]
+ return (math_ops.sparse_segment_mean_grad(grad,
+ op.inputs[1],
+ op.inputs[2],
+ dim0),
+ None, None)
+
+
+@ops.RegisterGradient("SegmentMin")
+def _SegmentMinGrad(op, grad):
+ """Gradient for SegmentMin."""
+ zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
+ dtype=op.inputs[0].dtype)
+ gathered_grads = array_ops.gather(grad, op.inputs[1])
+ gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
+ return math_ops.select(math_ops.greater(op.inputs[0], gathered_outputs),
+ zeros,
+ gathered_grads), None
+
+
+@ops.RegisterGradient("SegmentMax")
+def _SegmentMaxGrad(op, grad):
+ """Gradient for SegmentMax."""
+ zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
+ dtype=op.inputs[0].dtype)
+ gathered_grads = array_ops.gather(grad, op.inputs[1])
+ gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
+ return math_ops.select(math_ops.less(op.inputs[0], gathered_outputs),
+ zeros,
+ gathered_grads), None
+
+
+@ops.RegisterGradient("UnsortedSegmentSum")
+def _UnsortedSegmentSumGrad(op, grad):
+ """Gradient for SegmentSum."""
+ return array_ops.gather(grad, op.inputs[1]), None, None
+
+
+@ops.RegisterGradient("Abs")
+def _AbsGrad(op, grad):
+ x = op.inputs[0]
+ return grad * math_ops.sign(x)
+
+
+@ops.RegisterGradient("Neg")
+def _NegGrad(_, grad):
+ """Returns -grad."""
+ return - grad
+
+
+@ops.RegisterGradient("Inv")
+def _InvGrad(op, grad):
+ """Returns -grad * (1 / x^2)."""
+ y = op.outputs[0] # y = 1 / x
+ return grad * (- math_ops.square(y))
+
+
+@ops.RegisterGradient("Square")
+def _SquareGrad(op, grad):
+ x = op.inputs[0]
+ return grad * (2.0 * x)
+
+
+@ops.RegisterGradient("Sqrt")
+def _SqrtGrad(op, grad):
+ y = op.outputs[0] # y = x^(1/2)
+ return grad * (.5 * math_ops.inv(y))
+
+
+@ops.RegisterGradient("Rsqrt")
+def _RsqrtGrad(op, grad):
+ x = op.inputs[0]
+ y = op.outputs[0] # y = x^(-1/2)
+ return grad * ((-0.5) * math_ops.inv(x) * y)
+
+
+@ops.RegisterGradient("Exp")
+def _ExpGrad(op, grad):
+ """Returns grad * exp(x)."""
+ y = op.outputs[0] # y = e^x
+ return grad * y
+
+
+@ops.RegisterGradient("Log")
+def _LogGrad(op, grad):
+ """Returns grad * (1/x)."""
+ x = op.inputs[0]
+ return grad * math_ops.inv(x)
+
+
+@ops.RegisterGradient("Tanh")
+def _TanhGrad(op, grad):
+ """Returns grad * (1 - tanh(x) * tanh(x))."""
+ y = op.outputs[0] # y = tanh(x)
+ return grad * (1 - math_ops.square(y))
+
+
+@ops.RegisterGradient("Sigmoid")
+def _SigmoidGrad(op, grad):
+ """Returns grad * sigmoid(x) * (1 - sigmoid(x))."""
+ y = op.outputs[0] # y = sigmoid(x)
+ return grad * (y * (1 - y))
+
+
+@ops.RegisterGradient("Sign")
+def _SignGrad(op, _):
+ """Returns 0."""
+ x = op.inputs[0]
+ return array_ops.zeros(array_ops.shape(x), dtype=x.dtype)
+
+
+@ops.RegisterGradient("Sin")
+def _SinGrad(op, grad):
+ """Returns grad * cos(x)."""
+ x = op.inputs[0]
+ return grad * math_ops.cos(x)
+
+
+@ops.RegisterGradient("Cos")
+def _CosGrad(op, grad):
+ """Returns grad * -sin(x)."""
+ x = op.inputs[0]
+ return -grad * math_ops.sin(x)
+
+
+@ops.RegisterGradient("AddN")
+def _AddNGrad(op, grad):
+ """Copies the gradient to all inputs."""
+ # Not broadcasting.
+ return [grad] * len(op.inputs)
+
+
+@ops.RegisterGradient("Add")
+def _AddGrad(op, grad):
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(grad, ry), sy))
+
+
+@ops.RegisterGradient("Sub")
+def _SubGrad(op, grad):
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
+ array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy))
+
+
+@ops.RegisterGradient("Mul")
+def _MulGrad(op, grad):
+ x = op.inputs[0]
+ y = op.inputs[1]
+ assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ if x.dtype.base_dtype == types.complex64:
+ return (array_ops.reshape(math_ops.reduce_sum(grad * math_ops.conj(y), rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(math_ops.conj(x) * grad, ry), sy))
+ else:
+ return (array_ops.reshape(math_ops.reduce_sum(grad * y, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(x * grad, ry), sy))
+
+
+@ops.RegisterGradient("Div")
+def _DivGrad(op, grad):
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ return (array_ops.reshape(math_ops.reduce_sum(grad / y, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(grad *
+ (-x / math_ops.square(y)), ry), sy))
+
+
+@ops.RegisterGradient("Pow")
+def _PowGrad(op, grad):
+ """Returns grad * (y*x^(y-1), z*log(x))."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ z = op.outputs[0]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ gx = array_ops.reshape(math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx),
+ sx)
+ gy = array_ops.reshape(math_ops.reduce_sum(grad * z * math_ops.log(x), ry), sy)
+ return gx, gy
+
+
+def _MaximumMinimumGrad(op, grad, selector_op):
+ """Factor out the code for the gradient of Maximum or Minimum."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ gdtype = grad.dtype
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ gradshape = array_ops.shape(grad)
+ zeros = array_ops.zeros(gradshape, gdtype)
+ xmask = selector_op(x, y)
+ rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
+ xgrad = math_ops.select(xmask, grad, zeros)
+ ygrad = math_ops.select(math_ops.logical_not(xmask), grad, zeros)
+ gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
+ gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
+ return (gx, gy)
+
+
+@ops.RegisterGradient("Maximum")
+def _MaximumGrad(op, grad):
+ """Returns grad*(x > y, x <= y) with type of grad."""
+ return _MaximumMinimumGrad(op, grad, math_ops.greater_equal)
+
+
+@ops.RegisterGradient("Minimum")
+def _MinimumGrad(op, grad):
+ """Returns grad*(x < y, x >= y) with type of grad."""
+ return _MaximumMinimumGrad(op, grad, math_ops.less_equal)
+
+
+# Logical operations have no gradients.
+ops.NoGradient("Less")
+ops.NoGradient("LessEqual")
+ops.NoGradient("Greater")
+ops.NoGradient("GreaterEqual")
+ops.NoGradient("Equal")
+ops.NoGradient("NotEqual")
+ops.NoGradient("LogicalAnd")
+ops.NoGradient("LogicalOr")
+ops.NoGradient("LogicalNot")
+
+
+@ops.RegisterGradient("Select")
+def _SelectGrad(op, grad):
+ c = op.inputs[0]
+ x = op.inputs[1]
+ zeros = array_ops.zeros(array_ops.shape(c), dtype=x.dtype)
+ return (None, math_ops.select(c, grad, zeros),
+ math_ops.select(c, zeros, grad))
+
+
+@ops.RegisterGradient("MatMul")
+def _MatMulGrad(op, grad):
+ t_a = op.get_attr("transpose_a")
+ t_b = op.get_attr("transpose_b")
+ if not t_a and not t_b:
+ return (math_ops.matmul(grad, op.inputs[1], transpose_b=True),
+ math_ops.matmul(op.inputs[0], grad, transpose_a=True))
+ elif not t_a and t_b:
+ return (math_ops.matmul(grad, op.inputs[1]),
+ math_ops.matmul(grad, op.inputs[0], transpose_a=True))
+ elif t_a and not t_b:
+ return (math_ops.matmul(op.inputs[1], grad, transpose_b=True),
+ math_ops.matmul(op.inputs[0], grad))
+ elif t_a and t_b:
+ return (math_ops.matmul(op.inputs[1], grad, transpose_a=True,
+ transpose_b=True),
+ math_ops.matmul(grad, op.inputs[0], transpose_a=True,
+ transpose_b=True))
+
+
+@ops.RegisterGradient("SparseMatMul")
+def _SparseMatMulGrad(op, grad):
+ """Gradient for SparseMatMul."""
+
+ t_a = op.get_attr("transpose_a")
+ t_b = op.get_attr("transpose_b")
+ is_sparse = {
+ op.inputs[0]: op.get_attr("a_is_sparse"),
+ op.inputs[1]: op.get_attr("b_is_sparse"),
+ # Use heuristic to figure out if grad might be sparse
+ grad: (grad.op.type == "ReluGrad")
+ }
+ def _SparseMatMul(t1, t2, transpose_a=False, transpose_b=False):
+ """Helper function to create SparseMatMul op."""
+
+ assert t1 in is_sparse and t2 in is_sparse
+ t1_sparse = is_sparse[t1]
+ t2_sparse = is_sparse[t2]
+ if not t1_sparse and not t2_sparse:
+ return math_ops.matmul(t1, t2,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b)
+ transpose_out = False
+ if not t1_sparse:
+ transpose_out = True
+ t1, t2 = t2, t1
+ t1_sparse, t2_sparse = t2_sparse, t1_sparse
+ assert t1_sparse
+ transpose_a, transpose_b = not transpose_b, not transpose_a
+
+ if transpose_b:
+ t2 = array_ops.transpose(t2)
+ transpose_b = False
+ m = math_ops.matmul(t1, t2,
+ transpose_a=transpose_a,
+ transpose_b=transpose_b,
+ a_is_sparse=t1_sparse,
+ b_is_sparse=t2_sparse)
+ if transpose_out:
+ m = array_ops.transpose(m)
+ return m
+
+ if not t_a and not t_b:
+ return (_SparseMatMul(grad, op.inputs[1], transpose_b=True),
+ _SparseMatMul(op.inputs[0], grad, transpose_a=True))
+ elif not t_a and t_b:
+ return (_SparseMatMul(grad, op.inputs[1]),
+ _SparseMatMul(grad, op.inputs[0], transpose_a=True))
+ elif t_a and not t_b:
+ return (_SparseMatMul(op.inputs[1], grad, transpose_b=True),
+ _SparseMatMul(op.inputs[0], grad))
+ elif t_a and t_b:
+ return (_SparseMatMul(op.inputs[1], grad,
+ transpose_a=True, transpose_b=True),
+ _SparseMatMul(grad, op.inputs[0],
+ transpose_a=True, transpose_b=True))
+
+
+@ops.RegisterGradient("Floor")
+def _FloorGrad(_, grad):
+ return grad
+
+
+@ops.RegisterGradient("BatchMatMul")
+def _BatchMatMul(op, grad):
+ """Returns the gradient of x and y given the gradient of x * y."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ adj_x = op.get_attr("adj_x")
+ adj_y = op.get_attr("adj_y")
+
+ if not adj_x:
+ if not adj_y:
+ grad_x = math_ops.batch_matmul(grad, y, False, True)
+ grad_y = math_ops.batch_matmul(x, grad, True, False)
+ else:
+ grad_x = math_ops.batch_matmul(grad, y, False, False)
+ grad_y = math_ops.batch_matmul(grad, x, True, False)
+ else:
+ if not adj_y:
+ grad_x = math_ops.batch_matmul(y, grad, False, True)
+ grad_y = math_ops.batch_matmul(x, grad, False, False)
+ else:
+ grad_x = math_ops.batch_matmul(y, grad, True, True)
+ grad_y = math_ops.batch_matmul(grad, x, True, True)
+
+ return grad_x, grad_y
+
+
+ops.NoGradient("Range")
+ops.NoGradient("LinSpace")
+
+
+@ops.RegisterGradient("Complex")
+def _ComplexGrad(_, grad):
+ """Returns the real and imaginary components of 'grad', respectively."""
+ return math_ops.real(grad), math_ops.imag(grad)
+
+
+@ops.RegisterGradient("Real")
+def _RealGrad(_, grad):
+ """Returns 'grad' as the real part and set the imaginary part 0."""
+ zero = constant_op.constant(0, dtype=grad.dtype)
+ return math_ops.complex(grad, zero)
+
+
+@ops.RegisterGradient("Imag")
+def _ImagGrad(_, grad):
+ """Returns 'grad' as the imaginary part and set the real part 0."""
+ zero = constant_op.constant(0, dtype=grad.dtype)
+ return math_ops.complex(zero, grad)
+
+
+@ops.RegisterGradient("Conj")
+def _ConjGrad(_, grad):
+ """Returns the complex conjugate of grad."""
+ return math_ops.conj(grad)
+
+
+@ops.RegisterGradient("Cast")
+def _CastGrad(op, grad):
+ t = [types.float32, types.float64, types.bfloat16]
+ src_type = op.inputs[0].dtype.base_dtype
+ dst_type = grad.dtype.base_dtype
+ if src_type in t and dst_type in t:
+ return math_ops.cast(grad, src_type)
+ else:
+ return None