diff options
author | Peter Hawkins <phawkins@google.com> | 2017-06-16 10:32:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-16 10:36:33 -0700 |
commit | a66de1eca225bc95e7972974a7089d84df8a8055 (patch) | |
tree | 7a332e8fd1f1656a0ef0caacbfff0851ab13967f /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | |
parent | d06121593035687176d8b660b83bab568853deff (diff) |
[TF:XLA] Refactor handling of Resources (Variables and TensorArrays) in the XLA bridge.
* Rename "Variable" to "Resource" in many places where non-Variable resources might be used.
* Add kTensorArray to the XlaCompiler::Argument enum. Remove kUninitializedVariable and make "initialized" a separate boolean field.
* Add a kind field to XlaResource. Add checks that Variables are not used where TensorArrays are expected, and vice-versa.
* Clean ups to the TensorArray operators.
PiperOrigin-RevId: 159244478
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 205 |
1 files changed, 122 insertions, 83 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index c7510bf3d2..598b341002 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -41,32 +41,36 @@ namespace { // Since the element shape is not always provided to the TensorArrayV3 operator, // we must support lazily initialization of the TensorArray at the time of the // first write. -// If a TensorArray `var` has not been initialized, constructs storage for the -// TensorArray with elements of `elem_shape`. For both initialized and +// If a TensorArray `resource` has not been initialized, constructs storage for +// the TensorArray with elements of `elem_shape`. For both initialized and // uninitialized TensorArrays, checks that the tensor has a type compatible with // 'dtype' and shape compatible with 'elem_shape'. Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, - XlaVariable* var, DataType dtype, + XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { - if (var->type != dtype) { + if (resource->kind != XlaResource::kTensorArray) { + return errors::InvalidArgument("Unexpected non-TensorArray resource"); + } + + if (resource->type != dtype) { return errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(var->type), + "TensorArray dtype is ", DataTypeString(resource->type), " but op has dtype ", DataTypeString(dtype), "."); } - TF_RET_CHECK(var->tensor_array_size >= 0) - << var->name << " size " << var->tensor_array_size; + TF_RET_CHECK(resource->tensor_array_size >= 0) + << resource->name << " size " << resource->tensor_array_size; TensorShape ta_shape; - ta_shape.AddDim(var->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size); ta_shape.AppendShape(elem_shape); - if (var->value.handle() == 0) { + if (resource->value.handle() == 0) { // TensorArray has not been initialized. - xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type); - var->value = builder->Broadcast(zero, ta_shape.dim_sizes()); + xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type); + resource->value = builder->Broadcast(zero, ta_shape.dim_sizes()); } else { // Checks the elem_shape matches the TensorArray shape. - auto shape_or_status = builder->GetShape(var->value); + auto shape_or_status = builder->GetShape(resource->value); if (!shape_or_status.ok()) { return shape_or_status.status(); } @@ -80,6 +84,44 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, return Status::OK(); } +// Checks that the TensorArray 'resource' has been initialized, and has type +// 'dtype'. Sets 'shape' to the shape +Status CheckTensorArrayIsInitialized(const string& op_name, + const XlaResource* resource, + DataType dtype) { + if (resource->kind != XlaResource::kTensorArray) { + return errors::InvalidArgument( + "Unexpected non-TensorArray resource passed " + "to ", + op_name); + } + if (resource->value.handle() == 0) { + return errors::InvalidArgument("Uninitialized TensorArray passed to ", + op_name); + } + if (resource->type != dtype) { + return errors::InvalidArgument( + "TensorArray dtype is ", DataTypeString(resource->type), + " but op has dtype ", DataTypeString(dtype), "."); + } + + return Status::OK(); +} + +Status GetTensorArrayShape(const XlaResource* resource, + xla::ComputationBuilder* builder, + TensorShape* shape) { + auto shape_or_status = builder->GetShape(resource->value); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + *shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie()); + if (shape->dims() < 1) { + return errors::InvalidArgument("TensorArray rank must be >= 1"); + } + return Status::OK(); +} + // Pads 'x' with 'count' zero indices. 'x' must have 1 element. xla::ComputationDataHandle PadIndexWithZeros( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, @@ -125,7 +167,6 @@ class TensorArrayOp : public XlaOpKernel { errors::InvalidArgument("TensorArray size must be >= 0")); xla::ComputationBuilder* b = ctx->builder(); - b->set_die_immediately_on_error(true); // Initializes the TensorArray value if we know the element shape. // Otherwise, defer initialization to the first write. @@ -141,12 +182,13 @@ class TensorArrayOp : public XlaOpKernel { } XlaContext& xc = XlaContext::Get(ctx); - XlaVariable* var; + XlaResource* var; string name = strings::StrCat("TensorArray: ", tensor_array_name_); - OP_REQUIRES_OK(ctx, - xc.CreateVariable(-1, std::move(name), dtype_, value, &var)); + OP_REQUIRES_OK( + ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), + dtype_, value, &var)); var->tensor_array_size = size; - ctx->SetVariableOutput(0, var); + ctx->SetResourceOutput(0, var); ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); } @@ -173,11 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel { // Initializes the TensorArray, if the element shape was not known at // construction time. - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = var->value; + xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle index = ctx->Input(1); xla::ComputationDataHandle value = ctx->Input(2); @@ -191,7 +234,7 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written)); + resource->value = written; ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); } @@ -210,20 +253,17 @@ class TensorArrayReadOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; - TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(ta_type), - " but Op requested dtype ", DataTypeString(dtype_), ".")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); + + xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. @@ -255,13 +295,15 @@ class TensorArrayGatherOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; + xla::ComputationBuilder* b = ctx->builder(); + + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument("TensorArray type mismatch")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); const TensorShape indices_shape = ctx->InputShape(1); OP_REQUIRES(ctx, indices_shape.dims() >= 1, @@ -269,10 +311,7 @@ class TensorArrayGatherOp : public XlaOpKernel { const int num_indices = indices_shape.dim_size(0); auto indices = ctx->Input(1); - xla::ComputationBuilder* b = ctx->builder(); - - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + xla::ComputationDataHandle ta = resource->value; // For each index in `indices`, add the corresponding slice to `slices`. std::vector<xla::ComputationDataHandle> slices(num_indices); @@ -320,11 +359,12 @@ class TensorArrayScatterOp : public XlaOpKernel { const TensorShape value_shape = ctx->InputShape(2); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); TensorShape elem_shape = value_shape; elem_shape.RemoveDim(0); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); const TensorShape indices_shape = ctx->InputShape(1); OP_REQUIRES(ctx, indices_shape.dims() >= 1, @@ -332,7 +372,7 @@ class TensorArrayScatterOp : public XlaOpKernel { const int num_indices = indices_shape.dim_size(0); const xla::ComputationDataHandle indices = ctx->Input(1); - xla::ComputationDataHandle ta = var->value; + xla::ComputationDataHandle ta = resource->value; const xla::ComputationDataHandle value = ctx->Input(2); auto slice_dims = value_shape.dim_sizes(); @@ -355,7 +395,7 @@ class TensorArrayScatterOp : public XlaOpKernel { ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + resource->value = ta; ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); } @@ -374,18 +414,17 @@ class TensorArrayConcatOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; - TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument("TensorArray type mismatch")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); + + xla::ComputationDataHandle ta = resource->value; auto ta_dims = ta_shape.dim_sizes(); std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end()); @@ -436,19 +475,20 @@ class TensorArraySplitOp : public XlaOpKernel { elem_shape.set_dim(0, length); xla::ComputationBuilder* b = ctx->builder(); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); - xla::ComputationDataHandle ta = var->value; + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); + xla::ComputationDataHandle ta = resource->value; TensorShape ta_shape; - ta_shape.AddDim(var->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size); ta_shape.AppendShape(elem_shape); - OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size, + OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size, errors::InvalidArgument( "TensorArray's size is not equal to the size of lengths (", - lengths.size(), " vs. ", var->tensor_array_size, ")")); + lengths.size(), " vs. ", resource->tensor_array_size, ")")); const xla::ComputationDataHandle value = ctx->Input(1); @@ -457,8 +497,7 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); } @@ -476,8 +515,8 @@ class TensorArraySizeOp : public XlaOpKernel { explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* var; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); size_tensor.scalar<int32>()() = static_cast<int32>(var->tensor_array_size); ctx->SetConstantOutput(0, size_tensor); @@ -498,31 +537,31 @@ class TensorArrayGradOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* b = ctx->builder(); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); - DataType ta_type; + OP_REQUIRES_OK( + ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type)); TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); // Finds or looks up the corresponding gradient TensorArray, which stores // gradients computed during backpropagation. - XlaVariable*& gradient = var->tensor_array_gradient[source_]; + XlaResource*& gradient = resource->tensor_array_gradient[source_]; if (!gradient) { - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type); + xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type); xla::ComputationDataHandle value = b->Broadcast(zero, ta_shape.dim_sizes()); XlaContext& xc = XlaContext::Get(ctx); - string name = strings::StrCat("TensorArrayGrad: ", var->name); - OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type, - value, &gradient)); - gradient->tensor_array_size = var->tensor_array_size; + string name = strings::StrCat("TensorArrayGrad: ", resource->name); + OP_REQUIRES_OK( + ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), + resource->type, value, &gradient)); + gradient->tensor_array_size = resource->tensor_array_size; } - ctx->SetVariableOutput(0, gradient); + ctx->SetResourceOutput(0, gradient); ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); } |