diff options
author | Peter Hawkins <phawkins@google.com> | 2017-12-20 14:32:34 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-20 14:35:57 -0800 |
commit | bf5326a75412e59985b727b26f5cad01315b6c89 (patch) | |
tree | e9e2a5d7d62a4d19955eab0f5cc8fb2fc563d672 /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | |
parent | 1279bb10b9bd76f15637074c6518a3464916e007 (diff) |
[TF:XLA] Move XlaResource into its own file, and refactor it into a better-abstracted class. No functional changes intended.
PiperOrigin-RevId: 179734920
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 75 |
1 files changed, 39 insertions, 36 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 03c22354a9..8a742ff11c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -50,29 +50,30 @@ namespace { Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { - if (resource->kind != XlaResource::kTensorArray) { + if (resource->kind() != XlaResource::kTensorArray) { return errors::InvalidArgument("Unexpected non-TensorArray resource"); } - if (resource->type != dtype) { + if (resource->type() != dtype) { return errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(resource->type), + "TensorArray dtype is ", DataTypeString(resource->type()), " but op has dtype ", DataTypeString(dtype), "."); } - TF_RET_CHECK(resource->tensor_array_size >= 0) - << resource->name << " size " << resource->tensor_array_size; + TF_RET_CHECK(resource->tensor_array_size() >= 0) + << resource->name() << " size " << resource->tensor_array_size(); TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size()); ta_shape.AppendShape(elem_shape); - if (resource->value.handle() == 0) { - // TensorArray has not been initialized. - xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type); - resource->value = builder->Broadcast(zero, ta_shape.dim_sizes()); + if (!resource->initialized()) { + xla::ComputationDataHandle zero = + XlaHelpers::Zero(builder, resource->type()); + TF_RETURN_IF_ERROR(resource->SetValue( + dtype, builder->Broadcast(zero, ta_shape.dim_sizes()))); } else { // Checks the elem_shape matches the TensorArray shape. - auto shape_or_status = builder->GetShape(resource->value); + auto shape_or_status = builder->GetShape(resource->value()); if (!shape_or_status.ok()) { return shape_or_status.status(); } @@ -93,19 +94,17 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, Status CheckTensorArrayIsInitialized(const string& op_name, const XlaResource* resource, DataType dtype) { - if (resource->kind != XlaResource::kTensorArray) { + if (resource->kind() != XlaResource::kTensorArray) { return errors::InvalidArgument( - "Unexpected non-TensorArray resource passed " - "to ", - op_name); + "Unexpected non-TensorArray resource passed to ", op_name); } - if (resource->value.handle() == 0) { + if (!resource->initialized()) { return errors::InvalidArgument("Uninitialized TensorArray passed to ", op_name); } - if (resource->type != dtype) { + if (resource->type() != dtype) { return errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(resource->type), + "TensorArray dtype is ", DataTypeString(resource->type()), " but op has dtype ", DataTypeString(dtype), "."); } @@ -177,7 +176,7 @@ class TensorArrayOp : public XlaOpKernel { OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), dtype_, value, &var)); - var->tensor_array_size = size; + var->set_tensor_array_size(size); ctx->SetResourceOutput(0, var); Tensor flow(DT_FLOAT, TensorShape({})); @@ -213,7 +212,7 @@ class TensorArrayWriteOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); xla::ComputationDataHandle index = ctx->Input(1); xla::ComputationDataHandle value = ctx->Input(2); xla::ComputationDataHandle flow = ctx->Input(3); @@ -230,7 +229,7 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - resource->value = written; + OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written)); ctx->SetOutput(0, flow); } @@ -259,7 +258,7 @@ class TensorArrayReadOp : public XlaOpKernel { TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); xla::ComputationDataHandle index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. @@ -309,7 +308,7 @@ class TensorArrayGatherOp : public XlaOpKernel { auto indices = ctx->Input(1); DataType index_type = ctx->input_type(1); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); // Look for the case where the gather takes a simple slice from the // tensor array (0, 1, 2, 3, 4, ..., N) @@ -374,7 +373,7 @@ class TensorArrayScatterOp : public XlaOpKernel { const int num_indices = indices_shape.dim_size(0); const xla::ComputationDataHandle indices = ctx->Input(1); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); const xla::ComputationDataHandle value = ctx->Input(2); const xla::ComputationDataHandle flow = ctx->Input(3); @@ -421,7 +420,7 @@ class TensorArrayScatterOp : public XlaOpKernel { } } - resource->value = ta; + OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta)); ctx->SetOutput(0, flow); } @@ -450,7 +449,7 @@ class TensorArrayConcatOp : public XlaOpKernel { TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); auto ta_dims = ta_shape.dim_sizes(); std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end()); @@ -505,16 +504,17 @@ class TensorArraySplitOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size()); ta_shape.AppendShape(elem_shape); - OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size, - errors::InvalidArgument( - "TensorArray's size is not equal to the size of lengths (", - lengths.size(), " vs. ", resource->tensor_array_size, ")")); + OP_REQUIRES( + ctx, lengths.size() == resource->tensor_array_size(), + errors::InvalidArgument( + "TensorArray's size is not equal to the size of lengths (", + lengths.size(), " vs. ", resource->tensor_array_size(), ")")); const xla::ComputationDataHandle value = ctx->Input(1); const xla::ComputationDataHandle flow = ctx->Input(3); @@ -524,7 +524,9 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); + OP_REQUIRES_OK( + ctx, resource->SetValue( + dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())))); ctx->SetOutput(0, flow); } @@ -545,7 +547,8 @@ class TensorArraySizeOp : public XlaOpKernel { XlaResource* var; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); - size_tensor.scalar<int32>()() = static_cast<int32>(var->tensor_array_size); + size_tensor.scalar<int32>()() = + static_cast<int32>(var->tensor_array_size()); ctx->SetConstantOutput(0, size_tensor); } @@ -568,7 +571,7 @@ class TensorArrayGradOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); OP_REQUIRES_OK( - ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type)); + ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type())); TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); |