diff options
Diffstat (limited to 'tensorflow/python/ops/sparse_grad.py')
-rw-r--r-- | tensorflow/python/ops/sparse_grad.py | 14 |
1 files changed, 6 insertions, 8 deletions
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py index fa015856ce..b8e356c78c 100644 --- a/tensorflow/python/ops/sparse_grad.py +++ b/tensorflow/python/ops/sparse_grad.py @@ -136,12 +136,13 @@ def _SparseTensorDenseMatMulGrad(op, grad): Raises: TypeError: When the two operands don't have the same type. """ - sp_t = sparse_tensor.SparseTensor(*op.inputs[:3]) + a_indices, a_values, a_shape = op.inputs[:3] + b = op.inputs[3] adj_a = op.get_attr("adjoint_a") adj_b = op.get_attr("adjoint_b") - a_type = sp_t.values.dtype.base_dtype - b_type = op.inputs[3].dtype.base_dtype + a_type = a_values.dtype.base_dtype + b_type = b.dtype.base_dtype if a_type != b_type: raise TypeError("SparseTensorDenseMatMul op received operands with " "different types: ", a_type, " and ", b_type) @@ -150,15 +151,12 @@ def _SparseTensorDenseMatMulGrad(op, grad): "complex gradients.") # gradient w.r.t. dense - b_grad = sparse_ops.sparse_tensor_dense_matmul(sp_t, grad, - adjoint_a=not adj_a) + b_grad = gen_sparse_ops._sparse_tensor_dense_mat_mul( # pylint: disable=protected-access + a_indices, a_values, a_shape, grad, adjoint_a=not adj_a) if adj_b: b_grad = array_ops.transpose(b_grad) # gradient w.r.t. sparse values - a_indices = op.inputs[0] - b = op.inputs[3] - rows = a_indices[:, 0] cols = a_indices[:, 1] |