From 32a3642ef448d93706ab22e894637b2dd0c197c7 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 12 Sep 2018 15:28:47 -0700 Subject: Export the XLA dynamic-slice HLO as a TF op I need this in a subsequent CL where I'll rewrite the Slice TF op to DynamicSlice in some cases. PiperOrigin-RevId: 212715067 --- tensorflow/compiler/tests/xla_ops_test.py | 41 ++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) (limited to 'tensorflow/compiler/tests') diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 0f3843dc1e..1e600c44e9 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest @@ -34,7 +35,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): def _assertOpOutputMatchesExpected(self, op, args, expected, equality_fn=None): - with self.cached_session() as session: + with self.test_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) @@ -296,6 +297,44 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): self._assertOpOutputMatchesExpected( lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) + def testDynamicSlice(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.dynamic_slice, + args=(np.arange(1000, + dtype=np.int32).astype(dtype).reshape([10, 10, 10]), + np.array([5, 7, 3]), np.array([2, 3, 2])), + expected=np.array( + np.array([[[573, 574], [583, 584], [593, 594]], + [[673, 674], [683, 684], [693, 694]]]), + dtype=dtype)) + + def testDynamicSliceWithIncorrectStartIndicesShape(self): + with self.test_session() as session: + with self.test_scope(): + output = xla.dynamic_slice( + np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), + np.array([5, 7]), np.array([2, 3, 4])) + with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error: + session.run(output) + self.assertRegexpMatches( + invalid_arg_error.exception.message, + (r'^start_indices must be a vector with length equal to input rank, ' + r'but input rank is 3 and start_indices has shape \[2\].*')) + + def testDynamicSliceWithIncorrectSizeIndicesShape(self): + with self.test_session() as session: + with self.test_scope(): + output = xla.dynamic_slice( + np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), + np.array([5, 7, 3]), np.array([2, 3])) + with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error: + session.run(output) + self.assertRegexpMatches( + invalid_arg_error.exception.message, + (r'^size_indices must be a vector with length equal to input rank, ' + r'but input rank is 3 and size_indices has shape \[2\].*')) + if __name__ == '__main__': googletest.main() -- cgit v1.2.3