aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-09-19 19:08:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-19 19:12:43 -0700
commit1f20a786d69c4b91a4015fe3f4df8c23bd345f40 (patch)
tree9175a24a490a21587cb899b4dda98f11f83f948c /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parent5ce3523bcc844217b47e7f862c1bed894cbaa34e (diff)
[TF:XLA] Add support for reading and writing TensorArray gradients in a while loop.
Previously, there was no code to handle propagating the values of a TensorArray's gradients into and out of loops. This change passes TensorArray gradients into and out of loops by packing them up as a (base array, gradient values...) tuple. PiperOrigin-RevId: 169338418
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc24
1 files changed, 5 insertions, 19 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 7f1597e9ad..c42d8b97ea 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -114,12 +115,7 @@ Status CheckTensorArrayIsInitialized(const string& op_name,
Status GetTensorArrayShape(const XlaResource* resource,
xla::ComputationBuilder* builder,
TensorShape* shape) {
- auto shape_or_status = builder->GetShape(resource->value);
- if (!shape_or_status.ok()) {
- return shape_or_status.status();
- }
- TF_RETURN_IF_ERROR(
- XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape));
+ TF_RETURN_IF_ERROR(resource->GetShape(builder, shape));
if (shape->dims() < 1) {
return errors::InvalidArgument("TensorArray rank must be >= 1");
}
@@ -532,19 +528,9 @@ class TensorArrayGradOp : public XlaOpKernel {
// Finds or looks up the corresponding gradient TensorArray, which stores
// gradients computed during backpropagation.
- XlaResource*& gradient = resource->tensor_array_gradient[source_];
- if (!gradient) {
- xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type);
- xla::ComputationDataHandle value =
- b->Broadcast(zero, ta_shape.dim_sizes());
-
- XlaContext& xc = XlaContext::Get(ctx);
- string name = strings::StrCat("TensorArrayGrad: ", resource->name);
- OP_REQUIRES_OK(
- ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
- resource->type, value, &gradient));
- gradient->tensor_array_size = resource->tensor_array_size;
- }
+ XlaResource* gradient;
+ OP_REQUIRES_OK(
+ ctx, resource->GetOrCreateTensorArrayGradient(source_, b, &gradient));
ctx->SetResourceOutput(0, gradient);
ctx->SetConstantOutput(1, Tensor(DT_FLOAT));