aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/tensor_array_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/tensor_array_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py37
1 files changed, 34 insertions, 3 deletions
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index 78244d0b36..46ca371c8a 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -920,6 +920,34 @@ class TensorArrayTest(xla_test.XLATestCase):
def testTensorArrayEvalEmptyWithDefault(self):
self._testTensorArrayEvalEmptyWithDefault()
+ def _testTensorArrayScatterRead(self, tf_dtype):
+ with self.cached_session() as session, self.test_scope():
+ convert = _make_converter(tf_dtype)
+
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype,
+ tensor_array_name="foo",
+ size=10)
+
+ indices = constant_op.constant([1, 8])
+ value = constant_op.constant(convert([[1.0, -1.0], [10.0, -10.0]]))
+ id0 = array_ops.placeholder(dtypes.int32)
+ id1 = array_ops.placeholder(dtypes.int32)
+
+ w = ta.scatter(indices, value)
+ r0 = w.read(id0)
+ r1 = w.read(id1)
+
+ # Test aggregation of read
+ read_vals = session.run([r0, r1], feed_dict={id0: 1, id1: 8})
+ self.assertAllEqual(convert([1.0, -1.0]), read_vals[0])
+ self.assertAllEqual(convert([10.0, -10.0]), read_vals[1])
+
+ def testTensorArrayScatterRead(self):
+ for dtype in self.numeric_tf_types:
+ self._testTensorArrayScatterRead(dtype)
+ self._testTensorArrayScatterRead(dtypes.bool)
+
def testTensorArrayScatterReadAndGradients(self):
with self.cached_session() as session, self.test_scope():
ta = tensor_array_ops.TensorArray(
@@ -929,15 +957,18 @@ class TensorArrayTest(xla_test.XLATestCase):
indices = constant_op.constant([1, 8])
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
+ id0 = array_ops.placeholder(dtypes.int32)
+ id1 = array_ops.placeholder(dtypes.int32)
w = ta.scatter(indices, value)
- r0 = w.read(1)
- r1 = w.read(8)
+ r0 = w.read(id0)
+ r1 = w.read(id1)
# Test combined gradients + aggregation of read(0).
grad = gradients_impl.gradients(
ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]])
- read_vals, grad_vals = session.run([[r0, r1], grad])
+ read_vals, grad_vals = session.run([[r0, r1], grad],
+ feed_dict={id0: 1, id1: 8})
self.assertEqual(len(read_vals), 2)
self.assertEqual(len(grad_vals), 1)