aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-12-20 14:32:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-20 14:35:57 -0800
commitbf5326a75412e59985b727b26f5cad01315b6c89 (patch)
treee9e2a5d7d62a4d19955eab0f5cc8fb2fc563d672 /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parent1279bb10b9bd76f15637074c6518a3464916e007 (diff)
[TF:XLA] Move XlaResource into its own file, and refactor it into a better-abstracted class. No functional changes intended.
PiperOrigin-RevId: 179734920
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc75
1 files changed, 39 insertions, 36 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 03c22354a9..8a742ff11c 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -21,10 +21,10 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
-#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/tf2xla/xla_resource.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -50,29 +50,30 @@ namespace {
Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
XlaResource* resource, DataType dtype,
const TensorShape& elem_shape) {
- if (resource->kind != XlaResource::kTensorArray) {
+ if (resource->kind() != XlaResource::kTensorArray) {
return errors::InvalidArgument("Unexpected non-TensorArray resource");
}
- if (resource->type != dtype) {
+ if (resource->type() != dtype) {
return errors::InvalidArgument(
- "TensorArray dtype is ", DataTypeString(resource->type),
+ "TensorArray dtype is ", DataTypeString(resource->type()),
" but op has dtype ", DataTypeString(dtype), ".");
}
- TF_RET_CHECK(resource->tensor_array_size >= 0)
- << resource->name << " size " << resource->tensor_array_size;
+ TF_RET_CHECK(resource->tensor_array_size() >= 0)
+ << resource->name() << " size " << resource->tensor_array_size();
TensorShape ta_shape;
- ta_shape.AddDim(resource->tensor_array_size);
+ ta_shape.AddDim(resource->tensor_array_size());
ta_shape.AppendShape(elem_shape);
- if (resource->value.handle() == 0) {
- // TensorArray has not been initialized.
- xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type);
- resource->value = builder->Broadcast(zero, ta_shape.dim_sizes());
+ if (!resource->initialized()) {
+ xla::ComputationDataHandle zero =
+ XlaHelpers::Zero(builder, resource->type());
+ TF_RETURN_IF_ERROR(resource->SetValue(
+ dtype, builder->Broadcast(zero, ta_shape.dim_sizes())));
} else {
// Checks the elem_shape matches the TensorArray shape.
- auto shape_or_status = builder->GetShape(resource->value);
+ auto shape_or_status = builder->GetShape(resource->value());
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
@@ -93,19 +94,17 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
Status CheckTensorArrayIsInitialized(const string& op_name,
const XlaResource* resource,
DataType dtype) {
- if (resource->kind != XlaResource::kTensorArray) {
+ if (resource->kind() != XlaResource::kTensorArray) {
return errors::InvalidArgument(
- "Unexpected non-TensorArray resource passed "
- "to ",
- op_name);
+ "Unexpected non-TensorArray resource passed to ", op_name);
}
- if (resource->value.handle() == 0) {
+ if (!resource->initialized()) {
return errors::InvalidArgument("Uninitialized TensorArray passed to ",
op_name);
}
- if (resource->type != dtype) {
+ if (resource->type() != dtype) {
return errors::InvalidArgument(
- "TensorArray dtype is ", DataTypeString(resource->type),
+ "TensorArray dtype is ", DataTypeString(resource->type()),
" but op has dtype ", DataTypeString(dtype), ".");
}
@@ -177,7 +176,7 @@ class TensorArrayOp : public XlaOpKernel {
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
dtype_, value, &var));
- var->tensor_array_size = size;
+ var->set_tensor_array_size(size);
ctx->SetResourceOutput(0, var);
Tensor flow(DT_FLOAT, TensorShape({}));
@@ -213,7 +212,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
- xla::ComputationDataHandle ta = resource->value;
+ xla::ComputationDataHandle ta = resource->value();
xla::ComputationDataHandle index = ctx->Input(1);
xla::ComputationDataHandle value = ctx->Input(2);
xla::ComputationDataHandle flow = ctx->Input(3);
@@ -230,7 +229,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
xla::ComputationDataHandle written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
- resource->value = written;
+ OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written));
ctx->SetOutput(0, flow);
}
@@ -259,7 +258,7 @@ class TensorArrayReadOp : public XlaOpKernel {
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
- xla::ComputationDataHandle ta = resource->value;
+ xla::ComputationDataHandle ta = resource->value();
xla::ComputationDataHandle index = ctx->Input(1);
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
@@ -309,7 +308,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
auto indices = ctx->Input(1);
DataType index_type = ctx->input_type(1);
- xla::ComputationDataHandle ta = resource->value;
+ xla::ComputationDataHandle ta = resource->value();
// Look for the case where the gather takes a simple slice from the
// tensor array (0, 1, 2, 3, 4, ..., N)
@@ -374,7 +373,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
const int num_indices = indices_shape.dim_size(0);
const xla::ComputationDataHandle indices = ctx->Input(1);
- xla::ComputationDataHandle ta = resource->value;
+ xla::ComputationDataHandle ta = resource->value();
const xla::ComputationDataHandle value = ctx->Input(2);
const xla::ComputationDataHandle flow = ctx->Input(3);
@@ -421,7 +420,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
}
- resource->value = ta;
+ OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta));
ctx->SetOutput(0, flow);
}
@@ -450,7 +449,7 @@ class TensorArrayConcatOp : public XlaOpKernel {
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
- xla::ComputationDataHandle ta = resource->value;
+ xla::ComputationDataHandle ta = resource->value();
auto ta_dims = ta_shape.dim_sizes();
std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
@@ -505,16 +504,17 @@ class TensorArraySplitOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
OP_REQUIRES_OK(ctx,
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
- xla::ComputationDataHandle ta = resource->value;
+ xla::ComputationDataHandle ta = resource->value();
TensorShape ta_shape;
- ta_shape.AddDim(resource->tensor_array_size);
+ ta_shape.AddDim(resource->tensor_array_size());
ta_shape.AppendShape(elem_shape);
- OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size,
- errors::InvalidArgument(
- "TensorArray's size is not equal to the size of lengths (",
- lengths.size(), " vs. ", resource->tensor_array_size, ")"));
+ OP_REQUIRES(
+ ctx, lengths.size() == resource->tensor_array_size(),
+ errors::InvalidArgument(
+ "TensorArray's size is not equal to the size of lengths (",
+ lengths.size(), " vs. ", resource->tensor_array_size(), ")"));
const xla::ComputationDataHandle value = ctx->Input(1);
const xla::ComputationDataHandle flow = ctx->Input(3);
@@ -524,7 +524,9 @@ class TensorArraySplitOp : public XlaOpKernel {
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
- resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
+ OP_REQUIRES_OK(
+ ctx, resource->SetValue(
+ dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()))));
ctx->SetOutput(0, flow);
}
@@ -545,7 +547,8 @@ class TensorArraySizeOp : public XlaOpKernel {
XlaResource* var;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var));
Tensor size_tensor(DT_INT32, {});
- size_tensor.scalar<int32>()() = static_cast<int32>(var->tensor_array_size);
+ size_tensor.scalar<int32>()() =
+ static_cast<int32>(var->tensor_array_size());
ctx->SetConstantOutput(0, size_tensor);
}
@@ -568,7 +571,7 @@ class TensorArrayGradOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
OP_REQUIRES_OK(
- ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type));
+ ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type()));
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));