aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-30 17:41:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 17:43:59 -0700
commit45bafe9a3589fc735c22c3c703f8689ea9c1e71e (patch)
treee39723521a1ca68e9c2c74d1a9d3ac5ef2e8abc4 /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parentc89a1d9605427d74079774af7da37933f9ca153c (diff)
[XLA] Redesign: migrate tensorflow/compiler/tf2xla, tensorflow/compiler/aot:
- xla::ComputationBuilder -> xla::XlaBuilder - xla::ComputationDataHandle -> xla::XlaOp - xla::Computation -> xla::XlaComputation - xla::CompileOnlyClient::AotComputationInstance -> xla::CompileOnlyClient::AotXlaComputationInstance - xla::SessionModule -> xla::HloSnapshot PiperOrigin-RevId: 194874462
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc82
1 files changed, 38 insertions, 44 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 000b50af6b..9adee78a1f 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -47,7 +47,7 @@ namespace {
// the TensorArray with elements of `elem_shape`. For both initialized and
// uninitialized TensorArrays, checks that the tensor has a type compatible with
// 'dtype' and shape compatible with 'elem_shape'.
-Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
+Status MaybeInitializeTensorArray(xla::XlaBuilder* builder,
XlaResource* resource, DataType dtype,
const TensorShape& elem_shape) {
if (resource->kind() != XlaResource::kTensorArray) {
@@ -64,9 +64,6 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
<< resource->name() << " size " << resource->tensor_array_size();
if (!resource->initialized()) {
- xla::ComputationDataHandle zero =
- XlaHelpers::Zero(builder, resource->type());
-
TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
} else {
@@ -77,7 +74,7 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
}
TensorShape shape;
TF_RETURN_IF_ERROR(
- XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
+ XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
TensorShape ta_shape;
ta_shape.AddDim(resource->tensor_array_size());
@@ -114,23 +111,21 @@ Status CheckTensorArrayIsInitialized(const string& op_name,
}
Status GetTensorArrayShape(const XlaResource* resource,
- xla::ComputationBuilder* builder,
- TensorShape* shape) {
+ xla::XlaBuilder* builder, TensorShape* shape) {
*shape = resource->shape();
shape->InsertDim(0, resource->tensor_array_size());
return Status::OK();
}
-// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the
+// Like XlaBuilder::DynamicUpdateSlice, but adds 'update' to the
// relevant slice of 'operand'.
-xla::ComputationDataHandle DynamicAddSlice(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand,
- const xla::ComputationDataHandle& update,
- const gtl::ArraySlice<int64>& update_dims,
- const xla::ComputationDataHandle& start_indices) {
- xla::ComputationDataHandle current =
+xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand,
+ const xla::XlaOp& update,
+ const gtl::ArraySlice<int64>& update_dims,
+ const xla::XlaOp& start_indices) {
+ xla::XlaOp current =
builder->DynamicSlice(operand, start_indices, update_dims);
- xla::ComputationDataHandle sum = builder->Add(current, update);
+ xla::XlaOp sum = builder->Add(current, update);
return builder->DynamicUpdateSlice(operand, sum, start_indices);
}
@@ -155,18 +150,18 @@ class TensorArrayOp : public XlaOpKernel {
OP_REQUIRES(ctx, size >= 0,
errors::InvalidArgument("TensorArray size must be >= 0"));
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
// Initializes the TensorArray value if we know the element shape.
// Otherwise, defer initialization to the first write.
- xla::ComputationDataHandle value;
+ xla::XlaOp value;
TensorShape shape;
if (element_shape_.IsFullyDefined()) {
CHECK(element_shape_.AsTensorShape(&shape));
TensorShape ta_shape;
ta_shape.AddDim(size);
ta_shape.AppendShape(shape);
- xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_);
+ xla::XlaOp zero = XlaHelpers::Zero(b, dtype_);
value = b->Broadcast(zero, ta_shape.dim_sizes());
}
@@ -202,7 +197,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
TensorShape elem_shape = ctx->InputShape(2);
@@ -213,10 +208,10 @@ class TensorArrayWriteOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
- xla::ComputationDataHandle ta = resource->value();
- xla::ComputationDataHandle index = ctx->Input(1);
- xla::ComputationDataHandle value = ctx->Input(2);
- xla::ComputationDataHandle flow = ctx->Input(3);
+ xla::XlaOp ta = resource->value();
+ xla::XlaOp index = ctx->Input(1);
+ xla::XlaOp value = ctx->Input(2);
+ xla::XlaOp flow = ctx->Input(3);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices =
@@ -227,7 +222,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
slice_shape.InsertDim(0, 1LL);
auto update = b->Reshape(value, slice_shape.dim_sizes());
- xla::ComputationDataHandle written =
+ xla::XlaOp written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
OP_REQUIRES_OK(ctx, resource->SetValue(written));
@@ -249,7 +244,7 @@ class TensorArrayReadOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
@@ -259,8 +254,8 @@ class TensorArrayReadOp : public XlaOpKernel {
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
- xla::ComputationDataHandle ta = resource->value();
- xla::ComputationDataHandle index = ctx->Input(1);
+ xla::XlaOp ta = resource->value();
+ xla::XlaOp index = ctx->Input(1);
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =
@@ -270,8 +265,7 @@ class TensorArrayReadOp : public XlaOpKernel {
auto slice_shape = ta_shape.dim_sizes();
slice_shape[0] = 1LL;
- xla::ComputationDataHandle read =
- b->DynamicSlice(ta, start_indices, slice_shape);
+ xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape);
// Remove the leading '1' dimension.
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
@@ -293,7 +287,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
@@ -309,7 +303,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
auto indices = ctx->Input(1);
DataType index_type = ctx->input_type(1);
- xla::ComputationDataHandle ta = resource->value();
+ xla::XlaOp ta = resource->value();
// Look for the case where the gather takes a simple slice from the
// tensor array (0, 1, 2, 3, 4, ..., N)
@@ -337,7 +331,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
}
}
- xla::ComputationDataHandle gather;
+ xla::XlaOp gather;
OP_REQUIRES_OK(
ctx,
XlaGather(ta, ta_shape, indices, indices_shape, /*axis=*/0,
@@ -360,7 +354,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
const TensorShape value_shape = ctx->InputShape(2);
@@ -375,11 +369,11 @@ class TensorArrayScatterOp : public XlaOpKernel {
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
errors::InvalidArgument("indices must be rank 1"));
const int num_indices = indices_shape.dim_size(0);
- const xla::ComputationDataHandle indices = ctx->Input(1);
+ const xla::XlaOp indices = ctx->Input(1);
- xla::ComputationDataHandle ta = resource->value();
- const xla::ComputationDataHandle value = ctx->Input(2);
- const xla::ComputationDataHandle flow = ctx->Input(3);
+ xla::XlaOp ta = resource->value();
+ const xla::XlaOp value = ctx->Input(2);
+ const xla::XlaOp flow = ctx->Input(3);
// Look for the case where the scatter is for each sub-tensor in order. The
// tensor array implementation allows for this to be a straight addition.
@@ -443,7 +437,7 @@ class TensorArrayConcatOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
@@ -453,7 +447,7 @@ class TensorArrayConcatOp : public XlaOpKernel {
TensorShape ta_shape;
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
- xla::ComputationDataHandle ta = resource->value();
+ xla::XlaOp ta = resource->value();
auto ta_dims = ta_shape.dim_sizes();
std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
@@ -503,12 +497,12 @@ class TensorArraySplitOp : public XlaOpKernel {
TensorShape elem_shape = value_shape;
elem_shape.set_dim(0, length);
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
OP_REQUIRES_OK(ctx,
MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
- xla::ComputationDataHandle ta = resource->value();
+ xla::XlaOp ta = resource->value();
TensorShape ta_shape;
ta_shape.AddDim(resource->tensor_array_size());
@@ -520,8 +514,8 @@ class TensorArraySplitOp : public XlaOpKernel {
"TensorArray's size is not equal to the size of lengths (",
lengths.size(), " vs. ", resource->tensor_array_size(), ")"));
- const xla::ComputationDataHandle value = ctx->Input(1);
- const xla::ComputationDataHandle flow = ctx->Input(3);
+ const xla::XlaOp value = ctx->Input(1);
+ const xla::XlaOp flow = ctx->Input(3);
OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(),
errors::InvalidArgument("mismatched element count ",
@@ -569,7 +563,7 @@ class TensorArrayGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationBuilder* b = ctx->builder();
+ xla::XlaBuilder* b = ctx->builder();
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));