aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/xla_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/xla_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 3f928a1bea..1e600c44e9 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
@@ -296,6 +297,44 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
self._assertOpOutputMatchesExpected(
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
+ def testDynamicSlice(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.dynamic_slice,
+ args=(np.arange(1000,
+ dtype=np.int32).astype(dtype).reshape([10, 10, 10]),
+ np.array([5, 7, 3]), np.array([2, 3, 2])),
+ expected=np.array(
+ np.array([[[573, 574], [583, 584], [593, 594]],
+ [[673, 674], [683, 684], [693, 694]]]),
+ dtype=dtype))
+
+ def testDynamicSliceWithIncorrectStartIndicesShape(self):
+ with self.test_session() as session:
+ with self.test_scope():
+ output = xla.dynamic_slice(
+ np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
+ np.array([5, 7]), np.array([2, 3, 4]))
+ with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
+ session.run(output)
+ self.assertRegexpMatches(
+ invalid_arg_error.exception.message,
+ (r'^start_indices must be a vector with length equal to input rank, '
+ r'but input rank is 3 and start_indices has shape \[2\].*'))
+
+ def testDynamicSliceWithIncorrectSizeIndicesShape(self):
+ with self.test_session() as session:
+ with self.test_scope():
+ output = xla.dynamic_slice(
+ np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
+ np.array([5, 7, 3]), np.array([2, 3]))
+ with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
+ session.run(output)
+ self.assertRegexpMatches(
+ invalid_arg_error.exception.message,
+ (r'^size_indices must be a vector with length equal to input rank, '
+ r'but input rank is 3 and size_indices has shape \[2\].*'))
+
if __name__ == '__main__':
googletest.main()