aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-06-16 14:37:30 -0800
committerGravatar Martin Wicke <wicke@google.com>2016-06-19 14:08:37 -0700
commit99e52a8ad1fb708b703f1913d1d98994426bf653 (patch)
treed6f6da5745b6bb28385937487a17e6b8997ec36c
parent1bea99abc531bf2ce47c1f6767f4796be7168f02 (diff)
Bugfixes to TensorArray and functional ops:
- TensorArray shape inference now works correctly for scalar elements - TensorArrays each now get a unique name at runtime, per step. This means that they can be used in nested functional ops (e.g. tf.scan(tf.scan(...))) Change: 125110643
-rw-r--r--tensorflow/core/kernels/tensor_array.cc2
-rw-r--r--tensorflow/core/kernels/tensor_array.h2
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc8
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py22
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py22
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py15
6 files changed, 53 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc
index 85f1299038..dc1b14ec36 100644
--- a/tensorflow/core/kernels/tensor_array.cc
+++ b/tensorflow/core/kernels/tensor_array.cc
@@ -75,6 +75,8 @@ TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
} // namespace tensor_array
+std::atomic<int64> TensorArray::tensor_array_counter{0};
+
Status TensorArray::CopyShapesFrom(TensorArray* rhs) {
mutex_lock l(mu_);
mutex_lock l_rhs(*rhs->mu());
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index b993f0dfe2..35f199be91 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -124,6 +124,8 @@ TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
//
class TensorArray : public ResourceBase {
public:
+ static std::atomic<int64> tensor_array_counter;
+
// Construct a TensorArray for holding Tensors of type 'dtype' with
// 'N' elements. While the underlying storage is a std::vector and
// can hold more than MAX_INT entries, in practice we do not expect
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index eafd28639d..95734dbae2 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -147,14 +147,18 @@ class TensorArrayOp : public TensorArrayCreationOp {
const int32 size = tensor_size->scalar<int32>()();
auto handle = tensor_array_output_handle->flat<string>();
+ string unique_tensor_array_name =
+ strings::StrCat(tensor_array_name_, "_",
+ TensorArray::tensor_array_counter.fetch_add(1));
handle(0) = "_tensor_arrays";
- handle(1) = tensor_array_name_;
+ handle(1) = unique_tensor_array_name;
TensorArray* tensor_array = new TensorArray(
dtype_, *tensor_array_output_handle, size, dynamic_size_,
false /* multiple_writes_aggregate */, clear_after_read_);
- TF_RETURN_IF_ERROR(rm->Create(handle(0), tensor_array_name_, tensor_array));
+ TF_RETURN_IF_ERROR(
+ rm->Create(handle(0), unique_tensor_array_name, tensor_array));
*output_tensor_array = tensor_array;
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index d0fc17c059..7a59b07bce 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -180,6 +180,28 @@ class FunctionalOpsTest(tf.test.TestCase):
results = np.array([6, 16, 38, 84, 178, 368])
self.assertAllEqual(results, r.eval())
+ def testScanFoldl_Nested(self):
+ with self.test_session():
+ elems = tf.constant([1.0, 2.0, 3.0, 4.0], name="data")
+ inner_elems = tf.constant([0.5, 0.5], name="data")
+
+ def r_inner(a, x):
+ return tf.foldl(lambda b, y: b * y * x, inner_elems, initializer=a)
+
+ r = tf.scan(r_inner, elems)
+
+ # t == 0 (returns 1)
+ # t == 1, a == 1, x == 2 (returns 1)
+ # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1
+ # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1
+ # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25)
+ # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5
+ # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5
+ # t == 3, a == 2.25, x == 4 (returns 9)
+ # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
+ # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9
+ self.assertAllClose([1., 1., 2.25, 9.], r.eval())
+
def testScan_Control(self):
with self.test_session() as sess:
s = tf.placeholder(tf.float32, shape=[None])
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 25f582eb2c..6038d8993f 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
-from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import tensor_array_grad
@@ -462,7 +461,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
# Assert that if multiple_writes_aggregate is not enabled,
# multiple writes raise an exception.
with self.assertRaisesOpError(
- r"TensorArray foo: Could not write to TensorArray index 2 because "
+ r"TensorArray foo_.*: Could not write to TensorArray index 2 because "
r"it has already been written to."):
w1.flow.eval()
@@ -495,7 +494,7 @@ class TensorArrayCPUTest(tf.test.TestCase):
r = r1 + r2
self.assertAllClose(9.0, r.eval())
- def testDuplicateTensorArrayFails(self):
+ def testDuplicateTensorArrayHasDifferentName(self):
with self.test_session(use_gpu=self._use_gpu) as session:
h1 = tensor_array_ops.TensorArray(
size=1, dtype=tf.float32, tensor_array_name="foo")
@@ -503,8 +502,14 @@ class TensorArrayCPUTest(tf.test.TestCase):
h2 = tensor_array_ops.TensorArray(
size=1, dtype=tf.float32, tensor_array_name="foo")
c2 = h2.write(0, 5.0)
- with self.assertRaises(errors.AlreadyExistsError):
- session.run([c1.flow, c2.flow])
+ _, _, c1h, c2h = session.run([c1.flow, c2.flow, c1.handle, c2.handle])
+ c1h = [x.decode("ascii") for x in c1h]
+ c2h = [x.decode("ascii") for x in c2h]
+ self.assertEqual(c1h[0], "_tensor_arrays")
+ self.assertEqual(c2h[0], "_tensor_arrays")
+ self.assertTrue(c1h[1].startswith("foo_"))
+ self.assertTrue(c2h[1].startswith("foo_"))
+ self.assertNotEqual(c1h[1], c2h[1])
def _testTensorArrayGradientWriteReadType(self, dtype):
with self.test_session(use_gpu=self._use_gpu) as session:
@@ -692,13 +697,6 @@ class TensorArrayCPUTest(tf.test.TestCase):
w1 = w0.write(1, [3.0])
w1.close().run() # Expected to run without problems
- ta = tensor_array_ops.TensorArray(
- dtype=tf.float32, tensor_array_name="foo", size=3)
- with self.assertRaisesOpError(
- r"TensorArray foo has already been closed."):
- with tf.control_dependencies([w1.close()]):
- w1.write(2, 3.0).flow.eval()
-
def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
np_dtype = dtype.as_numpy_dtype
with self.test_session(use_gpu=self._use_gpu) as session:
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index 08f71bd8e5..e3a6b3db7c 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -64,6 +64,13 @@ class TensorArray(object):
flow=None, infer_shape=True, name=None):
"""Construct a new TensorArray or wrap an existing TensorArray handle.
+ A note about the parameter `name`:
+
+ The name of the `TensorArray` (even if passed in) is uniquified: each time
+ a new `TensorArray` is created at runtime it is assigned its own name for
+ the duration of the run. This avoids name collissions if a `TensorArray`
+ is created within a `while_loop`.
+
Args:
dtype: (required) data type of the TensorArray.
size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
@@ -235,7 +242,7 @@ class TensorArray(object):
value = gen_data_flow_ops._tensor_array_pack(
handle=self._handle, flow_in=self._flow, dtype=self._dtype,
name=name)
- if self._elem_shape and self._elem_shape[0].dims:
+ if self._elem_shape and self._elem_shape[0].dims is not None:
value.set_shape([None] + self._elem_shape[0].dims)
return value
@@ -255,7 +262,7 @@ class TensorArray(object):
value, _ = gen_data_flow_ops._tensor_array_concat(
handle=self._handle, flow_in=self._flow, dtype=self._dtype,
name=name)
- if self._elem_shape and self._elem_shape[0].dims:
+ if self._elem_shape and self._elem_shape[0].dims is not None:
value.set_shape([None] + self._elem_shape[0].dims[1:])
return value
@@ -284,7 +291,7 @@ class TensorArray(object):
if ta._infer_shape:
val_shape = flow_out.op.inputs[1].get_shape()
elem_shape = tensor_shape.unknown_shape()
- if val_shape.dims:
+ if val_shape.dims is not None:
elem_shape = tensor_shape.TensorShape(val_shape.dims[1:])
if ta._elem_shape:
if not elem_shape == ta._elem_shape[0]:
@@ -326,7 +333,7 @@ class TensorArray(object):
val_shape = flow_out.op.inputs[1].get_shape()
clengths = tensor_util.constant_value(flow_out.op.inputs[2])
elem_shape = tensor_shape.unknown_shape()
- if val_shape.dims:
+ if val_shape.dims is not None:
if clengths is not None and clengths.max() == clengths.min():
elem_shape = tensor_shape.TensorShape(
[clengths[0]] + val_shape.dims[1:])