aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/sparse_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/sparse_grad.py')
-rw-r--r--tensorflow/python/ops/sparse_grad.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
index 5d64e0eef1..0d43c85634 100644
--- a/tensorflow/python/ops/sparse_grad.py
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -171,3 +171,50 @@ def _SparseTensorDenseMatMulGrad(op, grad):
# gradients w.r.t. (a_indices, a_values, a_shape, b)
return (None, a_values_grad, None, b_grad)
+
+
+def _SparseDenseCwiseMulOrDivGrad(op, grad, is_mul):
+ """Common code for SparseDenseCwise{Mul,Div} gradients."""
+ x_indices = op.inputs[0]
+ x_shape = op.inputs[2]
+ y = op.inputs[3]
+
+ y_shape = math_ops.to_int64(array_ops.shape(y))
+ num_added_dims = array_ops.expand_dims(
+ array_ops.size(x_shape) - array_ops.size(y_shape), 0)
+ augmented_y_shape = array_ops.concat(0, [array_ops.ones(num_added_dims,
+ ops.dtypes.int64),
+ y_shape])
+
+ scaling = x_shape // augmented_y_shape
+ scaled_indices = x_indices // scaling
+ scaled_indices = array_ops.slice(scaled_indices,
+ array_ops.concat(0, [[0], num_added_dims]),
+ [-1, -1])
+ dense_vals = array_ops.gather_nd(y, scaled_indices)
+
+ if is_mul:
+ dx = grad * dense_vals
+ dy_val = grad * op.inputs[1]
+ else:
+ dx = grad / dense_vals
+ dy_val = grad * (-op.inputs[1] / math_ops.square(dense_vals))
+ # indices can repeat after scaling, so we can't use sparse_to_dense().
+ dy = sparse_ops.sparse_add(
+ array_ops.zeros_like(y),
+ ops.SparseTensor(scaled_indices, dy_val, y_shape))
+
+ # (sp_indices, sp_vals, sp_shape, dense)
+ return (None, dx, None, dy)
+
+
+@ops.RegisterGradient("SparseDenseCwiseMul")
+def _SparseDenseCwiseMulGrad(op, grad):
+ """Gradients for SparseDenseCwiseMul."""
+ return _SparseDenseCwiseMulOrDivGrad(op, grad, True)
+
+
+@ops.RegisterGradient("SparseDenseCwiseDiv")
+def _SparseDenseCwiseDivGrad(op, grad):
+ """Gradients for SparseDenseCwiseDiv."""
+ return _SparseDenseCwiseMulOrDivGrad(op, grad, False)