aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/binary_ops_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-28 03:39:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-28 03:43:21 -0800
commit6ac343bdfc942678d64dcbfc4d4fc90c0df6a4a0 (patch)
tree6791b60d1544e7eba741a6ee90fa5528c3ac0408 /tensorflow/compiler/tests/binary_ops_test.py
parent503d9b522e28272e032bc45a10e3c0f21398a16e (diff)
[TF:XLA] Fix SplitV implementation to support negative split_dim.
Mirror behavior of Split op when a negative split_dim is used. PiperOrigin-RevId: 187304771
Diffstat (limited to 'tensorflow/compiler/tests/binary_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 30a6d3a74d..0e4efaed86 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1045,6 +1045,20 @@ class BinaryOpsTest(XLATestCase):
],
equality_test=self.ListsAreClose)
+ def splitvOp(x, y): # pylint: disable=invalid-name
+ return array_ops.split(value=y, num_or_size_splits=[2, 3], axis=x)
+ for axis in [1, -1]:
+ self._testBinary(
+ splitvOp,
+ np.int32(axis),
+ np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]],
+ dtype=dtype),
+ expected=[
+ np.array([[0, 1], [5, 6]], dtype=dtype),
+ np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype),
+ ],
+ equality_test=self.ListsAreClose)
+
def testTile(self):
for dtype in self.numeric_types:
self._testBinary(