diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-08-14 15:34:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-14 16:00:22 -0700 |
commit | 878e6366362612a2ffba740bde51999c72a73acf (patch) | |
tree | 0790a312c972600b7755fe29dc94d3e63ea0ad0f /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | |
parent | 2a792a35111dbd55757fd592a9913b5048b55468 (diff) |
Refactor XLA Gather to use a common implementation for Gather, ResourceGather, etc.
PiperOrigin-RevId: 165239093
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 29 |
1 files changed, 3 insertions, 26 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 34cc8b2315..8b9ef4de76 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include <limits> #include <vector> +#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -306,36 +307,12 @@ class TensorArrayGatherOp : public XlaOpKernel { const TensorShape indices_shape = ctx->InputShape(1); OP_REQUIRES(ctx, indices_shape.dims() == 1, errors::InvalidArgument("indices must be rank 1")); - const int num_indices = indices_shape.dim_size(0); auto indices = ctx->Input(1); xla::ComputationDataHandle ta = resource->value; - // For each index in `indices`, add the corresponding slice to `slices`. - std::vector<xla::ComputationDataHandle> slices(num_indices); - for (int i = 0; i < num_indices; ++i) { - // Slices the i-th index out of `indices`, and pads it with zeros in the - // minor dimensions to form an index into the TensorArray storage. - auto index = b->Slice(indices, {i}, {i + 1}, {1}); - - // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. - auto start_indices = - XlaHelpers::PadWithZeros(b, index, ta_shape.dims() - 1); - - auto slice_shape = ta_shape.dim_sizes(); - slice_shape[0] = 1LL; - - slices[i] = b->DynamicSlice(ta, start_indices, slice_shape); - } - - xla::ComputationDataHandle gather; - if (slices.empty()) { - auto shape = ta_shape.dim_sizes(); - shape[0] = 0; - gather = b->Broadcast(XlaHelpers::Zero(b, dtype_), shape); - } else { - gather = b->ConcatInDim(slices, 0); - } + xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( + ta, ta_shape, indices, indices_shape, dtype_, b); ctx->SetOutput(0, gather); } |