aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-02-02 11:31:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-02 11:36:32 -0800
commit22116459b258d5753aa76410ab6f4d3cbc928a5a (patch)
treed2f906ab2dafa09b6d40028a951c68c4be409c85 /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parent224874002f93fec471e401488e23d97d4f36c4fc (diff)
[TF:XLA] Improve/refactor the handling of resource types/shapes.
Previously we used an xla::Shape to track the shape of a resource (Variable, TensorArray, Stack) shape. The xla::Shape described how the resource was represented to XLA, e.g., as a (buffer, size) pair for a Stack resource. Instead, separate the TensorFlow abstract shape representation from the XLA shape representation and track it separately. This leads to simpler and more readable code. PiperOrigin-RevId: 184310694
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc33
1 files changed, 16 insertions, 17 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 9224072a3c..7cf9b796b9 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -62,15 +62,13 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
TF_RET_CHECK(resource->tensor_array_size() >= 0)
<< resource->name() << " size " << resource->tensor_array_size();
- TensorShape ta_shape;
- ta_shape.AddDim(resource->tensor_array_size());
- ta_shape.AppendShape(elem_shape);
if (!resource->initialized()) {
xla::ComputationDataHandle zero =
XlaHelpers::Zero(builder, resource->type());
- TF_RETURN_IF_ERROR(resource->SetValue(
- dtype, builder->Broadcast(zero, ta_shape.dim_sizes())));
+
+ TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
+ TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
} else {
// Checks the elem_shape matches the TensorArray shape.
auto shape_or_status = builder->GetShape(resource->value());
@@ -80,6 +78,10 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
TensorShape shape;
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
+
+ TensorShape ta_shape;
+ ta_shape.AddDim(resource->tensor_array_size());
+ ta_shape.AppendShape(elem_shape);
if (ta_shape != shape) {
return errors::InvalidArgument(
"Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
@@ -114,10 +116,8 @@ Status CheckTensorArrayIsInitialized(const string& op_name,
Status GetTensorArrayShape(const XlaResource* resource,
xla::ComputationBuilder* builder,
TensorShape* shape) {
- TF_RETURN_IF_ERROR(resource->GetShape(builder, shape));
- if (shape->dims() < 1) {
- return errors::InvalidArgument("TensorArray rank must be >= 1");
- }
+ *shape = resource->shape();
+ shape->InsertDim(0, resource->tensor_array_size());
return Status::OK();
}
@@ -160,8 +160,8 @@ class TensorArrayOp : public XlaOpKernel {
// Initializes the TensorArray value if we know the element shape.
// Otherwise, defer initialization to the first write.
xla::ComputationDataHandle value;
+ TensorShape shape;
if (element_shape_.IsFullyDefined()) {
- TensorShape shape;
CHECK(element_shape_.AsTensorShape(&shape));
TensorShape ta_shape;
ta_shape.AddDim(size);
@@ -175,8 +175,8 @@ class TensorArrayOp : public XlaOpKernel {
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
- dtype_, value, &var));
- var->set_tensor_array_size(size);
+ dtype_, shape, value, /*tensor_array_size=*/size,
+ /*tensor_array_gradients=*/{}, &var));
ctx->SetResourceOutput(0, var);
Tensor flow(DT_FLOAT, TensorShape({}));
@@ -230,7 +230,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
xla::ComputationDataHandle written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
- OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written));
+ OP_REQUIRES_OK(ctx, resource->SetValue(written));
ctx->SetOutput(0, flow);
}
@@ -421,7 +421,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
}
- OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta));
+ OP_REQUIRES_OK(ctx, resource->SetValue(ta));
ctx->SetOutput(0, flow);
}
@@ -525,9 +525,8 @@ class TensorArraySplitOp : public XlaOpKernel {
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
- OP_REQUIRES_OK(
- ctx, resource->SetValue(
- dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()))));
+ OP_REQUIRES_OK(ctx, resource->SetValue(b->Add(
+ ta, b->Reshape(value, ta_shape.dim_sizes()))));
ctx->SetOutput(0, flow);
}