aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-12 15:48:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-12 15:54:02 -0700
commit9526b27d7d904acc9e1a7a1990e1320235a8720c (patch)
treea9582fe78a48ef33b4e8610b0e27231b6876c3ab /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parent8f2de3847fdab53926b940fee99d7b95e7dfc6c7 (diff)
[TF:XLA] Implementing ResourceGather in TF2XLA.
PiperOrigin-RevId: 161730154
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc23
1 files changed, 8 insertions, 15 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index d720496f74..bdd52b7f8e 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -125,16 +125,6 @@ Status GetTensorArrayShape(const XlaResource* resource,
return Status::OK();
}
-// Pads 'x' with 'count' zero indices. 'x' must have 1 element.
-xla::ComputationDataHandle PadIndexWithZeros(
- xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
- int count) {
- xla::ComputationDataHandle zero = builder->ConstantR1<int32>({0});
- std::vector<xla::ComputationDataHandle> xs(count + 1, zero);
- xs[0] = builder->Reshape(x, {1});
- return builder->ConcatInDim(xs, 0);
-}
-
// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the
// relevant slice of 'operand'.
xla::ComputationDataHandle DynamicAddSlice(
@@ -228,7 +218,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
xla::ComputationDataHandle value = ctx->Input(2);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
- auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
+ auto start_indices = XlaHelpers::PadWithZeros(b, index, elem_shape.dims());
TensorShape slice_shape = elem_shape;
slice_shape.InsertDim(0, 1LL);
@@ -270,7 +260,8 @@ class TensorArrayReadOp : public XlaOpKernel {
xla::ComputationDataHandle index = ctx->Input(1);
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
- auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1);
+ auto start_indices =
+ XlaHelpers::PadWithZeros(b, index, ta_shape.dims() - 1);
auto slice_shape = ta_shape.dim_sizes();
slice_shape[0] = 1LL;
@@ -309,7 +300,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
const TensorShape indices_shape = ctx->InputShape(1);
- OP_REQUIRES(ctx, indices_shape.dims() >= 1,
+ OP_REQUIRES(ctx, indices_shape.dims() == 1,
errors::InvalidArgument("indices must be rank 1"));
const int num_indices = indices_shape.dim_size(0);
auto indices = ctx->Input(1);
@@ -324,7 +315,8 @@ class TensorArrayGatherOp : public XlaOpKernel {
auto index = b->Slice(indices, {i}, {i + 1}, {1});
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
- auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1);
+ auto start_indices =
+ XlaHelpers::PadWithZeros(b, index, ta_shape.dims() - 1);
auto slice_shape = ta_shape.dim_sizes();
slice_shape[0] = 1LL;
@@ -396,7 +388,8 @@ class TensorArrayScatterOp : public XlaOpKernel {
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto index = b->Slice(indices, {i}, {i + 1}, {1});
- auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
+ auto start_indices =
+ XlaHelpers::PadWithZeros(b, index, elem_shape.dims());
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}