aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/binary_ops_test.py
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-10-02 14:38:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-02 14:49:11 -0700
commit553d10cfe42edcb6b3b8d748b315f13925fcf28f (patch)
tree4ac0714f0fbb82e3ce27741450f9e6b27df0dc7c /tensorflow/compiler/tests/binary_ops_test.py
parent061897179e9f576380f72fe2131cd48d4af3b581 (diff)
[TF:XLA] Add support for negative values of "split_dim" argument to Split operator.
PiperOrigin-RevId: 170755169
Diffstat (limited to 'tensorflow/compiler/tests/binary_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py46
1 files changed, 24 insertions, 22 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index f3ea57596e..792c01327c 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -790,28 +790,30 @@ class BinaryOpsTest(XLATestCase):
def testSplit(self):
for dtype in self.numeric_types:
- self._testBinary(
- lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x),
- np.int32(0),
- np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
- dtype=dtype),
- expected=[
- np.array([[[1], [2]]], dtype=dtype),
- np.array([[[3], [4]]], dtype=dtype),
- np.array([[[5], [6]]], dtype=dtype),
- ],
- equality_test=self.ListsAreClose)
-
- self._testBinary(
- lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x),
- np.int32(1),
- np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
- dtype=dtype),
- expected=[
- np.array([[[1]], [[3]], [[5]]], dtype=dtype),
- np.array([[[2]], [[4]], [[6]]], dtype=dtype),
- ],
- equality_test=self.ListsAreClose)
+ for axis in [0, -3]:
+ self._testBinary(
+ lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x),
+ np.int32(axis),
+ np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
+ dtype=dtype),
+ expected=[
+ np.array([[[1], [2]]], dtype=dtype),
+ np.array([[[3], [4]]], dtype=dtype),
+ np.array([[[5], [6]]], dtype=dtype),
+ ],
+ equality_test=self.ListsAreClose)
+
+ for axis in [1, -2]:
+ self._testBinary(
+ lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x),
+ np.int32(axis),
+ np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
+ dtype=dtype),
+ expected=[
+ np.array([[[1]], [[3]], [[5]]], dtype=dtype),
+ np.array([[[2]], [[4]], [[6]]], dtype=dtype),
+ ],
+ equality_test=self.ListsAreClose)
def testTile(self):
for dtype in self.numeric_types: