diff options
author | Peter Hawkins <phawkins@google.com> | 2017-09-19 19:08:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-19 19:12:43 -0700 |
commit | 1f20a786d69c4b91a4015fe3f4df8c23bd345f40 (patch) | |
tree | 9175a24a490a21587cb899b4dda98f11f83f948c /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | |
parent | 5ce3523bcc844217b47e7f862c1bed894cbaa34e (diff) |
[TF:XLA] Add support for reading and writing TensorArray gradients in a while loop.
Previously, there was no code to handle propagating the values of a TensorArray's gradients into and out of loops. This change passes TensorArray gradients into and out of loops by packing them up as a (base array, gradient values...) tuple.
PiperOrigin-RevId: 169338418
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 24 |
1 files changed, 5 insertions, 19 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 7f1597e9ad..c42d8b97ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -114,12 +115,7 @@ Status CheckTensorArrayIsInitialized(const string& op_name, Status GetTensorArrayShape(const XlaResource* resource, xla::ComputationBuilder* builder, TensorShape* shape) { - auto shape_or_status = builder->GetShape(resource->value); - if (!shape_or_status.ok()) { - return shape_or_status.status(); - } - TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape)); + TF_RETURN_IF_ERROR(resource->GetShape(builder, shape)); if (shape->dims() < 1) { return errors::InvalidArgument("TensorArray rank must be >= 1"); } @@ -532,19 +528,9 @@ class TensorArrayGradOp : public XlaOpKernel { // Finds or looks up the corresponding gradient TensorArray, which stores // gradients computed during backpropagation. - XlaResource*& gradient = resource->tensor_array_gradient[source_]; - if (!gradient) { - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type); - xla::ComputationDataHandle value = - b->Broadcast(zero, ta_shape.dim_sizes()); - - XlaContext& xc = XlaContext::Get(ctx); - string name = strings::StrCat("TensorArrayGrad: ", resource->name); - OP_REQUIRES_OK( - ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), - resource->type, value, &gradient)); - gradient->tensor_array_size = resource->tensor_array_size; - } + XlaResource* gradient; + OP_REQUIRES_OK( + ctx, resource->GetOrCreateTensorArrayGradient(source_, b, &gradient)); ctx->SetResourceOutput(0, gradient); ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); |