diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py | 42 |
1 files changed, 27 insertions, 15 deletions
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py index df5462dd2d..e8b94294b1 100644 --- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py +++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py @@ -30,34 +30,44 @@ from tensorflow.python.platform import test class SparseTensorDenseMatMulGradientTest(test.TestCase): - def _sparsify(self, x): + def _sparsify(self, x, indices_dtype=np.int64): x[x < 0.5] = 0 non_zero = np.where(x) - x_indices = np.vstack(non_zero).astype(np.int64).T + x_indices = np.vstack(non_zero).astype(indices_dtype).T x_values = x[non_zero] x_shape = x.shape return sparse_tensor.SparseTensor( indices=x_indices, values=x_values, dense_shape=x_shape), len(x_values) - def _randomTensor(self, size, np_dtype, adjoint=False, sparse=False): + def _randomTensor(self, + size, + values_dtype, + adjoint=False, + sparse=False, + indices_dtype=np.int64): n, m = size - x = np.random.randn(n, m).astype(np_dtype) + x = np.random.randn(n, m).astype(values_dtype) if adjoint: x = x.transpose() if sparse: - return self._sparsify(x) + return self._sparsify(x, indices_dtype=indices_dtype) else: - return constant_op.constant(x, dtype=np_dtype) + return constant_op.constant(x, dtype=values_dtype) - def _testGradients(self, adjoint_a, adjoint_b, name, np_dtype): + def _testGradients(self, adjoint_a, adjoint_b, name, values_dtype, + indices_dtype): n, k, m = np.random.randint(1, 10, size=3) sp_t, nnz = self._randomTensor( - [n, k], np_dtype, adjoint=adjoint_a, sparse=True) - dense_t = self._randomTensor([k, m], np_dtype, adjoint=adjoint_b) + [n, k], + values_dtype, + adjoint=adjoint_a, + sparse=True, + indices_dtype=indices_dtype) + dense_t = self._randomTensor([k, m], values_dtype, adjoint=adjoint_b) matmul = sparse_ops.sparse_tensor_dense_matmul( sp_t, dense_t, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name=name) @@ -71,17 +81,19 @@ class SparseTensorDenseMatMulGradientTest(test.TestCase): print("%s gradient err = %s" % (name, err)) self.assertLess(err, 1e-3) - def _testGradientsType(self, np_dtype): + def _testGradientsType(self, values_dtype, indices_dtype): for adjoint_a in [True, False]: for adjoint_b in [True, False]: - name = "sparse_tensor_dense_matmul_%s_%s_%s" % (adjoint_a, adjoint_b, - np_dtype.__name__) - self._testGradients(adjoint_a, adjoint_b, name, np_dtype) + name = "sparse_tensor_dense_matmul_%s_%s_%s_%s" % ( + adjoint_a, adjoint_b, values_dtype.__name__, indices_dtype.__name__) + self._testGradients(adjoint_a, adjoint_b, name, values_dtype, + indices_dtype) def testGradients(self): np.random.seed(5) # Fix seed to avoid flakiness - self._testGradientsType(np.float32) - self._testGradientsType(np.float64) + self._testGradientsType(np.float32, np.int64) + self._testGradientsType(np.float64, np.int64) + self._testGradientsType(np.float32, np.int32) if __name__ == "__main__": |