aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-06-07 14:55:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-07 14:59:59 -0700
commitc19e6cac0413b0b93d5a15f9d4dc7c861aa1c734 (patch)
tree0ab76d177a8a3e9c4b1b92051257b772ce3c211a
parentb5e8d308655a027e8c163c3fe3bd3445e09e9d23 (diff)
[TF:XLA] Initial implementation of TensorArray ops.
The XLA implementation of TensorArrays is more restrictive than regular TensorArrays: * XLA TensorArrays must have dynamic_size=False. * all elements in an XLA TensorArray must have the same shape. * writes always add their values to any existing values; neither reads nor writes ever issue errors. Out-of-bounds writes currently wrap. Refactor Variable handling in the TF/XLA bridge. Use a XlaVariable* to refer to variables inside compilation rather than a numerical ID. Allow for variables that don't correspond to variables known to the user. Also use XlaVariable to handle TensorArrays. PiperOrigin-RevId: 158322041
-rw-r--r--tensorflow/compiler/tests/BUILD19
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py1018
-rw-r--r--tensorflow/compiler/tests/xla_test.py16
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/arg_op.cc13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc538
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.cc2
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.h39
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc39
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h4
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc29
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h44
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc41
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h11
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc2
16 files changed, 1710 insertions, 108 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 19f7ff8354..d18e51e32c 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -347,6 +347,25 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "tensor_array_ops_test",
+ size = "small",
+ srcs = ["tensor_array_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:math_ops_gen",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:nn_ops_gen",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:tensor_array_grad",
+ "//tensorflow/python:tensor_array_ops",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
name = "ternary_ops_test",
size = "small",
srcs = ["ternary_ops_test.py"],
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
new file mode 100644
index 0000000000..27a2977305
--- /dev/null
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -0,0 +1,1018 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for XLA TensorArray Ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def _make_converter(dtype):
+ def _converter(x):
+ return np.asarray(x).astype(dtype.as_numpy_dtype)
+ return _converter
+
+
+class TensorArrayTest(xla_test.XLATestCase):
+
+ def testTensorArrayWriteRead(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=3)
+
+ w0 = ta.write(0, [[4.0, 5.0]])
+ w1 = w0.write(1, [[1.0, 3.0]])
+ w2 = w1.write(2, [[7.0, -8.5]])
+
+ r0 = w2.read(0)
+ r1 = w2.read(1)
+ r2 = w2.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual([[4.0, 5.0]], d0)
+ self.assertAllEqual([[1.0, 3.0]], d1)
+ self.assertAllEqual([[7.0, -8.5]], d2)
+
+ def _testTensorArrayWritePack(self, tf_dtype):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+
+ convert = _make_converter(tf_dtype)
+
+ w0 = ta.write(0, convert([[4.0, 5.0]]))
+ w1 = w0.write(1, convert([[6.0, 7.0]]))
+ w2 = w1.write(2, convert([[8.0, 9.0]]))
+
+ c0 = w2.stack()
+
+ self.assertAllEqual(
+ convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0.eval())
+
+ def testTensorArrayWritePack(self):
+ for dtype in self.numeric_tf_types:
+ self._testTensorArrayWritePack(dtype)
+
+ def testEmptyTensorArrayPack(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+
+ empty_element = np.zeros((0, 1), dtype=np.float32)
+ w0 = ta.write(0, empty_element)
+ w1 = w0.write(1, empty_element)
+ w2 = w1.write(2, empty_element)
+
+ c0 = w2.stack()
+
+ self.assertAllEqual([3, 0, 1], c0.eval().shape)
+
+ def _testTensorArrayWriteConcat(self, tf_dtype):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+
+ convert = _make_converter(tf_dtype)
+
+ w0 = ta.write(0, convert([[4.0, 5.0], [104.0, 105.0]]))
+ w1 = w0.write(1, convert([[6.0, 7.0], [106.0, 107.0]]))
+ w2 = w1.write(2, convert([[8.0, 9.0], [204.0, 205.0]]))
+
+ c0 = w2.concat()
+
+ self.assertAllEqual(
+ convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0],
+ [106.0, 107.0], [8.0, 9.0], [204.0, 205.0]]), c0.eval())
+
+ def testTensorArrayWriteConcat(self):
+ for dtype in self.numeric_tf_types:
+ self._testTensorArrayWriteConcat(dtype)
+
+ def _testTensorArrayUnpackRead(self, tf_dtype):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+
+ convert = _make_converter(tf_dtype)
+
+ # Unpack a vector into scalars
+ w0 = ta.unstack(convert([1.0, 2.0, 3.0]))
+ r0 = w0.read(0)
+ r1 = w0.read(1)
+ r2 = w0.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual(convert(1.0), d0)
+ self.assertAllEqual(convert(2.0), d1)
+ self.assertAllEqual(convert(3.0), d2)
+
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+
+ # Unpack a matrix into vectors
+ w1 = ta.unstack(convert([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]))
+ r0 = w1.read(0)
+ r1 = w1.read(1)
+ r2 = w1.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual(convert([1.0, 1.1]), d0)
+ self.assertAllEqual(convert([2.0, 2.1]), d1)
+ self.assertAllEqual(convert([3.0, 3.1]), d2)
+
+ # Reset ta because we're going to change the shape, else shape
+ # inference will throw an error.
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+
+ # Try unpacking an empty matrix, which should not cause an error.
+ w2 = ta.unstack(convert([[], [], []]))
+ r0 = w2.read(0)
+ r1 = w2.read(1)
+ r2 = w2.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual(convert([]), d0)
+ self.assertAllEqual(convert([]), d1)
+ self.assertAllEqual(convert([]), d2)
+
+ def _testTensorArrayUnpackReadMaybeLegacy(self):
+ for dtype in self.numeric_tf_types:
+ self._testTensorArrayUnpackRead(dtype)
+
+ def testTensorArrayUnpackRead(self):
+ self._testTensorArrayUnpackReadMaybeLegacy()
+
+ def _testTensorArraySplitRead(self, tf_dtype):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+
+ convert = _make_converter(tf_dtype)
+
+ # Split an empty vector
+ lengths = constant_op.constant([0, 0, 0])
+ w0 = ta.split(convert([]), lengths=lengths)
+ r0 = w0.read(0)
+ r1 = w0.read(1)
+ r2 = w0.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual(convert([]), d0)
+ self.assertAllEqual(convert([]), d1)
+ self.assertAllEqual(convert([]), d2)
+
+ # Split a vector
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+ lengths = constant_op.constant([1, 1, 1])
+ w0 = ta.split(convert([1.0, 2.0, 3.0]), lengths=lengths)
+ r0 = w0.read(0)
+ r1 = w0.read(1)
+ r2 = w0.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual(convert([1.0]), d0)
+ self.assertAllEqual(convert([2.0]), d1)
+ self.assertAllEqual(convert([3.0]), d2)
+
+ # Split a matrix
+ ta = tensor_array_ops.TensorArray(
+ dtype=tf_dtype, tensor_array_name="foo", size=3)
+ lengths = constant_op.constant([1, 1, 1])
+ w0 = ta.split(
+ convert([[1.0, 101.0], [2.0, 201.0], [3.0, 301.0]]), lengths=lengths)
+ r0 = w0.read(0)
+ r1 = w0.read(1)
+ r2 = w0.read(2)
+
+ d0, d1, d2 = session.run([r0, r1, r2])
+ self.assertAllEqual(convert([[1.0, 101.0]]), d0)
+ self.assertAllEqual(convert([[2.0, 201.0]]), d1)
+ self.assertAllEqual(convert([[3.0, 301.0]]), d2)
+
+ def testTensorArraySplitRead(self):
+ for dtype in self.numeric_tf_types:
+ self._testTensorArraySplitRead(dtype)
+
+ def testTensorGradArrayWriteRead(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=3)
+
+ w0 = ta.write(0, [[4.0]])
+ w1 = w0.write(1, [[1.0]])
+ w2 = w1.write(2, [[-3.0]])
+
+ g_ta = w2.grad("grad")
+
+ g_w0 = g_ta.write(0, [[5.0]])
+ g_w1 = g_w0.write(1, [[2.0]])
+ g_w2 = g_w1.write(2, [[-2.0]])
+
+ r0 = w2.read(0)
+ r1 = w2.read(1)
+ r2 = w2.read(2)
+
+ g_r0 = g_w2.read(0)
+ g_r1 = g_w2.read(1)
+ g_r2 = g_w2.read(2)
+
+ d0, d1, d2, g_d0, g_d1, g_d2 = session.run([r0, r1, r2, g_r0, g_r1, g_r2])
+ self.assertAllEqual([[4.0]], d0)
+ self.assertAllEqual([[1.0]], d1)
+ self.assertAllEqual([[-3.0]], d2)
+ self.assertAllEqual([[5.0]], g_d0)
+ self.assertAllEqual([[2.0]], g_d1)
+ self.assertAllEqual([[-2.0]], g_d2)
+
+ def testTensorGradArrayDynamicWriteRead(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=3)
+
+ w0 = ta.write(0, [[4.0]])
+ w1 = w0.write(1, [[1.0]])
+ w2 = w1.write(2, [[-3.0]])
+
+ g_ta = w2.grad("grad") # Get gradient array here so we know the shape
+
+ s = w2.size()
+ g_s = g_ta.size()
+
+ g_w0 = g_ta.write(0, [[5.0]])
+ g_w1 = g_w0.write(1, [[2.0]])
+ g_w2 = g_w1.write(2, [[-2.0]])
+
+ r0 = w2.read(0)
+ r1 = w2.read(1)
+ r2 = w2.read(2)
+
+ g_r0 = g_w2.read(0)
+ g_r1 = g_w2.read(1)
+ g_r2 = g_w2.read(2)
+
+ d0, d1, d2, g_d0, g_d1, g_d2, vs, g_vs = session.run(
+ [r0, r1, r2, g_r0, g_r1, g_r2, s, g_s])
+ self.assertAllEqual([[4.0]], d0)
+ self.assertAllEqual([[1.0]], d1)
+ self.assertAllEqual([[-3.0]], d2)
+ self.assertAllEqual([[5.0]], g_d0)
+ self.assertAllEqual([[2.0]], g_d1)
+ self.assertAllEqual([[-2.0]], g_d2)
+ self.assertAllEqual(3, vs)
+ self.assertAllEqual(3, g_vs)
+
+ def testTensorGradAccessTwiceReceiveSameObject(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3,
+ element_shape=[1, 2])
+ g_ta_0 = ta.grad("grad")
+ g_ta_1 = ta.grad("grad")
+
+ with ops.control_dependencies([g_ta_0.write(0, [[4.0, 5.0]]).flow]):
+ # Write with one gradient handle, read with another copy of it
+ r1_0 = g_ta_1.read(0)
+
+ t_g_ta_0, t_g_ta_1, d_r1_0 = session.run(
+ [g_ta_0.handle.op, g_ta_1.handle.op, r1_0])
+ self.assertAllEqual(t_g_ta_0, t_g_ta_1)
+ self.assertAllEqual([[4.0, 5.0]], d_r1_0)
+
+ def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+
+ # Test writing the wrong datatype
+ with self.assertRaisesOpError(
+ "TensorArray dtype is float but op has dtype int32"):
+ ta.write(-1, np.int32(7)).flow.eval()
+
+ def testTensorArrayReadWrongIndexOrDataTypeFails(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+
+ w0 = ta.write(0, [[4.0, 5.0]])
+
+ # Test reading wrong datatype
+ r0_bad = gen_data_flow_ops._tensor_array_read_v3(
+ handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
+ with self.assertRaisesOpError(
+ "TensorArray dtype is float but Op requested dtype double."):
+ r0_bad.eval()
+
+ # Test reading from a different index than the one we wrote to
+ w0.read(1)
+
+ def testTensorArraySplitIncompatibleShapesFails(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=3,
+ infer_shape=False)
+
+ with self.assertRaisesOpError(
+ r"value is not 1D"):
+ lengths = array_ops.placeholder(dtypes.int64)
+ ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
+
+ with self.assertRaisesOpError(
+ r"lengths must be equal: 1 vs. 2"):
+ ta.split([1.0, 2.0, 3.0], [1, 2, 3]).flow.eval()
+
+ with self.assertRaisesOpError(
+ r"value must have rank >= 1"):
+ ta.split(1.0, [1]).flow.eval()
+
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=2,
+ infer_shape=False)
+
+ with self.assertRaisesOpError(
+ r"TensorArray's size is not equal to the size of lengths "
+ r"\(1 vs. 2\)"):
+ ta.split([1.0], [1]).flow.eval()
+
+ def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False)
+
+ c = lambda x: np.asarray(x, dtype=dtype.as_numpy_dtype)
+
+ w0 = ta.write(2, c(3.0))
+ w1 = w0.write(2, c(4.0))
+
+ ta_grad = w1.grad("grad")
+
+ w0_grad = ta_grad.write(2, c(3.0))
+ w1_grad = w0_grad.write(2, c(4.0))
+ w2_grad = w1_grad.write(2, c(5.0))
+
+ # Assert that aggregation works correctly
+ self.assertAllEqual(c(12.00), w2_grad.read(2).eval())
+
+ # Using differing shapes causes an exception
+ wb0_grad = ta_grad.write(1, c(1.0))
+ wb1_grad = wb0_grad.write(1, c([1.0]))
+
+ with self.assertRaisesOpError(
+ r"Mismatched TensorArray sizes"):
+ wb1_grad.flow.eval()
+
+ def testTensorArrayWriteGradientAddMultipleAdds(self):
+ for dtype in self.numeric_tf_types:
+ self._testTensorArrayWriteGradientAddMultipleAdds(dtype)
+
+ def testMultiTensorArray(self):
+ with self.test_session(), self.test_scope():
+ h1 = tensor_array_ops.TensorArray(
+ size=1, dtype=dtypes.float32, tensor_array_name="foo")
+ w1 = h1.write(0, 4.0)
+ r1 = w1.read(0)
+
+ h2 = tensor_array_ops.TensorArray(
+ size=1, dtype=dtypes.float32, tensor_array_name="bar")
+
+ w2 = h2.write(0, 5.0)
+ r2 = w2.read(0)
+ r = r1 + r2
+ self.assertAllClose(9.0, r.eval())
+
+ def _testTensorArrayGradientWriteReadType(self, dtype):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.as_dtype(dtype),
+ tensor_array_name="foo",
+ size=3,
+ infer_shape=False)
+
+ c = lambda x: np.array(x, dtype=dtype)
+
+ value_0 = constant_op.constant(c([[4.0, 5.0]]))
+ value_1 = constant_op.constant(c([[3.0, 3.5]]))
+
+ w0 = ta.write(0, value_0)
+ w1 = w0.write(1, value_1)
+ r0 = w1.read(0)
+ r1 = w1.read(1)
+ r0_2 = w1.read(0)
+
+ # Test individual components' gradients
+ grad_just_r0 = gradients_impl.gradients(
+ ys=[r0], xs=[value_0], grad_ys=[c([[2.0, 3.0]])])
+ grad_just_r0_vals = session.run(grad_just_r0)
+ self.assertAllEqual(c([[2.0, 3.0]]), grad_just_r0_vals[0])
+
+ grad_r0_r0_2 = gradients_impl.gradients(
+ ys=[r0, r0_2],
+ xs=[value_0],
+ grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]])])
+ grad_r0_r0_2_vals = session.run(grad_r0_r0_2)
+ self.assertAllEqual(c([[3.0, 2.0]]), grad_r0_r0_2_vals[0])
+
+ grad_just_r1 = gradients_impl.gradients(
+ ys=[r1], xs=[value_1], grad_ys=[c([[-2.0, -4.0]])])
+ grad_just_r1_vals = session.run(grad_just_r1)
+ self.assertAllEqual(c([[-2.0, -4.0]]), grad_just_r1_vals[0])
+
+ # Test combined gradients
+ grad = gradients_impl.gradients(
+ ys=[r0, r0_2, r1],
+ xs=[value_0, value_1],
+ grad_ys=[c([[2.0, 3.0]]), c([[1.0, -1.0]]), c([[-2.0, -10.0]])])
+ grad_vals = session.run(grad)
+ self.assertEqual(len(grad_vals), 2)
+ self.assertAllEqual(c([[3.0, 2.0]]), grad_vals[0])
+ self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1])
+
+ def testTensorArrayGradientWriteRead(self):
+ for dtype in self.numeric_types:
+ self._testTensorArrayGradientWriteReadType(dtype)
+
+ def _testTensorArrayGradientWritePackConcatAndRead(self):
+ with self.test_session() as sess, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=2,
+ clear_after_read=False)
+
+ value_0 = constant_op.constant([-1.0, 1.0])
+ value_1 = constant_op.constant([-10.0, 10.0])
+
+ w0 = ta.write(0, value_0)
+ w1 = w0.write(1, value_1)
+ p0 = w1.stack()
+ r0 = w1.read(0)
+ s0 = w1.concat()
+
+ # Test gradient accumulation between read(0), pack(), and concat()
+ with ops.control_dependencies([p0, r0, s0]):
+ grad_r = gradients_impl.gradients(
+ ys=[p0, r0, s0],
+ xs=[value_0, value_1],
+ grad_ys=[
+ [[2.0, 3.0], [4.0, 5.0]], # stack gradient
+ [-0.5, 1.5], # read(0) gradient
+ [20.0, 30.0, 40.0, 50.0], # concat gradient
+ ])
+ grad_vals = sess.run(grad_r) # 2 + 2 entries
+
+ self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0])
+ self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1])
+
+ def testTensorArrayGradientWritePackConcatAndRead(self):
+ self._testTensorArrayGradientWritePackConcatAndRead()
+
+ def testTensorArrayReadTwice(self):
+ with self.test_session(), self.test_scope():
+ value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
+
+ ta_readtwice = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=2,
+ clear_after_read=False)
+ w_readtwice = ta_readtwice.unstack(value)
+ r0_readtwice = w_readtwice.read(0)
+ with ops.control_dependencies([r0_readtwice]):
+ r1_readtwice = w_readtwice.read(0)
+
+ self.assertAllEqual([1.0, -1.0], r1_readtwice.eval())
+
+ def _testTensorArrayGradientUnpackRead(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=2,
+ clear_after_read=False)
+
+ value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
+
+ w = ta.unstack(value)
+ r0 = w.read(0)
+ r0_1 = w.read(0)
+ r1 = w.read(1)
+
+ # Test combined gradients + aggregation of read(0)
+ grad = gradients_impl.gradients(
+ ys=[r0, r0_1, r1],
+ xs=[value],
+ grad_ys=[[2.0, 3.0], [-1.5, 1.5], [4.0, 5.0]])
+ grad_vals = session.run(grad)
+
+ self.assertEqual(len(grad_vals), 1)
+ self.assertAllEqual([[2.0 - 1.5, 3.0 + 1.5], [4.0, 5.0]], grad_vals[0])
+
+ def testTensorArrayGradientUnpackRead(self):
+ self._testTensorArrayGradientUnpackRead()
+
+ def testTensorArrayGradientSplitConcat(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=2)
+
+ value = constant_op.constant(
+ [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0], [1000.0, -1000.0]])
+
+ w = ta.split(value, [2, 2])
+ r = w.concat()
+
+ # Test combined gradients
+ grad = gradients_impl.gradients(
+ ys=[r],
+ xs=[value],
+ grad_ys=[[[2.0, -2.0], [20.0, -20.0], [200.0, -200.0],
+ [2000.0, -2000.0]]])
+ grad_vals = session.run(grad)
+
+ self.assertEqual(len(grad_vals), 1)
+ self.assertAllEqual([[2.0, -2.0], [20.0, -20.0], [200.0, -200.0],
+ [2000.0, -2000.0]],
+ grad_vals[0])
+
+ # TODO(phawkins): implement TensorArrayClose
+ # def testCloseTensorArray(self):
+ # with self.test_session() as session, self.test_scope():
+ # ta = tensor_array_ops.TensorArray(
+ # dtype=dtypes.float32, tensor_array_name="foo", size=3)
+ # c1 = ta.close()
+ # session.run(c1)
+
+ def testSizeTensorArray(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+ s = ta.size()
+ self.assertAllEqual(3, s.eval())
+
+ # TODO(phawkins): implement TensorArrayClose
+ # def testWriteCloseTensorArray(self):
+ # with self.test_session(), self.test_scope():
+ # ta = tensor_array_ops.TensorArray(
+ # dtype=dtypes.float32,
+ # tensor_array_name="foo",
+ # size=3,
+ # infer_shape=False)
+ # w0 = ta.write(0, [[4.0, 5.0]])
+ # w1 = w0.write(1, [3.0])
+ # w1.close().run() # Expected to run without problems
+
+ # TODO(phawkins): implement while loops.
+ # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
+ # np_dtype = dtype.as_numpy_dtype
+ # with self.test_session() as session, self.test_scope():
+ # v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5))
+ # var = variables.Variable(np.arange(100, 105, dtype=np_dtype))
+ # state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype))
+ # ta = tensor_array_ops.TensorArray(
+ # dtype=dtype,
+ # tensor_array_name="foo",
+ # size=0 if dynamic_size else 3,
+ # dynamic_size=dynamic_size)
+ # time_0 = array_ops.identity(0)
+
+ # def body(time, ta_t, state):
+ # sliced = array_ops.slice(
+ # v0, begin=array_ops.stack([time, 0]), size=[1, -1])
+ # sliced = array_ops.squeeze(sliced)
+ # out = sliced + var + state
+ # state += sliced
+ # ta_t = ta_t.write(time, out)
+ # return (time + 1, ta_t, state)
+
+ # (unused_0, h_final, unused_2) = control_flow_ops.while_loop(
+ # cond=lambda time, unused_1, unused_2: time < 3,
+ # body=body,
+ # loop_vars=(time_0, ta, state0),
+ # shape_invariants=(time_0.get_shape(), tensor_shape.unknown_shape(),
+ # tensor_shape.unknown_shape()),
+ # parallel_iterations=3)
+ # vout = h_final.stack()
+
+ # grad_val = -np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)
+ # v0_grad = gradients_impl.gradients([vout], [v0], [grad_val])[0]
+ # state0_grad = gradients_impl.gradients([vout], [state0], [grad_val])[0]
+ # var_grad = gradients_impl.gradients([vout], [var], [grad_val])[0]
+
+ # variables.global_variables_initializer().run()
+ # state0_t, var_t, v0_t, vout_t, v0_grad_t, var_grad_t, state0_grad_t = (
+ # session.run([state0, var, v0, vout, v0_grad, var_grad, state0_grad])
+ # )
+ # just_v0_grad_t, = session.run([v0_grad])
+
+ # # state = [ state0 | state0 + v0[0] | state0 + v0[0] + v0[1] ]
+ # # vout = [ v0[0] + var + state[0] |
+ # # v0[1] + var + state[1] |
+ # # v0[2] + var + state[2] ]
+ # # = [ v0[0] + var + state0 |
+ # # v0[1] + var + state0 + v0[0] |
+ # # v0[2] + var + state0 + v0[0] + v0[1] ]
+ # #
+ # # d(vout[0])/d(v0) = [1 | 0 | 0 ]
+ # # d(vout[1])/d(v0) = [1 | 1 | 0 ]
+ # # d(vout[2])/d(v0) = [1 | 1 | 1 ]
+ # # d(vout)/d(var) = [1 | 1 | 1]
+ # # d(vout)/d(state0) = [ 1 | 1 | 1 ]
+
+ # state_per_time = np.array(
+ # [state0_t, state0_t + v0_t[0, :],
+ # state0_t + v0_t[0, :] + v0_t[1, :]])
+
+ # # Compare forward prop
+ # self.assertAllClose(v0_t + var_t + state_per_time, vout_t)
+
+ # # Compare backward prop
+ # expected_v0_grad_t = np.array([
+ # grad_val[0, :] + grad_val[1, :] + grad_val[2, :],
+ # grad_val[1, :] + grad_val[2, :], grad_val[2, :]
+ # ])
+
+ # self.assertAllEqual(expected_v0_grad_t, v0_grad_t)
+ # self.assertAllEqual(expected_v0_grad_t, just_v0_grad_t)
+ # self.assertAllClose(grad_val.sum(axis=0), var_grad_t)
+ # self.assertAllClose(grad_val.sum(axis=0), state0_grad_t)
+
+ # def testWhileLoopWritePackGradients(self):
+ # self._testWhileLoopWritePackGradients(
+ # dynamic_size=False, dtype=dtypes.float32)
+ # # TODO(ebrevdo): re-enable when While supports non-float32 gradients.
+ # # self._testWhileLoopWritePackGradients(
+ # # dynamic_size=False, dtype=tf.int64)
+
+ # def testWhileLoopDynamicWritePackGradients(self):
+ # self._testWhileLoopWritePackGradients(
+ # dynamic_size=True, dtype=dtypes.float32)
+
+ # def testGradSerialTwoLoops(self):
+ # with self.test_session(), self.test_scope():
+ # num_steps = 100
+ # acc = tensor_array_ops.TensorArray(
+ # dtype=dtypes.float32,
+ # size=num_steps,
+ # clear_after_read=False,
+ # element_shape=tensor_shape.scalar())
+ # i = constant_op.constant(0, name="i")
+ # x = constant_op.constant(2.0, name="x")
+
+ # c = lambda i, acc: i < 5
+
+ # def b(i, acc):
+ # x1 = control_flow_ops.cond(
+ # math_ops.equal(i, 0), lambda: x,
+ # lambda: math_ops.multiply(acc.read(i - 1), 2.0))
+ # return i + 1, acc.write(i, x1)
+
+ # i1, acc1 = control_flow_ops.while_loop(c, b, [i, acc])
+
+ # z = constant_op.constant(0.0)
+
+ # def fn(i, acc):
+ # return i + 1, acc.write(i, z)
+
+ # _, acc2 = control_flow_ops.while_loop(lambda i, acc: i < num_steps, fn,
+ # [i1, acc1])
+
+ # r = acc2.stack()
+ # grad = gradients_impl.gradients(r, [x])[0]
+ # self.assertAllClose(31.0, grad.eval())
+
+ def testSumOfTwoReadVariablesWithoutRepeatGrad(self):
+ with self.test_session() as session, self.test_scope():
+ a = array_ops.identity(
+ np.arange(
+ 3 * 5, dtype=np.float32).reshape(3, 5) + 1)
+ b = array_ops.identity(
+ np.arange(
+ 3 * 5, dtype=np.float32).reshape(3, 5) + 1 + 3 * 5)
+ ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
+ ta = ta.write(0, a, name="write_a")
+ ta = ta.write(1, b, name="write_b")
+ c = (
+ ta.read(
+ 0, name="read_a_0") + # a + b
+ ta.read(
+ 1, name="read_b_0"))
+ g0 = -(np.arange(3 * 5, dtype=np.float32).reshape(3, 5) + 1)
+ grad_a = gradients_impl.gradients([c], [a], [g0])[0] # d(a+b)/da = 1
+ grad_b = gradients_impl.gradients([c], [b], [g0])[0] # d(a+b)/db = 1
+
+ # Test gradients calculated individually
+ grad_a_t, = session.run([grad_a])
+ self.assertAllEqual(grad_a_t, g0)
+
+ grad_b_t, = session.run([grad_b])
+ self.assertAllEqual(grad_b_t, g0)
+
+ # Test gradients calculated jointly
+ joint_grad_a_t, joint_grad_b_t = session.run([grad_a, grad_b])
+ self.assertAllEqual(joint_grad_a_t, g0)
+ self.assertAllEqual(joint_grad_b_t, g0)
+
+ def testWriteShape(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+ c0 = constant_op.constant([4.0, 5.0])
+ w0 = ta.write(0, c0)
+ r0 = w0.read(0)
+ self.assertAllEqual(c0.get_shape(), r0.get_shape())
+
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+ c1 = constant_op.constant([6.0, 7.0])
+ w1 = w0.write(1, c1)
+ r0 = w1.read(0)
+ r1 = w1.read(1)
+ self.assertAllEqual(c0.get_shape(), r0.get_shape())
+ self.assertAllEqual(c1.get_shape(), r1.get_shape())
+
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=3)
+ c2 = constant_op.constant([4.0, 5.0, 6.0])
+ with self.assertRaises(ValueError):
+ w0.write(0, c2)
+
+ def testPartlyUnknownShape(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, tensor_array_name="foo", size=6)
+
+ c0 = array_ops.placeholder(dtypes.float32, [None, None, None, 3])
+ w0 = ta.write(0, c0)
+ r0 = w0.read(0)
+ self.assertAllEqual([None, None, None, 3], r0.get_shape().as_list())
+
+ c1 = array_ops.placeholder(dtypes.float32, [None, None, None, 3])
+ w1 = w0.write(1, c1)
+ r1 = w1.read(0)
+ self.assertAllEqual([None, None, None, 3], r1.get_shape().as_list())
+
+ # Writing less specific shape (doesn't change type.)
+ c2 = array_ops.placeholder(dtypes.float32, [None, None, None, None])
+ w2 = w1.write(2, c2)
+ r2 = w2.read(0)
+ self.assertAllEqual([None, None, None, 3], r2.get_shape().as_list())
+
+ # Writing more specific shape in one dimension and less specific in
+ # another.
+ c3 = array_ops.placeholder(dtypes.float32, [None, None, 2, None])
+ w3 = w2.write(3, c3)
+ r3 = w3.read(0)
+ self.assertAllEqual([None, None, 2, 3], r3.get_shape().as_list())
+
+ # Writing partly defined shape using TensorArray.scatter.
+ c4 = array_ops.placeholder(dtypes.float32, [2, None, 4, 2, 3])
+ w4 = w3.scatter([4, 5], c4)
+ r4 = w4.read(0)
+ self.assertAllEqual([None, 4, 2, 3], r4.get_shape().as_list())
+
+ # Writing fully defined shape using TensorArray.split.
+ c5 = array_ops.placeholder(dtypes.float32, [10, 4, 2, 3])
+ w5 = w4.split(c5, constant_op.constant([5, 5]))
+ r5 = w5.read(0)
+ self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list())
+
+ def _testUnpackShape(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=0,
+ infer_shape=True)
+ value = constant_op.constant(
+ [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]])
+ w0 = ta.unstack(value)
+ r0 = w0.read(0)
+ self.assertAllEqual((2,), r0.get_shape())
+
+ c1 = constant_op.constant([4.0, 5.0])
+ w1 = w0.write(3, c1)
+ r1 = w1.read(0)
+ self.assertAllEqual(c1.get_shape(), r1.get_shape())
+
+ c2 = constant_op.constant([4.0, 5.0, 6.0])
+ with self.assertRaises(ValueError):
+ w1.write(4, c2)
+
+ def testUnpackShape(self):
+ self._testUnpackShape()
+
+ def testSplitShape(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=0,
+ infer_shape=True)
+ value = constant_op.constant([[1.0, -1.0], [2.0, -2.0], [3.0, -3.0]])
+ w0 = ta.split(value, [1, 1, 1])
+ r0 = w0.read(0)
+ self.assertAllEqual((1, 2), r0.get_shape())
+
+ ta1 = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo1",
+ size=0,
+ infer_shape=True)
+ w0 = ta1.split(value, [1, 2])
+ r0 = w0.read(0)
+ self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape())
+
+ def testWriteUnknownShape(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=3,
+ infer_shape=True)
+ c0 = array_ops.placeholder(dtypes.float32)
+ w0 = ta.write(0, c0)
+ r0 = w0.read(0)
+ self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape())
+
+ def _testGradientWhenNotAllComponentsRead(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2)
+ x = constant_op.constant([2.0, 3.0])
+ w = ta.unstack(x)
+ r0 = w.read(0)
+ # calculate (dr0/dx0, dr0/dx1). since r0 = x0, gradients are (1, 0).
+ grad_r0 = gradients_impl.gradients(ys=[r0], xs=[x], grad_ys=[1.0])
+ grad_r0_vals = session.run(grad_r0)[0]
+ self.assertAllEqual(grad_r0_vals, [1.0, 0.0])
+
+ def testGradientWhenNotAllComponentsRead(self):
+ self._testGradientWhenNotAllComponentsRead()
+
+ def _testTensorArrayEvalEmpty(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, size=0, infer_shape=False)
+ with self.assertRaisesOpError(
+ "TensorArray has size zero, but element shape <unknown> is not fully "
+ "defined. Currently only static shapes are supported when packing "
+ "zero-size TensorArrays."):
+ ta.stack().eval()
+
+ def testTensorArrayEvalEmpty(self):
+ self._testTensorArrayEvalEmpty()
+
+ def _testTensorArrayEvalEmptyWithDefault(self):
+ with self.test_session(), self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32, size=0, infer_shape=True)
+ self.assertEqual(0, ta.size().eval())
+ ta = ta.unstack(array_ops.zeros([0, 3, 5]))
+ packed = ta.stack()
+ self.assertAllEqual([0, 3, 5], packed.eval().shape)
+ # Concatenating zero tensors along their first dimension gives a
+ # first dimension of zero
+ self.assertAllEqual([0, 5], ta.concat().eval().shape)
+
+ def testTensorArrayEvalEmptyWithDefault(self):
+ self._testTensorArrayEvalEmptyWithDefault()
+
+ def testTensorArrayScatterReadAndGradients(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=10)
+
+ indices = constant_op.constant([1, 8])
+ value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
+
+ w = ta.scatter(indices, value)
+ r0 = w.read(1)
+ r1 = w.read(8)
+
+ # Test combined gradients + aggregation of read(0)
+ grad = gradients_impl.gradients(
+ ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]])
+ read_vals, grad_vals = session.run([[r0, r1], grad])
+
+ self.assertEqual(len(read_vals), 2)
+ self.assertEqual(len(grad_vals), 1)
+ self.assertAllEqual([1.0, -1.0], read_vals[0])
+ self.assertAllEqual([10.0, -10.0], read_vals[1])
+ self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
+
+ def testTensorArrayWriteGatherAndGradients(self):
+ with self.test_session() as session, self.test_scope():
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=10)
+
+ values = constant_op.constant([[1.0 * x, -1.0 * x] for x in range(10)])
+ indices = constant_op.constant([1, 8])
+
+ w = ta.unstack(values)
+ g = w.gather(indices)
+
+ # Test combined gradients + aggregation of read(0)
+ grad = gradients_impl.gradients(
+ ys=[g], xs=[values], grad_ys=[[[2.0, 3.0], [4.0, 5.0]]])
+ g_vals, grad_vals = session.run([[g], grad])
+
+ # Gradients for 8 of the 10 unread components are zero.
+ expected_grad = np.zeros((10, 2))
+ expected_grad[1] = [2.0, 3.0]
+ expected_grad[8] = [4.0, 5.0]
+
+ self.assertEqual(len(g_vals), 1)
+ self.assertEqual(len(grad_vals), 1)
+ self.assertAllEqual([[1.0, -1.0], [8.0, -8.0]], g_vals[0])
+ self.assertAllEqual(expected_grad, grad_vals[0])
+
+ def testTensorArrayIdentity(self):
+ with self.test_session() as session, self.test_scope():
+ ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2,
+ infer_shape=False)
+ ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4,
+ infer_shape=True)
+
+ ta0 = ta0.write(0, 0.)
+ ta1 = ta1.write(0, 1)
+
+ v0 = resource_variable_ops.ResourceVariable(0)
+ v1 = resource_variable_ops.ResourceVariable(0)
+
+ with ops.control_dependencies([v0.assign_add(1)]):
+ ta0 = ta0.identity()
+
+ with ops.control_dependencies([v1.assign_add(1)]):
+ ta1 = ta1.identity()
+
+ read0 = ta0.read(0)
+ read1 = ta1.read(0)
+
+ size0 = ta0.size()
+ size1 = ta1.size()
+
+ # Tests correct properties on new TensorArrays.
+ self.assertEqual(dtypes.float32, ta0.dtype)
+ self.assertEqual(dtypes.int32, ta1.dtype)
+ self.assertEqual(tensor_shape.unknown_shape(), read0.get_shape())
+ self.assertEqual(tensor_shape.scalar(), read1.get_shape())
+
+ variables.global_variables_initializer().run()
+
+ read0_v, read1_v, size0_v, size1_v = session.run(
+ (read0, read1, size0, size1))
+
+ # Tests that the control dependencies was added and executed.
+ self.assertEqual(1, v0.eval())
+ self.assertEqual(1, v1.eval())
+
+ # Tests correct TensorArray.
+ self.assertEqual(read0_v, 0)
+ self.assertEqual(read1_v, 1)
+ self.assertEqual(size0_v, 2)
+ self.assertEqual(size1_v, 4)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index f7fe186cf8..79549644ea 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -54,16 +54,20 @@ class XLATestCase(test.TestCase):
self.device = FLAGS.test_device
self.has_custom_call = (self.device == 'XLA_CPU')
self.all_tf_types = [
- dtypes.DType(types_pb2.DataType.Value(name))
+ dtypes.as_dtype(types_pb2.DataType.Value(name))
for name in FLAGS.types.split(',')
]
- self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
- self.int_types = [
- dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_integer
+ self.int_tf_types = [
+ dtype for dtype in self.all_tf_types if dtype.is_integer
]
- self.float_types = [
- dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_floating
+ self.float_tf_types = [
+ dtype for dtype in self.all_tf_types if dtype.is_floating
]
+ self.numeric_tf_types = self.int_tf_types + self.float_tf_types
+
+ self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
+ self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types]
+ self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types]
self.numeric_types = self.int_types + self.float_types
# Parse the manifest file, if any, into a regex identifying tests to
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index c4cbaebb25..36a6c90af4 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -89,6 +89,8 @@ Status BackwardsConstAnalysis(const Graph& g,
{"StridedSliceGrad", "end"},
{"StridedSliceGrad", "strides"},
{"Sum", "reduction_indices"},
+ {"TensorArrayV3", "size"},
+ {"TensorArraySplitV3", "lengths"},
{"Tile", "multiples"},
{"Transpose", "perm"}};
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 81b065689d..a434c74680 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -55,6 +55,7 @@ tf_kernel_library(
"spacetobatch_op.cc",
"split_op.cc",
"strided_slice_op.cc",
+ "tensor_array_ops.cc",
"tile_ops.cc",
"training_ops.cc",
"transpose_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
index d6897d6e33..620fc84437 100644
--- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
@@ -49,14 +49,15 @@ class ArgOp : public XlaOpKernel {
return;
}
- XlaContext& tc = XlaContext::Get(ctx);
- const XlaContext::Argument& arg = tc.args()[index_];
+ XlaContext& xc = XlaContext::Get(ctx);
+ const XlaContext::Argument& arg = xc.args()[index_];
if (arg.is_variable) {
- // We use the argument position of the variable input as a unique ID.
// TODO(phawkins): this code assumes that variables do not alias.
- OP_REQUIRES_OK(ctx, tc.CreateVariable(index_, arg.name, arg.value.type,
- arg.value.handle));
- ctx->SetVariableOutput(0, index_);
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, xc.CreateVariable(index_, arg.name, arg.value.type,
+ arg.value.handle, &var));
+ var->tensor_array_size = arg.tensor_array_size;
+ ctx->SetVariableOutput(0, var);
} else if (arg.value.is_constant) {
ctx->SetConstantOutput(0, arg.value.constant_value);
} else {
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
new file mode 100644
index 0000000000..de542d55e8
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -0,0 +1,538 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA TensorArray operators.
+
+#include <limits>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+// Since the element shape is not always provided to the TensorArrayV3 operator,
+// we must support lazily initialization of the TensorArray at the time of the
+// first write.
+// If a TensorArray `var` has not been initialized, constructs storage for the
+// TensorArray with elements of `elem_shape`. For both initialized and
+// uninitialized TensorArrays, checks that the tensor has a type compatible with
+// 'dtype' and shape compatible with 'elem_shape'.
+Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
+ XlaVariable* var, DataType dtype,
+ const TensorShape& elem_shape) {
+ if (var->type != dtype) {
+ return errors::InvalidArgument(
+ "TensorArray dtype is ", DataTypeString(var->type),
+ " but op has dtype ", DataTypeString(dtype), ".");
+ }
+
+ TF_RET_CHECK(var->tensor_array_size >= 0)
+ << var->name << " size " << var->tensor_array_size;
+ TensorShape ta_shape;
+ ta_shape.AddDim(var->tensor_array_size);
+ ta_shape.AppendShape(elem_shape);
+
+ if (var->value.handle() == 0) {
+ // TensorArray has not been initialized.
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type);
+ var->value = builder->Broadcast(zero, ta_shape.dim_sizes());
+ } else {
+ // Checks the elem_shape matches the TensorArray shape.
+ auto shape_or_status = builder->GetShape(var->value);
+ if (!shape_or_status.ok()) {
+ return shape_or_status.status();
+ }
+ TensorShape shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie());
+ if (ta_shape != shape) {
+ return errors::InvalidArgument(
+ "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
+ shape.DebugString());
+ }
+ }
+ return Status::OK();
+}
+
+// Pads 'x' with 'count' zero indices. 'x' must have 1 element.
+xla::ComputationDataHandle PadIndexWithZeros(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ int count) {
+ xla::ComputationDataHandle zero = builder->ConstantR1<int32>({0});
+ std::vector<xla::ComputationDataHandle> xs(count + 1, zero);
+ xs[0] = builder->Reshape(x, {1});
+ return builder->ConcatInDim(xs, 0);
+}
+
+// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the
+// relevant slice of 'operand'.
+xla::ComputationDataHandle DynamicAddSlice(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand,
+ const xla::ComputationDataHandle& update,
+ const gtl::ArraySlice<int64>& update_dims,
+ const xla::ComputationDataHandle& start_indices) {
+ xla::ComputationDataHandle current =
+ builder->DynamicSlice(operand, start_indices, update_dims);
+ xla::ComputationDataHandle sum = builder->Add(current, update);
+ return builder->DynamicUpdateSlice(operand, sum, start_indices);
+}
+
+class TensorArrayOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("element_shape", &element_shape_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ bool dynamic_size;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dynamic_size", &dynamic_size));
+ OP_REQUIRES(
+ ctx, !dynamic_size,
+ errors::Unimplemented(
+ "TensorArrays with dynamic size are not supported by XLA."));
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_array_name", &tensor_array_name_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ int64 size;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size));
+ OP_REQUIRES(ctx, size >= 0,
+ errors::InvalidArgument("TensorArray size must be >= 0"));
+
+ xla::ComputationBuilder* b = ctx->builder();
+ b->set_die_immediately_on_error(true);
+
+ // Initializes the TensorArray value if we know the element shape.
+ // Otherwise, defer initialization to the first write.
+ xla::ComputationDataHandle value;
+ if (element_shape_.IsFullyDefined()) {
+ TensorShape shape;
+ CHECK(element_shape_.AsTensorShape(&shape));
+ TensorShape ta_shape;
+ ta_shape.AddDim(size);
+ ta_shape.AppendShape(shape);
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_);
+ value = b->Broadcast(zero, ta_shape.dim_sizes());
+ }
+
+ XlaContext& xc = XlaContext::Get(ctx);
+ XlaVariable* var;
+ string name = strings::StrCat("TensorArray: ", tensor_array_name_);
+ OP_REQUIRES_OK(ctx,
+ xc.CreateVariable(-1, std::move(name), dtype_, value, &var));
+ var->tensor_array_size = size;
+ ctx->SetVariableOutput(0, var);
+ ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
+ }
+
+ private:
+ PartialTensorShape element_shape_;
+ DataType dtype_;
+ string tensor_array_name_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp);
+
+class TensorArrayWriteOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayWriteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+
+ TensorShape elem_shape = ctx->InputShape(2);
+
+ // Initializes the TensorArray, if the element shape was not known at
+ // construction time.
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+
+ xla::ComputationDataHandle ta = var->value;
+ xla::ComputationDataHandle index = ctx->Input(1);
+ xla::ComputationDataHandle value = ctx->Input(2);
+
+ // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
+ auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
+
+ TensorShape slice_shape = elem_shape;
+ slice_shape.InsertDim(0, 1LL);
+ auto update = b->Reshape(value, slice_shape.dim_sizes());
+
+ xla::ComputationDataHandle written =
+ DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written));
+ ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayWriteOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayWriteV3"), TensorArrayWriteOp);
+
+class TensorArrayReadOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayReadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType ta_type;
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
+ OP_REQUIRES(ctx, ta_type == dtype_,
+ errors::InvalidArgument(
+ "TensorArray dtype is ", DataTypeString(ta_type),
+ " but Op requested dtype ", DataTypeString(dtype_), "."));
+ OP_REQUIRES(ctx, ta_shape.dims() >= 1,
+ errors::InvalidArgument("TensorArray rank must be >= 1"));
+
+ xla::ComputationBuilder* b = ctx->builder();
+
+ xla::ComputationDataHandle ta;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+ xla::ComputationDataHandle index = ctx->Input(1);
+
+ // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
+ auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1);
+
+ auto slice_shape = ta_shape.dim_sizes();
+ slice_shape[0] = 1LL;
+
+ xla::ComputationDataHandle read =
+ b->DynamicSlice(ta, start_indices, slice_shape);
+
+ // Remove the leading '1' dimension.
+ std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
+ ctx->SetOutput(0, b->Reshape(read, value_shape));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayReadOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayReadV3"), TensorArrayReadOp);
+
+class TensorArrayGatherOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType ta_type;
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
+ OP_REQUIRES(ctx, ta_type == dtype_,
+ errors::InvalidArgument("TensorArray type mismatch"));
+ OP_REQUIRES(ctx, ta_shape.dims() >= 1,
+ errors::InvalidArgument("TensorArray rank must be >= 1"));
+
+ const TensorShape indices_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, indices_shape.dims() >= 1,
+ errors::InvalidArgument("indices must be rank 1"));
+ const int num_indices = indices_shape.dim_size(0);
+ auto indices = ctx->Input(1);
+
+ xla::ComputationBuilder* b = ctx->builder();
+
+ xla::ComputationDataHandle ta;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+
+ // For each index in `indices`, add the corresponding slice to `slices`.
+ std::vector<xla::ComputationDataHandle> slices(num_indices);
+ for (int i = 0; i < num_indices; ++i) {
+ // Slices the i-th index out of `indices`, and pads it with zeros in the
+ // minor dimensions to form an index into the TensorArray storage.
+ auto index = b->Slice(indices, {i}, {i + 1});
+
+ // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
+ auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1);
+
+ auto slice_shape = ta_shape.dim_sizes();
+ slice_shape[0] = 1LL;
+
+ slices[i] = b->DynamicSlice(ta, start_indices, slice_shape);
+ }
+
+ xla::ComputationDataHandle gather;
+ if (slices.empty()) {
+ auto shape = ta_shape.dim_sizes();
+ shape[0] = 0;
+ gather = b->Broadcast(XlaHelpers::Zero(b, dtype_), shape);
+ } else {
+ gather = b->ConcatInDim(slices, 0);
+ }
+ ctx->SetOutput(0, gather);
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGatherOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayGatherV3"), TensorArrayGatherOp);
+
+class TensorArrayScatterOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayScatterOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+
+ const TensorShape value_shape = ctx->InputShape(2);
+
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ TensorShape elem_shape = value_shape;
+ elem_shape.RemoveDim(0);
+ OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+
+ const TensorShape indices_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, indices_shape.dims() >= 1,
+ errors::InvalidArgument("indices must be rank 1"));
+ const int num_indices = indices_shape.dim_size(0);
+ const xla::ComputationDataHandle indices = ctx->Input(1);
+
+ xla::ComputationDataHandle ta = var->value;
+ const xla::ComputationDataHandle value = ctx->Input(2);
+
+ auto slice_dims = value_shape.dim_sizes();
+ slice_dims[0] = 1LL;
+
+ std::vector<int64> value_starts(value_shape.dims(), 0);
+ auto value_ends = value_shape.dim_sizes();
+
+ // For every (index, value) pair, update the corresponding TensorArray
+ // storage.
+ for (int i = 0; i < num_indices; ++i) {
+ // Slice out part of the value.
+ value_starts[0] = i;
+ value_ends[0] = i + 1;
+ auto slice = b->Slice(value, value_starts, value_ends);
+
+ // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
+ auto index = b->Slice(indices, {i}, {i + 1});
+ auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
+ ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
+ }
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
+ ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayScatterOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayScatterV3"), TensorArrayScatterOp);
+
+class TensorArrayConcatOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType ta_type;
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
+ OP_REQUIRES(ctx, ta_type == dtype_,
+ errors::InvalidArgument("TensorArray type mismatch"));
+ OP_REQUIRES(ctx, ta_shape.dims() >= 1,
+ errors::InvalidArgument("TensorArray rank must be >= 1"));
+
+ xla::ComputationBuilder* b = ctx->builder();
+
+ xla::ComputationDataHandle ta;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+
+ auto ta_dims = ta_shape.dim_sizes();
+ std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
+ shape[0] *= ta_shape.dim_size(0);
+ ctx->SetOutput(0, b->Reshape(ta, shape));
+
+ Tensor lengths(DT_INT64, {ta_dims[0]});
+ auto lengths_vec = lengths.vec<int64>();
+ for (int i = 0; i < ta_dims[0]; ++i) {
+ lengths_vec(i) = ta_dims[1];
+ }
+ ctx->SetConstantOutput(1, lengths);
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayConcatOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayConcatV3"), TensorArrayConcatOp);
+
+class TensorArraySplitOp : public XlaOpKernel {
+ public:
+ explicit TensorArraySplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ std::vector<int64> lengths;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths));
+
+ int64 length = 0;
+ if (!lengths.empty()) {
+ length = lengths[0];
+ for (int i = 1; i < lengths.size(); ++i) {
+ OP_REQUIRES(ctx, lengths[i] == length,
+ errors::InvalidArgument("lengths must be equal: ", length,
+ " vs. ", lengths[i]));
+ }
+ }
+
+ TensorShape value_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, value_shape.dims() >= 1,
+ errors::InvalidArgument("value must have rank >= 1, got ",
+ value_shape.DebugString()));
+ TensorShape elem_shape = value_shape;
+ elem_shape.set_dim(0, length);
+
+ xla::ComputationBuilder* b = ctx->builder();
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+ xla::ComputationDataHandle ta = var->value;
+
+ TensorShape ta_shape;
+ ta_shape.AddDim(var->tensor_array_size);
+ ta_shape.AppendShape(elem_shape);
+
+ OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size,
+ errors::InvalidArgument(
+ "TensorArray's size is not equal to the size of lengths (",
+ lengths.size(), " vs. ", var->tensor_array_size, ")"));
+
+ const xla::ComputationDataHandle value = ctx->Input(1);
+
+ OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(),
+ errors::InvalidArgument("mismatched element count ",
+ value_shape.DebugString(), " vs. ",
+ ta_shape.DebugString()));
+
+ ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
+
+ ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp);
+
+class TensorArraySizeOp : public XlaOpKernel {
+ public:
+ explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ Tensor size_tensor(DT_INT32, {});
+ size_tensor.scalar<int32>()() = static_cast<int32>(var->tensor_array_size);
+ ctx->SetConstantOutput(0, size_tensor);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySizeOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArraySizeV3"), TensorArraySizeOp);
+
+class TensorArrayGradOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("source", &source_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+
+ DataType ta_type;
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
+ OP_REQUIRES(ctx, ta_shape.dims() >= 1,
+ errors::InvalidArgument("TensorArray rank must be >= 1"));
+
+ // Finds or looks up the corresponding gradient TensorArray, which stores
+ // gradients computed during backpropagation.
+ XlaVariable*& gradient = var->tensor_array_gradient[source_];
+ if (!gradient) {
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type);
+ xla::ComputationDataHandle value =
+ b->Broadcast(zero, ta_shape.dim_sizes());
+
+ XlaContext& xc = XlaContext::Get(ctx);
+ string name = strings::StrCat("TensorArrayGrad: ", var->name);
+ OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type,
+ value, &gradient));
+ gradient->tensor_array_size = var->tensor_array_size;
+ }
+
+ ctx->SetVariableOutput(0, gradient);
+ ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
+ }
+
+ private:
+ string source_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGradOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
index 362a101895..1d0098591e 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -119,6 +119,4 @@ void XlaExpression::set_constant_value(Tensor value) {
constant_value_ = std::move(value);
}
-void XlaExpression::set_variable_id(int id) { variable_id_ = id; }
-
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h
index 1ee96e5e6c..75630bee39 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.h
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h
@@ -64,6 +64,39 @@ class XlaCompilationDevice : public LocalDevice {
std::unique_ptr<XlaCompilationAllocator> allocator_;
};
+struct XlaVariable {
+ // If this variable is visible externally, what was its argument number?
+ int arg_num = -1;
+
+ // A descriptive name for the variable, used in error messages.
+ string name;
+
+ // Current type and value of the variable. Uninitialized variables are
+ // represented by a default (zero) handle and type DT_INVALID.
+ // While the type of a variable is notionally fixed during execution, when
+ // a variable is first initialized we do not yet know its type, so we keep
+ // track of its type dynamically.
+ DataType type = DT_INVALID;
+ xla::ComputationDataHandle value;
+
+ // Value of the variable at computation entry. Used to detect which
+ // variables have new values that need to be written back.
+ xla::ComputationDataHandle initial_value;
+
+ // We treat TensorArrays as a Variable with some extra metadata.
+
+ // 'tensor_array_size' stores the expected size of the TensorArray. We need
+ // to store this since sometimes TensorArrays must be initialized lazily since
+ // we do not know the element shape at construction time.
+ int64 tensor_array_size = -1;
+
+ // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes
+ // to an XlaVariable containing the gradient TensorArrays. We store a pointer
+ // here since there should only be one gradient TensorArray per 'source'
+ // string, irrespective of the number of calls to TensorArrayGrad.
+ std::unordered_map<string, XlaVariable*> tensor_array_gradient;
+};
+
// A XlaExpression wraps an XLA computation. Each Tensor on an
// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor
// matches the shape of the subcomputation in the ComputationDataHandle. Each
@@ -82,8 +115,8 @@ class XlaExpression {
bool has_constant_value() const { return has_constant_value_; }
const Tensor& constant_value() const { return constant_value_; }
- void set_variable_id(int id);
- int variable_id() const { return variable_id_; }
+ void set_variable(XlaVariable* variable) { variable_ = variable; }
+ XlaVariable* variable() const { return variable_; }
private:
// The XLA handle of the expression's computation.
@@ -95,7 +128,7 @@ class XlaExpression {
bool has_constant_value_ = false;
Tensor constant_value_;
- int variable_id_ = -1;
+ XlaVariable* variable_ = nullptr; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
};
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 820e8dd56f..580ce3d802 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -59,8 +59,9 @@ Status CheckSignature(const DataTypeVector& types,
bool XlaCompiler::Argument::operator==(
const XlaCompiler::Argument& other) const {
- if (std::tie(kind, type, shape, name) !=
- std::tie(other.kind, other.type, other.shape, other.name)) {
+ if (std::tie(kind, type, shape, name, tensor_array_size) !=
+ std::tie(other.kind, other.type, other.shape, other.name,
+ other.tensor_array_size)) {
return false;
}
if (constant_value.shape() != other.constant_value.shape()) {
@@ -264,8 +265,9 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
switch (args[i].kind) {
case XlaCompiler::Argument::kVariable:
variables.push_back(i);
- context_arg.value.is_constant = false;
context_arg.is_variable = true;
+ context_arg.value.is_constant = false;
+ context_arg.tensor_array_size = args[i].tensor_array_size;
break;
case XlaCompiler::Argument::kParameter:
parameters.push_back(i);
@@ -274,6 +276,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
case XlaCompiler::Argument::kUninitializedVariable:
context_arg.is_variable = true;
context_arg.value.is_constant = true;
+ context_arg.tensor_array_size = args[i].tensor_array_size;
break;
case XlaCompiler::Argument::kConstant:
context_arg.value.is_constant = true;
@@ -337,7 +340,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
// type of the final output.
Status BuildComputation(
const std::vector<XlaContext::HandleOrConstant>& retvals,
- const std::unordered_map<int, XlaContext::Variable>& variable_map,
+ const std::vector<std::unique_ptr<XlaVariable>>& variables,
bool has_side_effects, bool return_updated_values_for_all_variables,
xla::ComputationBuilder* builder, xla::Computation* computation,
int* num_nonconst_outputs,
@@ -352,27 +355,27 @@ Status BuildComputation(
*num_nonconst_outputs = elems.size();
// Add return values for variables whose values have changed.
- std::vector<std::pair<int, const XlaContext::Variable*>> variables;
- variables.reserve(variable_map.size());
- for (const auto& entry : variable_map) {
- variables.emplace_back(entry.first, &entry.second);
+ std::vector<const XlaVariable*> arg_vars;
+ arg_vars.reserve(variables.size());
+ for (const auto& var : variables) {
+ if (var->arg_num >= 0) {
+ arg_vars.push_back(var.get());
+ }
}
- std::sort(variables.begin(), variables.end(),
- [](const std::pair<int, const XlaContext::Variable*>& a,
- const std::pair<int, const XlaContext::Variable*>& b) {
- return a.first < b.first;
+ std::sort(arg_vars.begin(), arg_vars.end(),
+ [](const XlaVariable* a, const XlaVariable* b) {
+ return a->arg_num < b->arg_num;
});
- for (const auto& entry : variables) {
- bool modified =
- entry.second->value.handle() != entry.second->initial_value.handle();
+ for (const XlaVariable* var : arg_vars) {
+ bool modified = var->value.handle() != var->initial_value.handle();
if (return_updated_values_for_all_variables || modified) {
variable_updates->emplace_back();
XlaCompiler::VariableUpdate& update = variable_updates->back();
- update.input_index = entry.first;
- update.type = entry.second->type;
+ update.input_index = var->arg_num;
+ update.type = var->type;
update.modified = modified;
- elems.push_back(entry.second->value);
+ elems.push_back(var->value);
}
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 15f723ad78..1314305532 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -114,6 +114,10 @@ class XlaCompiler {
// The name of this argument, used for debugging.
string name;
+ // For a kVariable or kUninitializedVariable corresponding to a TensorArray,
+ // what is the tensor array's declared size?
+ int64 tensor_array_size = -1;
+
bool operator==(const Argument& other) const;
};
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 3592680303..4440b53069 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
@@ -53,6 +54,10 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context";
return *context;
}
+/* static */ XlaContext& XlaContext::Get(const XlaOpKernelContext* ctx) {
+ return Get(ctx->op_kernel_context());
+}
+
void XlaContext::set_args(std::vector<Argument> args) {
args_ = std::move(args);
}
@@ -124,29 +129,19 @@ void XlaContext::AddSideEffects() {
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
-Status XlaContext::CreateVariable(int variable_id, string name, DataType type,
- const xla::ComputationDataHandle& handle) {
- auto result = variables_.emplace(variable_id, Variable());
- if (!result.second) {
- return errors::InvalidArgument("Duplicate ID ", variable_id,
- " for variable ", name);
- }
- Variable& var = result.first->second;
+Status XlaContext::CreateVariable(int arg_num, string name, DataType type,
+ const xla::ComputationDataHandle& handle,
+ XlaVariable** variable) {
+ variables_.emplace_back(new XlaVariable);
+ *variable = variables_.back().get();
+ XlaVariable& var = **variable;
+ var.arg_num = arg_num;
var.name = std::move(name);
var.type = type;
var.initial_value = var.value = handle;
return Status::OK();
}
-Status XlaContext::GetVariable(int variable_id, Variable** variable) {
- auto it = variables_.find(variable_id);
- if (it == variables_.end()) {
- return errors::InvalidArgument("Unknown variable ID ", variable_id);
- }
- *variable = &it->second;
- return Status::OK();
-}
-
const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) {
return LookupOrCreate(type, &max_func_, [this, type] {
const string type_string = DataTypeString(type);
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 657ead5391..3978baaf63 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -21,7 +21,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -31,6 +30,8 @@ limitations under the License.
namespace tensorflow {
+class XlaOpKernelContext;
+
// The XlaContext is the data structure that holds the state of an XLA
// compilation, that is accessible from OpKernelContexts when compiling a
// subgraph of Ops using XLA.
@@ -55,16 +56,16 @@ class XlaContext : public ResourceBase {
string name;
// Is this a variable?
- bool is_variable;
+ bool is_variable = false;
HandleOrConstant value;
+
+ int64 tensor_array_size = -1;
};
// Retrieves the XlaContext of the current compilation.
static XlaContext& Get(const OpKernelContext* ctx);
- static XlaContext& Get(const XlaOpKernelContext* ctx) {
- return Get(ctx->op_kernel_context());
- }
+ static XlaContext& Get(const XlaOpKernelContext* ctx);
// Creates a new XlaContext.
XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
@@ -105,33 +106,16 @@ class XlaContext : public ResourceBase {
bool has_side_effects() const { return has_side_effects_; }
- struct Variable {
- // A descriptive name for the variable, used in error messages.
- string name;
-
- // Current type and value of the variable. Uninitialized variables are
- // represented by a default (zero) handle and type DT_INVALID.
- // While the type of a variable is notionally fixed during execution, when
- // a variable is first initialized we do not yet know its type, so we keep
- // track of its type dynamically.
- DataType type = DT_INVALID;
- xla::ComputationDataHandle value;
-
- // Value of the variable at computation entry. Used to detect which
- // variables have new values that need to be written back.
- xla::ComputationDataHandle initial_value;
- };
-
// Creates a variable with variable `variable_id` and initial type `type` and
// value `handle`. `name` is a descriptive name for use in error messages.
// Fails if the variable already exists.
- Status CreateVariable(int variable_id, string name, DataType type,
- const xla::ComputationDataHandle& handle);
+ Status CreateVariable(int arg_num, string name, DataType type,
+ const xla::ComputationDataHandle& handle,
+ XlaVariable** variable);
- // Retrieves variable `variable_id`. Fails if the variable does not exist.
- Status GetVariable(int variable_id, Variable** variable);
-
- const std::unordered_map<int, Variable>& variables() { return variables_; }
+ const std::vector<std::unique_ptr<XlaVariable>>& variables() {
+ return variables_;
+ }
// Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
@@ -182,8 +166,8 @@ class XlaContext : public ResourceBase {
// Does the computation have side effects, i.e., Send() calls?
bool has_side_effects_ = false;
- // Map from variable ID to the current value of each variable.
- std::unordered_map<int, Variable> variables_;
+ // Holds ownership of variables. The variables are not ordered.
+ std::vector<std::unique_ptr<XlaVariable>> variables_;
// Cache of prebuilt computations indexed by their type.
using ComputationMap = std::map<DataType, xla::Computation>;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 4de69ee43c..3272b1efa1 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -38,7 +38,8 @@ xla::ComputationBuilder* XlaOpKernelContext::builder() const {
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
- CHECK(expression->handle().handle() != 0 || expression->variable_id() >= 0);
+ CHECK(expression->handle().handle() != 0 ||
+ expression->variable() != nullptr);
VLOG(1) << "Fetched T" << expression->handle().handle();
return expression;
}
@@ -251,11 +252,8 @@ Status XlaOpKernelContext::ReadVariableInput(
int index, xla::ComputationDataHandle* value) {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
- int variable_id = expression->variable_id();
-
- XlaContext::Variable* variable;
- XlaContext& context = XlaContext::Get(this);
- TF_RETURN_IF_ERROR(context.GetVariable(variable_id, &variable));
+ XlaVariable* variable = expression->variable();
+ TF_RET_CHECK(variable != nullptr);
if (variable->value.handle() == 0) {
return errors::InvalidArgument("Read of uninitialized variable ",
variable->name);
@@ -267,11 +265,8 @@ Status XlaOpKernelContext::ReadVariableInput(
string XlaOpKernelContext::VariableDebugString(int index) {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
- int variable_id = expression->variable_id();
-
- XlaContext::Variable* variable;
- XlaContext& context = XlaContext::Get(this);
- if (!context.GetVariable(variable_id, &variable).ok()) {
+ XlaVariable* variable = expression->variable();
+ if (!variable) {
return "<invalid variable ID>";
}
return variable->name;
@@ -281,11 +276,8 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
TensorShape* shape) const {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
- int variable_id = expression->variable_id();
-
- XlaContext::Variable* variable;
- XlaContext& context = XlaContext::Get(this);
- TF_RETURN_IF_ERROR(context.GetVariable(variable_id, &variable));
+ XlaVariable* variable = expression->variable();
+ TF_RET_CHECK(variable != nullptr);
if (variable->value.handle() == 0) {
return errors::InvalidArgument("Read of uninitialized variable ",
variable->name);
@@ -345,14 +337,22 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
expression->set_constant_value(constant);
}
-void XlaOpKernelContext::SetVariableOutput(int index, int variable_id) {
+void XlaOpKernelContext::SetVariableOutput(int index, XlaVariable* variable) {
Tensor* output = nullptr;
// The shape of the output tensor is the shape of the variable resource
// (i.e., a scalar), not the shape of the variable's value.
OP_REQUIRES_OK(context_,
context_->allocate_output(index, TensorShape(), &output));
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
- expression->set_variable_id(variable_id);
+ expression->set_variable(variable);
+}
+
+Status XlaOpKernelContext::GetVariableInput(int index, XlaVariable** variable) {
+ const XlaExpression* expression =
+ CastExpressionFromTensor(context_->input(index));
+ TF_RET_CHECK(expression->variable() != nullptr);
+ *variable = expression->variable();
+ return Status::OK();
}
Status XlaOpKernelContext::AssignVariable(
@@ -362,9 +362,8 @@ Status XlaOpKernelContext::AssignVariable(
const XlaExpression* expression =
CastExpressionFromTensor(context_->input(index));
- XlaContext& context = XlaContext::Get(this);
- XlaContext::Variable* variable;
- TF_RETURN_IF_ERROR(context.GetVariable(expression->variable_id(), &variable));
+ XlaVariable* variable = expression->variable();
+ TF_RET_CHECK(variable != nullptr);
if (!((variable->type == DT_INVALID && type != DT_INVALID) ||
(variable->type == type))) {
return errors::InvalidArgument(
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 0a8a928418..a25774c3a6 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -157,15 +157,18 @@ class XlaOpKernelContext {
// 'index'.
Status ReadVariableInput(int index, xla::ComputationDataHandle* value);
- // Sets output 'index' to be a reference to variable 'variable_id'. Used
- // to propagate resource variables through the compilation.
- void SetVariableOutput(int index, int variable_id);
-
// Assigns the value `handle` to the variable referenced by input
// `variable_index`. Marks the operator as having side effects.
Status AssignVariable(int variable_index, DataType type,
const xla::ComputationDataHandle& handle);
+ // Sets '*variable' to the variable associated with input `index`.
+ Status GetVariableInput(int index, XlaVariable** variable);
+
+ // Sets output 'index' to be a reference to variable 'variable'. Used
+ // to propagate resource variables through the compilation.
+ void SetVariableOutput(int index, XlaVariable* variable);
+
// Returns a human-readable debug string describing 'variable_index'.
string VariableDebugString(int variable_index);
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index f007581e8d..97d0800d12 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -1221,7 +1221,7 @@ of the forward TensorArray is known when this operation is called.
TensorArray gradient calls use an accumulator TensorArray object. If
multiple gradients are calculated and run in the same session, the multiple
-gradient nodes may accidentally flow throuth the same accumulator TensorArray.
+gradient nodes may accidentally flow through the same accumulator TensorArray.
This double counts and generally breaks the TensorArray gradient flow.
The solution is to identify which gradient call this particular