aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-06-16 10:32:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-16 10:36:33 -0700
commita66de1eca225bc95e7972974a7089d84df8a8055 (patch)
tree7a332e8fd1f1656a0ef0caacbfff0851ab13967f /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parentd06121593035687176d8b660b83bab568853deff (diff)
[TF:XLA] Refactor handling of Resources (Variables and TensorArrays) in the XLA bridge.
* Rename "Variable" to "Resource" in many places where non-Variable resources might be used. * Add kTensorArray to the XlaCompiler::Argument enum. Remove kUninitializedVariable and make "initialized" a separate boolean field. * Add a kind field to XlaResource. Add checks that Variables are not used where TensorArrays are expected, and vice-versa. * Clean ups to the TensorArray operators. PiperOrigin-RevId: 159244478
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc205
1 files changed, 122 insertions, 83 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index c7510bf3d2..598b341002 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -41,32 +41,36 @@ namespace {
// Since the element shape is not always provided to the TensorArrayV3 operator,
// we must support lazily initialization of the TensorArray at the time of the
// first write.
-// If a TensorArray `var` has not been initialized, constructs storage for the
-// TensorArray with elements of `elem_shape`. For both initialized and
+// If a TensorArray `resource` has not been initialized, constructs storage for
+// the TensorArray with elements of `elem_shape`. For both initialized and
// uninitialized TensorArrays, checks that the tensor has a type compatible with
// 'dtype' and shape compatible with 'elem_shape'.
Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
- XlaVariable* var, DataType dtype,
+ XlaResource* resource, DataType dtype,
const TensorShape& elem_shape) {
- if (var->type != dtype) {
+ if (resource->kind != XlaResource::kTensorArray) {
+ return errors::InvalidArgument("Unexpected non-TensorArray resource");
+ }
+
+ if (resource->type != dtype) {
return errors::InvalidArgument(
- "TensorArray dtype is ", DataTypeString(var->type),
+ "TensorArray dtype is ", DataTypeString(resource->type),
" but op has dtype ", DataTypeString(dtype), ".");
}
- TF_RET_CHECK(var->tensor_array_size >= 0)
- << var->name << " size " << var->tensor_array_size;
+ TF_RET_CHECK(resource->tensor_array_size >= 0)
+ << resource->name << " size " << resource->tensor_array_size;
TensorShape ta_shape;
- ta_shape.AddDim(var->tensor_array_size);
+ ta_shape.AddDim(resource->tensor_array_size);
ta_shape.AppendShape(elem_shape);
- if (var->value.handle() == 0) {
+ if (resource->value.handle() == 0) {
// TensorArray has not been initialized.
- xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type);
- var->value = builder->Broadcast(zero, ta_shape.dim_sizes());
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type);
+ resource->value = builder->Broadcast(zero, ta_shape.dim_sizes());
} else {
// Checks the elem_shape matches the TensorArray shape.
- auto shape_or_status = builder->GetShape(var->value);
+ auto shape_or_status = builder->GetShape(resource->value);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
@@ -80,6 +84,44 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
return Status::OK();
}
+// Checks that the TensorArray 'resource' has been initialized, and has type
+// 'dtype'. Sets 'shape' to the shape
+Status CheckTensorArrayIsInitialized(const string& op_name,
+ const XlaResource* resource,
+ DataType dtype) {
+ if (resource->kind != XlaResource::kTensorArray) {
+ return errors::InvalidArgument(
+ "Unexpected non-TensorArray resource passed "
+ "to ",
+ op_name);
+ }
+ if (resource->value.handle() == 0) {
+ return errors::InvalidArgument("Uninitialized TensorArray passed to ",
+ op_name);
+ }
+ if (resource->type != dtype) {
+ return errors::InvalidArgument(
+ "TensorArray dtype is ", DataTypeString(resource->type),
+ " but op has dtype ", DataTypeString(dtype), ".");
+ }
+
+ return Status::OK();
+}
+
+Status GetTensorArrayShape(const XlaResource* resource,
+ xla::ComputationBuilder* builder,
+ TensorShape* shape) {
+ auto shape_or_status = builder->GetShape(resource->value);
+ if (!shape_or_status.ok()) {
+ return shape_or_status.status();
+ }
+ *shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie());
+ if (shape->dims() < 1) {
+ return errors::InvalidArgument("TensorArray rank must be >= 1");
+ }
+ return Status::OK();
+}
+
// Pads 'x' with 'count' zero indices. 'x' must have 1 element.
xla::ComputationDataHandle PadIndexWithZeros(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
@@ -125,7 +167,6 @@ class TensorArrayOp : public XlaOpKernel {
errors::InvalidArgument("TensorArray size must be >= 0"));
xla::ComputationBuilder* b = ctx->builder();
- b->set_die_immediately_on_error(true);
// Initializes the TensorArray value if we know the element shape.
// Otherwise, defer initialization to the first write.
@@ -141,12 +182,13 @@ class TensorArrayOp : public XlaOpKernel {
}
XlaContext& xc = XlaContext::Get(ctx);
- XlaVariable* var;
+ XlaResource* var;
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
- OP_REQUIRES_OK(ctx,
- xc.CreateVariable(-1, std::move(name), dtype_, value, &var));
+ OP_REQUIRES_OK(
+ ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
+ dtype_, value, &var));
var->tensor_array_size = size;
- ctx->SetVariableOutput(0, var);
+ ctx->SetResourceOutput(0, var);
ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
}
@@ -173,11 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel {
// Initializes the TensorArray, if the element shape was not known at
// construction time.
- XlaVariable* var;
- OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+ OP_REQUIRES_OK(ctx,
+ MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
- xla::ComputationDataHandle ta = var->value;
+ xla::ComputationDataHandle ta = resource->value;
xla::ComputationDataHandle index = ctx->Input(1);
xla::ComputationDataHandle value = ctx->Input(2);
@@ -191,7 +234,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
xla::ComputationDataHandle written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written));
+ resource->value = written;
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
}
@@ -210,20 +253,17 @@ class TensorArrayReadOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- DataType ta_type;
- TensorShape ta_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
- OP_REQUIRES(ctx, ta_type == dtype_,
- errors::InvalidArgument(
- "TensorArray dtype is ", DataTypeString(ta_type),
- " but Op requested dtype ", DataTypeString(dtype_), "."));
- OP_REQUIRES(ctx, ta_shape.dims() >= 1,
- errors::InvalidArgument("TensorArray rank must be >= 1"));
-
xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle ta;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+
+ OP_REQUIRES_OK(ctx,
+ CheckTensorArrayIsInitialized(name(), resource, dtype_));
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
+
+ xla::ComputationDataHandle ta = resource->value;
xla::ComputationDataHandle index = ctx->Input(1);
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
@@ -255,13 +295,15 @@ class TensorArrayGatherOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- DataType ta_type;
+ xla::ComputationBuilder* b = ctx->builder();
+
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+
+ OP_REQUIRES_OK(ctx,
+ CheckTensorArrayIsInitialized(name(), resource, dtype_));
TensorShape ta_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
- OP_REQUIRES(ctx, ta_type == dtype_,
- errors::InvalidArgument("TensorArray type mismatch"));
- OP_REQUIRES(ctx, ta_shape.dims() >= 1,
- errors::InvalidArgument("TensorArray rank must be >= 1"));
+ OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
const TensorShape indices_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
@@ -269,10 +311,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
const int num_indices = indices_shape.dim_size(0);
auto indices = ctx->Input(1);
- xla::ComputationBuilder* b = ctx->builder();
-
- xla::ComputationDataHandle ta;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+ xla::ComputationDataHandle ta = resource->value;
// For each index in `indices`, add the corresponding slice to `slices`.
std::vector<xla::ComputationDataHandle> slices(num_indices);
@@ -320,11 +359,12 @@ class TensorArrayScatterOp : public XlaOpKernel {
const TensorShape value_shape = ctx->InputShape(2);
- XlaVariable* var;
- OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
TensorShape elem_shape = value_shape;
elem_shape.RemoveDim(0);
- OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+ OP_REQUIRES_OK(ctx,
+ MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
const TensorShape indices_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
@@ -332,7 +372,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
const int num_indices = indices_shape.dim_size(0);
const xla::ComputationDataHandle indices = ctx->Input(1);
- xla::ComputationDataHandle ta = var->value;
+ xla::ComputationDataHandle ta = resource->value;
const xla::ComputationDataHandle value = ctx->Input(2);
auto slice_dims = value_shape.dim_sizes();
@@ -355,7 +395,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
+ resource->value = ta;
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
}
@@ -374,18 +414,17 @@ class TensorArrayConcatOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- DataType ta_type;
- TensorShape ta_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
- OP_REQUIRES(ctx, ta_type == dtype_,
- errors::InvalidArgument("TensorArray type mismatch"));
- OP_REQUIRES(ctx, ta_shape.dims() >= 1,
- errors::InvalidArgument("TensorArray rank must be >= 1"));
-
xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle ta;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+
+ OP_REQUIRES_OK(ctx,
+ CheckTensorArrayIsInitialized(name(), resource, dtype_));
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
+
+ xla::ComputationDataHandle ta = resource->value;
auto ta_dims = ta_shape.dim_sizes();
std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
@@ -436,19 +475,20 @@ class TensorArraySplitOp : public XlaOpKernel {
elem_shape.set_dim(0, length);
xla::ComputationBuilder* b = ctx->builder();
- XlaVariable* var;
- OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
- xla::ComputationDataHandle ta = var->value;
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+ OP_REQUIRES_OK(ctx,
+ MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
+ xla::ComputationDataHandle ta = resource->value;
TensorShape ta_shape;
- ta_shape.AddDim(var->tensor_array_size);
+ ta_shape.AddDim(resource->tensor_array_size);
ta_shape.AppendShape(elem_shape);
- OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size,
+ OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size,
errors::InvalidArgument(
"TensorArray's size is not equal to the size of lengths (",
- lengths.size(), " vs. ", var->tensor_array_size, ")"));
+ lengths.size(), " vs. ", resource->tensor_array_size, ")"));
const xla::ComputationDataHandle value = ctx->Input(1);
@@ -457,8 +497,7 @@ class TensorArraySplitOp : public XlaOpKernel {
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
- ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
+ resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
}
@@ -476,8 +515,8 @@ class TensorArraySizeOp : public XlaOpKernel {
explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- XlaVariable* var;
- OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ XlaResource* var;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var));
Tensor size_tensor(DT_INT32, {});
size_tensor.scalar<int32>()() = static_cast<int32>(var->tensor_array_size);
ctx->SetConstantOutput(0, size_tensor);
@@ -498,31 +537,31 @@ class TensorArrayGradOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
- XlaVariable* var;
- OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
- DataType ta_type;
+ OP_REQUIRES_OK(
+ ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type));
TensorShape ta_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
- OP_REQUIRES(ctx, ta_shape.dims() >= 1,
- errors::InvalidArgument("TensorArray rank must be >= 1"));
+ OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
// Finds or looks up the corresponding gradient TensorArray, which stores
// gradients computed during backpropagation.
- XlaVariable*& gradient = var->tensor_array_gradient[source_];
+ XlaResource*& gradient = resource->tensor_array_gradient[source_];
if (!gradient) {
- xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type);
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type);
xla::ComputationDataHandle value =
b->Broadcast(zero, ta_shape.dim_sizes());
XlaContext& xc = XlaContext::Get(ctx);
- string name = strings::StrCat("TensorArrayGrad: ", var->name);
- OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type,
- value, &gradient));
- gradient->tensor_array_size = var->tensor_array_size;
+ string name = strings::StrCat("TensorArrayGrad: ", resource->name);
+ OP_REQUIRES_OK(
+ ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
+ resource->type, value, &gradient));
+ gradient->tensor_array_size = resource->tensor_array_size;
}
- ctx->SetVariableOutput(0, gradient);
+ ctx->SetResourceOutput(0, gradient);
ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
}