From ee7c9597f4ab8e586e921f9fe3e3c1383417169c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 10 Oct 2018 02:22:32 -0700 Subject: Emit xla::Or in TensorArrayScatterV3 for PRED types instead of xla::Add Previosuly we emitted xla::Add what isn't supported by some XLA backend on PRED types. PiperOrigin-RevId: 216497939 --- tensorflow/compiler/tests/tensor_array_ops_test.py | 37 ++++++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) (limited to 'tensorflow/compiler/tests/tensor_array_ops_test.py') diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 78244d0b36..46ca371c8a 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -920,6 +920,34 @@ class TensorArrayTest(xla_test.XLATestCase): def testTensorArrayEvalEmptyWithDefault(self): self._testTensorArrayEvalEmptyWithDefault() + def _testTensorArrayScatterRead(self, tf_dtype): + with self.cached_session() as session, self.test_scope(): + convert = _make_converter(tf_dtype) + + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, + tensor_array_name="foo", + size=10) + + indices = constant_op.constant([1, 8]) + value = constant_op.constant(convert([[1.0, -1.0], [10.0, -10.0]])) + id0 = array_ops.placeholder(dtypes.int32) + id1 = array_ops.placeholder(dtypes.int32) + + w = ta.scatter(indices, value) + r0 = w.read(id0) + r1 = w.read(id1) + + # Test aggregation of read + read_vals = session.run([r0, r1], feed_dict={id0: 1, id1: 8}) + self.assertAllEqual(convert([1.0, -1.0]), read_vals[0]) + self.assertAllEqual(convert([10.0, -10.0]), read_vals[1]) + + def testTensorArrayScatterRead(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayScatterRead(dtype) + self._testTensorArrayScatterRead(dtypes.bool) + def testTensorArrayScatterReadAndGradients(self): with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -929,15 +957,18 @@ class TensorArrayTest(xla_test.XLATestCase): indices = constant_op.constant([1, 8]) value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + id0 = array_ops.placeholder(dtypes.int32) + id1 = array_ops.placeholder(dtypes.int32) w = ta.scatter(indices, value) - r0 = w.read(1) - r1 = w.read(8) + r0 = w.read(id0) + r1 = w.read(id1) # Test combined gradients + aggregation of read(0). grad = gradients_impl.gradients( ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]]) - read_vals, grad_vals = session.run([[r0, r1], grad]) + read_vals, grad_vals = session.run([[r0, r1], grad], + feed_dict={id0: 1, id1: 8}) self.assertEqual(len(read_vals), 2) self.assertEqual(len(grad_vals), 1) -- cgit v1.2.3