diff options
author | Peter Hawkins <phawkins@google.com> | 2017-07-25 09:46:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-25 09:52:05 -0700 |
commit | 78cec04df0f714741f930ff3f234268102b71065 (patch) | |
tree | 5e571bde9ed59c97d7d52a07fe0f42e84bf8c3b6 | |
parent | 3169f504faae8ede8443525a073567a512095c4f (diff) |
[TF:XLA] Make the shape of a TensorArray flow value a scalar.
Previously we used an f32[0] value, since the exact flow value does not matter, however this causes problems when a TensorArray computation is placed in a loop since the shape of the flow value is no longer loop invariant.
PiperOrigin-RevId: 163082452
-rw-r--r-- | tensorflow/compiler/tests/tensor_array_ops_test.py | 4 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 14 |
2 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index f277314352..ac039e0162 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -57,11 +57,13 @@ class TensorArrayTest(xla_test.XLATestCase): r0 = w2.read(0) r1 = w2.read(1) r2 = w2.read(2) + flow = w2.flow - d0, d1, d2 = session.run([r0, r1, r2]) + d0, d1, d2, flow_val = session.run([r0, r1, r2, flow]) self.assertAllEqual([[4.0, 5.0]], d0) self.assertAllEqual([[1.0, 3.0]], d1) self.assertAllEqual([[7.0, -8.5]], d2) + self.assertAllEqual([], flow_val.shape) def _testTensorArrayWritePack(self, tf_dtype): with self.test_session(), self.test_scope(): diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index bdd52b7f8e..34cc8b2315 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -182,7 +182,10 @@ class TensorArrayOp : public XlaOpKernel { dtype_, value, &var)); var->tensor_array_size = size; ctx->SetResourceOutput(0, var); - ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); + + Tensor flow(DT_FLOAT, TensorShape({})); + flow.scalar<float>()() = 0.0f; + ctx->SetConstantOutput(1, flow); } private: @@ -216,6 +219,7 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle index = ctx->Input(1); xla::ComputationDataHandle value = ctx->Input(2); + xla::ComputationDataHandle flow = ctx->Input(3); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = XlaHelpers::PadWithZeros(b, index, elem_shape.dims()); @@ -228,7 +232,7 @@ class TensorArrayWriteOp : public XlaOpKernel { DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); resource->value = written; - ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); + ctx->SetOutput(0, flow); } private: @@ -369,6 +373,7 @@ class TensorArrayScatterOp : public XlaOpKernel { xla::ComputationDataHandle ta = resource->value; const xla::ComputationDataHandle value = ctx->Input(2); + const xla::ComputationDataHandle flow = ctx->Input(3); auto slice_dims = value_shape.dim_sizes(); slice_dims[0] = 1LL; @@ -394,7 +399,7 @@ class TensorArrayScatterOp : public XlaOpKernel { } resource->value = ta; - ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); + ctx->SetOutput(0, flow); } private: @@ -489,6 +494,7 @@ class TensorArraySplitOp : public XlaOpKernel { lengths.size(), " vs. ", resource->tensor_array_size, ")")); const xla::ComputationDataHandle value = ctx->Input(1); + const xla::ComputationDataHandle flow = ctx->Input(3); OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(), errors::InvalidArgument("mismatched element count ", @@ -497,7 +503,7 @@ class TensorArraySplitOp : public XlaOpKernel { resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); - ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); + ctx->SetOutput(0, flow); } private: |