aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-14 15:34:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-14 16:00:22 -0700
commit878e6366362612a2ffba740bde51999c72a73acf (patch)
tree0790a312c972600b7755fe29dc94d3e63ea0ad0f /tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
parent2a792a35111dbd55757fd592a9913b5048b55468 (diff)
Refactor XLA Gather to use a common implementation for Gather, ResourceGather, etc.
PiperOrigin-RevId: 165239093
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc29
1 files changed, 3 insertions, 26 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 34cc8b2315..8b9ef4de76 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <limits>
#include <vector>
+#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -306,36 +307,12 @@ class TensorArrayGatherOp : public XlaOpKernel {
const TensorShape indices_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, indices_shape.dims() == 1,
errors::InvalidArgument("indices must be rank 1"));
- const int num_indices = indices_shape.dim_size(0);
auto indices = ctx->Input(1);
xla::ComputationDataHandle ta = resource->value;
- // For each index in `indices`, add the corresponding slice to `slices`.
- std::vector<xla::ComputationDataHandle> slices(num_indices);
- for (int i = 0; i < num_indices; ++i) {
- // Slices the i-th index out of `indices`, and pads it with zeros in the
- // minor dimensions to form an index into the TensorArray storage.
- auto index = b->Slice(indices, {i}, {i + 1}, {1});
-
- // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
- auto start_indices =
- XlaHelpers::PadWithZeros(b, index, ta_shape.dims() - 1);
-
- auto slice_shape = ta_shape.dim_sizes();
- slice_shape[0] = 1LL;
-
- slices[i] = b->DynamicSlice(ta, start_indices, slice_shape);
- }
-
- xla::ComputationDataHandle gather;
- if (slices.empty()) {
- auto shape = ta_shape.dim_sizes();
- shape[0] = 0;
- gather = b->Broadcast(XlaHelpers::Zero(b, dtype_), shape);
- } else {
- gather = b->ConcatInDim(slices, 0);
- }
+ xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
+ ta, ta_shape, indices, indices_shape, dtype_, b);
ctx->SetOutput(0, gather);
}