diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-07-12 15:48:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-12 15:54:02 -0700 |
commit | 9526b27d7d904acc9e1a7a1990e1320235a8720c (patch) | |
tree | a9582fe78a48ef33b4e8610b0e27231b6876c3ab /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | |
parent | 8f2de3847fdab53926b940fee99d7b95e7dfc6c7 (diff) |
[TF:XLA] Implementing ResourceGather in TF2XLA.
PiperOrigin-RevId: 161730154
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 23 |
1 files changed, 8 insertions, 15 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index d720496f74..bdd52b7f8e 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -125,16 +125,6 @@ Status GetTensorArrayShape(const XlaResource* resource, return Status::OK(); } -// Pads 'x' with 'count' zero indices. 'x' must have 1 element. -xla::ComputationDataHandle PadIndexWithZeros( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - int count) { - xla::ComputationDataHandle zero = builder->ConstantR1<int32>({0}); - std::vector<xla::ComputationDataHandle> xs(count + 1, zero); - xs[0] = builder->Reshape(x, {1}); - return builder->ConcatInDim(xs, 0); -} - // Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the // relevant slice of 'operand'. xla::ComputationDataHandle DynamicAddSlice( @@ -228,7 +218,7 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle value = ctx->Input(2); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims()); + auto start_indices = XlaHelpers::PadWithZeros(b, index, elem_shape.dims()); TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); @@ -270,7 +260,8 @@ class TensorArrayReadOp : public XlaOpKernel { xla::ComputationDataHandle index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1); + auto start_indices = + XlaHelpers::PadWithZeros(b, index, ta_shape.dims() - 1); auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; @@ -309,7 +300,7 @@ class TensorArrayGatherOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); const TensorShape indices_shape = ctx->InputShape(1); - OP_REQUIRES(ctx, indices_shape.dims() >= 1, + OP_REQUIRES(ctx, indices_shape.dims() == 1, errors::InvalidArgument("indices must be rank 1")); const int num_indices = indices_shape.dim_size(0); auto indices = ctx->Input(1); @@ -324,7 +315,8 @@ class TensorArrayGatherOp : public XlaOpKernel { auto index = b->Slice(indices, {i}, {i + 1}, {1}); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1); + auto start_indices = + XlaHelpers::PadWithZeros(b, index, ta_shape.dims() - 1); auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; @@ -396,7 +388,8 @@ class TensorArrayScatterOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto index = b->Slice(indices, {i}, {i + 1}, {1}); - auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims()); + auto start_indices = + XlaHelpers::PadWithZeros(b, index, elem_shape.dims()); ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } |