diff options
author | 2016-06-16 14:37:30 -0800 | |
---|---|---|
committer | 2016-06-19 14:08:37 -0700 | |
commit | 99e52a8ad1fb708b703f1913d1d98994426bf653 (patch) | |
tree | d6f6da5745b6bb28385937487a17e6b8997ec36c | |
parent | 1bea99abc531bf2ce47c1f6767f4796be7168f02 (diff) |
Bugfixes to TensorArray and functional ops:
- TensorArray shape inference now works correctly for scalar elements
- TensorArrays each now get a unique name at runtime, per step. This means
that they can be used in nested functional ops (e.g. tf.scan(tf.scan(...)))
Change: 125110643
-rw-r--r-- | tensorflow/core/kernels/tensor_array.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/tensor_array.h | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/tensor_array_ops.cc | 8 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/functional_ops_test.py | 22 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/tensor_array_ops_test.py | 22 | ||||
-rw-r--r-- | tensorflow/python/ops/tensor_array_ops.py | 15 |
6 files changed, 53 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc index 85f1299038..dc1b14ec36 100644 --- a/tensorflow/core/kernels/tensor_array.cc +++ b/tensorflow/core/kernels/tensor_array.cc @@ -75,6 +75,8 @@ TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU); } // namespace tensor_array +std::atomic<int64> TensorArray::tensor_array_counter{0}; + Status TensorArray::CopyShapesFrom(TensorArray* rhs) { mutex_lock l(mu_); mutex_lock l_rhs(*rhs->mu()); diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index b993f0dfe2..35f199be91 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -124,6 +124,8 @@ TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU); // class TensorArray : public ResourceBase { public: + static std::atomic<int64> tensor_array_counter; + // Construct a TensorArray for holding Tensors of type 'dtype' with // 'N' elements. While the underlying storage is a std::vector and // can hold more than MAX_INT entries, in practice we do not expect diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index eafd28639d..95734dbae2 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -147,14 +147,18 @@ class TensorArrayOp : public TensorArrayCreationOp { const int32 size = tensor_size->scalar<int32>()(); auto handle = tensor_array_output_handle->flat<string>(); + string unique_tensor_array_name = + strings::StrCat(tensor_array_name_, "_", + TensorArray::tensor_array_counter.fetch_add(1)); handle(0) = "_tensor_arrays"; - handle(1) = tensor_array_name_; + handle(1) = unique_tensor_array_name; TensorArray* tensor_array = new TensorArray( dtype_, *tensor_array_output_handle, size, dynamic_size_, false /* multiple_writes_aggregate */, clear_after_read_); - TF_RETURN_IF_ERROR(rm->Create(handle(0), tensor_array_name_, tensor_array)); + TF_RETURN_IF_ERROR( + rm->Create(handle(0), unique_tensor_array_name, tensor_array)); *output_tensor_array = tensor_array; diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index d0fc17c059..7a59b07bce 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -180,6 +180,28 @@ class FunctionalOpsTest(tf.test.TestCase): results = np.array([6, 16, 38, 84, 178, 368]) self.assertAllEqual(results, r.eval()) + def testScanFoldl_Nested(self): + with self.test_session(): + elems = tf.constant([1.0, 2.0, 3.0, 4.0], name="data") + inner_elems = tf.constant([0.5, 0.5], name="data") + + def r_inner(a, x): + return tf.foldl(lambda b, y: b * y * x, inner_elems, initializer=a) + + r = tf.scan(r_inner, elems) + + # t == 0 (returns 1) + # t == 1, a == 1, x == 2 (returns 1) + # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1 + # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1 + # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25) + # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5 + # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5 + # t == 3, a == 2.25, x == 4 (returns 9) + # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5 + # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9 + self.assertAllClose([1., 1., 2.25, 9.], r.eval()) + def testScan_Control(self): with self.test_session() as sess: s = tf.placeholder(tf.float32, shape=[None]) diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index 25f582eb2c..6038d8993f 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -22,7 +22,6 @@ from __future__ import print_function import numpy as np import tensorflow as tf -from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import tensor_array_grad @@ -462,7 +461,7 @@ class TensorArrayCPUTest(tf.test.TestCase): # Assert that if multiple_writes_aggregate is not enabled, # multiple writes raise an exception. with self.assertRaisesOpError( - r"TensorArray foo: Could not write to TensorArray index 2 because " + r"TensorArray foo_.*: Could not write to TensorArray index 2 because " r"it has already been written to."): w1.flow.eval() @@ -495,7 +494,7 @@ class TensorArrayCPUTest(tf.test.TestCase): r = r1 + r2 self.assertAllClose(9.0, r.eval()) - def testDuplicateTensorArrayFails(self): + def testDuplicateTensorArrayHasDifferentName(self): with self.test_session(use_gpu=self._use_gpu) as session: h1 = tensor_array_ops.TensorArray( size=1, dtype=tf.float32, tensor_array_name="foo") @@ -503,8 +502,14 @@ class TensorArrayCPUTest(tf.test.TestCase): h2 = tensor_array_ops.TensorArray( size=1, dtype=tf.float32, tensor_array_name="foo") c2 = h2.write(0, 5.0) - with self.assertRaises(errors.AlreadyExistsError): - session.run([c1.flow, c2.flow]) + _, _, c1h, c2h = session.run([c1.flow, c2.flow, c1.handle, c2.handle]) + c1h = [x.decode("ascii") for x in c1h] + c2h = [x.decode("ascii") for x in c2h] + self.assertEqual(c1h[0], "_tensor_arrays") + self.assertEqual(c2h[0], "_tensor_arrays") + self.assertTrue(c1h[1].startswith("foo_")) + self.assertTrue(c2h[1].startswith("foo_")) + self.assertNotEqual(c1h[1], c2h[1]) def _testTensorArrayGradientWriteReadType(self, dtype): with self.test_session(use_gpu=self._use_gpu) as session: @@ -692,13 +697,6 @@ class TensorArrayCPUTest(tf.test.TestCase): w1 = w0.write(1, [3.0]) w1.close().run() # Expected to run without problems - ta = tensor_array_ops.TensorArray( - dtype=tf.float32, tensor_array_name="foo", size=3) - with self.assertRaisesOpError( - r"TensorArray foo has already been closed."): - with tf.control_dependencies([w1.close()]): - w1.write(2, 3.0).flow.eval() - def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): np_dtype = dtype.as_numpy_dtype with self.test_session(use_gpu=self._use_gpu) as session: diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index 08f71bd8e5..e3a6b3db7c 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -64,6 +64,13 @@ class TensorArray(object): flow=None, infer_shape=True, name=None): """Construct a new TensorArray or wrap an existing TensorArray handle. + A note about the parameter `name`: + + The name of the `TensorArray` (even if passed in) is uniquified: each time + a new `TensorArray` is created at runtime it is assigned its own name for + the duration of the run. This avoids name collissions if a `TensorArray` + is created within a `while_loop`. + Args: dtype: (required) data type of the TensorArray. size: (optional) int32 scalar `Tensor`: the size of the TensorArray. @@ -235,7 +242,7 @@ class TensorArray(object): value = gen_data_flow_ops._tensor_array_pack( handle=self._handle, flow_in=self._flow, dtype=self._dtype, name=name) - if self._elem_shape and self._elem_shape[0].dims: + if self._elem_shape and self._elem_shape[0].dims is not None: value.set_shape([None] + self._elem_shape[0].dims) return value @@ -255,7 +262,7 @@ class TensorArray(object): value, _ = gen_data_flow_ops._tensor_array_concat( handle=self._handle, flow_in=self._flow, dtype=self._dtype, name=name) - if self._elem_shape and self._elem_shape[0].dims: + if self._elem_shape and self._elem_shape[0].dims is not None: value.set_shape([None] + self._elem_shape[0].dims[1:]) return value @@ -284,7 +291,7 @@ class TensorArray(object): if ta._infer_shape: val_shape = flow_out.op.inputs[1].get_shape() elem_shape = tensor_shape.unknown_shape() - if val_shape.dims: + if val_shape.dims is not None: elem_shape = tensor_shape.TensorShape(val_shape.dims[1:]) if ta._elem_shape: if not elem_shape == ta._elem_shape[0]: @@ -326,7 +333,7 @@ class TensorArray(object): val_shape = flow_out.op.inputs[1].get_shape() clengths = tensor_util.constant_value(flow_out.op.inputs[2]) elem_shape = tensor_shape.unknown_shape() - if val_shape.dims: + if val_shape.dims is not None: if clengths is not None and clengths.max() == clengths.min(): elem_shape = tensor_shape.TensorShape( [clengths[0]] + val_shape.dims[1:]) |