aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-04-05 11:54:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-05 13:01:49 -0700
commit91a3950a162ccdacaebbd816c525eca7f5a8ea69 (patch)
treeb308fa4352d4c507b0e742f2f1aa3c1136dd1147
parent37baf7a5fcfaffe98f1618db0f8ab5654871b151 (diff)
Minor tweaks to TensorArray & gfile python API.
Change: 119080278
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py456
-rw-r--r--tensorflow/python/ops/tensor_array_grad.py1
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py4
3 files changed, 187 insertions, 274 deletions
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 6bcfff86b7..0c68fd2d8a 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""Tests for tensorflow.ops.tensor_array_ops."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -28,14 +29,15 @@ from tensorflow.python.ops import tensor_array_grad
from tensorflow.python.ops import tensor_array_ops
-class TensorArrayTest(tf.test.TestCase):
+class TensorArrayCPUTest(tf.test.TestCase):
+ _use_gpu = False
- def _testTensorArrayWriteRead(self, use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
- h = tensor_array_ops.TensorArray(
+ def testTensorArrayWriteRead(self):
+ with self.test_session(use_gpu=self._use_gpu) as session:
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
- w0 = h.write(0, [[4.0, 5.0]])
+ w0 = ta.write(0, [[4.0, 5.0]])
w1 = w0.write(1, [[1.0]])
w2 = w1.write(2, -3.0)
@@ -43,19 +45,15 @@ class TensorArrayTest(tf.test.TestCase):
r1 = w2.read(1)
r2 = w2.read(2)
- d0, d1, d2 = sess.run([r0, r1, r2])
+ d0, d1, d2 = session.run([r0, r1, r2])
self.assertAllEqual([[4.0, 5.0]], d0)
self.assertAllEqual([[1.0]], d1)
self.assertAllEqual(-3.0, d2)
- def testTensorArrayWriteRead(self):
- self._testTensorArrayWriteRead(use_gpu=False)
- self._testTensorArrayWriteRead(use_gpu=True)
-
- def _testTensorArrayWritePack(self, tf_dtype, use_gpu):
+ def _testTensorArrayWritePack(self, tf_dtype):
dtype = tf_dtype.as_numpy_dtype()
- with self.test_session(use_gpu=use_gpu):
- h = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=self._use_gpu):
+ ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
if tf_dtype == tf.string:
@@ -64,7 +62,7 @@ class TensorArrayTest(tf.test.TestCase):
else:
convert = lambda x: np.asarray(x).astype(dtype)
- w0 = h.write(0, convert([[4.0, 5.0]]))
+ w0 = ta.write(0, convert([[4.0, 5.0]]))
w1 = w0.write(1, convert([[6.0, 7.0]]))
w2 = w1.write(2, convert([[8.0, 9.0]]))
@@ -73,22 +71,18 @@ class TensorArrayTest(tf.test.TestCase):
self.assertAllEqual(
convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0.eval())
- def _testTensorArrayWritePackWithType(self, tf_dtype):
- self._testTensorArrayWritePack(tf_dtype=tf_dtype, use_gpu=False)
- self._testTensorArrayWritePack(tf_dtype=tf_dtype, use_gpu=True)
-
def testTensorArrayWritePack(self):
- self._testTensorArrayWritePackWithType(tf.float32)
- self._testTensorArrayWritePackWithType(tf.float64)
- self._testTensorArrayWritePackWithType(tf.int32)
- self._testTensorArrayWritePackWithType(tf.int64)
- self._testTensorArrayWritePackWithType(tf.complex64)
- self._testTensorArrayWritePackWithType(tf.string)
-
- def _testTensorArrayWriteConcat(self, tf_dtype, use_gpu):
+ self._testTensorArrayWritePack(tf.float32)
+ self._testTensorArrayWritePack(tf.float64)
+ self._testTensorArrayWritePack(tf.int32)
+ self._testTensorArrayWritePack(tf.int64)
+ self._testTensorArrayWritePack(tf.complex64)
+ self._testTensorArrayWritePack(tf.string)
+
+ def _testTensorArrayWriteConcat(self, tf_dtype):
dtype = tf_dtype.as_numpy_dtype()
- with self.test_session(use_gpu=use_gpu):
- h = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=self._use_gpu):
+ ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
if tf_dtype == tf.string:
@@ -97,7 +91,7 @@ class TensorArrayTest(tf.test.TestCase):
else:
convert = lambda x: np.asarray(x).astype(dtype)
- w0 = h.write(0, convert([[4.0, 5.0], [104.0, 105.0], [204.0, 205.0]]))
+ w0 = ta.write(0, convert([[4.0, 5.0], [104.0, 105.0], [204.0, 205.0]]))
w1 = w0.write(1, convert([[6.0, 7.0], [106.0, 107.0]]))
w2 = w1.write(2, convert([[8.0, 9.0]]))
@@ -111,88 +105,80 @@ class TensorArrayTest(tf.test.TestCase):
[106.0, 107.0],
[8.0, 9.0]]), c0.eval())
- def _testTensorArrayWriteConcatWithType(self, tf_dtype):
- self._testTensorArrayWriteConcat(tf_dtype=tf_dtype, use_gpu=False)
- self._testTensorArrayWriteConcat(tf_dtype=tf_dtype, use_gpu=True)
-
def testTensorArrayWriteConcat(self):
- self._testTensorArrayWriteConcatWithType(tf.float32)
- self._testTensorArrayWriteConcatWithType(tf.float64)
- self._testTensorArrayWriteConcatWithType(tf.int32)
- self._testTensorArrayWriteConcatWithType(tf.int64)
- self._testTensorArrayWriteConcatWithType(tf.complex64)
- self._testTensorArrayWriteConcatWithType(tf.string)
+ self._testTensorArrayWriteConcat(tf.float32)
+ self._testTensorArrayWriteConcat(tf.float64)
+ self._testTensorArrayWriteConcat(tf.int32)
+ self._testTensorArrayWriteConcat(tf.int64)
+ self._testTensorArrayWriteConcat(tf.complex64)
+ self._testTensorArrayWriteConcat(tf.string)
def testTensorArrayUnpackWrongMajorSizeFails(self):
with self.test_session():
- h = tensor_array_ops.TensorArray(
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
with self.assertRaisesOpError(
r"Input value must have first dimension "
r"equal to the array size \(2 vs. 3\)"):
- h.unpack([1.0, 2.0]).flow.eval()
+ ta.unpack([1.0, 2.0]).flow.eval()
def testTensorArrayPackNotAllValuesAvailableFails(self):
with self.test_session():
- h = tensor_array_ops.TensorArray(
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
with self.assertRaisesOpError(
"Could not read from TensorArray index 1 "
"because it has not yet been written to."):
- h.write(0, [[4.0, 5.0]]).pack().eval()
+ ta.write(0, [[4.0, 5.0]]).pack().eval()
- def _testTensorArrayUnpackRead(self, tf_dtype, use_gpu):
+ def _testTensorArrayUnpackRead(self, tf_dtype):
dtype = tf_dtype.as_numpy_dtype()
- with self.test_session(use_gpu=use_gpu) as sess:
- h = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=self._use_gpu) as session:
+ ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
- if tf_dtype == tf.string:
+ if tf_dtype is tf.string:
# In Python3, np.str is unicode, while we always want bytes
convert = lambda x: np.asarray(x).astype("|S")
else:
convert = lambda x: np.asarray(x).astype(dtype)
# Unpack a vector into scalars
- w0 = h.unpack(convert([1.0, 2.0, 3.0]))
+ w0 = ta.unpack(convert([1.0, 2.0, 3.0]))
r0 = w0.read(0)
r1 = w0.read(1)
r2 = w0.read(2)
- d0, d1, d2 = sess.run([r0, r1, r2])
+ d0, d1, d2 = session.run([r0, r1, r2])
self.assertAllEqual(convert(1.0), d0)
self.assertAllEqual(convert(2.0), d1)
self.assertAllEqual(convert(3.0), d2)
# Unpack a matrix into vectors
- w1 = h.unpack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
+ w1 = ta.unpack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
r0 = w1.read(0)
r1 = w1.read(1)
r2 = w1.read(2)
- d0, d1, d2 = sess.run([r0, r1, r2])
+ d0, d1, d2 = session.run([r0, r1, r2])
self.assertAllEqual(convert([1.0, 1.1]), d0)
self.assertAllEqual(convert([2.0, 2.1]), d1)
self.assertAllEqual(convert([3.0, 3.1]), d2)
- def _testTensorArrayUnpackReadWithType(self, tf_dtype):
- self._testTensorArrayUnpackRead(tf_dtype=tf_dtype, use_gpu=False)
- self._testTensorArrayUnpackRead(tf_dtype=tf_dtype, use_gpu=True)
-
def testTensorArrayUnpackRead(self):
- self._testTensorArrayUnpackReadWithType(tf.float32)
- self._testTensorArrayUnpackReadWithType(tf.float64)
- self._testTensorArrayUnpackReadWithType(tf.int32)
- self._testTensorArrayUnpackReadWithType(tf.int64)
- self._testTensorArrayUnpackReadWithType(tf.complex64)
- self._testTensorArrayUnpackReadWithType(tf.string)
-
- def _testTensorArraySplitRead(self, tf_dtype, use_gpu):
+ self._testTensorArrayUnpackRead(tf.float32)
+ self._testTensorArrayUnpackRead(tf.float64)
+ self._testTensorArrayUnpackRead(tf.int32)
+ self._testTensorArrayUnpackRead(tf.int64)
+ self._testTensorArrayUnpackRead(tf.complex64)
+ self._testTensorArrayUnpackRead(tf.string)
+
+ def _testTensorArraySplitRead(self, tf_dtype):
dtype = tf_dtype.as_numpy_dtype()
- with self.test_session(use_gpu=use_gpu) as sess:
- h = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=self._use_gpu) as session:
+ ta = tensor_array_ops.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
if tf_dtype == tf.string:
@@ -203,65 +189,61 @@ class TensorArrayTest(tf.test.TestCase):
# Split an empty vector
lengths = tf.constant([0, 0, 0])
- w0 = h.split(convert([]), lengths=lengths)
+ w0 = ta.split(convert([]), lengths=lengths)
r0 = w0.read(0)
r1 = w0.read(1)
r2 = w0.read(2)
- d0, d1, d2 = sess.run([r0, r1, r2])
+ d0, d1, d2 = session.run([r0, r1, r2])
self.assertAllEqual(convert([]), d0)
self.assertAllEqual(convert([]), d1)
self.assertAllEqual(convert([]), d2)
# Split a vector
lengths = tf.constant([2, 0, 1])
- w0 = h.split(
+ w0 = ta.split(
convert([1.0, 2.0, 3.0]), lengths=lengths)
r0 = w0.read(0)
r1 = w0.read(1)
r2 = w0.read(2)
- d0, d1, d2 = sess.run([r0, r1, r2])
+ d0, d1, d2 = session.run([r0, r1, r2])
self.assertAllEqual(convert([1.0, 2.0]), d0)
self.assertAllEqual(convert([]), d1)
self.assertAllEqual(convert([3.0]), d2)
# Split a matrix
lengths = tf.constant([2, 0, 1])
- w0 = h.split(
+ w0 = ta.split(
convert([[1.0, 101.0], [2.0, 201.0], [3.0, 301.0]]), lengths=lengths)
r0 = w0.read(0)
r1 = w0.read(1)
r2 = w0.read(2)
- d0, d1, d2 = sess.run([r0, r1, r2])
+ d0, d1, d2 = session.run([r0, r1, r2])
self.assertAllEqual(convert([[1.0, 101.0], [2.0, 201.0]]), d0)
self.assertAllEqual(convert([]).reshape(0, 2), d1)
self.assertAllEqual(convert([[3.0, 301.0]]), d2)
- def _testTensorArraySplitReadWithType(self, tf_dtype):
- self._testTensorArraySplitRead(tf_dtype=tf_dtype, use_gpu=False)
- self._testTensorArraySplitRead(tf_dtype=tf_dtype, use_gpu=True)
-
def testTensorArraySplitRead(self):
- self._testTensorArraySplitReadWithType(tf.float32)
- self._testTensorArraySplitReadWithType(tf.float64)
- self._testTensorArraySplitReadWithType(tf.int32)
- self._testTensorArraySplitReadWithType(tf.int64)
- self._testTensorArraySplitReadWithType(tf.complex64)
- self._testTensorArraySplitReadWithType(tf.string)
-
- def _testTensorGradArrayWriteRead(self, use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
- h = tensor_array_ops.TensorArray(
+ self._testTensorArraySplitRead(tf.float32)
+ self._testTensorArraySplitRead(tf.float64)
+ self._testTensorArraySplitRead(tf.int32)
+ self._testTensorArraySplitRead(tf.int64)
+ self._testTensorArraySplitRead(tf.complex64)
+ self._testTensorArraySplitRead(tf.string)
+
+ def testTensorGradArrayWriteRead(self):
+ with self.test_session(use_gpu=self._use_gpu) as session:
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
- g_h = h.grad("grad")
+ g_ta = ta.grad("grad")
- w0 = h.write(0, [[4.0, 5.0]])
+ w0 = ta.write(0, [[4.0, 5.0]])
w1 = w0.write(1, [[1.0]])
w2 = w1.write(2, -3.0)
- g_w0 = g_h.write(0, [[5.0, 6.0]])
+ g_w0 = g_ta.write(0, [[5.0, 6.0]])
g_w1 = g_w0.write(1, [[2.0]])
g_w2 = g_w1.write(2, -2.0)
@@ -273,7 +255,7 @@ class TensorArrayTest(tf.test.TestCase):
g_r1 = g_w2.read(1)
g_r2 = g_w2.read(2)
- d0, d1, d2, g_d0, g_d1, g_d2 = sess.run([r0, r1, r2, g_r0, g_r1, g_r2])
+ d0, d1, d2, g_d0, g_d1, g_d2 = session.run([r0, r1, r2, g_r0, g_r1, g_r2])
self.assertAllEqual([[4.0, 5.0]], d0)
self.assertAllEqual([[1.0]], d1)
self.assertAllEqual(-3.0, d2)
@@ -281,25 +263,21 @@ class TensorArrayTest(tf.test.TestCase):
self.assertAllEqual([[2.0]], g_d1)
self.assertAllEqual(-2.0, g_d2)
- def testTensorGradArrayWriteRead(self):
- self._testTensorGradArrayWriteRead(use_gpu=False)
- self._testTensorGradArrayWriteRead(use_gpu=True)
-
- def _testTensorGradArrayDynamicWriteRead(self, use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
- h = tensor_array_ops.TensorArray(
+ def testTensorGradArrayDynamicWriteRead(self):
+ with self.test_session(use_gpu=self._use_gpu) as session:
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=0, dynamic_size=True)
- w0 = h.write(0, [[4.0, 5.0]])
+ w0 = ta.write(0, [[4.0, 5.0]])
w1 = w0.write(1, [[1.0]])
w2 = w1.write(2, -3.0)
- g_h = w2.grad("grad") # Get gradient array here so we know the shape
+ g_ta = w2.grad("grad") # Get gradient array here so we know the shape
s = w2.size()
- g_s = g_h.size()
+ g_s = g_ta.size()
- g_w0 = g_h.write(0, [[5.0, 6.0]])
+ g_w0 = g_ta.write(0, [[5.0, 6.0]])
g_w1 = g_w0.write(1, [[2.0]])
g_w2 = g_w1.write(2, -2.0)
@@ -311,7 +289,7 @@ class TensorArrayTest(tf.test.TestCase):
g_r1 = g_w2.read(1)
g_r2 = g_w2.read(2)
- d0, d1, d2, g_d0, g_d1, g_d2, vs, g_vs = sess.run([
+ d0, d1, d2, g_d0, g_d1, g_d2, vs, g_vs = session.run([
r0, r1, r2, g_r0, g_r1, g_r2, s, g_s])
self.assertAllEqual([[4.0, 5.0]], d0)
self.assertAllEqual([[1.0]], d1)
@@ -322,61 +300,50 @@ class TensorArrayTest(tf.test.TestCase):
self.assertAllEqual(3, vs)
self.assertAllEqual(3, g_vs)
- def testTensorGradArrayDynamicWriteRead(self):
- self._testTensorGradArrayDynamicWriteRead(use_gpu=False)
- self._testTensorGradArrayDynamicWriteRead(use_gpu=True)
-
- def _testTensorGradAccessTwiceReceiveSameObject(self, use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
- h = tensor_array_ops.TensorArray(
+ def testTensorGradAccessTwiceReceiveSameObject(self):
+ with self.test_session(use_gpu=self._use_gpu) as session:
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
- g_h_0 = h.grad("grad")
- g_h_1 = h.grad("grad")
+ g_ta_0 = ta.grad("grad")
+ g_ta_1 = ta.grad("grad")
- with tf.control_dependencies([g_h_0.write(0, [[4.0, 5.0]]).flow]):
+ with tf.control_dependencies([g_ta_0.write(0, [[4.0, 5.0]]).flow]):
# Write with one gradient handle, read with another copy of it
- r1_0 = g_h_1.read(0)
+ r1_0 = g_ta_1.read(0)
- t_g_h_0, t_g_h_1, d_r1_0 = sess.run([g_h_0.handle, g_h_1.handle, r1_0])
- self.assertAllEqual(t_g_h_0, t_g_h_1)
+ t_g_ta_0, t_g_ta_1, d_r1_0 = session.run(
+ [g_ta_0.handle, g_ta_1.handle, r1_0])
+ self.assertAllEqual(t_g_ta_0, t_g_ta_1)
self.assertAllEqual([[4.0, 5.0]], d_r1_0)
- def testTensorGradAccessTwiceReceiveSameObject(self):
- self._testTensorGradAccessTwiceReceiveSameObject(False)
- self._testTensorGradAccessTwiceReceiveSameObject(True)
-
- def _testTensorArrayWriteWrongIndexOrDataTypeFails(self, use_gpu):
- with self.test_session(use_gpu=use_gpu):
- h = tensor_array_ops.TensorArray(
+ def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
+ with self.test_session(use_gpu=self._use_gpu):
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
# Test writing the wrong datatype
with self.assertRaisesOpError(
"TensorArray dtype is float but Op is trying to write dtype string"):
- h.write(-1, "wrong_type_scalar").flow.eval()
+ ta.write(-1, "wrong_type_scalar").flow.eval()
# Test writing to a negative index
with self.assertRaisesOpError(
"Tried to write to index -1 but array is not "
"resizeable and size is: 3"):
- h.write(-1, 3.0).flow.eval()
+ ta.write(-1, 3.0).flow.eval()
# Test reading from too large an index
with self.assertRaisesOpError(
"Tried to write to index 3 but array is not "
"resizeable and size is: 3"):
- h.write(3, 3.0).flow.eval()
+ ta.write(3, 3.0).flow.eval()
- def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
- self._testTensorArrayWriteWrongIndexOrDataTypeFails(use_gpu=False)
- self._testTensorArrayWriteWrongIndexOrDataTypeFails(use_gpu=True)
-
- def _testTensorArrayReadWrongIndexOrDataTypeFails(self, use_gpu):
- with self.test_session(use_gpu=use_gpu):
- h = tensor_array_ops.TensorArray(
+ def testTensorArrayReadWrongIndexOrDataTypeFails(self):
+ with self.test_session(use_gpu=self._use_gpu):
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
- w0 = h.write(0, [[4.0, 5.0]])
+ w0 = ta.write(0, [[4.0, 5.0]])
# Test reading wrong datatype
r0_bad = gen_data_flow_ops._tensor_array_read(
@@ -395,37 +362,29 @@ class TensorArrayTest(tf.test.TestCase):
# Test reading from a negative index
with self.assertRaisesOpError(
r"Tried to read from index -1 but array size is: 3"):
- h.read(-1).eval()
+ ta.read(-1).eval()
# Test reading from too large an index
with self.assertRaisesOpError(
"Tried to read from index 3 but array size is: 3"):
- h.read(3).eval()
+ ta.read(3).eval()
- def testTensorArrayReadWrongIndexOrDataTypeFails(self):
- self._testTensorArrayReadWrongIndexOrDataTypeFails(use_gpu=False)
- self._testTensorArrayReadWrongIndexOrDataTypeFails(use_gpu=True)
-
- def _testTensorArrayWriteMultipleFails(self, use_gpu):
- with self.test_session(use_gpu=use_gpu):
- h = tensor_array_ops.TensorArray(
+ def testTensorArrayWriteMultipleFails(self):
+ with self.test_session(use_gpu=self._use_gpu):
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
with self.assertRaisesOpError(
"Could not write to TensorArray index 2 because "
"it has already been written to."):
- h.write(2, 3.0).write(2, 3.0).flow.eval()
+ ta.write(2, 3.0).write(2, 3.0).flow.eval()
- def testTensorArrayWriteMultipleFails(self):
- self._testTensorArrayWriteMultipleFails(use_gpu=False)
- self._testTensorArrayWriteMultipleFails(use_gpu=True)
-
- def _testTensorArrayConcatIncompatibleShapesFails(self, use_gpu):
- with self.test_session(use_gpu=use_gpu):
- h = tensor_array_ops.TensorArray(
+ def testTensorArrayConcatIncompatibleShapesFails(self):
+ with self.test_session(use_gpu=self._use_gpu):
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
- w1 = h.write(0, 3.0)
+ w1 = ta.write(0, 3.0)
w2 = w1.write(1, 4.0)
w3 = w2.write(2, [3.0])
@@ -433,10 +392,10 @@ class TensorArrayTest(tf.test.TestCase):
"Concat saw a scalar shape at index 0 but requires at least vectors"):
w3.concat().eval()
- h = tensor_array_ops.TensorArray(
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
- w1 = h.write(0, [3.0])
+ w1 = ta.write(0, [3.0])
w2 = w1.write(1, [4.0])
w3 = w2.write(2, [[3.0]])
@@ -446,44 +405,36 @@ class TensorArrayTest(tf.test.TestCase):
r"dimension 0\) shape: \[1\]"):
w3.concat().eval()
- def testTensorArrayConcatIncompatibleShapesFails(self):
- self._testTensorArrayConcatIncompatibleShapesFails(use_gpu=False)
- self._testTensorArrayConcatIncompatibleShapesFails(use_gpu=True)
-
- def _testTensorArraySplitIncompatibleShapesFails(self, use_gpu):
- with self.test_session(use_gpu=use_gpu):
- h = tensor_array_ops.TensorArray(
+ def testTensorArraySplitIncompatibleShapesFails(self):
+ with self.test_session(use_gpu=self._use_gpu):
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
with self.assertRaisesOpError(
r"Expected lengths to be a vector, received shape: \[\]"):
lengths = tf.placeholder(tf.int64)
- h.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
+ ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
with self.assertRaisesOpError(
r"Expected sum of lengths to be equal to values.shape\[0\], "
r"but sum of lengths is 1 and value's shape is: \[3\]"):
- h.split([1.0, 2.0, 3.0], [1]).flow.eval()
+ ta.split([1.0, 2.0, 3.0], [1]).flow.eval()
with self.assertRaisesOpError(
r"Expected value to be at least a vector, but received shape: \[\]"):
- h.split(1.0, [1]).flow.eval()
+ ta.split(1.0, [1]).flow.eval()
- h = tensor_array_ops.TensorArray(
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=2)
with self.assertRaisesOpError(
r"TensorArray's size is not equal to the size of lengths "
r"\(2 vs. 1\), and the TensorArray is not marked as "
r"dynamically resizeable"):
- h.split([1.0], [1]).flow.eval()
-
- def testTensorArraySplitIncompatibleShapesFails(self):
- self._testTensorArraySplitIncompatibleShapesFails(use_gpu=False)
- self._testTensorArraySplitIncompatibleShapesFails(use_gpu=True)
+ ta.split([1.0], [1]).flow.eval()
- def _testMultiTensorArray(self, use_gpu):
- with self.test_session(use_gpu=use_gpu):
+ def testMultiTensorArray(self):
+ with self.test_session(use_gpu=self._use_gpu):
h1 = tensor_array_ops.TensorArray(
size=1, dtype=tf.float32, tensor_array_name="foo")
w1 = h1.write(0, 4.0)
@@ -497,12 +448,8 @@ class TensorArrayTest(tf.test.TestCase):
r = r1 + r2
self.assertAllClose(9.0, r.eval())
- def testMultiTensorArray(self):
- self._testMultiTensorArray(use_gpu=False)
- self._testMultiTensorArray(use_gpu=True)
-
- def _testDuplicateTensorArrayFails(self, use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
+ def testDuplicateTensorArrayFails(self):
+ with self.test_session(use_gpu=self._use_gpu) as session:
h1 = tensor_array_ops.TensorArray(
size=1, dtype=tf.float32, tensor_array_name="foo")
c1 = h1.write(0, 4.0)
@@ -510,15 +457,11 @@ class TensorArrayTest(tf.test.TestCase):
size=1, dtype=tf.float32, tensor_array_name="foo")
c2 = h2.write(0, 5.0)
with self.assertRaises(errors.AlreadyExistsError):
- sess.run([c1.flow, c2.flow])
-
- def testDuplicateTensorArrayFails(self):
- self._testDuplicateTensorArrayFails(use_gpu=False)
- self._testDuplicateTensorArrayFails(use_gpu=True)
+ session.run([c1.flow, c2.flow])
- def _testTensorArrayGradientWriteReadType(self, use_gpu, dtype):
- with self.test_session(use_gpu=use_gpu) as sess:
- h = tensor_array_ops.TensorArray(
+ def _testTensorArrayGradientWriteReadType(self, dtype):
+ with self.test_session(use_gpu=self._use_gpu) as session:
+ ta = tensor_array_ops.TensorArray(
dtype=tf.as_dtype(dtype), tensor_array_name="foo", size=3)
c = lambda x: np.array(x, dtype=dtype)
@@ -526,7 +469,7 @@ class TensorArrayTest(tf.test.TestCase):
value_0 = tf.constant(c([[4.0, 5.0]]))
value_1 = tf.constant(c(3.0))
- w0 = h.write(0, value_0)
+ w0 = ta.write(0, value_0)
w1 = w0.write(1, value_1)
r0 = w1.read(0)
r1 = w1.read(1)
@@ -534,128 +477,104 @@ class TensorArrayTest(tf.test.TestCase):
# Test individual components' gradients
grad_just_r0 = tf.gradients(
ys=[r0], xs=[value_0], grad_ys=[c([[2.0, 3.0]])])
- grad_just_r0_vals = sess.run(grad_just_r0)
+ grad_just_r0_vals = session.run(grad_just_r0)
self.assertAllEqual(c([[2.0, 3.0]]), grad_just_r0_vals[0])
grad_just_r1 = tf.gradients(
ys=[r1], xs=[value_1], grad_ys=[c(-2.0)])
- grad_just_r1_vals = sess.run(grad_just_r1)
+ grad_just_r1_vals = session.run(grad_just_r1)
self.assertAllEqual(c(-2.0), grad_just_r1_vals[0])
# Test combined gradients
grad = tf.gradients(
ys=[r0, r1], xs=[value_0, value_1],
grad_ys=[c(-1.0), c([[2.0, 3.0]])])
- grad_vals = sess.run(grad)
+ grad_vals = session.run(grad)
self.assertEqual(len(grad_vals), 2)
self.assertAllEqual(c(-1.0), grad_vals[0])
self.assertAllEqual(c([[2.0, 3.0]]), grad_vals[1])
- def _testTensorArrayGradientWriteRead(self, use_gpu):
- for dtype in (np.float32, np.float64, np.int32, np.int64, np.complex64):
- self._testTensorArrayGradientWriteReadType(use_gpu, dtype)
-
def testTensorArrayGradientWriteRead(self):
- self._testTensorArrayGradientWriteRead(False)
- self._testTensorArrayGradientWriteRead(True)
+ for dtype in (np.float32, np.float64, np.int32, np.int64, np.complex64):
+ self._testTensorArrayGradientWriteReadType(dtype)
- def _testTensorArrayGradientUnpackRead(self, use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
- h = tensor_array_ops.TensorArray(
+ def testTensorArrayGradientUnpackRead(self):
+ with self.test_session(use_gpu=self._use_gpu) as session:
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=2)
value = tf.constant([[1.0, -1.0], [10.0, -10.0]])
- w = h.unpack(value)
+ w = ta.unpack(value)
r0 = w.read(0)
r1 = w.read(1)
# Test combined gradients + aggregation of read(0)
grad = tf.gradients(
ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]])
- grad_vals = sess.run(grad)
+ grad_vals = session.run(grad)
self.assertEqual(len(grad_vals), 1)
self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
- def testTensorArrayGradientUnpackRead(self):
- self._testTensorArrayGradientUnpackRead(False)
- self._testTensorArrayGradientUnpackRead(True)
-
- def _testTensorArrayGradientSplitConcat(self, use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
- h = tensor_array_ops.TensorArray(
+ def testTensorArrayGradientSplitConcat(self):
+ with self.test_session(use_gpu=self._use_gpu) as session:
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=2)
value = tf.constant([[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]])
- w = h.split(value, [2, 1])
+ w = ta.split(value, [2, 1])
r = w.concat()
# Test combined gradients
grad = tf.gradients(
ys=[r], xs=[value],
grad_ys=[[[2.0, -2.0], [20.0, -20.0], [200.0, -200.0]]])
- grad_vals = sess.run(grad)
+ grad_vals = session.run(grad)
self.assertEqual(len(grad_vals), 1)
self.assertAllEqual(
[[2.0, -2.0], [20.0, -20.0], [200.0, -200.0]], grad_vals[0])
- def testTensorArrayGradientSplitConcat(self):
- self._testTensorArrayGradientSplitConcat(False)
- self._testTensorArrayGradientSplitConcat(True)
-
- def _testTensorArrayGradientDynamicUnpackRead(self, use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
- h = tensor_array_ops.TensorArray(
+ def testTensorArrayGradientDynamicUnpackRead(self):
+ with self.test_session(use_gpu=self._use_gpu) as session:
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=0, dynamic_size=True)
value = tf.constant([[1.0, -1.0], [10.0, -10.0]])
- w = h.unpack(value)
+ w = ta.unpack(value)
r0 = w.read(0)
r1 = w.read(1)
# Test combined gradients + aggregation of read(0)
grad = tf.gradients(
ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]])
- grad_vals = sess.run(grad)
+ grad_vals = session.run(grad)
self.assertEqual(len(grad_vals), 1)
self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
- def testTensorArrayGradientDynamicUnpackRead(self):
- self._testTensorArrayGradientDynamicUnpackRead(False)
- self._testTensorArrayGradientDynamicUnpackRead(True)
-
- def _testCloseTensorArray(self, use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
- h = tensor_array_ops.TensorArray(
- dtype=tf.float32, tensor_array_name="foo", size=3)
- c1 = h.close()
- sess.run(c1)
-
def testCloseTensorArray(self):
- self._testCloseTensorArray(use_gpu=False)
- self._testCloseTensorArray(use_gpu=True)
-
- def _testSizeTensorArray(self, use_gpu):
- with self.test_session(use_gpu=use_gpu):
- h = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=self._use_gpu) as session:
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
- s = h.size()
- self.assertAllEqual(3, s.eval())
+ c1 = ta.close()
+ session.run(c1)
def testSizeTensorArray(self):
- self._testSizeTensorArray(use_gpu=False)
- self._testSizeTensorArray(use_gpu=True)
+ with self.test_session(use_gpu=self._use_gpu):
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf.float32, tensor_array_name="foo", size=3)
+ s = ta.size()
+ self.assertAllEqual(3, s.eval())
- def _testWriteCloseTensorArray(self, use_gpu):
- with self.test_session(use_gpu=use_gpu):
- h = tensor_array_ops.TensorArray(
+ def testWriteCloseTensorArray(self):
+ with self.test_session(use_gpu=self._use_gpu):
+ ta = tensor_array_ops.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
- w0 = h.write(0, [[4.0, 5.0]])
+ w0 = ta.write(0, [[4.0, 5.0]])
w1 = w0.write(1, [3.0])
w1.close().run() # Expected to run without problems
@@ -664,33 +583,29 @@ class TensorArrayTest(tf.test.TestCase):
with tf.control_dependencies([w1.close()]):
w1.write(2, 3.0).flow.eval()
- def testWriteCloseTensorArray(self):
- self._testWriteCloseTensorArray(use_gpu=False)
- self._testWriteCloseTensorArray(use_gpu=True)
-
- def _testWhileLoopWritePackGradients(self, dynamic_size, dtype, use_gpu):
+ def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
np_dtype = dtype.as_numpy_dtype
- with self.test_session(use_gpu=use_gpu) as sess:
+ with self.test_session(use_gpu=self._use_gpu) as session:
v0 = tf.identity(np.arange(3*5, dtype=np_dtype).reshape(3, 5))
var = tf.Variable(np.arange(100, 105, dtype=np_dtype))
state0 = tf.identity(np.array([1] * 5, dtype=np_dtype))
- h = tensor_array_ops.TensorArray(
+ ta = tensor_array_ops.TensorArray(
dtype=dtype, tensor_array_name="foo",
size=0 if dynamic_size else 3, dynamic_size=dynamic_size)
time_0 = tf.identity(0)
- def body(time, h_t, state):
+ def body(time, ta_t, state):
sliced = tf.slice(v0, begin=tf.pack([time, 0]), size=[1, -1])
sliced = tf.squeeze(sliced)
out = sliced + var + state
state += sliced
- h_t = h_t.write(time, out)
- return (time+1, h_t, state)
+ ta_t = ta_t.write(time, out)
+ return (time+1, ta_t, state)
(unused_0, h_final, unused_2) = control_flow_ops.While(
cond=lambda time, unused_1, unused_2: time < 3,
body=body,
- loop_vars=(time_0, h, state0),
+ loop_vars=(time_0, ta, state0),
parallel_iterations=3)
vout = h_final.pack()
@@ -701,8 +616,8 @@ class TensorArrayTest(tf.test.TestCase):
tf.initialize_all_variables().run()
state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = (
- sess.run([state0, var, v0, vout, v0_grad, var_grad, state0_grad]))
- just_v0_grad_t, = sess.run([v0_grad])
+ session.run([state0, var, v0, vout, v0_grad, var_grad, state0_grad]))
+ just_v0_grad_t, = session.run([v0_grad])
# state = [ state0 | state0 + v0[0] | state0 + v0[0] + v0[1] ]
# vout = [ v0[0] + var + state[0] |
@@ -739,23 +654,17 @@ class TensorArrayTest(tf.test.TestCase):
def testWhileLoopWritePackGradients(self):
self._testWhileLoopWritePackGradients(
- dynamic_size=False, dtype=tf.float32, use_gpu=False)
- self._testWhileLoopWritePackGradients(
- dynamic_size=False, dtype=tf.float32, use_gpu=True)
+ dynamic_size=False, dtype=tf.float32)
# TODO(ebrevdo): re-enable when While supports non-float32 gradients.
# self._testWhileLoopWritePackGradients(
- # dynamic_size=False, dtype=tf.int64, use_gpu=False)
- # self._testWhileLoopWritePackGradients(
- # dynamic_size=False, dtype=tf.int64, use_gpu=True)
+ # dynamic_size=False, dtype=tf.int64)
def testWhileLoopDynamicWritePackGradients(self):
self._testWhileLoopWritePackGradients(
- dynamic_size=True, dtype=tf.float32, use_gpu=False)
- self._testWhileLoopWritePackGradients(
- dynamic_size=True, dtype=tf.float32, use_gpu=True)
+ dynamic_size=True, dtype=tf.float32)
- def _testSumOfTwoReadVariablesWithoutRepeatGrad(self, use_gpu):
- with self.test_session(use_gpu=use_gpu) as sess:
+ def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
+ with self.test_session(use_gpu=self._use_gpu) as session:
a = tf.identity(np.arange(3*5, dtype=np.float32).reshape(3, 5) + 1)
b = tf.identity(np.arange(3*5, dtype=np.float32).reshape(3, 5) + 1 + 3*5)
ta = tensor_array_ops.TensorArray(dtype=tf.float32, size=2)
@@ -768,21 +677,17 @@ class TensorArrayTest(tf.test.TestCase):
grad_b = tf.gradients([c], [b], [g0])[0] # d(a+b)/db = 1
# Test gradients calculated individually
- grad_a_t, = sess.run([grad_a])
+ grad_a_t, = session.run([grad_a])
self.assertAllEqual(grad_a_t, g0)
- grad_b_t, = sess.run([grad_b])
+ grad_b_t, = session.run([grad_b])
self.assertAllEqual(grad_b_t, g0)
# Test gradients calculated jointly
- joint_grad_a_t, joint_grad_b_t = sess.run([grad_a, grad_b])
+ joint_grad_a_t, joint_grad_b_t = session.run([grad_a, grad_b])
self.assertAllEqual(joint_grad_a_t, g0)
self.assertAllEqual(joint_grad_b_t, g0)
- def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
- self._testSumOfTwoReadVariablesWithoutRepeatGrad(use_gpu=False)
- self._testSumOfTwoReadVariablesWithoutRepeatGrad(use_gpu=True)
-
def _grad_source_for_name(self, name):
return tensor_array_grad._GetGradSource(tf.constant(0, name=name))
@@ -826,5 +731,10 @@ class TensorArrayTest(tf.test.TestCase):
"foo/gradients",
self._grad_source_for_name("foo/gradients/bar/gradients_0/baz"))
+
+class TensorArrayGPUTest(TensorArrayCPUTest):
+ _use_gpu = True
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/tensor_array_grad.py b/tensorflow/python/ops/tensor_array_grad.py
index d138006d33..a2a8b59267 100644
--- a/tensorflow/python/ops/tensor_array_grad.py
+++ b/tensorflow/python/ops/tensor_array_grad.py
@@ -211,4 +211,3 @@ def _TensorArraySplitGrad(op, flow):
grad = g.concat()
# handle, value, lengths, flow_in
return [None, grad, None, flow]
-# pylint: enable=protected-access
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index e17757daae..35f63db23f 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""Data Flow Operations."""
+# Mixture of pep8 and non-pep8 names, so disable pylint bad-name
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
@@ -27,6 +28,8 @@ from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
+# TensorArray object accesses many of the hidden generated ops, but is
+# in fact built to wrap these methods.
# pylint: disable=protected-access
class TensorArray(object):
"""Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays.
@@ -275,4 +278,5 @@ def _TensorArrayUnpackShape(op):
op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
# flow_out
return [tensor_shape.scalar()]
+
# pylint: enable=protected-access