aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-06-07 14:55:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-07 14:59:59 -0700
commitc19e6cac0413b0b93d5a15f9d4dc7c861aa1c734 (patch)
tree0ab76d177a8a3e9c4b1b92051257b772ce3c211a /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parentb5e8d308655a027e8c163c3fe3bd3445e09e9d23 (diff)
[TF:XLA] Initial implementation of TensorArray ops.
The XLA implementation of TensorArrays is more restrictive than regular TensorArrays: * XLA TensorArrays must have dynamic_size=False. * all elements in an XLA TensorArray must have the same shape. * writes always add their values to any existing values; neither reads nor writes ever issue errors. Out-of-bounds writes currently wrap. Refactor Variable handling in the TF/XLA bridge. Use a XlaVariable* to refer to variables inside compilation rather than a numerical ID. Allow for variables that don't correspond to variables known to the user. Also use XlaVariable to handle TensorArrays. PiperOrigin-RevId: 158322041
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc538
1 files changed, 538 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
new file mode 100644
index 0000000000..de542d55e8
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -0,0 +1,538 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA TensorArray operators.
+
+#include <limits>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+// Since the element shape is not always provided to the TensorArrayV3 operator,
+// we must support lazily initialization of the TensorArray at the time of the
+// first write.
+// If a TensorArray `var` has not been initialized, constructs storage for the
+// TensorArray with elements of `elem_shape`. For both initialized and
+// uninitialized TensorArrays, checks that the tensor has a type compatible with
+// 'dtype' and shape compatible with 'elem_shape'.
+Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
+ XlaVariable* var, DataType dtype,
+ const TensorShape& elem_shape) {
+ if (var->type != dtype) {
+ return errors::InvalidArgument(
+ "TensorArray dtype is ", DataTypeString(var->type),
+ " but op has dtype ", DataTypeString(dtype), ".");
+ }
+
+ TF_RET_CHECK(var->tensor_array_size >= 0)
+ << var->name << " size " << var->tensor_array_size;
+ TensorShape ta_shape;
+ ta_shape.AddDim(var->tensor_array_size);
+ ta_shape.AppendShape(elem_shape);
+
+ if (var->value.handle() == 0) {
+ // TensorArray has not been initialized.
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type);
+ var->value = builder->Broadcast(zero, ta_shape.dim_sizes());
+ } else {
+ // Checks the elem_shape matches the TensorArray shape.
+ auto shape_or_status = builder->GetShape(var->value);
+ if (!shape_or_status.ok()) {
+ return shape_or_status.status();
+ }
+ TensorShape shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie());
+ if (ta_shape != shape) {
+ return errors::InvalidArgument(
+ "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
+ shape.DebugString());
+ }
+ }
+ return Status::OK();
+}
+
+// Pads 'x' with 'count' zero indices. 'x' must have 1 element.
+xla::ComputationDataHandle PadIndexWithZeros(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
+ int count) {
+ xla::ComputationDataHandle zero = builder->ConstantR1<int32>({0});
+ std::vector<xla::ComputationDataHandle> xs(count + 1, zero);
+ xs[0] = builder->Reshape(x, {1});
+ return builder->ConcatInDim(xs, 0);
+}
+
+// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the
+// relevant slice of 'operand'.
+xla::ComputationDataHandle DynamicAddSlice(
+ xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand,
+ const xla::ComputationDataHandle& update,
+ const gtl::ArraySlice<int64>& update_dims,
+ const xla::ComputationDataHandle& start_indices) {
+ xla::ComputationDataHandle current =
+ builder->DynamicSlice(operand, start_indices, update_dims);
+ xla::ComputationDataHandle sum = builder->Add(current, update);
+ return builder->DynamicUpdateSlice(operand, sum, start_indices);
+}
+
+class TensorArrayOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("element_shape", &element_shape_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ bool dynamic_size;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dynamic_size", &dynamic_size));
+ OP_REQUIRES(
+ ctx, !dynamic_size,
+ errors::Unimplemented(
+ "TensorArrays with dynamic size are not supported by XLA."));
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_array_name", &tensor_array_name_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ int64 size;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size));
+ OP_REQUIRES(ctx, size >= 0,
+ errors::InvalidArgument("TensorArray size must be >= 0"));
+
+ xla::ComputationBuilder* b = ctx->builder();
+ b->set_die_immediately_on_error(true);
+
+ // Initializes the TensorArray value if we know the element shape.
+ // Otherwise, defer initialization to the first write.
+ xla::ComputationDataHandle value;
+ if (element_shape_.IsFullyDefined()) {
+ TensorShape shape;
+ CHECK(element_shape_.AsTensorShape(&shape));
+ TensorShape ta_shape;
+ ta_shape.AddDim(size);
+ ta_shape.AppendShape(shape);
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_);
+ value = b->Broadcast(zero, ta_shape.dim_sizes());
+ }
+
+ XlaContext& xc = XlaContext::Get(ctx);
+ XlaVariable* var;
+ string name = strings::StrCat("TensorArray: ", tensor_array_name_);
+ OP_REQUIRES_OK(ctx,
+ xc.CreateVariable(-1, std::move(name), dtype_, value, &var));
+ var->tensor_array_size = size;
+ ctx->SetVariableOutput(0, var);
+ ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
+ }
+
+ private:
+ PartialTensorShape element_shape_;
+ DataType dtype_;
+ string tensor_array_name_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp);
+
+class TensorArrayWriteOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayWriteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+
+ TensorShape elem_shape = ctx->InputShape(2);
+
+ // Initializes the TensorArray, if the element shape was not known at
+ // construction time.
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+
+ xla::ComputationDataHandle ta = var->value;
+ xla::ComputationDataHandle index = ctx->Input(1);
+ xla::ComputationDataHandle value = ctx->Input(2);
+
+ // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
+ auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
+
+ TensorShape slice_shape = elem_shape;
+ slice_shape.InsertDim(0, 1LL);
+ auto update = b->Reshape(value, slice_shape.dim_sizes());
+
+ xla::ComputationDataHandle written =
+ DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written));
+ ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayWriteOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayWriteV3"), TensorArrayWriteOp);
+
+class TensorArrayReadOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayReadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType ta_type;
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
+ OP_REQUIRES(ctx, ta_type == dtype_,
+ errors::InvalidArgument(
+ "TensorArray dtype is ", DataTypeString(ta_type),
+ " but Op requested dtype ", DataTypeString(dtype_), "."));
+ OP_REQUIRES(ctx, ta_shape.dims() >= 1,
+ errors::InvalidArgument("TensorArray rank must be >= 1"));
+
+ xla::ComputationBuilder* b = ctx->builder();
+
+ xla::ComputationDataHandle ta;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+ xla::ComputationDataHandle index = ctx->Input(1);
+
+ // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
+ auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1);
+
+ auto slice_shape = ta_shape.dim_sizes();
+ slice_shape[0] = 1LL;
+
+ xla::ComputationDataHandle read =
+ b->DynamicSlice(ta, start_indices, slice_shape);
+
+ // Remove the leading '1' dimension.
+ std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
+ ctx->SetOutput(0, b->Reshape(read, value_shape));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayReadOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayReadV3"), TensorArrayReadOp);
+
+class TensorArrayGatherOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType ta_type;
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
+ OP_REQUIRES(ctx, ta_type == dtype_,
+ errors::InvalidArgument("TensorArray type mismatch"));
+ OP_REQUIRES(ctx, ta_shape.dims() >= 1,
+ errors::InvalidArgument("TensorArray rank must be >= 1"));
+
+ const TensorShape indices_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, indices_shape.dims() >= 1,
+ errors::InvalidArgument("indices must be rank 1"));
+ const int num_indices = indices_shape.dim_size(0);
+ auto indices = ctx->Input(1);
+
+ xla::ComputationBuilder* b = ctx->builder();
+
+ xla::ComputationDataHandle ta;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+
+ // For each index in `indices`, add the corresponding slice to `slices`.
+ std::vector<xla::ComputationDataHandle> slices(num_indices);
+ for (int i = 0; i < num_indices; ++i) {
+ // Slices the i-th index out of `indices`, and pads it with zeros in the
+ // minor dimensions to form an index into the TensorArray storage.
+ auto index = b->Slice(indices, {i}, {i + 1});
+
+ // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
+ auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1);
+
+ auto slice_shape = ta_shape.dim_sizes();
+ slice_shape[0] = 1LL;
+
+ slices[i] = b->DynamicSlice(ta, start_indices, slice_shape);
+ }
+
+ xla::ComputationDataHandle gather;
+ if (slices.empty()) {
+ auto shape = ta_shape.dim_sizes();
+ shape[0] = 0;
+ gather = b->Broadcast(XlaHelpers::Zero(b, dtype_), shape);
+ } else {
+ gather = b->ConcatInDim(slices, 0);
+ }
+ ctx->SetOutput(0, gather);
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGatherOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayGatherV3"), TensorArrayGatherOp);
+
+class TensorArrayScatterOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayScatterOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+
+ const TensorShape value_shape = ctx->InputShape(2);
+
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ TensorShape elem_shape = value_shape;
+ elem_shape.RemoveDim(0);
+ OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+
+ const TensorShape indices_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, indices_shape.dims() >= 1,
+ errors::InvalidArgument("indices must be rank 1"));
+ const int num_indices = indices_shape.dim_size(0);
+ const xla::ComputationDataHandle indices = ctx->Input(1);
+
+ xla::ComputationDataHandle ta = var->value;
+ const xla::ComputationDataHandle value = ctx->Input(2);
+
+ auto slice_dims = value_shape.dim_sizes();
+ slice_dims[0] = 1LL;
+
+ std::vector<int64> value_starts(value_shape.dims(), 0);
+ auto value_ends = value_shape.dim_sizes();
+
+ // For every (index, value) pair, update the corresponding TensorArray
+ // storage.
+ for (int i = 0; i < num_indices; ++i) {
+ // Slice out part of the value.
+ value_starts[0] = i;
+ value_ends[0] = i + 1;
+ auto slice = b->Slice(value, value_starts, value_ends);
+
+ // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
+ auto index = b->Slice(indices, {i}, {i + 1});
+ auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
+ ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
+ }
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
+ ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayScatterOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayScatterV3"), TensorArrayScatterOp);
+
+class TensorArrayConcatOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType ta_type;
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
+ OP_REQUIRES(ctx, ta_type == dtype_,
+ errors::InvalidArgument("TensorArray type mismatch"));
+ OP_REQUIRES(ctx, ta_shape.dims() >= 1,
+ errors::InvalidArgument("TensorArray rank must be >= 1"));
+
+ xla::ComputationBuilder* b = ctx->builder();
+
+ xla::ComputationDataHandle ta;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+
+ auto ta_dims = ta_shape.dim_sizes();
+ std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
+ shape[0] *= ta_shape.dim_size(0);
+ ctx->SetOutput(0, b->Reshape(ta, shape));
+
+ Tensor lengths(DT_INT64, {ta_dims[0]});
+ auto lengths_vec = lengths.vec<int64>();
+ for (int i = 0; i < ta_dims[0]; ++i) {
+ lengths_vec(i) = ta_dims[1];
+ }
+ ctx->SetConstantOutput(1, lengths);
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayConcatOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayConcatV3"), TensorArrayConcatOp);
+
+class TensorArraySplitOp : public XlaOpKernel {
+ public:
+ explicit TensorArraySplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ std::vector<int64> lengths;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths));
+
+ int64 length = 0;
+ if (!lengths.empty()) {
+ length = lengths[0];
+ for (int i = 1; i < lengths.size(); ++i) {
+ OP_REQUIRES(ctx, lengths[i] == length,
+ errors::InvalidArgument("lengths must be equal: ", length,
+ " vs. ", lengths[i]));
+ }
+ }
+
+ TensorShape value_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, value_shape.dims() >= 1,
+ errors::InvalidArgument("value must have rank >= 1, got ",
+ value_shape.DebugString()));
+ TensorShape elem_shape = value_shape;
+ elem_shape.set_dim(0, length);
+
+ xla::ComputationBuilder* b = ctx->builder();
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+ xla::ComputationDataHandle ta = var->value;
+
+ TensorShape ta_shape;
+ ta_shape.AddDim(var->tensor_array_size);
+ ta_shape.AppendShape(elem_shape);
+
+ OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size,
+ errors::InvalidArgument(
+ "TensorArray's size is not equal to the size of lengths (",
+ lengths.size(), " vs. ", var->tensor_array_size, ")"));
+
+ const xla::ComputationDataHandle value = ctx->Input(1);
+
+ OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(),
+ errors::InvalidArgument("mismatched element count ",
+ value_shape.DebugString(), " vs. ",
+ ta_shape.DebugString()));
+
+ ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
+
+ ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp);
+
+class TensorArraySizeOp : public XlaOpKernel {
+ public:
+ explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ Tensor size_tensor(DT_INT32, {});
+ size_tensor.scalar<int32>()() = static_cast<int32>(var->tensor_array_size);
+ ctx->SetConstantOutput(0, size_tensor);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySizeOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArraySizeV3"), TensorArraySizeOp);
+
+class TensorArrayGradOp : public XlaOpKernel {
+ public:
+ explicit TensorArrayGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("source", &source_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* b = ctx->builder();
+
+ XlaVariable* var;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+
+ DataType ta_type;
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
+ OP_REQUIRES(ctx, ta_shape.dims() >= 1,
+ errors::InvalidArgument("TensorArray rank must be >= 1"));
+
+ // Finds or looks up the corresponding gradient TensorArray, which stores
+ // gradients computed during backpropagation.
+ XlaVariable*& gradient = var->tensor_array_gradient[source_];
+ if (!gradient) {
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type);
+ xla::ComputationDataHandle value =
+ b->Broadcast(zero, ta_shape.dim_sizes());
+
+ XlaContext& xc = XlaContext::Get(ctx);
+ string name = strings::StrCat("TensorArrayGrad: ", var->name);
+ OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type,
+ value, &gradient));
+ gradient->tensor_array_size = var->tensor_array_size;
+ }
+
+ ctx->SetVariableOutput(0, gradient);
+ ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
+ }
+
+ private:
+ string source_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGradOp);
+};
+
+REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp);
+
+} // anonymous namespace
+} // namespace tensorflow