aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/sparse_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/sparse_grad.py')
-rw-r--r--tensorflow/python/ops/sparse_grad.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
index fa015856ce..b8e356c78c 100644
--- a/tensorflow/python/ops/sparse_grad.py
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -136,12 +136,13 @@ def _SparseTensorDenseMatMulGrad(op, grad):
Raises:
TypeError: When the two operands don't have the same type.
"""
- sp_t = sparse_tensor.SparseTensor(*op.inputs[:3])
+ a_indices, a_values, a_shape = op.inputs[:3]
+ b = op.inputs[3]
adj_a = op.get_attr("adjoint_a")
adj_b = op.get_attr("adjoint_b")
- a_type = sp_t.values.dtype.base_dtype
- b_type = op.inputs[3].dtype.base_dtype
+ a_type = a_values.dtype.base_dtype
+ b_type = b.dtype.base_dtype
if a_type != b_type:
raise TypeError("SparseTensorDenseMatMul op received operands with "
"different types: ", a_type, " and ", b_type)
@@ -150,15 +151,12 @@ def _SparseTensorDenseMatMulGrad(op, grad):
"complex gradients.")
# gradient w.r.t. dense
- b_grad = sparse_ops.sparse_tensor_dense_matmul(sp_t, grad,
- adjoint_a=not adj_a)
+ b_grad = gen_sparse_ops._sparse_tensor_dense_mat_mul( # pylint: disable=protected-access
+ a_indices, a_values, a_shape, grad, adjoint_a=not adj_a)
if adj_b:
b_grad = array_ops.transpose(b_grad)
# gradient w.r.t. sparse values
- a_indices = op.inputs[0]
- b = op.inputs[3]
-
rows = a_indices[:, 0]
cols = a_indices[:, 1]