aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-11-23 13:43:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-23 14:05:06 -0800
commit3572910e60d18540bab1acb4a53bd05d2c82b7ad (patch)
tree76deebccb12f650b3f8a3390a74d9b9589826f2d
parent4c6dc271dfedc569cc2ff8d6b6fae32f20fbbb94 (diff)
Lazy device setting for TensorArrays.
Prior to this change, TensorArrays were always created on the device set by the device scope (if any); which is not necessarily the device on which the Tensors written to the given TensorArray sit. Since TensorArrays have strong colocation requirements, this often meant expensive round-trips to write and read Tensors. With this change, TensorArray ops are created with no device set; and the first call to write/unpack/split to a TensorArray with a Tensor bound to a particular device will set the TensorArray's device to match. Change: 140067532
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py264
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py78
2 files changed, 222 insertions, 120 deletions
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 14011e88cc..b21dcaf8e8 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -25,15 +25,13 @@ import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import tensor_array_grad
-from tensorflow.python.ops import tensor_array_ops
-class TensorArrayCPUTest(tf.test.TestCase):
- _use_gpu = False
+class TensorArrayTest(tf.test.TestCase):
def testTensorArrayWriteRead(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3, infer_shape=False)
w0 = ta.write(0, [[4.0, 5.0]])
@@ -51,8 +49,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
def _testTensorArrayWritePack(self, tf_dtype):
dtype = tf_dtype.as_numpy_dtype()
- with self.test_session(use_gpu=self._use_gpu):
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True):
+ ta = tf.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
if tf_dtype == tf.string:
@@ -84,8 +82,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
def _testTensorArrayWriteConcat(self, tf_dtype):
dtype = tf_dtype.as_numpy_dtype()
- with self.test_session(use_gpu=self._use_gpu):
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True):
+ ta = tf.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3, infer_shape=False)
if tf_dtype == tf.string:
@@ -119,7 +117,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
def _testTensorArrayPackNotAllValuesAvailableFails(self):
with self.test_session():
- ta = tensor_array_ops.TensorArray(
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
with self.assertRaisesOpError(
@@ -132,8 +130,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
def _testTensorArrayUnpackRead(self, tf_dtype):
dtype = tf_dtype.as_numpy_dtype()
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
if tf_dtype is tf.string:
@@ -153,7 +151,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertAllEqual(convert(2.0), d1)
self.assertAllEqual(convert(3.0), d2)
- ta = tensor_array_ops.TensorArray(
+ ta = tf.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
# Unpack a matrix into vectors
@@ -169,7 +167,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
# Reset ta because we're going to change the shape, else shape
# inference will throw an error.
- ta = tensor_array_ops.TensorArray(
+ ta = tf.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3)
# Try unpacking an empty matrix, which should not cause an error.
@@ -197,8 +195,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
def _testTensorArraySplitRead(self, tf_dtype):
dtype = tf_dtype.as_numpy_dtype()
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf_dtype, tensor_array_name="foo", size=3, infer_shape=False)
if tf_dtype == tf.string:
@@ -255,8 +253,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArraySplitRead(tf.string)
def testTensorGradArrayWriteRead(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3, infer_shape=False)
g_ta = ta.grad("grad")
@@ -285,8 +283,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertAllEqual(-2.0, g_d2)
def testTensorGradArrayDynamicWriteRead(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=0, dynamic_size=True,
infer_shape=False)
@@ -323,8 +321,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertAllEqual(3, g_vs)
def testTensorGradAccessTwiceReceiveSameObject(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
g_ta_0 = ta.grad("grad")
g_ta_1 = ta.grad("grad")
@@ -339,8 +337,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertAllEqual([[4.0, 5.0]], d_r1_0)
def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
- with self.test_session(use_gpu=self._use_gpu):
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True):
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
# Test writing the wrong datatype
@@ -361,8 +359,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
ta.write(3, 3.0).flow.eval()
def testTensorArrayReadWrongIndexOrDataTypeFails(self):
- with self.test_session(use_gpu=self._use_gpu):
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True):
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
w0 = ta.write(0, [[4.0, 5.0]])
@@ -392,8 +390,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
ta.read(3).eval()
def testTensorArrayWriteMultipleFails(self):
- with self.test_session(use_gpu=self._use_gpu):
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True):
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
with self.assertRaisesOpError(
@@ -402,8 +400,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
ta.write(2, 3.0).write(2, 3.0).flow.eval()
def testTensorArrayConcatIncompatibleShapesFails(self):
- with self.test_session(use_gpu=self._use_gpu):
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True):
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3, infer_shape=False)
w1 = ta.write(0, 3.0)
@@ -414,7 +412,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
"Concat saw a scalar shape at index 0 but requires at least vectors"):
w3.concat().eval()
- ta = tensor_array_ops.TensorArray(
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3, infer_shape=False)
w1 = ta.write(0, [3.0])
@@ -428,8 +426,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
w3.concat().eval()
def testTensorArraySplitIncompatibleShapesFails(self):
- with self.test_session(use_gpu=self._use_gpu):
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True):
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3, infer_shape=False)
with self.assertRaisesOpError(
@@ -446,7 +444,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
r"Expected value to be at least a vector, but received shape: \[\]"):
ta.split(1.0, [1]).flow.eval()
- ta = tensor_array_ops.TensorArray(
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=2, infer_shape=False)
with self.assertRaisesOpError(
@@ -456,8 +454,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
ta.split([1.0], [1]).flow.eval()
def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
- with self.test_session(use_gpu=self._use_gpu):
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True):
+ ta = tf.TensorArray(
dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False)
ta_grad = ta.grad("grad")
@@ -495,13 +493,13 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArrayWriteGradientAddMultipleAdds(dtype)
def testMultiTensorArray(self):
- with self.test_session(use_gpu=self._use_gpu):
- h1 = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True):
+ h1 = tf.TensorArray(
size=1, dtype=tf.float32, tensor_array_name="foo")
w1 = h1.write(0, 4.0)
r1 = w1.read(0)
- h2 = tensor_array_ops.TensorArray(
+ h2 = tf.TensorArray(
size=1, dtype=tf.float32, tensor_array_name="bar")
w2 = h2.write(0, 5.0)
@@ -510,11 +508,11 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertAllClose(9.0, r.eval())
def testDuplicateTensorArrayHasDifferentName(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
- h1 = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ h1 = tf.TensorArray(
size=1, dtype=tf.float32, tensor_array_name="foo")
c1 = h1.write(0, 4.0)
- h2 = tensor_array_ops.TensorArray(
+ h2 = tf.TensorArray(
size=1, dtype=tf.float32, tensor_array_name="foo")
c2 = h2.write(0, 5.0)
_, _, c1h, c2h = session.run([c1.flow, c2.flow, c1.handle, c2.handle])
@@ -527,8 +525,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertNotEqual(c1h[1], c2h[1])
def _testTensorArrayGradientWriteReadType(self, dtype):
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf.as_dtype(dtype), tensor_array_name="foo", size=3,
infer_shape=False)
@@ -575,8 +573,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArrayGradientWriteReadType(dtype)
def _testTensorArrayGradientWritePackConcatAndRead(self):
- with self.test_session(use_gpu=self._use_gpu) as sess:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as sess:
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=2,
clear_after_read=False)
@@ -606,10 +604,10 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArrayGradientWritePackConcatAndRead()
def testTensorArrayReadTwice(self):
- with self.test_session(use_gpu=self._use_gpu):
+ with self.test_session(use_gpu=True):
value = tf.constant([[1.0, -1.0], [10.0, -10.0]])
- ta_readonce = tensor_array_ops.TensorArray(
+ ta_readonce = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=2)
w_readonce = ta_readonce.unpack(value)
@@ -622,7 +620,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
r"previous read \(perhaps try setting clear_after_read = false\?\)"):
r1_readonce.eval()
- ta_readtwice = tensor_array_ops.TensorArray(
+ ta_readtwice = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=2,
clear_after_read=False)
w_readtwice = ta_readtwice.unpack(value)
@@ -633,8 +631,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertAllEqual([1.0, -1.0], r1_readtwice.eval())
def _testTensorArrayGradientUnpackRead(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=2,
clear_after_read=False)
@@ -658,8 +656,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArrayGradientUnpackRead()
def testTensorArrayGradientSplitConcat(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=2)
value = tf.constant([[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]])
@@ -678,8 +676,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
[[2.0, -2.0], [20.0, -20.0], [200.0, -200.0]], grad_vals[0])
def _testTensorArrayGradientDynamicUnpackRead(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=0, dynamic_size=True)
value = tf.constant([[1.0, -1.0], [10.0, -10.0]])
@@ -700,22 +698,22 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArrayGradientDynamicUnpackRead()
def testCloseTensorArray(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
c1 = ta.close()
session.run(c1)
def testSizeTensorArray(self):
- with self.test_session(use_gpu=self._use_gpu):
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True):
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
s = ta.size()
self.assertAllEqual(3, s.eval())
def testWriteCloseTensorArray(self):
- with self.test_session(use_gpu=self._use_gpu):
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True):
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3, infer_shape=False)
w0 = ta.write(0, [[4.0, 5.0]])
w1 = w0.write(1, [3.0])
@@ -723,11 +721,11 @@ class TensorArrayCPUTest(tf.test.TestCase):
def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
np_dtype = dtype.as_numpy_dtype
- with self.test_session(use_gpu=self._use_gpu) as session:
+ with self.test_session(use_gpu=True) as session:
v0 = tf.identity(np.arange(3*5, dtype=np_dtype).reshape(3, 5))
var = tf.Variable(np.arange(100, 105, dtype=np_dtype))
state0 = tf.identity(np.array([1] * 5, dtype=np_dtype))
- ta = tensor_array_ops.TensorArray(
+ ta = tf.TensorArray(
dtype=dtype, tensor_array_name="foo",
size=0 if dynamic_size else 3, dynamic_size=dynamic_size)
time_0 = tf.identity(0)
@@ -831,10 +829,10 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertAllClose(31.0, grad.eval())
def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
+ with self.test_session(use_gpu=True) as session:
a = tf.identity(np.arange(3*5, dtype=np.float32).reshape(3, 5) + 1)
b = tf.identity(np.arange(3*5, dtype=np.float32).reshape(3, 5) + 1 + 3*5)
- ta = tensor_array_ops.TensorArray(dtype=tf.float32, size=2)
+ ta = tf.TensorArray(dtype=tf.float32, size=2)
ta = ta.write(0, a, name="write_a")
ta = ta.write(1, b, name="write_b")
c = (ta.read(0, name="read_a_0") + # a + b
@@ -900,14 +898,14 @@ class TensorArrayCPUTest(tf.test.TestCase):
def testWriteShape(self):
with self.test_session():
- ta = tensor_array_ops.TensorArray(
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
c0 = tf.constant([4.0, 5.0])
w0 = ta.write(0, c0)
r0 = w0.read(0)
self.assertAllEqual(c0.get_shape(), r0.get_shape())
- ta = tensor_array_ops.TensorArray(
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
c1 = tf.constant([6.0, 7.0])
w1 = w0.write(1, c1)
@@ -916,7 +914,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertAllEqual(c0.get_shape(), r0.get_shape())
self.assertAllEqual(c1.get_shape(), r1.get_shape())
- ta = tensor_array_ops.TensorArray(
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3)
c2 = tf.constant([4.0, 5.0, 6.0])
with self.assertRaises(ValueError):
@@ -924,7 +922,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
def _testUnpackShape(self):
with self.test_session():
- ta = tensor_array_ops.TensorArray(
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo",
size=0, dynamic_size=True, infer_shape=True)
value = tf.constant([[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]])
@@ -946,7 +944,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
def testSplitShape(self):
with self.test_session():
- ta = tensor_array_ops.TensorArray(
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo",
size=0, dynamic_size=True, infer_shape=True)
value = tf.constant([[1.0, -1.0], [2.0, -2.0], [3.0, -3.0]])
@@ -954,7 +952,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
r0 = w0.read(0)
self.assertAllEqual((1, 2), r0.get_shape())
- ta1 = tensor_array_ops.TensorArray(
+ ta1 = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo1",
size=0, dynamic_size=True, infer_shape=True)
w0 = ta1.split(value, [1, 2])
@@ -963,7 +961,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
def testWriteUnknownShape(self):
with self.test_session():
- ta = tensor_array_ops.TensorArray(
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=3, infer_shape=True)
c0 = tf.placeholder(tf.float32)
w0 = ta.write(0, c0)
@@ -971,8 +969,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape())
def _testGradientWhenNotAllComponentsRead(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(dtype=tf.float32, size=2)
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(dtype=tf.float32, size=2)
x = tf.constant([2.0, 3.0])
w = ta.unpack(x)
r0 = w.read(0)
@@ -985,9 +983,9 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testGradientWhenNotAllComponentsRead()
def _testTensorArrayUnpackDynamic(self):
- with self.test_session(use_gpu=self._use_gpu) as sess:
- ta = tensor_array_ops.TensorArray(dtype=tf.float32, size=3,
- dynamic_size=True)
+ with self.test_session(use_gpu=True) as sess:
+ ta = tf.TensorArray(dtype=tf.float32, size=3,
+ dynamic_size=True)
x = tf.constant([1.0, 2.0, 3.0])
w0 = ta.unpack(x)
w1 = w0.write(3, 4.0)
@@ -1001,9 +999,9 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArrayUnpackDynamic()
def testTensorArraySplitDynamic(self):
- with self.test_session(use_gpu=self._use_gpu) as sess:
- ta = tensor_array_ops.TensorArray(dtype=tf.float32, size=3,
- dynamic_size=True)
+ with self.test_session(use_gpu=True) as sess:
+ ta = tf.TensorArray(dtype=tf.float32, size=3,
+ dynamic_size=True)
x = tf.constant([1.0, 2.0, 3.0])
w0 = ta.split(x, [1, 1, 1])
w1 = w0.write(3, [4.0])
@@ -1014,11 +1012,11 @@ class TensorArrayCPUTest(tf.test.TestCase):
sess.run(grad)[0])
def _testTensorArrayEvalEmpty(self):
- with self.test_session(use_gpu=self._use_gpu):
- ta = tensor_array_ops.TensorArray(dtype=tf.float32,
- size=0,
- dynamic_size=False,
- infer_shape=False)
+ with self.test_session(use_gpu=True):
+ ta = tf.TensorArray(dtype=tf.float32,
+ size=0,
+ dynamic_size=False,
+ infer_shape=False)
with self.assertRaisesOpError(
"TensorArray has size zero, but element shape <unknown> is not fully "
"defined. Currently only static shapes are supported when packing "
@@ -1029,11 +1027,11 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArrayEvalEmpty()
def _testTensorArrayEvalEmptyWithDefault(self):
- with self.test_session(use_gpu=self._use_gpu):
- ta = tensor_array_ops.TensorArray(dtype=tf.float32,
- size=0,
- dynamic_size=False,
- infer_shape=True)
+ with self.test_session(use_gpu=True):
+ ta = tf.TensorArray(dtype=tf.float32,
+ size=0,
+ dynamic_size=False,
+ infer_shape=True)
self.assertEqual(0, ta.size().eval())
# Don't actually perform the pack. This stores the static shape.
ta.unpack(tf.zeros([0, 3, 5]))
@@ -1047,8 +1045,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
self._testTensorArrayEvalEmptyWithDefault()
def testTensorArrayScatterReadAndGradients(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=0, dynamic_size=True)
indices = tf.constant([1, 8])
@@ -1070,8 +1068,8 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
def testTensorArrayWriteGatherAndGradients(self):
- with self.test_session(use_gpu=self._use_gpu) as session:
- ta = tensor_array_ops.TensorArray(
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(
dtype=tf.float32, tensor_array_name="foo", size=0, dynamic_size=True)
values = tf.constant([[1.0*x, -1.0*x] for x in range(10)])
@@ -1095,9 +1093,83 @@ class TensorArrayCPUTest(tf.test.TestCase):
self.assertAllEqual([[1.0, -1.0], [8.0, -8.0]], g_vals[0])
self.assertAllEqual(expected_grad, grad_vals[0])
+ def testTensorArrayGetsDeviceFromFirstWrite(self):
+ with tf.device("/gpu:1"):
+ ta = tf.TensorArray(dtype=tf.float32, size=2)
+ # parent device was ignored when creating the TensorArray
+ self.assertEqual(ta.handle.device, "")
+ self.assertEqual(ta.flow.device, "")
+ with tf.device("/gpu:0"):
+ # the first write sets the op's device
+ ta = ta.write(0, 1.0)
+ self.assertTrue("gpu:0" in ta.handle.device.lower())
+ self.assertTrue("gpu:0" in ta.flow.device.lower())
+ with tf.device("/gpu:1"):
+ # subsequent writes do not modify the op's device
+ ta = ta.write(1, 1.0)
+ self.assertTrue("gpu:0" in ta.handle.device.lower())
+ self.assertTrue("gpu:0" in ta.flow.device.lower())
+
+ ta_grad = ta.grad("grad")
+ self.assertTrue("gpu:0" in ta_grad.handle.device.lower())
+ self.assertTrue("gpu:0" in ta_grad.flow.device.lower())
+
+ # Similar tests for unpack and split
+ ta = tf.TensorArray(dtype=tf.float32, size=2)
+ self.assertEqual(ta.handle.device, "")
+ self.assertEqual(ta.flow.device, "")
+ with tf.device("/gpu:0"):
+ ta = ta.unpack([1.0, 2.0])
+ self.assertTrue("gpu:0" in ta.handle.device.lower())
+ self.assertTrue("gpu:0" in ta.flow.device.lower())
+ with tf.device("/gpu:1"):
+ ta = ta.unpack([1.0, 2.0])
+ self.assertTrue("gpu:0" in ta.handle.device.lower())
+ self.assertTrue("gpu:0" in ta.flow.device.lower())
+
+ ta = tf.TensorArray(dtype=tf.float32, size=2)
+ self.assertEqual(ta.handle.device, "")
+ self.assertEqual(ta.flow.device, "")
+ with tf.device("/gpu:0"):
+ ta = ta.split([1.0, 2.0], [1, 1])
+ self.assertTrue("gpu:0" in ta.handle.device.lower())
+ self.assertTrue("gpu:0" in ta.flow.device.lower())
+ with tf.device("/gpu:1"):
+ ta = ta.split([1.0, 2.0], [1, 1])
+ self.assertTrue("gpu:0" in ta.handle.device.lower())
+ self.assertTrue("gpu:0" in ta.flow.device.lower())
+
+ def testTensorArrayGetsDeviceFromFirstWriteInWhileLoop(self):
+ ta = tf.TensorArray(dtype=tf.float32, size=2)
+ def _body(i, ta_i):
+ with tf.device("/gpu:0"):
+ return i + 1, ta_i.write(i, 0.0)
+
+ self.assertEqual(ta.handle.device, "")
+ self.assertEqual(ta.flow.device, "")
+
+ _, ta_out = tf.while_loop(
+ lambda i, ta: i < 2, _body, loop_vars=[0, ta])
+
+ self.assertTrue("gpu:0" in ta_out.handle.device.lower())
+ self.assertTrue("gpu:0" in ta.handle.device.lower())
+
+ def testTensorArrayLazyDeviceSettingDoesNotConfuseInitialAccess(self):
+ with self.test_session(use_gpu=True) as session:
+ ta = tf.TensorArray(dtype=tf.float32, size=2)
+ self.assertEqual(ta.handle.device, "")
+
+ with tf.device("/cpu:0"):
+ size = ta.size()
+ with tf.device("/gpu:0"):
+ ta = ta.write(0, 0.0)
+
+ self.assertTrue("gpu:0" in ta.handle.device.lower())
+
+ # This should use the TensorArray on /gpu:0
+ size_value, _ = session.run((size, ta.flow))
+ self.assertEqual(2, size_value)
-class TensorArrayGPUTest(TensorArrayCPUTest):
- _use_gpu = True
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index a343051114..8d03a1e23a 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -35,6 +35,24 @@ from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
+def _maybe_set_device(handle_op, value_t):
+ # NOTE(ebrevdo): Do not try this at home, kids
+ # _______________________________________________
+ # | I WILL NOT ACCESS PRIVATE METHODS ^^^^^^^^\ |
+ # | I WILL NOT ACCESS PRIVATE METHODS | | |
+ # | I WILL NOT ACCESS PRIVATE METHODS |_ __ | |
+ # | I WILL NOT ACCESS PRIVATE METHODS (.(. ) | |
+ # | I WILL NOT ACCESS PRIVATE (_ ) |
+ # | \\ /___/' / |
+ # | _\\_ \ | |
+ # | (( ) /====| |
+ # | \ <.__._- \ |
+ # |___________________________ <//___. ||
+ #
+ if not handle_op.device and value_t.device:
+ handle_op._set_device(value_t.device) # pylint: disable=protected-access
+
+
# TensorArray object accesses many of the hidden generated ops, but is
# in fact built to wrap these methods.
# pylint: disable=protected-access
@@ -142,10 +160,14 @@ class TensorArray(object):
clear_after_read=clear_after_read,
tensor_array_name=tensor_array_name, name=scope)
else:
- self._handle = gen_data_flow_ops._tensor_array_v2(
- dtype=dtype, size=size, dynamic_size=dynamic_size,
- clear_after_read=clear_after_read,
- tensor_array_name=tensor_array_name, name=scope)
+ # Construct the TensorArray with an empty device. The first
+ # write into the TensorArray from a Tensor with a set device
+ # will retroactively set the device value of this op.
+ with ops.device(None), ops.colocate_with(None, ignore_existing=True):
+ self._handle = gen_data_flow_ops._tensor_array_v2(
+ dtype=dtype, size=size, dynamic_size=dynamic_size,
+ clear_after_read=clear_after_read,
+ tensor_array_name=tensor_array_name, name=scope)
if flow is not None:
self._flow = flow
else:
@@ -218,10 +240,13 @@ class TensorArray(object):
Raises:
ValueError: if there are more writers than specified.
"""
- with ops.colocate_with(self._handle):
- flow_out = gen_data_flow_ops._tensor_array_write_v2(
- handle=self._handle, index=index, value=value, flow_in=self._flow,
- name=name)
+ with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
+ value = ops.convert_to_tensor(value, name="value")
+ _maybe_set_device(self._handle.op, value)
+ with ops.colocate_with(self._handle):
+ flow_out = gen_data_flow_ops._tensor_array_write_v2(
+ handle=self._handle, index=index, value=value, flow_in=self._flow,
+ name=name)
ta = TensorArray(dtype=self._dtype, handle=self._handle)
ta._flow = flow_out
ta._infer_shape = self._infer_shape
@@ -324,11 +349,10 @@ class TensorArray(object):
Raises:
ValueError: if the shape inference fails.
"""
- with ops.colocate_with(self._handle):
- with ops.name_scope(name, "TensorArrayPack", [self._handle, value]):
- num_elements = array_ops.shape(value)[0]
- return self.scatter(
- indices=math_ops.range(0, num_elements), value=value, name=name)
+ with ops.name_scope(name, "TensorArrayPack", [self._handle, value]):
+ num_elements = array_ops.shape(value)[0]
+ return self.scatter(
+ indices=math_ops.range(0, num_elements), value=value, name=name)
def scatter(self, indices, value, name=None):
"""Scatter the values of a `Tensor` in specific indices of a `TensorArray`.
@@ -346,10 +370,14 @@ class TensorArray(object):
Raises:
ValueError: if the shape inference fails.
"""
- with ops.colocate_with(self._handle):
- flow_out = gen_data_flow_ops._tensor_array_scatter_v2(
- handle=self._handle, indices=indices, value=value, flow_in=self._flow,
- name=name)
+ with ops.name_scope(name, "TensorArrayScatter",
+ [self._handle, value, indices]):
+ value = ops.convert_to_tensor(value, name="value")
+ _maybe_set_device(self._handle.op, value)
+ with ops.colocate_with(self._handle):
+ flow_out = gen_data_flow_ops._tensor_array_scatter_v2(
+ handle=self._handle, indices=indices, value=value,
+ flow_in=self._flow, name=name)
ta = TensorArray(dtype=self._dtype, handle=self._handle)
ta._flow = flow_out
ta._infer_shape = self._infer_shape
@@ -384,13 +412,15 @@ class TensorArray(object):
Raises:
ValueError: if the shape inference fails.
"""
- with ops.colocate_with(self._handle):
- with ops.name_scope(name, "TensorArraySplit",
- [self._handle, value, lengths]):
- lengths_64 = math_ops.to_int64(lengths)
- flow_out = gen_data_flow_ops._tensor_array_split_v2(
- handle=self._handle, value=value, lengths=lengths_64,
- flow_in=self._flow, name=name)
+ with ops.name_scope(name, "TensorArraySplit",
+ [self._handle, value, lengths]):
+ value = ops.convert_to_tensor(value, name="value")
+ _maybe_set_device(self._handle.op, value)
+ lengths_64 = math_ops.to_int64(lengths)
+ with ops.colocate_with(self._handle):
+ flow_out = gen_data_flow_ops._tensor_array_split_v2(
+ handle=self._handle, value=value, lengths=lengths_64,
+ flow_in=self._flow, name=name)
ta = TensorArray(dtype=self._dtype, handle=self._handle)
ta._flow = flow_out
ta._infer_shape = self._infer_shape