aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-09-12 15:28:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 15:32:20 -0700
commit32a3642ef448d93706ab22e894637b2dd0c197c7 (patch)
treec873d9ec097cc297e7032667d7d803715e3bfa52 /tensorflow/compiler/tests
parent90876942a3f4403ebae7d1c9223c241e006eeaaa (diff)
Export the XLA dynamic-slice HLO as a TF op
I need this in a subsequent CL where I'll rewrite the Slice TF op to DynamicSlice in some cases. PiperOrigin-RevId: 212715067
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py41
1 files changed, 40 insertions, 1 deletions
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 0f3843dc1e..1e600c44e9 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
@@ -34,7 +35,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected,
equality_fn=None):
- with self.cached_session() as session:
+ with self.test_session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
@@ -296,6 +297,44 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
self._assertOpOutputMatchesExpected(
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
+ def testDynamicSlice(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.dynamic_slice,
+ args=(np.arange(1000,
+ dtype=np.int32).astype(dtype).reshape([10, 10, 10]),
+ np.array([5, 7, 3]), np.array([2, 3, 2])),
+ expected=np.array(
+ np.array([[[573, 574], [583, 584], [593, 594]],
+ [[673, 674], [683, 684], [693, 694]]]),
+ dtype=dtype))
+
+ def testDynamicSliceWithIncorrectStartIndicesShape(self):
+ with self.test_session() as session:
+ with self.test_scope():
+ output = xla.dynamic_slice(
+ np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
+ np.array([5, 7]), np.array([2, 3, 4]))
+ with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
+ session.run(output)
+ self.assertRegexpMatches(
+ invalid_arg_error.exception.message,
+ (r'^start_indices must be a vector with length equal to input rank, '
+ r'but input rank is 3 and start_indices has shape \[2\].*'))
+
+ def testDynamicSliceWithIncorrectSizeIndicesShape(self):
+ with self.test_session() as session:
+ with self.test_scope():
+ output = xla.dynamic_slice(
+ np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
+ np.array([5, 7, 3]), np.array([2, 3]))
+ with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
+ session.run(output)
+ self.assertRegexpMatches(
+ invalid_arg_error.exception.message,
+ (r'^size_indices must be a vector with length equal to input rank, '
+ r'but input rank is 3 and size_indices has shape \[2\].*'))
+
if __name__ == '__main__':
googletest.main()