aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar Yanan Cao <ycao@google.com>2018-09-12 08:33:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 08:42:00 -0700
commit9098f75af917df9b9d4f5ecc423037fd2fb365f9 (patch)
treeb9bd2fcf5873a318cc993981d03f0567d7976bef /tensorflow/compiler/tests
parent6bf71666feb2184771ec3d0d304329b50a9a4ad2 (diff)
Parameterize test matrix_band_part_test
PiperOrigin-RevId: 212643986
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/BUILD1
-rw-r--r--tensorflow/compiler/tests/matrix_band_part_test.py190
2 files changed, 161 insertions, 30 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 050d827a09..e7623582f6 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -581,6 +581,7 @@ tf_xla_py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py
index 9222db4b7e..c61965b97f 100644
--- a/tensorflow/compiler/tests/matrix_band_part_test.py
+++ b/tensorflow/compiler/tests/matrix_band_part_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
@@ -26,38 +27,167 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class MatrixBandPartTest(xla_test.XLATestCase):
+class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase):
- def _testMatrixBandPart(self, dtype, shape):
- with self.cached_session():
- batch_shape = shape[:-2]
- mat = np.ones(shape).astype(dtype)
- batch_mat = np.tile(mat, batch_shape + [1, 1])
- for lower in -1, 0, 1, shape[-2] - 1:
- for upper in -1, 0, 1, shape[-1] - 1:
- band_np = mat
- if lower >= 0:
- band_np = np.triu(band_np, -lower)
- if upper >= 0:
- band_np = np.tril(band_np, upper)
- if batch_shape:
- band_np = np.tile(band_np, batch_shape + [1, 1])
-
- placeholder = array_ops.placeholder(dtype)
- with self.test_scope():
- band = array_ops.matrix_band_part(
- placeholder,
- constant_op.constant(lower, dtype=dtypes.int32),
- constant_op.constant(upper, dtype=dtypes.int32))
- feed_dict = {placeholder: batch_mat}
- self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
-
- def testMatrixBandPart(self):
+ @parameterized.parameters(
+ {
+ 'batch_shape': [],
+ 'rows': 1,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 1,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 1,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 2,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 2,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 2,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 7,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 7,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 7,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 1,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 1,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 1,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 2,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 2,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 2,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 7,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 7,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 7,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 1,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 1,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 1,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 2,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 2,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 2,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 7,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 7,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 7,
+ 'cols': 7
+ },
+ )
+ def testMatrixBandPart(self, batch_shape, rows, cols):
for dtype in self.float_types:
- for batch_shape in [[], [2,], [1, 3, 2]]:
- for rows in 1, 2, 7:
- for cols in 1, 2, 7:
- self._testMatrixBandPart(dtype, batch_shape + [rows, cols])
+ with self.cached_session():
+ mat = np.ones(batch_shape + [rows, cols]).astype(dtype)
+ batch_mat = np.tile(mat, batch_shape + [1, 1])
+ for lower in -1, 0, 1, rows - 1:
+ for upper in -1, 0, 1, cols - 1:
+ band_np = mat
+ if lower >= 0:
+ band_np = np.triu(band_np, -lower)
+ if upper >= 0:
+ band_np = np.tril(band_np, upper)
+ if batch_shape:
+ band_np = np.tile(band_np, batch_shape + [1, 1])
+
+ placeholder = array_ops.placeholder(dtype)
+ with self.test_scope():
+ band = array_ops.matrix_band_part(
+ placeholder, constant_op.constant(lower, dtype=dtypes.int32),
+ constant_op.constant(upper, dtype=dtypes.int32))
+ feed_dict = {placeholder: batch_mat}
+ self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
if __name__ == "__main__":