diff options
author | 2018-02-28 03:39:04 -0800 | |
---|---|---|
committer | 2018-02-28 03:43:21 -0800 | |
commit | 6ac343bdfc942678d64dcbfc4d4fc90c0df6a4a0 (patch) | |
tree | 6791b60d1544e7eba741a6ee90fa5528c3ac0408 /tensorflow/compiler/tests/binary_ops_test.py | |
parent | 503d9b522e28272e032bc45a10e3c0f21398a16e (diff) |
[TF:XLA] Fix SplitV implementation to support negative split_dim.
Mirror behavior of Split op when a negative split_dim is used.
PiperOrigin-RevId: 187304771
Diffstat (limited to 'tensorflow/compiler/tests/binary_ops_test.py')
-rw-r--r-- | tensorflow/compiler/tests/binary_ops_test.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 30a6d3a74d..0e4efaed86 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -1045,6 +1045,20 @@ class BinaryOpsTest(XLATestCase): ], equality_test=self.ListsAreClose) + def splitvOp(x, y): # pylint: disable=invalid-name + return array_ops.split(value=y, num_or_size_splits=[2, 3], axis=x) + for axis in [1, -1]: + self._testBinary( + splitvOp, + np.int32(axis), + np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], + dtype=dtype), + expected=[ + np.array([[0, 1], [5, 6]], dtype=dtype), + np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype), + ], + equality_test=self.ListsAreClose) + def testTile(self): for dtype in self.numeric_types: self._testBinary( |