aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-07-25 09:46:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-25 09:52:05 -0700
commit78cec04df0f714741f930ff3f234268102b71065 (patch)
tree5e571bde9ed59c97d7d52a07fe0f42e84bf8c3b6
parent3169f504faae8ede8443525a073567a512095c4f (diff)
[TF:XLA] Make the shape of a TensorArray flow value a scalar.
Previously we used an f32[0] value, since the exact flow value does not matter, however this causes problems when a TensorArray computation is placed in a loop since the shape of the flow value is no longer loop invariant. PiperOrigin-RevId: 163082452
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc14
2 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index f277314352..ac039e0162 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -57,11 +57,13 @@ class TensorArrayTest(xla_test.XLATestCase):
r0 = w2.read(0)
r1 = w2.read(1)
r2 = w2.read(2)
+ flow = w2.flow
- d0, d1, d2 = session.run([r0, r1, r2])
+ d0, d1, d2, flow_val = session.run([r0, r1, r2, flow])
self.assertAllEqual([[4.0, 5.0]], d0)
self.assertAllEqual([[1.0, 3.0]], d1)
self.assertAllEqual([[7.0, -8.5]], d2)
+ self.assertAllEqual([], flow_val.shape)
def _testTensorArrayWritePack(self, tf_dtype):
with self.test_session(), self.test_scope():
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index bdd52b7f8e..34cc8b2315 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -182,7 +182,10 @@ class TensorArrayOp : public XlaOpKernel {
dtype_, value, &var));
var->tensor_array_size = size;
ctx->SetResourceOutput(0, var);
- ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
+
+ Tensor flow(DT_FLOAT, TensorShape({}));
+ flow.scalar<float>()() = 0.0f;
+ ctx->SetConstantOutput(1, flow);
}
private:
@@ -216,6 +219,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
xla::ComputationDataHandle ta = resource->value;
xla::ComputationDataHandle index = ctx->Input(1);
xla::ComputationDataHandle value = ctx->Input(2);
+ xla::ComputationDataHandle flow = ctx->Input(3);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices = XlaHelpers::PadWithZeros(b, index, elem_shape.dims());
@@ -228,7 +232,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
resource->value = written;
- ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ ctx->SetOutput(0, flow);
}
private:
@@ -369,6 +373,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
xla::ComputationDataHandle ta = resource->value;
const xla::ComputationDataHandle value = ctx->Input(2);
+ const xla::ComputationDataHandle flow = ctx->Input(3);
auto slice_dims = value_shape.dim_sizes();
slice_dims[0] = 1LL;
@@ -394,7 +399,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
resource->value = ta;
- ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ ctx->SetOutput(0, flow);
}
private:
@@ -489,6 +494,7 @@ class TensorArraySplitOp : public XlaOpKernel {
lengths.size(), " vs. ", resource->tensor_array_size, ")"));
const xla::ComputationDataHandle value = ctx->Input(1);
+ const xla::ComputationDataHandle flow = ctx->Input(3);
OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(),
errors::InvalidArgument("mismatched element count ",
@@ -497,7 +503,7 @@ class TensorArraySplitOp : public XlaOpKernel {
resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
- ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ ctx->SetOutput(0, flow);
}
private: