diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py index da72803ee7..8099175186 100644 --- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py @@ -45,7 +45,12 @@ def _maybe_complex(x): class SparseTensorDenseMatMulTest(test.TestCase): - def _testMatmul(self, x, y, adjoint_a=False, adjoint_b=False): + def _testMatmul(self, + x, + y, + adjoint_a=False, + adjoint_b=False, + indices_dtype=np.int64): x_mat = np.matrix(x) if adjoint_a: x_mat = x_mat.H @@ -55,7 +60,7 @@ class SparseTensorDenseMatMulTest(test.TestCase): np_ans = x_mat * y_mat - x_indices = np.vstack(np.where(x)).astype(np.int64).T + x_indices = np.vstack(np.where(x)).astype(indices_dtype).T x_values = x[np.where(x)] x_shape = x.shape @@ -82,13 +87,13 @@ class SparseTensorDenseMatMulTest(test.TestCase): else: self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4) - def _testBasic(self, np_dtype): - x = _maybe_complex(np.random.rand(10, 10).astype(np_dtype)) + def _testBasic(self, value_dtype, indices_dtype=np.int64): + x = _maybe_complex(np.random.rand(10, 10).astype(value_dtype)) x[np.abs(x) < 0.5] = 0 # Make it sparse - y = _maybe_complex(np.random.randn(10, 20).astype(np_dtype)) + y = _maybe_complex(np.random.randn(10, 20).astype(value_dtype)) - self._testMatmul(x, y) + self._testMatmul(x, y, indices_dtype=indices_dtype) def testBasic(self): np.random.seed(127) # Repeatable results @@ -97,6 +102,8 @@ class SparseTensorDenseMatMulTest(test.TestCase): self._testBasic(np.float64) self._testBasic(np.complex64) self._testBasic(np.complex128) + self._testBasic(np.int32, indices_dtype=np.int32) + self._testBasic(np.float32, indices_dtype=np.int32) def testShapeInference(self): x = np.random.rand(10, 10) |