aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/tensor_array_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/tensor_array_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py56
1 files changed, 30 insertions, 26 deletions
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index b3067be51d..f277314352 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -139,7 +139,7 @@ class TensorArrayTest(xla_test.XLATestCase):
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
- # Unpack a matrix into vectors
+ # Unpack a matrix into vectors.
w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
r0 = w1.read(0)
r1 = w1.read(1)
@@ -180,7 +180,7 @@ class TensorArrayTest(xla_test.XLATestCase):
convert = _make_converter(tf_dtype)
- # Split an empty vector
+ # Split an empty vector.
lengths = constant_op.constant([0, 0, 0])
w0 = ta.split(convert([]), lengths=lengths)
r0 = w0.read(0)
@@ -192,7 +192,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(convert([]), d1)
self.assertAllEqual(convert([]), d2)
- # Split a vector
+ # Split a vector.
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
lengths = constant_op.constant([1, 1, 1])
@@ -206,7 +206,7 @@ class TensorArrayTest(xla_test.XLATestCase):
self.assertAllEqual(convert([2.0]), d1)
self.assertAllEqual(convert([3.0]), d2)
- # Split a matrix
+ # Split a matrix.
ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
lengths = constant_op.constant([1, 1, 1])
@@ -319,27 +319,31 @@ class TensorArrayTest(xla_test.XLATestCase):
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
- # Test writing the wrong datatype
+ # Test writing the wrong datatype.
with self.assertRaisesOpError(
"TensorArray dtype is float but op has dtype int32"):
ta.write(-1, np.int32(7)).flow.eval()
def testTensorArrayReadWrongIndexOrDataTypeFails(self):
- with self.test_session(), self.test_scope():
- ta = tensor_array_ops.TensorArray(
- dtype=dtypes.float32, tensor_array_name="foo", size=3)
-
- w0 = ta.write(0, [[4.0, 5.0]])
-
- # Test reading wrong datatype
- r0_bad = gen_data_flow_ops._tensor_array_read_v3(
- handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
- with self.assertRaisesOpError(
- "TensorArray dtype is float but op has dtype double."):
- r0_bad.eval()
-
- # Test reading from a different index than the one we wrote to
- w0.read(1)
+ # Find two different floating point types, create an array of
+ # the first type, but try to read the other type.
+ if len(self.float_types) > 1:
+ dtype1 = self.float_types[0]
+ dtype2 = self.float_types[1]
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtype1, tensor_array_name="foo", size=3)
+
+ w0 = ta.write(0, [[4.0, 5.0]])
+
+ # Test reading wrong datatype.
+ r0_bad = gen_data_flow_ops._tensor_array_read_v3(
+ handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow)
+ with self.assertRaisesOpError("TensorArray dtype is "):
+ r0_bad.eval()
+
+ # Test reading from a different index than the one we wrote to
+ w0.read(1)
def testTensorArraySplitIncompatibleShapesFails(self):
with self.test_session(), self.test_scope():
@@ -487,7 +491,7 @@ class TensorArrayTest(xla_test.XLATestCase):
r0 = w1.read(0)
s0 = w1.concat()
- # Test gradient accumulation between read(0), pack(), and concat()
+ # Test gradient accumulation between read(0), pack(), and concat().
with ops.control_dependencies([p0, r0, s0]):
grad_r = gradients_impl.gradients(
ys=[p0, r0, s0],
@@ -536,7 +540,7 @@ class TensorArrayTest(xla_test.XLATestCase):
r0_1 = w.read(0)
r1 = w.read(1)
- # Test combined gradients + aggregation of read(0)
+ # Test combined gradients + aggregation of read(0).
grad = gradients_impl.gradients(
ys=[r0, r0_1, r1],
xs=[value],
@@ -744,7 +748,7 @@ class TensorArrayTest(xla_test.XLATestCase):
grad_b_t, = session.run([grad_b])
self.assertAllEqual(grad_b_t, g0)
- # Test gradients calculated jointly
+ # Test gradients calculated jointly.
joint_grad_a_t, joint_grad_b_t = session.run([grad_a, grad_b])
self.assertAllEqual(joint_grad_a_t, g0)
self.assertAllEqual(joint_grad_b_t, g0)
@@ -877,7 +881,7 @@ class TensorArrayTest(xla_test.XLATestCase):
x = constant_op.constant([2.0, 3.0])
w = ta.unstack(x)
r0 = w.read(0)
- # calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0).
+ # Calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0).
grad_r0 = gradients_impl.gradients(ys=[r0], xs=[x], grad_ys=[1.0])
grad_r0_vals = session.run(grad_r0)[0]
self.assertAllEqual(grad_r0_vals, [1.0, 0.0])
@@ -927,7 +931,7 @@ class TensorArrayTest(xla_test.XLATestCase):
r0 = w.read(1)
r1 = w.read(8)
- # Test combined gradients + aggregation of read(0)
+ # Test combined gradients + aggregation of read(0).
grad = gradients_impl.gradients(
ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]])
read_vals, grad_vals = session.run([[r0, r1], grad])
@@ -951,7 +955,7 @@ class TensorArrayTest(xla_test.XLATestCase):
w = ta.unstack(values)
g = w.gather(indices)
- # Test combined gradients + aggregation of read(0)
+ # Test combined gradients + aggregation of read(0).
grad = gradients_impl.gradients(
ys=[g], xs=[values], grad_ys=[[[2.0, 3.0], [4.0, 5.0]]])
g_vals, grad_vals = session.run([[g], grad])