diff options
Diffstat (limited to 'tensorflow/python/ops/sparse_grad.py')
-rw-r--r-- | tensorflow/python/ops/sparse_grad.py | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py index 5d64e0eef1..0d43c85634 100644 --- a/tensorflow/python/ops/sparse_grad.py +++ b/tensorflow/python/ops/sparse_grad.py @@ -171,3 +171,50 @@ def _SparseTensorDenseMatMulGrad(op, grad): # gradients w.r.t. (a_indices, a_values, a_shape, b) return (None, a_values_grad, None, b_grad) + + +def _SparseDenseCwiseMulOrDivGrad(op, grad, is_mul): + """Common code for SparseDenseCwise{Mul,Div} gradients.""" + x_indices = op.inputs[0] + x_shape = op.inputs[2] + y = op.inputs[3] + + y_shape = math_ops.to_int64(array_ops.shape(y)) + num_added_dims = array_ops.expand_dims( + array_ops.size(x_shape) - array_ops.size(y_shape), 0) + augmented_y_shape = array_ops.concat(0, [array_ops.ones(num_added_dims, + ops.dtypes.int64), + y_shape]) + + scaling = x_shape // augmented_y_shape + scaled_indices = x_indices // scaling + scaled_indices = array_ops.slice(scaled_indices, + array_ops.concat(0, [[0], num_added_dims]), + [-1, -1]) + dense_vals = array_ops.gather_nd(y, scaled_indices) + + if is_mul: + dx = grad * dense_vals + dy_val = grad * op.inputs[1] + else: + dx = grad / dense_vals + dy_val = grad * (-op.inputs[1] / math_ops.square(dense_vals)) + # indices can repeat after scaling, so we can't use sparse_to_dense(). + dy = sparse_ops.sparse_add( + array_ops.zeros_like(y), + ops.SparseTensor(scaled_indices, dy_val, y_shape)) + + # (sp_indices, sp_vals, sp_shape, dense) + return (None, dx, None, dy) + + +@ops.RegisterGradient("SparseDenseCwiseMul") +def _SparseDenseCwiseMulGrad(op, grad): + """Gradients for SparseDenseCwiseMul.""" + return _SparseDenseCwiseMulOrDivGrad(op, grad, True) + + +@ops.RegisterGradient("SparseDenseCwiseDiv") +def _SparseDenseCwiseDivGrad(op, grad): + """Gradients for SparseDenseCwiseDiv.""" + return _SparseDenseCwiseMulOrDivGrad(op, grad, False) |