aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
index da72803ee7..8099175186 100644
--- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
@@ -45,7 +45,12 @@ def _maybe_complex(x):
class SparseTensorDenseMatMulTest(test.TestCase):
- def _testMatmul(self, x, y, adjoint_a=False, adjoint_b=False):
+ def _testMatmul(self,
+ x,
+ y,
+ adjoint_a=False,
+ adjoint_b=False,
+ indices_dtype=np.int64):
x_mat = np.matrix(x)
if adjoint_a:
x_mat = x_mat.H
@@ -55,7 +60,7 @@ class SparseTensorDenseMatMulTest(test.TestCase):
np_ans = x_mat * y_mat
- x_indices = np.vstack(np.where(x)).astype(np.int64).T
+ x_indices = np.vstack(np.where(x)).astype(indices_dtype).T
x_values = x[np.where(x)]
x_shape = x.shape
@@ -82,13 +87,13 @@ class SparseTensorDenseMatMulTest(test.TestCase):
else:
self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
- def _testBasic(self, np_dtype):
- x = _maybe_complex(np.random.rand(10, 10).astype(np_dtype))
+ def _testBasic(self, value_dtype, indices_dtype=np.int64):
+ x = _maybe_complex(np.random.rand(10, 10).astype(value_dtype))
x[np.abs(x) < 0.5] = 0 # Make it sparse
- y = _maybe_complex(np.random.randn(10, 20).astype(np_dtype))
+ y = _maybe_complex(np.random.randn(10, 20).astype(value_dtype))
- self._testMatmul(x, y)
+ self._testMatmul(x, y, indices_dtype=indices_dtype)
def testBasic(self):
np.random.seed(127) # Repeatable results
@@ -97,6 +102,8 @@ class SparseTensorDenseMatMulTest(test.TestCase):
self._testBasic(np.float64)
self._testBasic(np.complex64)
self._testBasic(np.complex128)
+ self._testBasic(np.int32, indices_dtype=np.int32)
+ self._testBasic(np.float32, indices_dtype=np.int32)
def testShapeInference(self):
x = np.random.rand(10, 10)