diff options
author | Yanan Cao <ycao@google.com> | 2018-09-12 08:33:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 08:42:00 -0700 |
commit | 9098f75af917df9b9d4f5ecc423037fd2fb365f9 (patch) | |
tree | b9bd2fcf5873a318cc993981d03f0567d7976bef /tensorflow/compiler/tests | |
parent | 6bf71666feb2184771ec3d0d304329b50a9a4ad2 (diff) |
Parameterize test matrix_band_part_test
PiperOrigin-RevId: 212643986
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r-- | tensorflow/compiler/tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tests/matrix_band_part_test.py | 190 |
2 files changed, 161 insertions, 30 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 050d827a09..e7623582f6 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -581,6 +581,7 @@ tf_xla_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 9222db4b7e..c61965b97f 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test @@ -26,38 +27,167 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class MatrixBandPartTest(xla_test.XLATestCase): +class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase): - def _testMatrixBandPart(self, dtype, shape): - with self.cached_session(): - batch_shape = shape[:-2] - mat = np.ones(shape).astype(dtype) - batch_mat = np.tile(mat, batch_shape + [1, 1]) - for lower in -1, 0, 1, shape[-2] - 1: - for upper in -1, 0, 1, shape[-1] - 1: - band_np = mat - if lower >= 0: - band_np = np.triu(band_np, -lower) - if upper >= 0: - band_np = np.tril(band_np, upper) - if batch_shape: - band_np = np.tile(band_np, batch_shape + [1, 1]) - - placeholder = array_ops.placeholder(dtype) - with self.test_scope(): - band = array_ops.matrix_band_part( - placeholder, - constant_op.constant(lower, dtype=dtypes.int32), - constant_op.constant(upper, dtype=dtypes.int32)) - feed_dict = {placeholder: batch_mat} - self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) - - def testMatrixBandPart(self): + @parameterized.parameters( + { + 'batch_shape': [], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 7 + }, + ) + def testMatrixBandPart(self, batch_shape, rows, cols): for dtype in self.float_types: - for batch_shape in [[], [2,], [1, 3, 2]]: - for rows in 1, 2, 7: - for cols in 1, 2, 7: - self._testMatrixBandPart(dtype, batch_shape + [rows, cols]) + with self.cached_session(): + mat = np.ones(batch_shape + [rows, cols]).astype(dtype) + batch_mat = np.tile(mat, batch_shape + [1, 1]) + for lower in -1, 0, 1, rows - 1: + for upper in -1, 0, 1, cols - 1: + band_np = mat + if lower >= 0: + band_np = np.triu(band_np, -lower) + if upper >= 0: + band_np = np.tril(band_np, upper) + if batch_shape: + band_np = np.tile(band_np, batch_shape + [1, 1]) + + placeholder = array_ops.placeholder(dtype) + with self.test_scope(): + band = array_ops.matrix_band_part( + placeholder, constant_op.constant(lower, dtype=dtypes.int32), + constant_op.constant(upper, dtype=dtypes.int32)) + feed_dict = {placeholder: batch_mat} + self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) if __name__ == "__main__": |